File size: 7,124 Bytes
db764ae 9f87ec0 db764ae 9f87ec0 db764ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import { useState } from "react";
import { api, getErrorMessage } from "../api";
import type { TrainResponse } from "../types";
import { useCorpusLoader } from "../hooks/useCorpusLoader";
import StatusMessage from "./StatusMessage";
import MetricCard from "./MetricCard";
import Toggle from "./Toggle";
import Select from "./Select";
import LogViewer from "./LogViewer";
type Strategy = "unsupervised" | "contrastive" | "keywords";
const STRATEGIES: { id: Strategy; label: string; desc: string }[] = [
{ id: "unsupervised", label: "Unsupervised", desc: "Soft-label domain adaptation. Samples random pairs and fine-tunes using the model's own similarity scores." },
{ id: "contrastive", label: "Contrastive", desc: "Adjacent sentences = positive pairs. Learns document structure with in-batch negatives and validation." },
{ id: "keywords", label: "Keyword-supervised", desc: "You provide keyword→meaning map. Best if you know the code words." },
];
const MODELS = [
{ value: "all-MiniLM-L6-v2", label: "all-MiniLM-L6-v2 (fast)" },
{ value: "all-mpnet-base-v2", label: "all-mpnet-base-v2 (best quality)" },
];
export default function TrainingPanel() {
// Training
const [strategy, setStrategy] = useState<Strategy>("contrastive");
const [baseModel, setBaseModel] = useState("all-MiniLM-L6-v2");
const [outputPath, setOutputPath] = useState("./trained_model");
const [epochs, setEpochs] = useState(5);
const [batchSize, setBatchSize] = useState(16);
const [keywordMapText, setKeywordMapText] = useState('{\n "pizza": "school",\n "pepperoni": "math class"\n}');
const [showAdvanced, setShowAdvanced] = useState(false);
const [training, setTraining] = useState(false);
const [result, setResult] = useState<TrainResponse | null>(null);
const { corpusText, setCorpusText, loading: corpusLoading, error, setError, parseCorpus, loadFromEngine } = useCorpusLoader();
async function handleTrain() {
setTraining(true); setError(""); setResult(null);
try {
const corpus = parseCorpus();
if (!corpus.length) { setError("Corpus is empty."); setTraining(false); return; }
const base = { corpus_texts: corpus, base_model: baseModel, output_path: outputPath, epochs, batch_size: batchSize };
let res: TrainResponse;
if (strategy === "unsupervised") {
res = await api.trainUnsupervised(base);
} else if (strategy === "contrastive") {
res = await api.trainContrastive(base);
} else {
const kw = JSON.parse(keywordMapText);
res = await api.trainKeywords({ ...base, keyword_meanings: kw });
}
setResult(res);
} catch (e) {
setError(e instanceof SyntaxError ? "Invalid JSON in keyword map." : getErrorMessage(e));
} finally {
setTraining(false);
}
}
return (
<div>
{/* 1. Training (strategy + config + corpus merged) */}
<div className="panel">
<h2>1. Fine-tune Transformer</h2>
<p className="panel-desc">
Fine-tune a pre-trained sentence transformer on your corpus to improve contextual understanding.
</p>
<div style={{ display: "flex", gap: 8, marginBottom: 10 }}>
<button className="btn btn-secondary" onClick={loadFromEngine}
disabled={corpusLoading}>
{corpusLoading ? "Loading..." : "Load from Engine"}
</button>
{corpusText && (
<button className="btn btn-secondary" onClick={() => setCorpusText("")}>
Clear
</button>
)}
</div>
<div className="form-group" style={{ marginBottom: 12 }}>
<label>
Corpus (separate documents with blank lines)
{corpusText && (
<span style={{ color: "var(--text-dim)", fontWeight: 400 }}>
{" "} — {parseCorpus().length} documents detected
</span>
)}
</label>
<textarea value={corpusText} onChange={e => setCorpusText(e.target.value)} rows={8}
placeholder="Document 1 text...\n\nDocument 2 text..." />
</div>
<label className="section-label">Strategy</label>
<Toggle
options={STRATEGIES.map(s => ({ value: s.id, label: s.label }))}
value={strategy}
onChange={(v) => setStrategy(v as Strategy)}
/>
<p style={{ color: "var(--text-dim)", fontSize: "0.85rem", marginBottom: 12 }}>
{STRATEGIES.find(s => s.id === strategy)?.desc}
</p>
{strategy === "keywords" && (
<div className="form-group" style={{ marginBottom: 12 }}>
<label>Keyword → Meaning Map (JSON)</label>
<textarea value={keywordMapText} onChange={e => setKeywordMapText(e.target.value)}
rows={4} style={{ fontFamily: "monospace", fontSize: "0.8rem" }} />
</div>
)}
<div className="form-row" style={{ marginBottom: 12 }}>
<div className="form-group">
<label>Base Model</label>
<Select options={MODELS} value={baseModel} onChange={setBaseModel} />
</div>
</div>
<button className="advanced-toggle" onClick={() => setShowAdvanced(!showAdvanced)}>
{showAdvanced ? "\u25be" : "\u25b8"} Advanced Settings
</button>
{showAdvanced && (
<div className="advanced-section">
<div className="form-row">
<div className="form-group" style={{ maxWidth: 100 }}>
<label>Epochs</label>
<input type="number" value={epochs} onChange={e => setEpochs(+e.target.value)} min={1} max={50} />
</div>
<div className="form-group" style={{ maxWidth: 120 }}>
<label>Batch Size</label>
<input type="number" value={batchSize} onChange={e => setBatchSize(+e.target.value)} min={4} max={128} />
</div>
<div className="form-group" style={{ maxWidth: 200 }}>
<label>Output Path</label>
<input value={outputPath} onChange={e => setOutputPath(e.target.value)} />
</div>
</div>
</div>
)}
<button className="btn btn-primary" onClick={handleTrain}
disabled={training || !corpusText.trim()} style={{ marginTop: 8 }}>
{training ? <><span className="spinner" /> Training...</> : "Start Training"}
</button>
<LogViewer active={training} />
</div>
{error && <StatusMessage type="err" message={error} />}
{result && (
<div className="panel">
<h2>Training Complete</h2>
<div className="metric-grid">
<MetricCard value={result.training_pairs} label="Training Pairs" />
<MetricCard value={result.epochs} label="Epochs" />
<MetricCard value={`${result.seconds}s`} label="Time" />
</div>
<StatusMessage type="ok"
message={`Model saved: ${result.model_path} — use this path in the Setup tab, then go to Analysis to explore results.`} />
</div>
)}
</div>
);
}
|