nullai-knowledge-system / frontend /src /components /TrainingDashboard.tsx
kofdai's picture
Deploy NullAI Knowledge System to Spaces
075a2b6 verified
// src/components/TrainingDashboard.tsx
import { useState, useEffect } from 'react';
import apiClient from '../services/api';
interface TrainingDataStats {
total_examples: number;
examples_by_domain: Record<string, number>;
domains: string[];
file_paths: string[];
}
interface TrainingStatus {
is_training: boolean;
progress: number;
current_epoch: number;
total_epochs: number;
loss: number;
model_id: string | null;
start_time: string | null;
}
interface Checkpoint {
name: string;
created_at: string;
size_mb: number;
model_name: string | null;
}
const TrainingDashboard = () => {
// Training form state
const [apprenticeModelName, setApprenticeModelName] = useState('microsoft/phi-2');
const [domainId, setDomainId] = useState<string>('all');
const [method, setMethod] = useState('peft');
const [epochs, setEpochs] = useState(3);
const [learningRate, setLearningRate] = useState('2e-4');
const [batchSize, setBatchSize] = useState(4);
const [loraR, setLoraR] = useState(8);
// Training state
const [trainingStatus, setTrainingStatus] = useState<TrainingStatus | null>(null);
const [dataStats, setDataStats] = useState<TrainingDataStats | null>(null);
const [checkpoints, setCheckpoints] = useState<Checkpoint[]>([]);
// UI state
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState('');
const [successMessage, setSuccessMessage] = useState('');
// Polling interval for status updates
const [statusPollingInterval, setStatusPollingInterval] = useState<NodeJS.Timeout | null>(null);
// Fetch training data statistics
const fetchDataStats = async () => {
try {
const response = await apiClient.get('/training/data/stats');
setDataStats(response.data);
} catch (err: any) {
console.error('Failed to fetch training data stats:', err);
}
};
// Fetch training status
const fetchTrainingStatus = async () => {
try {
const response = await apiClient.get('/training/status');
setTrainingStatus(response.data);
// If training is active, start polling
if (response.data.is_training && !statusPollingInterval) {
const interval = setInterval(fetchTrainingStatus, 2000); // Poll every 2 seconds
setStatusPollingInterval(interval);
}
// If training finished, stop polling
if (!response.data.is_training && statusPollingInterval) {
clearInterval(statusPollingInterval);
setStatusPollingInterval(null);
fetchCheckpoints(); // Refresh checkpoints list
}
} catch (err: any) {
console.error('Failed to fetch training status:', err);
}
};
// Fetch checkpoints
const fetchCheckpoints = async () => {
try {
const response = await apiClient.get('/training/checkpoints');
setCheckpoints(response.data);
} catch (err: any) {
console.error('Failed to fetch checkpoints:', err);
}
};
// Initial data fetch
useEffect(() => {
fetchDataStats();
fetchTrainingStatus();
fetchCheckpoints();
// Cleanup polling on unmount
return () => {
if (statusPollingInterval) {
clearInterval(statusPollingInterval);
}
};
}, []);
// Start training
const handleStartTraining = async () => {
setIsLoading(true);
setError('');
setSuccessMessage('');
try {
const requestData = {
apprentice_model_name: apprenticeModelName,
domain_id: domainId === 'all' ? null : domainId,
method,
epochs,
learning_rate: parseFloat(learningRate),
batch_size: batchSize,
lora_r: loraR
};
await apiClient.post('/training/start', requestData);
setSuccessMessage('Training started successfully! Check the progress below.');
// Start polling for status
fetchTrainingStatus();
} catch (err: any) {
const errorMessage = err.response?.data?.detail || err.message;
setError(`Failed to start training: ${errorMessage}`);
console.error(err);
} finally {
setIsLoading(false);
}
};
// Stop training
const handleStopTraining = async () => {
try {
await apiClient.post('/training/stop');
setSuccessMessage('Training stop requested. This may take a moment...');
fetchTrainingStatus();
} catch (err: any) {
const errorMessage = err.response?.data?.detail || err.message;
setError(`Failed to stop training: ${errorMessage}`);
console.error(err);
}
};
// Delete checkpoint
const handleDeleteCheckpoint = async (checkpointName: string) => {
if (!confirm(`Are you sure you want to delete checkpoint "${checkpointName}"? This action cannot be undone.`)) {
return;
}
try {
await apiClient.delete(`/training/checkpoints/${checkpointName}`);
setSuccessMessage(`Checkpoint "${checkpointName}" deleted successfully.`);
fetchCheckpoints();
} catch (err: any) {
const errorMessage = err.response?.data?.detail || err.message;
setError(`Failed to delete checkpoint: ${errorMessage}`);
console.error(err);
}
};
return (
<div className="bg-gray-800 text-white p-6 rounded-lg shadow-lg">
<h2 className="text-2xl font-bold mb-4">🎓 Fine-tuning Dashboard</h2>
{/* Error/Success Messages */}
{error && (
<div className="bg-red-600 text-white p-3 rounded mb-4">
{error}
</div>
)}
{successMessage && (
<div className="bg-green-600 text-white p-3 rounded mb-4">
{successMessage}
</div>
)}
{/* Training Data Statistics */}
{dataStats && (
<div className="mb-6 p-4 bg-gray-700 rounded">
<h3 className="text-lg font-semibold mb-2">📊 Training Data Available</h3>
<div className="grid grid-cols-2 gap-4">
<div>
<p className="text-sm text-gray-400">Total Examples</p>
<p className="text-2xl font-bold">{dataStats.total_examples}</p>
</div>
<div>
<p className="text-sm text-gray-400">Domains</p>
<p className="text-xl">{dataStats.domains.join(', ') || 'None'}</p>
</div>
</div>
{dataStats.total_examples === 0 && (
<p className="text-yellow-400 text-sm mt-2">
⚠️ No training data found. Use the Inference Panel with a Master engine to generate training data first.
</p>
)}
</div>
)}
{/* Training Status */}
{trainingStatus?.is_training && (
<div className="mb-6 p-4 bg-blue-900 rounded border border-blue-500">
<h3 className="text-lg font-semibold mb-3">🔄 Training in Progress</h3>
<div className="space-y-2">
<div>
<p className="text-sm text-gray-300">Model: {trainingStatus.model_id}</p>
<p className="text-sm text-gray-300">
Epoch: {trainingStatus.current_epoch} / {trainingStatus.total_epochs}
</p>
</div>
<div>
<div className="w-full bg-gray-600 rounded-full h-4">
<div
className="bg-blue-500 h-4 rounded-full transition-all duration-500"
style={{ width: `${trainingStatus.progress}%` }}
></div>
</div>
<p className="text-sm text-gray-300 mt-1">
Progress: {trainingStatus.progress.toFixed(1)}%
</p>
</div>
{trainingStatus.loss > 0 && (
<p className="text-sm text-gray-300">Current Loss: {trainingStatus.loss.toFixed(4)}</p>
)}
<button
onClick={handleStopTraining}
className="mt-3 bg-red-600 hover:bg-red-700 text-white px-4 py-2 rounded"
>
⛔ Stop Training
</button>
</div>
</div>
)}
{/* Training Form */}
{!trainingStatus?.is_training && dataStats && dataStats.total_examples > 0 && (
<div className="mb-6 p-4 bg-gray-700 rounded">
<h3 className="text-lg font-semibold mb-3">🚀 Start Fine-tuning</h3>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm mb-1">Apprentice Model</label>
<input
type="text"
value={apprenticeModelName}
onChange={(e) => setApprenticeModelName(e.target.value)}
placeholder="microsoft/phi-2"
className="w-full p-2 bg-gray-600 rounded text-white"
/>
<p className="text-xs text-gray-400 mt-1">HuggingFace model name or local path</p>
</div>
<div>
<label className="block text-sm mb-1">Domain</label>
<select
value={domainId}
onChange={(e) => setDomainId(e.target.value)}
className="w-full p-2 bg-gray-600 rounded text-white"
>
<option value="all">All Domains</option>
{dataStats.domains.map((domain) => (
<option key={domain} value={domain}>
{domain} ({dataStats.examples_by_domain[domain]} examples)
</option>
))}
</select>
</div>
<div>
<label className="block text-sm mb-1">Method</label>
<select
value={method}
onChange={(e) => setMethod(e.target.value)}
className="w-full p-2 bg-gray-600 rounded text-white"
>
<option value="peft">PEFT (QLoRA) - Recommended</option>
<option value="unsloth">Unsloth - Fastest</option>
<option value="mlx">MLX - Apple Silicon (Experimental)</option>
</select>
</div>
<div>
<label className="block text-sm mb-1">Epochs</label>
<input
type="number"
value={epochs}
onChange={(e) => setEpochs(parseInt(e.target.value))}
min="1"
max="100"
className="w-full p-2 bg-gray-600 rounded text-white"
/>
</div>
<div>
<label className="block text-sm mb-1">Learning Rate</label>
<input
type="text"
value={learningRate}
onChange={(e) => setLearningRate(e.target.value)}
placeholder="2e-4"
className="w-full p-2 bg-gray-600 rounded text-white"
/>
</div>
<div>
<label className="block text-sm mb-1">Batch Size</label>
<input
type="number"
value={batchSize}
onChange={(e) => setBatchSize(parseInt(e.target.value))}
min="1"
max="32"
className="w-full p-2 bg-gray-600 rounded text-white"
/>
</div>
<div>
<label className="block text-sm mb-1">LoRA Rank (r)</label>
<input
type="number"
value={loraR}
onChange={(e) => setLoraR(parseInt(e.target.value))}
min="4"
max="64"
className="w-full p-2 bg-gray-600 rounded text-white"
/>
</div>
</div>
<button
onClick={handleStartTraining}
disabled={isLoading}
className="mt-4 bg-green-600 hover:bg-green-700 text-white px-6 py-3 rounded font-semibold disabled:opacity-50"
>
{isLoading ? 'Starting...' : '▶️ Start Fine-tuning'}
</button>
</div>
)}
{/* Checkpoints List */}
<div className="p-4 bg-gray-700 rounded">
<h3 className="text-lg font-semibold mb-3">💾 Training Checkpoints</h3>
{checkpoints.length === 0 ? (
<p className="text-gray-400">No checkpoints available yet.</p>
) : (
<div className="space-y-2">
{checkpoints.map((checkpoint) => (
<div
key={checkpoint.name}
className="bg-gray-600 p-3 rounded flex justify-between items-center"
>
<div>
<p className="font-semibold">{checkpoint.name}</p>
{checkpoint.model_name && (
<p className="text-sm text-gray-400">Model: {checkpoint.model_name}</p>
)}
<p className="text-xs text-gray-500">
{new Date(checkpoint.created_at).toLocaleString()} • {checkpoint.size_mb.toFixed(1)} MB
</p>
</div>
<button
onClick={() => handleDeleteCheckpoint(checkpoint.name)}
className="bg-red-600 hover:bg-red-700 text-white px-3 py-1 rounded text-sm"
>
🗑️ Delete
</button>
</div>
))}
</div>
)}
</div>
{/* Info Box */}
<div className="mt-6 p-4 bg-blue-900 bg-opacity-30 rounded border border-blue-500">
<h4 className="font-semibold mb-2">ℹ️ How It Works</h4>
<ul className="text-sm text-gray-300 space-y-1">
<li>• Master model outputs are automatically saved as training data (confidence ≥ 0.8)</li>
<li>• Fine-tuning uses LoRA/QLoRA for efficient training with minimal memory</li>
<li>• Checkpoints are saved in <code>training_data/checkpoints/</code></li>
<li>• After training, register the checkpoint as a new model in Engine Manager</li>
</ul>
</div>
</div>
);
};
export default TrainingDashboard;