Spaces:
Sleeping
Sleeping
| let ws; | |
| let lossChart; | |
| let accuracyChart; | |
| function showTrainingForm(type) { | |
| const singleForm = document.getElementById('single-model-form'); | |
| const compareForm = document.getElementById('compare-models-form'); | |
| if (type === 'single') { | |
| singleForm.classList.remove('hidden'); | |
| compareForm.classList.add('hidden'); | |
| } else { | |
| singleForm.classList.add('hidden'); | |
| compareForm.classList.remove('hidden'); | |
| } | |
| } | |
| function initializeCharts() { | |
| const lossData = [{ | |
| name: 'Training Loss', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }, { | |
| name: 'Validation Loss', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }]; | |
| const accuracyData = [{ | |
| name: 'Training Accuracy', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }, { | |
| name: 'Validation Accuracy', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }]; | |
| Plotly.newPlot('loss-plot', lossData, { | |
| title: 'Training and Validation Loss', | |
| xaxis: { title: 'Iterations' }, | |
| yaxis: { title: 'Loss' } | |
| }); | |
| Plotly.newPlot('accuracy-plot', accuracyData, { | |
| title: 'Training and Validation Accuracy', | |
| xaxis: { title: 'Iterations' }, | |
| yaxis: { title: 'Accuracy (%)' } | |
| }); | |
| } | |
| function updateCharts(data) { | |
| const iteration = data.epoch * data.batch; | |
| Plotly.extendTraces('loss-plot', { | |
| x: [[iteration], [iteration]], | |
| y: [[data.train_loss], [data.val_loss]] | |
| }, [0, 1]); | |
| Plotly.extendTraces('accuracy-plot', { | |
| x: [[iteration], [iteration]], | |
| y: [[data.train_acc], [data.val_acc]] | |
| }, [0, 1]); | |
| // Update training logs | |
| const logsDiv = document.getElementById('training-logs'); | |
| logsDiv.innerHTML = ` | |
| <p>Epoch: ${data.epoch + 1}</p> | |
| <p>Training Loss: ${data.train_loss.toFixed(4)}</p> | |
| <p>Training Accuracy: ${data.train_acc.toFixed(2)}%</p> | |
| <p>Validation Loss: ${data.val_loss.toFixed(4)}</p> | |
| <p>Validation Accuracy: ${data.val_acc.toFixed(2)}%</p> | |
| `; | |
| } | |
| async function trainSingleModel() { | |
| const config = { | |
| kernels: [ | |
| parseInt(document.getElementById('kernel1').value), | |
| parseInt(document.getElementById('kernel2').value), | |
| parseInt(document.getElementById('kernel3').value) | |
| ], | |
| optimizer: document.getElementById('optimizer').value, | |
| batch_size: parseInt(document.getElementById('batch_size').value), | |
| epochs: parseInt(document.getElementById('epochs').value) | |
| }; | |
| // Show progress section and initialize charts | |
| document.getElementById('training-progress').classList.remove('hidden'); | |
| initializeCharts(); | |
| // Connect to WebSocket | |
| ws = new WebSocket(`ws://${window.location.host}/ws/train`); | |
| ws.onmessage = function(event) { | |
| const data = JSON.parse(event.data); | |
| updateCharts(data); | |
| }; | |
| try { | |
| const response = await fetch('/api/train_single', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify(config) | |
| }); | |
| const data = await response.json(); | |
| if (data.status === 'success') { | |
| alert('Training completed successfully!'); | |
| } | |
| } catch (error) { | |
| console.error('Error:', error); | |
| alert('Error during training. Please check console for details.'); | |
| } | |
| } | |
| async function compareModels() { | |
| const config = { | |
| model1: { | |
| kernels: [ | |
| parseInt(document.getElementById('model1_kernel1').value), | |
| parseInt(document.getElementById('model1_kernel2').value), | |
| parseInt(document.getElementById('model1_kernel3').value) | |
| ], | |
| optimizer: document.getElementById('model1_optimizer').value, | |
| batch_size: parseInt(document.getElementById('model1_batch_size').value), | |
| epochs: parseInt(document.getElementById('model1_epochs').value) | |
| }, | |
| model2: { | |
| kernels: [ | |
| parseInt(document.getElementById('model2_kernel1').value), | |
| parseInt(document.getElementById('model2_kernel2').value), | |
| parseInt(document.getElementById('model2_kernel3').value) | |
| ], | |
| optimizer: document.getElementById('model2_optimizer').value, | |
| batch_size: parseInt(document.getElementById('model2_batch_size').value), | |
| epochs: parseInt(document.getElementById('model2_epochs').value) | |
| } | |
| }; | |
| // Show comparison progress section | |
| document.getElementById('comparison-progress').classList.remove('hidden'); | |
| initializeComparisonCharts(); | |
| try { | |
| const response = await fetch('/api/train_compare', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify(config) | |
| }); | |
| const data = await response.json(); | |
| if (data.status === 'success') { | |
| displayComparisonResults(data); | |
| alert('Model comparison completed successfully!'); | |
| } | |
| } catch (error) { | |
| console.error('Error:', error); | |
| alert('Error during model comparison. Please check console for details.'); | |
| } | |
| } | |
| function initializeComparisonCharts() { | |
| const lossData = [{ | |
| name: 'Model A Loss', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }, { | |
| name: 'Model B Loss', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }]; | |
| const accuracyData = [{ | |
| name: 'Model A Accuracy', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }, { | |
| name: 'Model B Accuracy', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }]; | |
| Plotly.newPlot('comparison-loss-plot', lossData, { | |
| title: 'Loss Comparison', | |
| xaxis: { title: 'Iterations' }, | |
| yaxis: { title: 'Loss' } | |
| }); | |
| Plotly.newPlot('comparison-accuracy-plot', accuracyData, { | |
| title: 'Accuracy Comparison', | |
| xaxis: { title: 'Iterations' }, | |
| yaxis: { title: 'Accuracy (%)' } | |
| }); | |
| } | |
| function displayComparisonResults(data) { | |
| const logsDiv = document.getElementById('comparison-logs'); | |
| logsDiv.innerHTML = ` | |
| <div class="comparison-model"> | |
| <h4>Model A</h4> | |
| <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p> | |
| <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p> | |
| <p>Model Name: ${data.model1_results.model_name}</p> | |
| </div> | |
| <div class="comparison-model"> | |
| <h4>Model B</h4> | |
| <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p> | |
| <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p> | |
| <p>Model Name: ${data.model2_results.model_name}</p> | |
| </div> | |
| `; | |
| } | |
| function displayResults(data) { | |
| const resultsDiv = document.getElementById('training-results'); | |
| // Display training results | |
| } |