Spaces:
Sleeping
Sleeping
| let ws; | |
| function initializeComparisonCharts() { | |
| const lossData = [{ | |
| name: 'Model A Loss', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }, { | |
| name: 'Model B Loss', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }]; | |
| const accuracyData = [{ | |
| name: 'Model A Accuracy', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }, { | |
| name: 'Model B Accuracy', | |
| x: [], | |
| y: [], | |
| type: 'scatter' | |
| }]; | |
| Plotly.newPlot('comparison-loss-plot', lossData, { | |
| title: 'Loss Comparison', | |
| xaxis: { title: 'Iterations' }, | |
| yaxis: { title: 'Loss' } | |
| }); | |
| Plotly.newPlot('comparison-accuracy-plot', accuracyData, { | |
| title: 'Accuracy Comparison', | |
| xaxis: { title: 'Iterations' }, | |
| yaxis: { title: 'Accuracy (%)' } | |
| }); | |
| } | |
| async function compareModels() { | |
| const config = { | |
| model1: { | |
| kernels: [ | |
| parseInt(document.getElementById('model1_kernel1').value), | |
| parseInt(document.getElementById('model1_kernel2').value), | |
| parseInt(document.getElementById('model1_kernel3').value) | |
| ], | |
| optimizer: document.getElementById('model1_optimizer').value, | |
| batch_size: parseInt(document.getElementById('model1_batch_size').value), | |
| epochs: parseInt(document.getElementById('model1_epochs').value) | |
| }, | |
| model2: { | |
| kernels: [ | |
| parseInt(document.getElementById('model2_kernel1').value), | |
| parseInt(document.getElementById('model2_kernel2').value), | |
| parseInt(document.getElementById('model2_kernel3').value) | |
| ], | |
| optimizer: document.getElementById('model2_optimizer').value, | |
| batch_size: parseInt(document.getElementById('model2_batch_size').value), | |
| epochs: parseInt(document.getElementById('model2_epochs').value) | |
| } | |
| }; | |
| // Show comparison progress section | |
| document.getElementById('comparison-progress').classList.remove('hidden'); | |
| initializeComparisonCharts(); | |
| try { | |
| const response = await fetch('/api/train_compare', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify(config) | |
| }); | |
| const data = await response.json(); | |
| if (data.status === 'success') { | |
| displayComparisonResults(data); | |
| alert('Model comparison completed successfully!'); | |
| } | |
| } catch (error) { | |
| console.error('Error:', error); | |
| alert('Error during model comparison. Please check console for details.'); | |
| } | |
| } | |
| function displayComparisonResults(data) { | |
| const logsDiv = document.getElementById('comparison-logs'); | |
| logsDiv.innerHTML = ` | |
| <div class="comparison-model"> | |
| <h4>Model A</h4> | |
| <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p> | |
| <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p> | |
| <p>Model Name: ${data.model1_results.model_name}</p> | |
| </div> | |
| <div class="comparison-model"> | |
| <h4>Model B</h4> | |
| <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p> | |
| <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p> | |
| <p>Model Name: ${data.model2_results.model_name}</p> | |
| </div> | |
| `; | |
| } | |
| // Add these helper functions to get the parameters | |
| function getModelParameters() { | |
| try { | |
| const params = { | |
| model_a: { | |
| block1: parseInt(document.getElementById('model1_kernel1').value), | |
| block2: parseInt(document.getElementById('model1_kernel2').value), | |
| block3: parseInt(document.getElementById('model1_kernel3').value), | |
| optimizer: document.getElementById('model1_optimizer').value, | |
| batch_size: parseInt(document.getElementById('model1_batch_size').value), | |
| epochs: parseInt(document.getElementById('model1_epochs').value) | |
| }, | |
| model_b: { | |
| block1: parseInt(document.getElementById('model2_kernel1').value), | |
| block2: parseInt(document.getElementById('model2_kernel2').value), | |
| block3: parseInt(document.getElementById('model2_kernel3').value), | |
| optimizer: document.getElementById('model2_optimizer').value, | |
| batch_size: parseInt(document.getElementById('model2_batch_size').value), | |
| epochs: parseInt(document.getElementById('model2_epochs').value) | |
| } | |
| }; | |
| // Validate that all values are present and valid | |
| for (const model of ['model_a', 'model_b']) { | |
| for (const [key, value] of Object.entries(params[model])) { | |
| if (value === null || value === undefined || Number.isNaN(value)) { | |
| throw new Error(`Invalid value for ${model} ${key}: ${value}`); | |
| } | |
| } | |
| } | |
| console.log('Collected and validated model parameters:', params); | |
| return params; | |
| } catch (error) { | |
| console.error('Error in getModelParameters:', error); | |
| throw error; | |
| } | |
| } | |
| function getDatasetParameters() { | |
| return { | |
| batch_size: parseInt(document.getElementById('model1_batch_size').value), // Using model1's batch size for dataset | |
| shuffle: true | |
| }; | |
| } | |
| // Update the WebSocket event listener | |
| document.getElementById('startComparisonBtn').addEventListener('click', function() { | |
| console.log('Start Comparison button clicked'); | |
| // Validate form inputs before proceeding | |
| const formInputs = document.querySelectorAll('input[type="number"], select'); // Added select for optimizer | |
| let isValid = true; | |
| let formValues = {}; | |
| formInputs.forEach(input => { | |
| console.log(`Checking input ${input.id}: ${input.value}`); | |
| formValues[input.id] = input.value; | |
| if (!input.value) { | |
| console.error(`Missing value for ${input.id}`); | |
| isValid = false; | |
| } | |
| }); | |
| console.log('Form values:', formValues); // Log all form values | |
| if (!isValid) { | |
| alert('Please fill in all required fields'); | |
| return; | |
| } | |
| // Show comparison progress section | |
| document.getElementById('comparison-progress').classList.remove('hidden'); | |
| console.log('Initialized comparison charts'); | |
| initializeComparisonCharts(); | |
| console.log('Attempting WebSocket connection...'); | |
| const ws = new WebSocket(`ws://${window.location.host}/ws/compare`); | |
| ws.onopen = function() { | |
| console.log('WebSocket connection established'); | |
| const parameters = { | |
| model_params: getModelParameters(), | |
| dataset_params: getDatasetParameters() | |
| }; | |
| const message = { | |
| action: 'start_training', | |
| parameters: parameters | |
| }; | |
| console.log('Preparing to send message:', JSON.stringify(message, null, 2)); | |
| // Add a small delay to ensure WebSocket is ready | |
| setTimeout(() => { | |
| try { | |
| ws.send(JSON.stringify(message)); | |
| console.log('Message sent successfully'); | |
| } catch (error) { | |
| console.error('Error sending message:', error); | |
| alert('Error sending training parameters. Please check console for details.'); | |
| } | |
| }, 100); | |
| }; | |
| ws.onmessage = function(event) { | |
| console.log('Received WebSocket message:', event.data); | |
| try { | |
| const data = JSON.parse(event.data); | |
| console.log('Parsed message data:', data); | |
| updateTrainingProgress(data); | |
| } catch (error) { | |
| console.error('Error processing message:', error); | |
| } | |
| }; | |
| ws.onerror = function(error) { | |
| console.error('WebSocket error:', error); | |
| alert('Connection error occurred. Please check console for details.'); | |
| }; | |
| ws.onclose = function(event) { | |
| console.log('WebSocket connection closed. Code:', event.code, 'Reason:', event.reason); | |
| }; | |
| }); | |
| // Add the updateTrainingProgress function | |
| function updateTrainingProgress(data) { | |
| if (data.status === 'training') { | |
| // Update loss plot | |
| Plotly.extendTraces('comparison-loss-plot', { | |
| y: [[data.metrics.loss]], | |
| }, [data.model === 'A' ? 0 : 1]); | |
| // Update accuracy plot | |
| Plotly.extendTraces('comparison-accuracy-plot', { | |
| y: [[data.metrics.accuracy]], | |
| }, [data.model === 'A' ? 0 : 1]); | |
| // Update progress text | |
| const progressText = document.getElementById('training-progress-text'); | |
| progressText.textContent = `Training ${data.model === 'A' ? 'Model A' : 'Model B'} - Epoch ${data.epoch + 1}`; | |
| } else if (data.status === 'complete') { | |
| // Handle training completion | |
| document.getElementById('training-progress-text').textContent = 'Training Complete!'; | |
| displayComparisonResults(data.metrics); | |
| } else if (data.status === 'error') { | |
| // Handle error | |
| console.error('Training error:', data.message); | |
| alert(`Training error: ${data.message}`); | |
| } | |
| } |