| import { useState } from "react"; |
| import { api, getErrorMessage } from "../api"; |
| import type { TrainResponse } from "../types"; |
| import { useCorpusLoader } from "../hooks/useCorpusLoader"; |
| import StatusMessage from "./StatusMessage"; |
| import MetricCard from "./MetricCard"; |
| import Toggle from "./Toggle"; |
| import Select from "./Select"; |
| import LogViewer from "./LogViewer"; |
|
|
| type Strategy = "unsupervised" | "contrastive" | "keywords"; |
|
|
| const STRATEGIES: { id: Strategy; label: string; desc: string }[] = [ |
| { id: "unsupervised", label: "Unsupervised", desc: "Soft-label domain adaptation. Samples random pairs and fine-tunes using the model's own similarity scores." }, |
| { id: "contrastive", label: "Contrastive", desc: "Adjacent sentences = positive pairs. Learns document structure with in-batch negatives and validation." }, |
| { id: "keywords", label: "Keyword-supervised", desc: "You provide keyword→meaning map. Best if you know the code words." }, |
| ]; |
|
|
| const MODELS = [ |
| { value: "all-MiniLM-L6-v2", label: "all-MiniLM-L6-v2 (fast)" }, |
| { value: "all-mpnet-base-v2", label: "all-mpnet-base-v2 (best quality)" }, |
| ]; |
|
|
| export default function TrainingPanel() { |
| |
| const [strategy, setStrategy] = useState<Strategy>("contrastive"); |
| const [baseModel, setBaseModel] = useState("all-MiniLM-L6-v2"); |
| const [outputPath, setOutputPath] = useState("./trained_model"); |
| const [epochs, setEpochs] = useState(5); |
| const [batchSize, setBatchSize] = useState(16); |
| const [keywordMapText, setKeywordMapText] = useState('{\n "pizza": "school",\n "pepperoni": "math class"\n}'); |
| const [showAdvanced, setShowAdvanced] = useState(false); |
| const [training, setTraining] = useState(false); |
| const [result, setResult] = useState<TrainResponse | null>(null); |
|
|
| const { corpusText, setCorpusText, loading: corpusLoading, error, setError, parseCorpus, loadFromEngine } = useCorpusLoader(); |
|
|
| async function handleTrain() { |
| setTraining(true); setError(""); setResult(null); |
| try { |
| const corpus = parseCorpus(); |
| if (!corpus.length) { setError("Corpus is empty."); setTraining(false); return; } |
|
|
| const base = { corpus_texts: corpus, base_model: baseModel, output_path: outputPath, epochs, batch_size: batchSize }; |
| let res: TrainResponse; |
|
|
| if (strategy === "unsupervised") { |
| res = await api.trainUnsupervised(base); |
| } else if (strategy === "contrastive") { |
| res = await api.trainContrastive(base); |
| } else { |
| const kw = JSON.parse(keywordMapText); |
| res = await api.trainKeywords({ ...base, keyword_meanings: kw }); |
| } |
| setResult(res); |
| } catch (e) { |
| setError(e instanceof SyntaxError ? "Invalid JSON in keyword map." : getErrorMessage(e)); |
| } finally { |
| setTraining(false); |
| } |
| } |
|
|
| return ( |
| <div> |
| {/* 1. Training (strategy + config + corpus merged) */} |
| <div className="panel"> |
| <h2>1. Fine-tune Transformer</h2> |
| <p className="panel-desc"> |
| Fine-tune a pre-trained sentence transformer on your corpus to improve contextual understanding. |
| </p> |
| |
| <div style={{ display: "flex", gap: 8, marginBottom: 10 }}> |
| <button className="btn btn-secondary" onClick={loadFromEngine} |
| disabled={corpusLoading}> |
| {corpusLoading ? "Loading..." : "Load from Engine"} |
| </button> |
| {corpusText && ( |
| <button className="btn btn-secondary" onClick={() => setCorpusText("")}> |
| Clear |
| </button> |
| )} |
| </div> |
| <div className="form-group" style={{ marginBottom: 12 }}> |
| <label> |
| Corpus (separate documents with blank lines) |
| {corpusText && ( |
| <span style={{ color: "var(--text-dim)", fontWeight: 400 }}> |
| {" "} — {parseCorpus().length} documents detected |
| </span> |
| )} |
| </label> |
| <textarea value={corpusText} onChange={e => setCorpusText(e.target.value)} rows={8} |
| placeholder="Document 1 text...\n\nDocument 2 text..." /> |
| </div> |
| |
| <label className="section-label">Strategy</label> |
| <Toggle |
| options={STRATEGIES.map(s => ({ value: s.id, label: s.label }))} |
| value={strategy} |
| onChange={(v) => setStrategy(v as Strategy)} |
| /> |
| <p style={{ color: "var(--text-dim)", fontSize: "0.85rem", marginBottom: 12 }}> |
| {STRATEGIES.find(s => s.id === strategy)?.desc} |
| </p> |
| |
| {strategy === "keywords" && ( |
| <div className="form-group" style={{ marginBottom: 12 }}> |
| <label>Keyword → Meaning Map (JSON)</label> |
| <textarea value={keywordMapText} onChange={e => setKeywordMapText(e.target.value)} |
| rows={4} style={{ fontFamily: "monospace", fontSize: "0.8rem" }} /> |
| </div> |
| )} |
| |
| <div className="form-row" style={{ marginBottom: 12 }}> |
| <div className="form-group"> |
| <label>Base Model</label> |
| <Select options={MODELS} value={baseModel} onChange={setBaseModel} /> |
| </div> |
| </div> |
| |
| <button className="advanced-toggle" onClick={() => setShowAdvanced(!showAdvanced)}> |
| {showAdvanced ? "\u25be" : "\u25b8"} Advanced Settings |
| </button> |
| |
| {showAdvanced && ( |
| <div className="advanced-section"> |
| <div className="form-row"> |
| <div className="form-group" style={{ maxWidth: 100 }}> |
| <label>Epochs</label> |
| <input type="number" value={epochs} onChange={e => setEpochs(+e.target.value)} min={1} max={50} /> |
| </div> |
| <div className="form-group" style={{ maxWidth: 120 }}> |
| <label>Batch Size</label> |
| <input type="number" value={batchSize} onChange={e => setBatchSize(+e.target.value)} min={4} max={128} /> |
| </div> |
| <div className="form-group" style={{ maxWidth: 200 }}> |
| <label>Output Path</label> |
| <input value={outputPath} onChange={e => setOutputPath(e.target.value)} /> |
| </div> |
| </div> |
| </div> |
| )} |
| |
| <button className="btn btn-primary" onClick={handleTrain} |
| disabled={training || !corpusText.trim()} style={{ marginTop: 8 }}> |
| {training ? <><span className="spinner" /> Training...</> : "Start Training"} |
| </button> |
| |
| <LogViewer active={training} /> |
| </div> |
|
|
| {error && <StatusMessage type="err" message={error} />} |
|
|
| {result && ( |
| <div className="panel"> |
| <h2>Training Complete</h2> |
| <div className="metric-grid"> |
| <MetricCard value={result.training_pairs} label="Training Pairs" /> |
| <MetricCard value={result.epochs} label="Epochs" /> |
| <MetricCard value={`${result.seconds}s`} label="Time" /> |
| </div> |
| <StatusMessage type="ok" |
| message={`Model saved: ${result.model_path} — use this path in the Setup tab, then go to Analysis to explore results.`} /> |
| </div> |
| )} |
| </div> |
| ); |
| } |
|
|