Spaces:
Sleeping
Sleeping
| // src/components/TrainingDashboard.tsx | |
| import { useState, useEffect } from 'react'; | |
| import apiClient from '../services/api'; | |
| interface TrainingDataStats { | |
| total_examples: number; | |
| examples_by_domain: Record<string, number>; | |
| domains: string[]; | |
| file_paths: string[]; | |
| } | |
| interface TrainingStatus { | |
| is_training: boolean; | |
| progress: number; | |
| current_epoch: number; | |
| total_epochs: number; | |
| loss: number; | |
| model_id: string | null; | |
| start_time: string | null; | |
| } | |
| interface Checkpoint { | |
| name: string; | |
| created_at: string; | |
| size_mb: number; | |
| model_name: string | null; | |
| } | |
| const TrainingDashboard = () => { | |
| // Training form state | |
| const [apprenticeModelName, setApprenticeModelName] = useState('microsoft/phi-2'); | |
| const [domainId, setDomainId] = useState<string>('all'); | |
| const [method, setMethod] = useState('peft'); | |
| const [epochs, setEpochs] = useState(3); | |
| const [learningRate, setLearningRate] = useState('2e-4'); | |
| const [batchSize, setBatchSize] = useState(4); | |
| const [loraR, setLoraR] = useState(8); | |
| // Training state | |
| const [trainingStatus, setTrainingStatus] = useState<TrainingStatus | null>(null); | |
| const [dataStats, setDataStats] = useState<TrainingDataStats | null>(null); | |
| const [checkpoints, setCheckpoints] = useState<Checkpoint[]>([]); | |
| // UI state | |
| const [isLoading, setIsLoading] = useState(false); | |
| const [error, setError] = useState(''); | |
| const [successMessage, setSuccessMessage] = useState(''); | |
| // Polling interval for status updates | |
| const [statusPollingInterval, setStatusPollingInterval] = useState<NodeJS.Timeout | null>(null); | |
| // Fetch training data statistics | |
| const fetchDataStats = async () => { | |
| try { | |
| const response = await apiClient.get('/training/data/stats'); | |
| setDataStats(response.data); | |
| } catch (err: any) { | |
| console.error('Failed to fetch training data stats:', err); | |
| } | |
| }; | |
| // Fetch training status | |
| const fetchTrainingStatus = async () => { | |
| try { | |
| const response = await apiClient.get('/training/status'); | |
| setTrainingStatus(response.data); | |
| // If training is active, start polling | |
| if (response.data.is_training && !statusPollingInterval) { | |
| const interval = setInterval(fetchTrainingStatus, 2000); // Poll every 2 seconds | |
| setStatusPollingInterval(interval); | |
| } | |
| // If training finished, stop polling | |
| if (!response.data.is_training && statusPollingInterval) { | |
| clearInterval(statusPollingInterval); | |
| setStatusPollingInterval(null); | |
| fetchCheckpoints(); // Refresh checkpoints list | |
| } | |
| } catch (err: any) { | |
| console.error('Failed to fetch training status:', err); | |
| } | |
| }; | |
| // Fetch checkpoints | |
| const fetchCheckpoints = async () => { | |
| try { | |
| const response = await apiClient.get('/training/checkpoints'); | |
| setCheckpoints(response.data); | |
| } catch (err: any) { | |
| console.error('Failed to fetch checkpoints:', err); | |
| } | |
| }; | |
| // Initial data fetch | |
| useEffect(() => { | |
| fetchDataStats(); | |
| fetchTrainingStatus(); | |
| fetchCheckpoints(); | |
| // Cleanup polling on unmount | |
| return () => { | |
| if (statusPollingInterval) { | |
| clearInterval(statusPollingInterval); | |
| } | |
| }; | |
| }, []); | |
| // Start training | |
| const handleStartTraining = async () => { | |
| setIsLoading(true); | |
| setError(''); | |
| setSuccessMessage(''); | |
| try { | |
| const requestData = { | |
| apprentice_model_name: apprenticeModelName, | |
| domain_id: domainId === 'all' ? null : domainId, | |
| method, | |
| epochs, | |
| learning_rate: parseFloat(learningRate), | |
| batch_size: batchSize, | |
| lora_r: loraR | |
| }; | |
| await apiClient.post('/training/start', requestData); | |
| setSuccessMessage('Training started successfully! Check the progress below.'); | |
| // Start polling for status | |
| fetchTrainingStatus(); | |
| } catch (err: any) { | |
| const errorMessage = err.response?.data?.detail || err.message; | |
| setError(`Failed to start training: ${errorMessage}`); | |
| console.error(err); | |
| } finally { | |
| setIsLoading(false); | |
| } | |
| }; | |
| // Stop training | |
| const handleStopTraining = async () => { | |
| try { | |
| await apiClient.post('/training/stop'); | |
| setSuccessMessage('Training stop requested. This may take a moment...'); | |
| fetchTrainingStatus(); | |
| } catch (err: any) { | |
| const errorMessage = err.response?.data?.detail || err.message; | |
| setError(`Failed to stop training: ${errorMessage}`); | |
| console.error(err); | |
| } | |
| }; | |
| // Delete checkpoint | |
| const handleDeleteCheckpoint = async (checkpointName: string) => { | |
| if (!confirm(`Are you sure you want to delete checkpoint "${checkpointName}"? This action cannot be undone.`)) { | |
| return; | |
| } | |
| try { | |
| await apiClient.delete(`/training/checkpoints/${checkpointName}`); | |
| setSuccessMessage(`Checkpoint "${checkpointName}" deleted successfully.`); | |
| fetchCheckpoints(); | |
| } catch (err: any) { | |
| const errorMessage = err.response?.data?.detail || err.message; | |
| setError(`Failed to delete checkpoint: ${errorMessage}`); | |
| console.error(err); | |
| } | |
| }; | |
| return ( | |
| <div className="bg-gray-800 text-white p-6 rounded-lg shadow-lg"> | |
| <h2 className="text-2xl font-bold mb-4">🎓 Fine-tuning Dashboard</h2> | |
| {/* Error/Success Messages */} | |
| {error && ( | |
| <div className="bg-red-600 text-white p-3 rounded mb-4"> | |
| {error} | |
| </div> | |
| )} | |
| {successMessage && ( | |
| <div className="bg-green-600 text-white p-3 rounded mb-4"> | |
| {successMessage} | |
| </div> | |
| )} | |
| {/* Training Data Statistics */} | |
| {dataStats && ( | |
| <div className="mb-6 p-4 bg-gray-700 rounded"> | |
| <h3 className="text-lg font-semibold mb-2">📊 Training Data Available</h3> | |
| <div className="grid grid-cols-2 gap-4"> | |
| <div> | |
| <p className="text-sm text-gray-400">Total Examples</p> | |
| <p className="text-2xl font-bold">{dataStats.total_examples}</p> | |
| </div> | |
| <div> | |
| <p className="text-sm text-gray-400">Domains</p> | |
| <p className="text-xl">{dataStats.domains.join(', ') || 'None'}</p> | |
| </div> | |
| </div> | |
| {dataStats.total_examples === 0 && ( | |
| <p className="text-yellow-400 text-sm mt-2"> | |
| ⚠️ No training data found. Use the Inference Panel with a Master engine to generate training data first. | |
| </p> | |
| )} | |
| </div> | |
| )} | |
| {/* Training Status */} | |
| {trainingStatus?.is_training && ( | |
| <div className="mb-6 p-4 bg-blue-900 rounded border border-blue-500"> | |
| <h3 className="text-lg font-semibold mb-3">🔄 Training in Progress</h3> | |
| <div className="space-y-2"> | |
| <div> | |
| <p className="text-sm text-gray-300">Model: {trainingStatus.model_id}</p> | |
| <p className="text-sm text-gray-300"> | |
| Epoch: {trainingStatus.current_epoch} / {trainingStatus.total_epochs} | |
| </p> | |
| </div> | |
| <div> | |
| <div className="w-full bg-gray-600 rounded-full h-4"> | |
| <div | |
| className="bg-blue-500 h-4 rounded-full transition-all duration-500" | |
| style={{ width: `${trainingStatus.progress}%` }} | |
| ></div> | |
| </div> | |
| <p className="text-sm text-gray-300 mt-1"> | |
| Progress: {trainingStatus.progress.toFixed(1)}% | |
| </p> | |
| </div> | |
| {trainingStatus.loss > 0 && ( | |
| <p className="text-sm text-gray-300">Current Loss: {trainingStatus.loss.toFixed(4)}</p> | |
| )} | |
| <button | |
| onClick={handleStopTraining} | |
| className="mt-3 bg-red-600 hover:bg-red-700 text-white px-4 py-2 rounded" | |
| > | |
| ⛔ Stop Training | |
| </button> | |
| </div> | |
| </div> | |
| )} | |
| {/* Training Form */} | |
| {!trainingStatus?.is_training && dataStats && dataStats.total_examples > 0 && ( | |
| <div className="mb-6 p-4 bg-gray-700 rounded"> | |
| <h3 className="text-lg font-semibold mb-3">🚀 Start Fine-tuning</h3> | |
| <div className="grid grid-cols-2 gap-4"> | |
| <div> | |
| <label className="block text-sm mb-1">Apprentice Model</label> | |
| <input | |
| type="text" | |
| value={apprenticeModelName} | |
| onChange={(e) => setApprenticeModelName(e.target.value)} | |
| placeholder="microsoft/phi-2" | |
| className="w-full p-2 bg-gray-600 rounded text-white" | |
| /> | |
| <p className="text-xs text-gray-400 mt-1">HuggingFace model name or local path</p> | |
| </div> | |
| <div> | |
| <label className="block text-sm mb-1">Domain</label> | |
| <select | |
| value={domainId} | |
| onChange={(e) => setDomainId(e.target.value)} | |
| className="w-full p-2 bg-gray-600 rounded text-white" | |
| > | |
| <option value="all">All Domains</option> | |
| {dataStats.domains.map((domain) => ( | |
| <option key={domain} value={domain}> | |
| {domain} ({dataStats.examples_by_domain[domain]} examples) | |
| </option> | |
| ))} | |
| </select> | |
| </div> | |
| <div> | |
| <label className="block text-sm mb-1">Method</label> | |
| <select | |
| value={method} | |
| onChange={(e) => setMethod(e.target.value)} | |
| className="w-full p-2 bg-gray-600 rounded text-white" | |
| > | |
| <option value="peft">PEFT (QLoRA) - Recommended</option> | |
| <option value="unsloth">Unsloth - Fastest</option> | |
| <option value="mlx">MLX - Apple Silicon (Experimental)</option> | |
| </select> | |
| </div> | |
| <div> | |
| <label className="block text-sm mb-1">Epochs</label> | |
| <input | |
| type="number" | |
| value={epochs} | |
| onChange={(e) => setEpochs(parseInt(e.target.value))} | |
| min="1" | |
| max="100" | |
| className="w-full p-2 bg-gray-600 rounded text-white" | |
| /> | |
| </div> | |
| <div> | |
| <label className="block text-sm mb-1">Learning Rate</label> | |
| <input | |
| type="text" | |
| value={learningRate} | |
| onChange={(e) => setLearningRate(e.target.value)} | |
| placeholder="2e-4" | |
| className="w-full p-2 bg-gray-600 rounded text-white" | |
| /> | |
| </div> | |
| <div> | |
| <label className="block text-sm mb-1">Batch Size</label> | |
| <input | |
| type="number" | |
| value={batchSize} | |
| onChange={(e) => setBatchSize(parseInt(e.target.value))} | |
| min="1" | |
| max="32" | |
| className="w-full p-2 bg-gray-600 rounded text-white" | |
| /> | |
| </div> | |
| <div> | |
| <label className="block text-sm mb-1">LoRA Rank (r)</label> | |
| <input | |
| type="number" | |
| value={loraR} | |
| onChange={(e) => setLoraR(parseInt(e.target.value))} | |
| min="4" | |
| max="64" | |
| className="w-full p-2 bg-gray-600 rounded text-white" | |
| /> | |
| </div> | |
| </div> | |
| <button | |
| onClick={handleStartTraining} | |
| disabled={isLoading} | |
| className="mt-4 bg-green-600 hover:bg-green-700 text-white px-6 py-3 rounded font-semibold disabled:opacity-50" | |
| > | |
| {isLoading ? 'Starting...' : '▶️ Start Fine-tuning'} | |
| </button> | |
| </div> | |
| )} | |
| {/* Checkpoints List */} | |
| <div className="p-4 bg-gray-700 rounded"> | |
| <h3 className="text-lg font-semibold mb-3">💾 Training Checkpoints</h3> | |
| {checkpoints.length === 0 ? ( | |
| <p className="text-gray-400">No checkpoints available yet.</p> | |
| ) : ( | |
| <div className="space-y-2"> | |
| {checkpoints.map((checkpoint) => ( | |
| <div | |
| key={checkpoint.name} | |
| className="bg-gray-600 p-3 rounded flex justify-between items-center" | |
| > | |
| <div> | |
| <p className="font-semibold">{checkpoint.name}</p> | |
| {checkpoint.model_name && ( | |
| <p className="text-sm text-gray-400">Model: {checkpoint.model_name}</p> | |
| )} | |
| <p className="text-xs text-gray-500"> | |
| {new Date(checkpoint.created_at).toLocaleString()} • {checkpoint.size_mb.toFixed(1)} MB | |
| </p> | |
| </div> | |
| <button | |
| onClick={() => handleDeleteCheckpoint(checkpoint.name)} | |
| className="bg-red-600 hover:bg-red-700 text-white px-3 py-1 rounded text-sm" | |
| > | |
| 🗑️ Delete | |
| </button> | |
| </div> | |
| ))} | |
| </div> | |
| )} | |
| </div> | |
| {/* Info Box */} | |
| <div className="mt-6 p-4 bg-blue-900 bg-opacity-30 rounded border border-blue-500"> | |
| <h4 className="font-semibold mb-2">ℹ️ How It Works</h4> | |
| <ul className="text-sm text-gray-300 space-y-1"> | |
| <li>• Master model outputs are automatically saved as training data (confidence ≥ 0.8)</li> | |
| <li>• Fine-tuning uses LoRA/QLoRA for efficient training with minimal memory</li> | |
| <li>• Checkpoints are saved in <code>training_data/checkpoints/</code></li> | |
| <li>• After training, register the checkpoint as a new model in Engine Manager</li> | |
| </ul> | |
| </div> | |
| </div> | |
| ); | |
| }; | |
| export default TrainingDashboard; | |