esfiles / frontend /src /components /TrainingPanel.tsx
Besjon Cifliku
feat: simplify the workflow and search patterns
9f87ec0
import { useState } from "react";
import { api, getErrorMessage } from "../api";
import type { TrainResponse } from "../types";
import { useCorpusLoader } from "../hooks/useCorpusLoader";
import StatusMessage from "./StatusMessage";
import MetricCard from "./MetricCard";
import Toggle from "./Toggle";
import Select from "./Select";
import LogViewer from "./LogViewer";
type Strategy = "unsupervised" | "contrastive" | "keywords";
const STRATEGIES: { id: Strategy; label: string; desc: string }[] = [
{ id: "unsupervised", label: "Unsupervised", desc: "Soft-label domain adaptation. Samples random pairs and fine-tunes using the model's own similarity scores." },
{ id: "contrastive", label: "Contrastive", desc: "Adjacent sentences = positive pairs. Learns document structure with in-batch negatives and validation." },
{ id: "keywords", label: "Keyword-supervised", desc: "You provide keyword→meaning map. Best if you know the code words." },
];
const MODELS = [
{ value: "all-MiniLM-L6-v2", label: "all-MiniLM-L6-v2 (fast)" },
{ value: "all-mpnet-base-v2", label: "all-mpnet-base-v2 (best quality)" },
];
export default function TrainingPanel() {
// Training
const [strategy, setStrategy] = useState<Strategy>("contrastive");
const [baseModel, setBaseModel] = useState("all-MiniLM-L6-v2");
const [outputPath, setOutputPath] = useState("./trained_model");
const [epochs, setEpochs] = useState(5);
const [batchSize, setBatchSize] = useState(16);
const [keywordMapText, setKeywordMapText] = useState('{\n "pizza": "school",\n "pepperoni": "math class"\n}');
const [showAdvanced, setShowAdvanced] = useState(false);
const [training, setTraining] = useState(false);
const [result, setResult] = useState<TrainResponse | null>(null);
const { corpusText, setCorpusText, loading: corpusLoading, error, setError, parseCorpus, loadFromEngine } = useCorpusLoader();
async function handleTrain() {
setTraining(true); setError(""); setResult(null);
try {
const corpus = parseCorpus();
if (!corpus.length) { setError("Corpus is empty."); setTraining(false); return; }
const base = { corpus_texts: corpus, base_model: baseModel, output_path: outputPath, epochs, batch_size: batchSize };
let res: TrainResponse;
if (strategy === "unsupervised") {
res = await api.trainUnsupervised(base);
} else if (strategy === "contrastive") {
res = await api.trainContrastive(base);
} else {
const kw = JSON.parse(keywordMapText);
res = await api.trainKeywords({ ...base, keyword_meanings: kw });
}
setResult(res);
} catch (e) {
setError(e instanceof SyntaxError ? "Invalid JSON in keyword map." : getErrorMessage(e));
} finally {
setTraining(false);
}
}
return (
<div>
{/* 1. Training (strategy + config + corpus merged) */}
<div className="panel">
<h2>1. Fine-tune Transformer</h2>
<p className="panel-desc">
Fine-tune a pre-trained sentence transformer on your corpus to improve contextual understanding.
</p>
<div style={{ display: "flex", gap: 8, marginBottom: 10 }}>
<button className="btn btn-secondary" onClick={loadFromEngine}
disabled={corpusLoading}>
{corpusLoading ? "Loading..." : "Load from Engine"}
</button>
{corpusText && (
<button className="btn btn-secondary" onClick={() => setCorpusText("")}>
Clear
</button>
)}
</div>
<div className="form-group" style={{ marginBottom: 12 }}>
<label>
Corpus (separate documents with blank lines)
{corpusText && (
<span style={{ color: "var(--text-dim)", fontWeight: 400 }}>
{" "} — {parseCorpus().length} documents detected
</span>
)}
</label>
<textarea value={corpusText} onChange={e => setCorpusText(e.target.value)} rows={8}
placeholder="Document 1 text...\n\nDocument 2 text..." />
</div>
<label className="section-label">Strategy</label>
<Toggle
options={STRATEGIES.map(s => ({ value: s.id, label: s.label }))}
value={strategy}
onChange={(v) => setStrategy(v as Strategy)}
/>
<p style={{ color: "var(--text-dim)", fontSize: "0.85rem", marginBottom: 12 }}>
{STRATEGIES.find(s => s.id === strategy)?.desc}
</p>
{strategy === "keywords" && (
<div className="form-group" style={{ marginBottom: 12 }}>
<label>Keyword → Meaning Map (JSON)</label>
<textarea value={keywordMapText} onChange={e => setKeywordMapText(e.target.value)}
rows={4} style={{ fontFamily: "monospace", fontSize: "0.8rem" }} />
</div>
)}
<div className="form-row" style={{ marginBottom: 12 }}>
<div className="form-group">
<label>Base Model</label>
<Select options={MODELS} value={baseModel} onChange={setBaseModel} />
</div>
</div>
<button className="advanced-toggle" onClick={() => setShowAdvanced(!showAdvanced)}>
{showAdvanced ? "\u25be" : "\u25b8"} Advanced Settings
</button>
{showAdvanced && (
<div className="advanced-section">
<div className="form-row">
<div className="form-group" style={{ maxWidth: 100 }}>
<label>Epochs</label>
<input type="number" value={epochs} onChange={e => setEpochs(+e.target.value)} min={1} max={50} />
</div>
<div className="form-group" style={{ maxWidth: 120 }}>
<label>Batch Size</label>
<input type="number" value={batchSize} onChange={e => setBatchSize(+e.target.value)} min={4} max={128} />
</div>
<div className="form-group" style={{ maxWidth: 200 }}>
<label>Output Path</label>
<input value={outputPath} onChange={e => setOutputPath(e.target.value)} />
</div>
</div>
</div>
)}
<button className="btn btn-primary" onClick={handleTrain}
disabled={training || !corpusText.trim()} style={{ marginTop: 8 }}>
{training ? <><span className="spinner" /> Training...</> : "Start Training"}
</button>
<LogViewer active={training} />
</div>
{error && <StatusMessage type="err" message={error} />}
{result && (
<div className="panel">
<h2>Training Complete</h2>
<div className="metric-grid">
<MetricCard value={result.training_pairs} label="Training Pairs" />
<MetricCard value={result.epochs} label="Epochs" />
<MetricCard value={`${result.seconds}s`} label="Time" />
</div>
<StatusMessage type="ok"
message={`Model saved: ${result.model_path} — use this path in the Setup tab, then go to Analysis to explore results.`} />
</div>
)}
</div>
);
}