Spaces:
Running
Running
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Neural Network Visual Architect</title> | |
| <!-- Third-party libraries for machine learning and charting --> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/4.10.0/tf.min.js"></script> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/3.9.1/chart.min.js"></script> | |
| <style> | |
| /* General Styling and Resets */ | |
| :root { | |
| --primary-color: #6a82fb; | |
| --secondary-color: #fc5c7d; | |
| --bg-color: #f4f7f6; | |
| --panel-bg: rgba(255, 255, 255, 0.9); | |
| --text-color: #333; | |
| --shadow-light: rgba(0, 0, 0, 0.05); | |
| --shadow-dark: rgba(0, 0, 0, 0.1); | |
| } | |
| * { | |
| margin: 0; | |
| padding: 0; | |
| box-sizing: border-box; | |
| } | |
| body { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
| background: linear-gradient(135deg, var(--primary-color) 0%, var(--secondary-color) 100%); | |
| min-height: 100vh; | |
| color: var(--text-color); | |
| overflow-x: hidden; | |
| } | |
| /* Main Layout */ | |
| .container { | |
| max-width: 1800px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 30px; | |
| color: white; | |
| } | |
| .header h1 { | |
| font-size: 2.8rem; | |
| font-weight: 700; | |
| margin-bottom: 10px; | |
| text-shadow: 0 4px 15px var(--shadow-dark); | |
| } | |
| .header p { | |
| font-size: 1.2rem; | |
| opacity: 0.9; | |
| } | |
| .main-layout { | |
| display: grid; | |
| grid-template-columns: 320px 1fr 420px; | |
| gap: 20px; | |
| height: calc(100vh - 150px); | |
| } | |
| .panel { | |
| background: var(--panel-bg); | |
| backdrop-filter: blur(15px); | |
| border-radius: 20px; | |
| padding: 25px; | |
| box-shadow: 0 15px 30px var(--shadow-dark); | |
| border: 1px solid rgba(255, 255, 255, 0.2); | |
| overflow-y: auto; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .panel h2 { | |
| font-size: 1.4rem; | |
| margin-bottom: 20px; | |
| color: #4a5568; | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| } | |
| /* Layer Palette (Left Panel) */ | |
| .layer-palette .layer-template { | |
| padding: 15px; | |
| border-radius: 12px; | |
| cursor: grab; | |
| transition: all 0.3s ease; | |
| text-align: center; | |
| user-select: none; | |
| margin-bottom: 15px; | |
| } | |
| .layer-template:hover { | |
| transform: translateY(-3px); | |
| box-shadow: 0 8px 25px var(--shadow-dark); | |
| } | |
| .layer-template:active { | |
| cursor: grabbing; | |
| transform: scale(0.95); | |
| } | |
| .input-layer-bg { background: linear-gradient(145deg, #e0f7fa, #b2ebf2); border: 2px solid #4dd0e1; } | |
| .dense-layer-bg { background: linear-gradient(145deg, #ffcdd2, #ef9a9a); border: 2px solid #e57373; } | |
| .output-layer-bg { background: linear-gradient(145deg, #c8e6c9, #a5d6a7); border: 2px solid #81c784; } | |
| /* Layer Configuration */ | |
| .layer-config { | |
| margin-top: 20px; | |
| padding-top: 20px; | |
| border-top: 1px solid #e0e0e0; | |
| } | |
| .config-group { margin-bottom: 15px; } | |
| .config-group label { display: block; font-size: 0.9rem; margin-bottom: 8px; color: #4a5568; font-weight: 500; } | |
| .config-group input, .config-group select, .config-group textarea { | |
| width: 100%; | |
| padding: 10px; | |
| border: 1px solid #ddd; | |
| border-radius: 8px; | |
| font-size: 0.9rem; | |
| font-family: inherit; | |
| } | |
| /* Architecture Canvas (Center Panel) */ | |
| .architecture-canvas { | |
| position: relative; | |
| background: rgba(0, 0, 0, 0.1) url('data:image/svg+xml;utf8,<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20"><circle cx="1" cy="1" r="1" fill="rgba(255,255,255,0.1)"/></svg>'); | |
| border: 2px dashed rgba(255, 255, 255, 0.4); | |
| border-radius: 15px; | |
| overflow: hidden; | |
| height: 100%; | |
| } | |
| .drop-zone-text { | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| transform: translate(-50%, -50%); | |
| text-align: center; | |
| color: rgba(255, 255, 255, 0.8); | |
| font-size: 1.3rem; | |
| pointer-events: none; | |
| } | |
| /* Individual Layer Instances on Canvas */ | |
| .layer-instance { | |
| position: absolute; | |
| padding: 10px; | |
| border-radius: 12px; | |
| cursor: move; | |
| min-width: 80px; | |
| text-align: center; | |
| user-select: none; | |
| transition: box-shadow 0.2s ease, transform 0.2s ease; | |
| backdrop-filter: blur(10px); | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| gap: 5px; | |
| } | |
| .layer-instance.selected { | |
| box-shadow: 0 0 0 3px var(--primary-color); | |
| } | |
| .layer-header { font-weight: bold; font-size: 0.9rem; } | |
| .layer-details { font-size: 0.75rem; opacity: 0.8; } | |
| .neuron-column { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| gap: 4px; /* Space between neurons */ | |
| margin-top: 5px; | |
| } | |
| .neuron { | |
| width: 12px; | |
| height: 12px; | |
| border-radius: 50%; | |
| background-color: rgba(255, 255, 255, 0.7); | |
| border: 1px solid rgba(0, 0, 0, 0.2); | |
| } | |
| .delete-btn { | |
| position: absolute; | |
| top: -10px; right: -10px; | |
| width: 24px; height: 24px; | |
| background: #e53e3e; color: white; | |
| border: none; border-radius: 50%; | |
| cursor: pointer; font-size: 14px; | |
| display: flex; align-items: center; justify-content: center; | |
| opacity: 0; transition: opacity 0.2s; | |
| z-index: 10; | |
| } | |
| .layer-instance:hover .delete-btn { opacity: 1; } | |
| /* Connections */ | |
| #connection-svg { | |
| position: absolute; | |
| top: 0; left: 0; | |
| width: 100%; height: 100%; | |
| pointer-events: none; | |
| z-index: -1; | |
| } | |
| .connection-line { | |
| stroke: rgba(255, 255, 255, 0.5); | |
| stroke-width: 1.5; | |
| } | |
| /* Training Panel (Right Panel) */ | |
| .training-panel { display: flex; flex-direction: column; } | |
| .training-panel h3 { | |
| font-size: 1.1rem; | |
| margin-top: 15px; | |
| margin-bottom: 10px; | |
| padding-bottom: 5px; | |
| border-bottom: 1px solid #eee; | |
| } | |
| .train-btn, .validate-btn, .clear-btn, .load-data-btn { | |
| border: none; | |
| padding: 12px 20px; | |
| border-radius: 10px; | |
| cursor: pointer; | |
| font-size: 1rem; | |
| font-weight: 600; | |
| transition: all 0.3s ease; | |
| margin-top: 10px; | |
| color: white; | |
| } | |
| .train-btn { background: linear-gradient(45deg, #4CAF50, #81C784); } | |
| .train-btn:hover { transform: translateY(-2px); box-shadow: 0 4px 15px rgba(76, 175, 80, 0.4); } | |
| .train-btn:disabled { background: #ccc; cursor: not-allowed; transform: none; box-shadow: none; } | |
| .validate-btn { background: linear-gradient(45deg, #2196F3, #64B5F6); } | |
| .validate-btn:hover { transform: translateY(-2px); box-shadow: 0 4px 15px rgba(33, 150, 243, 0.4); } | |
| .validate-btn:disabled { background: #ccc; cursor: not-allowed; transform: none; box-shadow: none; } | |
| .clear-btn { background: linear-gradient(45deg, #f44336, #e57373); padding: 10px 18px; } | |
| .load-data-btn { background: linear-gradient(45deg, var(--primary-color), #899cfb); padding: 10px 18px; font-size: 0.9rem; } | |
| .chart-container { | |
| margin-top: 15px; | |
| padding-top: 15px; | |
| border-top: 1px solid #eee; | |
| height: 220px; | |
| min-height: 220px; | |
| } | |
| /* Data Input Methods */ | |
| .input-method-selector { display: flex; gap: 5px; margin-bottom: 15px; } | |
| .method-btn { | |
| flex: 1; padding: 8px 12px; border: 1px solid #e2e8f0; | |
| background: white; border-radius: 6px; cursor: pointer; | |
| font-size: 0.85rem; transition: all 0.2s ease; | |
| } | |
| .method-btn.active { background: var(--primary-color); color: white; border-color: var(--primary-color); } | |
| /* Metrics Display */ | |
| .metrics { display: grid; grid-template-columns: 1fr 1fr; gap: 10px; margin-top: 15px; } | |
| .metric { text-align: center; padding: 10px; background: rgba(0,0,0,0.05); border-radius: 8px; } | |
| .metric-value { font-size: 1.2rem; font-weight: 700; color: var(--primary-color); } | |
| .metric-label { font-size: 0.8rem; color: #718096; } | |
| /* Status Messages */ | |
| .status { | |
| margin-top: 10px; padding: 12px; | |
| border-radius: 8px; font-size: 0.9rem; | |
| text-align: center; display: none; | |
| } | |
| .status.success { background: rgba(76, 175, 80, 0.15); color: #388E3C; } | |
| .status.error { background: rgba(244, 67, 54, 0.15); color: #D32F2F; } | |
| /* Progress Bar */ | |
| .progress-bar { | |
| width: 100%; height: 8px; background: #e0e0e0; | |
| border-radius: 4px; overflow: hidden; margin: 10px 0 5px 0; | |
| } | |
| .progress-fill { | |
| height: 100%; background: linear-gradient(45deg, var(--primary-color), var(--secondary-color)); | |
| width: 0%; transition: width 0.3s ease; | |
| } | |
| /* Responsive Design */ | |
| @media (max-width: 1200px) { | |
| .main-layout { | |
| grid-template-columns: 1fr; | |
| grid-template-rows: auto 500px auto; | |
| height: auto; | |
| } | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <header class="header"> | |
| <h1>🧠 Neural Network Visual Architect</h1> | |
| <p>Build, train, and visualize neural networks interactively.</p> | |
| </header> | |
| <div class="main-layout"> | |
| <!-- Left Panel: Layer Palette & Configuration --> | |
| <div class="panel"> | |
| <h2><span class="icon">🧩</span>Layer Palette</h2> | |
| <div class="layer-palette"> | |
| <div class="layer-template input-layer-bg" draggable="true" data-type="input"><h4>Input Layer</h4><p>Starting point</p></div> | |
| <div class="layer-template dense-layer-bg" draggable="true" data-type="dense"><h4>Dense Layer</h4><p>Hidden layer</p></div> | |
| <div class="layer-template output-layer-bg" draggable="true" data-type="output"><h4>Output Layer</h4><p>Prediction layer</p></div> | |
| </div> | |
| <div class="layer-config" id="layerConfig" style="display: none;"> | |
| <h3>Selected Layer Settings</h3> | |
| <div class="config-group"><label for="layerUnits">Neurons:</label><input type="number" id="layerUnits" value="8" min="1" max="16"></div> | |
| <div class="config-group"><label for="layerActivation">Activation Function:</label><select id="layerActivation"><option value="relu">ReLU</option><option value="sigmoid">Sigmoid</option><option value="tanh">Tanh</option><option value="linear">Linear</option></select></div> | |
| </div> | |
| <button class="clear-btn" onclick="clearArchitecture()" style="margin-top: auto;">Clear Architecture</button> | |
| </div> | |
| <!-- Center Panel: Architecture Canvas --> | |
| <div class="architecture-canvas" id="architectureCanvas"> | |
| <svg id="connection-svg"></svg> | |
| <div class="drop-zone-text"><p>🎯 Drag layers here to build</p></div> | |
| </div> | |
| <!-- Right Panel: Data, Training & Results --> | |
| <div class="panel training-panel"> | |
| <h2><span class="icon">📊</span>Data & Training</h2> | |
| <!-- Training Data Section --> | |
| <div id="data-controls"> | |
| <h3>Training Dataset</h3> | |
| <div class="input-method-selector"> | |
| <button class="method-btn active" id="functionBtn" onclick="switchInputMethod('function', 'training')">Generate</button> | |
| <button class="method-btn" id="manualBtn" onclick="switchInputMethod('manual', 'training')">Manual</button> | |
| </div> | |
| <div id="functionInput"> | |
| <div class="config-group"><label>Function:</label><select id="functionType" onchange="generateFunctionData()"><option value="linear">Linear</option><option value="quadratic" selected>Quadratic</option><option value="sine">Sine Wave</option><option value="exponential">Exponential</option></select></div> | |
| <div class="config-group"><label>Samples:</label><input type="number" id="numSamples" value="100" min="10" max="500" step="10" onchange="generateFunctionData()"></div> | |
| </div> | |
| <div id="manualInput" style="display: none;"> | |
| <div class="config-group"><label>X Values (comma-separated):</label><textarea id="xValues" rows="2" placeholder="e.g., 1, 2, 3, 4"></textarea></div> | |
| <div class="config-group"><label>Y Values (comma-separated):</label><textarea id="yValues" rows="2" placeholder="e.g., 2, 4, 6, 8"></textarea></div> | |
| <button class="load-data-btn" onclick="processManualData()">Load Data</button> | |
| </div> | |
| </div> | |
| <h3>Training Settings</h3> | |
| <div class="training-controls"> | |
| <div class="config-group"><label>Learning Rate:</label><input type="number" id="learningRate" value="0.01" step="0.001"></div> | |
| <div class="config-group"><label>Epochs:</label><input type="number" id="epochs" value="100" step="10"></div> | |
| <div class="config-group"><label>Optimizer:</label><select id="optimizer"><option value="adam">Adam</option><option value="sgd">SGD</option><option value="rmsprop">RMSprop</option></select></div> | |
| <button class="train-btn" id="trainBtn" onclick="trainModel()" disabled>Train Network</button> | |
| <div id="trainingProgress" style="display: none;"> | |
| <div class="progress-bar"><div class="progress-fill" id="progressFill"></div></div> | |
| <div id="progressText" style="font-size: 0.8rem; text-align: center;"></div> | |
| </div> | |
| </div> | |
| <div class="metrics" id="metricsContainer" style="display: none;"> | |
| <div class="metric"><div class="metric-value" id="lossValue">-</div><div class="metric-label">Training Loss</div></div> | |
| <div class="metric"><div class="metric-value" id="r2Value">-</div><div class="metric-label">Training R²</div></div> | |
| </div> | |
| <div id="dataStatus" class="status"></div> | |
| <div class="chart-container"> | |
| <canvas id="chart"></canvas> | |
| </div> | |
| <!-- Validation Data Section --> | |
| <div id="validation-data-controls" style="margin-top: 20px; padding-top: 20px; border-top: 2px solid #ddd;"> | |
| <h3>Validation Dataset</h3> | |
| <div class="input-method-selector"> | |
| <button class="method-btn active" id="valFunctionBtn" onclick="switchInputMethod('function', 'validation')">Generate</button> | |
| <button class="method-btn" id="valManualBtn" onclick="switchInputMethod('manual', 'validation')">Manual</button> | |
| </div> | |
| <div id="valFunctionInput"> | |
| <div class="config-group"><label>Function:</label><select id="valFunctionType" onchange="generateValidationData()"><option value="linear">Linear</option><option value="quadratic">Quadratic</option><option value="sine" selected>Sine Wave</option><option value="exponential">Exponential</option></select></div> | |
| <div class="config-group"><label>Samples:</label><input type="number" id="valNumSamples" value="50" min="10" max="500" step="10" onchange="generateValidationData()"></div> | |
| </div> | |
| <div id="valManualInput" style="display: none;"> | |
| <div class="config-group"><label>X Values (comma-separated):</label><textarea id="valXValues" rows="2" placeholder="e.g., 1.5, 2.5, 3.5"></textarea></div> | |
| <div class="config-group"><label>Y Values (comma-separated):</label><textarea id="valYValues" rows="2" placeholder="e.g., 3, 5, 7"></textarea></div> | |
| <button class="load-data-btn" onclick="processManualValidationData()">Load Data</button> | |
| </div> | |
| <button class="validate-btn" id="validateBtn" onclick="validateModel()" disabled>Validate Model</button> | |
| </div> | |
| <div class="metrics" id="validationMetricsContainer" style="display: none;"> | |
| <div class="metric"><div class="metric-value" id="validationLossValue">-</div><div class="metric-label">Validation Loss</div></div> | |
| <div class="metric"><div class="metric-value" id="validationR2Value">-</div><div class="metric-label">Validation R²</div></div> | |
| </div> | |
| <div id="validationStatus" class="status"></div> | |
| <div class="chart-container"> | |
| <canvas id="validationChart"></canvas> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| // Global state variables | |
| let dataset = null, validationDataset = null, model = null, chart = null, validationChart = null, isTraining = false; | |
| let layers = [], selectedLayerId = null, layerCounter = 0; | |
| // --- CORE LOGIC: NEURAL NETWORK ARCHITECTURE --- | |
| const canvas = document.getElementById('architectureCanvas'); | |
| const connectionSvg = document.getElementById('connection-svg'); | |
| function createLayer(type, x, y) { | |
| if ((type === 'input' && layers.some(l => l.type === 'input')) || (type === 'output' && layers.some(l => l.type === 'output'))) { | |
| showStatus(`Only one ${type} layer is allowed.`, 'error', 'data'); | |
| return; | |
| } | |
| const layerId = `layer_${layerCounter++}`; | |
| const layer = { id: layerId, type, x, y, units: type === 'input' || type === 'output' ? 1 : 8, activation: type === 'output' ? 'linear' : 'relu' }; | |
| if (type === 'dense') layer.units = Math.min(layer.units, 16); | |
| layers.push(layer); | |
| renderLayer(layer); | |
| updateConnections(); | |
| checkTrainingReady(); | |
| document.querySelector('.drop-zone-text').style.display = 'none'; | |
| } | |
| function renderLayer(layer) { | |
| let layerEl = document.getElementById(layer.id); | |
| if (!layerEl) { | |
| layerEl = document.createElement('div'); | |
| layerEl.id = layer.id; | |
| canvas.appendChild(layerEl); | |
| layerEl.addEventListener('mousedown', (e) => startDrag(e, layer)); | |
| layerEl.addEventListener('click', (e) => { e.stopPropagation(); selectLayer(layer); }); | |
| } | |
| layerEl.className = `layer-instance ${layer.type}-layer-bg`; | |
| layerEl.style.left = `${layer.x}px`; | |
| layerEl.style.top = `${layer.y}px`; | |
| if (layer.id === selectedLayerId) layerEl.classList.add('selected'); | |
| const activationText = layer.type !== 'input' ? `(${layer.activation})` : ''; | |
| let neuronsHTML = Array.from({ length: Math.min(layer.units, 16) }, () => '<div class="neuron"></div>').join(''); | |
| layerEl.innerHTML = `<div class="layer-header">${layer.type.charAt(0).toUpperCase() + layer.type.slice(1)}</div><div class="layer-details">${layer.units} Neurons ${activationText}</div><div class="neuron-column">${neuronsHTML}</div><button class="delete-btn" onclick="deleteLayer(event, '${layer.id}')">×</button>`; | |
| } | |
| function deleteLayer(e, layerId) { | |
| e.stopPropagation(); | |
| layers = layers.filter(l => l.id !== layerId); | |
| document.getElementById(layerId).remove(); | |
| if (selectedLayerId === layerId) { | |
| selectedLayerId = null; | |
| document.getElementById('layerConfig').style.display = 'none'; | |
| } | |
| updateConnections(); | |
| checkTrainingReady(); | |
| if (layers.length === 0) document.querySelector('.drop-zone-text').style.display = 'block'; | |
| } | |
| function clearArchitecture() { | |
| layers = []; selectedLayerId = null; model = null; | |
| canvas.querySelectorAll('.layer-instance').forEach(el => el.remove()); | |
| document.getElementById('layerConfig').style.display = 'none'; | |
| document.getElementById('validateBtn').disabled = true; | |
| document.getElementById('metricsContainer').style.display = 'none'; | |
| document.getElementById('validationMetricsContainer').style.display = 'none'; | |
| updateConnections(); | |
| checkTrainingReady(); | |
| document.querySelector('.drop-zone-text').style.display = 'block'; | |
| } | |
| function selectLayer(layer) { | |
| selectedLayerId = layer.id; | |
| document.querySelectorAll('.layer-instance').forEach(el => el.classList.remove('selected')); | |
| document.getElementById(layer.id).classList.add('selected'); | |
| const configPanel = document.getElementById('layerConfig'); | |
| const unitsInput = document.getElementById('layerUnits'); | |
| const activationSelect = document.getElementById('layerActivation'); | |
| unitsInput.value = layer.units; | |
| activationSelect.value = layer.activation; | |
| unitsInput.disabled = (layer.type === 'input' || layer.type === 'output'); | |
| activationSelect.disabled = (layer.type === 'input'); | |
| configPanel.style.display = 'block'; | |
| } | |
| function updateConnections() { | |
| connectionSvg.innerHTML = ''; | |
| const sortedLayers = [...layers].sort((a, b) => a.x - b.x); | |
| for (let i = 0; i < sortedLayers.length - 1; i++) { | |
| const fromEl = document.getElementById(sortedLayers[i].id); | |
| const toEl = document.getElementById(sortedLayers[i + 1].id); | |
| const fromNeurons = fromEl.querySelectorAll('.neuron'); | |
| const toNeurons = toEl.querySelectorAll('.neuron'); | |
| fromNeurons.forEach(fromNode => { | |
| toNeurons.forEach(toNode => { | |
| const line = document.createElementNS('http://www.w3.org/2000/svg', 'line'); | |
| const fromRect = fromNode.getBoundingClientRect(); | |
| const toRect = toNode.getBoundingClientRect(); | |
| const canvasRect = canvas.getBoundingClientRect(); | |
| line.setAttribute('x1', fromRect.left - canvasRect.left + fromRect.width / 2); | |
| line.setAttribute('y1', fromRect.top - canvasRect.top + fromRect.height / 2); | |
| line.setAttribute('x2', toRect.left - canvasRect.left + toRect.width / 2); | |
| line.setAttribute('y2', toRect.top - canvasRect.top + toRect.height / 2); | |
| line.setAttribute('class', 'connection-line'); | |
| connectionSvg.appendChild(line); | |
| }); | |
| }); | |
| } | |
| } | |
| // --- DRAG AND DROP FUNCTIONALITY --- | |
| canvas.addEventListener('dragover', (e) => e.preventDefault()); | |
| canvas.addEventListener('drop', (e) => { | |
| e.preventDefault(); | |
| const type = e.dataTransfer.getData('text/plain'); | |
| const rect = canvas.getBoundingClientRect(); | |
| createLayer(type, e.clientX - rect.left - 40, e.clientY - rect.top - 50); | |
| }); | |
| document.querySelectorAll('.layer-template').forEach(template => { | |
| template.addEventListener('dragstart', (e) => e.dataTransfer.setData('text/plain', template.dataset.type)); | |
| }); | |
| function startDrag(e, layer) { | |
| const layerEl = e.currentTarget; | |
| const offsetX = e.clientX - layer.x, offsetY = e.clientY - layer.y; | |
| function onMouseMove(e) { | |
| const rect = canvas.getBoundingClientRect(); | |
| layer.x = Math.max(0, Math.min(e.clientX - offsetX, rect.width - layerEl.offsetWidth)); | |
| layer.y = Math.max(0, Math.min(e.clientY - offsetY, rect.height - layerEl.offsetHeight)); | |
| layerEl.style.left = `${layer.x}px`; | |
| layerEl.style.top = `${layer.y}px`; | |
| updateConnections(); | |
| } | |
| function onMouseUp() { | |
| document.removeEventListener('mousemove', onMouseMove); | |
| document.removeEventListener('mouseup', onMouseUp); | |
| } | |
| document.addEventListener('mousemove', onMouseMove); | |
| document.addEventListener('mouseup', onMouseUp); | |
| } | |
| // --- MODEL TRAINING & DATA HANDLING --- | |
| async function trainModel() { | |
| if (!dataset || isTraining || layers.length < 2) return; | |
| isTraining = true; | |
| const trainBtn = document.getElementById('trainBtn'); | |
| trainBtn.disabled = true; | |
| document.getElementById('validateBtn').disabled = true; | |
| trainBtn.textContent = 'Training...'; | |
| document.getElementById('trainingProgress').style.display = 'block'; | |
| document.getElementById('metricsContainer').style.display = 'none'; | |
| let finalLoss = 0; | |
| let inputTensor, outputTensor, predTensor; | |
| try { | |
| const xs = dataset.map(d => d.x); | |
| const ys = dataset.map(d => d.y); | |
| inputTensor = tf.tensor2d(xs, [xs.length, 1]); | |
| outputTensor = tf.tensor2d(ys, [ys.length, 1]); | |
| model = tf.sequential(); | |
| const sortedLayers = [...layers].sort((a, b) => a.x - b.x); | |
| sortedLayers.forEach((layer, i) => { | |
| if (layer.type === 'input') return; | |
| let config = { units: layer.units, activation: layer.activation }; | |
| if (i === 1 || (i === 0 && sortedLayers[0].type !== 'input')) config.inputShape = [1]; | |
| model.add(tf.layers.dense(config)); | |
| }); | |
| const learningRate = parseFloat(document.getElementById('learningRate').value); | |
| const optimizerType = document.getElementById('optimizer').value; | |
| let optimizer = optimizerType === 'sgd' ? tf.train.sgd(learningRate) : optimizerType === 'rmsprop' ? tf.train.rmsprop(learningRate) : tf.train.adam(learningRate); | |
| model.compile({ optimizer, loss: 'meanSquaredError' }); | |
| const epochs = parseInt(document.getElementById('epochs').value); | |
| await model.fit(inputTensor, outputTensor, { | |
| epochs: epochs, | |
| callbacks: { | |
| onEpochEnd: (epoch, logs) => { | |
| finalLoss = logs.loss; | |
| const progress = ((epoch + 1) / epochs) * 100; | |
| document.getElementById('progressFill').style.width = `${progress}%`; | |
| document.getElementById('progressText').textContent = `Epoch ${epoch + 1}/${epochs} - Loss: ${finalLoss.toFixed(5)}`; | |
| } | |
| } | |
| }); | |
| predTensor = model.predict(inputTensor); | |
| const predData = await predTensor.data(); | |
| plotPredictions(Array.from(predData), finalLoss, 'training'); | |
| showStatus('✓ Model trained successfully!', 'success', 'data'); | |
| } catch (error) { | |
| showStatus(`Training Error: ${error.message}`, 'error', 'data'); | |
| console.error(error); | |
| } finally { | |
| if (inputTensor) inputTensor.dispose(); | |
| if (outputTensor) outputTensor.dispose(); | |
| if (predTensor) predTensor.dispose(); | |
| isTraining = false; | |
| trainBtn.disabled = false; | |
| trainBtn.textContent = 'Train Network'; | |
| if (model) document.getElementById('validateBtn').disabled = false; | |
| } | |
| } | |
| async function validateModel() { | |
| if (!model || !validationDataset) { | |
| showStatus('Train a model and load validation data first.', 'error', 'validation'); | |
| return; | |
| } | |
| let valInputTensor, valOutputTensor, valPredTensor; | |
| try { | |
| const xs = validationDataset.map(d => d.x); | |
| const ys = validationDataset.map(d => d.y); | |
| valInputTensor = tf.tensor2d(xs, [xs.length, 1]); | |
| valOutputTensor = tf.tensor2d(ys, [ys.length, 1]); | |
| valPredTensor = model.predict(valInputTensor); | |
| const lossTensor = tf.losses.meanSquaredError(valOutputTensor, valPredTensor); | |
| const loss = await lossTensor.data(); | |
| lossTensor.dispose(); | |
| const predData = await valPredTensor.data(); | |
| plotPredictions(Array.from(predData), loss[0], 'validation'); | |
| showStatus('✓ Validation complete!', 'success', 'validation'); | |
| } catch (error) { | |
| showStatus(`Validation Error: ${error.message}`, 'error', 'validation'); | |
| console.error(error); | |
| } finally { | |
| if (valInputTensor) valInputTensor.dispose(); | |
| if (valOutputTensor) valOutputTensor.dispose(); | |
| if (valPredTensor) valPredTensor.dispose(); | |
| } | |
| } | |
| function processManualData() { | |
| const xText = document.getElementById('xValues').value.trim(); | |
| const yText = document.getElementById('yValues').value.trim(); | |
| if (!xText || !yText) return showStatus('Please enter both X and Y values.', 'error', 'data'); | |
| try { | |
| const xValues = xText.split(',').map(v => parseFloat(v.trim())); | |
| const yValues = yText.split(',').map(v => parseFloat(v.trim())); | |
| if (xValues.length !== yValues.length) return showStatus('X and Y must have the same number of values.', 'error', 'data'); | |
| if (xValues.some(isNaN) || yValues.some(isNaN)) return showStatus('All values must be valid numbers.', 'error', 'data'); | |
| dataset = xValues.map((x, i) => ({ x, y: yValues[i] })); | |
| updateChart('training'); | |
| checkTrainingReady(); | |
| showStatus(`✓ Loaded ${dataset.length} training data points`, 'success', 'data'); | |
| } catch (error) { | |
| showStatus(`Error processing data: ${error.message}`, 'error', 'data'); | |
| } | |
| } | |
| function processManualValidationData() { | |
| const xText = document.getElementById('valXValues').value.trim(); | |
| const yText = document.getElementById('valYValues').value.trim(); | |
| if (!xText || !yText) return showStatus('Please enter both X and Y values.', 'error', 'validation'); | |
| try { | |
| const xValues = xText.split(',').map(v => parseFloat(v.trim())); | |
| const yValues = yText.split(',').map(v => parseFloat(v.trim())); | |
| if (xValues.length !== yValues.length) return showStatus('X and Y must have the same number of values.', 'error', 'validation'); | |
| if (xValues.some(isNaN) || yValues.some(isNaN)) return showStatus('All values must be valid numbers.', 'error', 'validation'); | |
| validationDataset = xValues.map((x, i) => ({ x, y: yValues[i] })); | |
| updateChart('validation'); | |
| showStatus(`✓ Loaded ${validationDataset.length} validation data points`, 'success', 'validation'); | |
| } catch (error) { | |
| showStatus(`Error processing validation data: ${error.message}`, 'error', 'validation'); | |
| } | |
| } | |
| function generateFunctionData() { | |
| const type = document.getElementById('functionType').value; | |
| const numSamples = parseInt(document.getElementById('numSamples').value); | |
| const data = Array.from({ length: numSamples }, (_, i) => { | |
| const x = -5 + (i * 10 / (numSamples -1)); // Scale x from -5 to 5 | |
| let y; | |
| switch (type) { | |
| case 'quadratic': y = 0.5 * x**2 - x - 2; break; | |
| case 'sine': y = 3 * Math.sin(x); break; | |
| case 'exponential': y = Math.exp(0.5 * x); break; | |
| default: y = 2 * x + 1; | |
| } | |
| return { x, y: y + (Math.random() - 0.5) * 2.5 }; | |
| }); | |
| dataset = data; | |
| updateChart('training'); | |
| checkTrainingReady(); | |
| showStatus(`✓ Generated ${type} training dataset`, 'success', 'data'); | |
| } | |
| function generateValidationData() { | |
| const type = document.getElementById('valFunctionType').value; | |
| const numSamples = parseInt(document.getElementById('valNumSamples').value); | |
| const data = Array.from({ length: numSamples }, (_, i) => { | |
| // Generate data from a different range (e.g., 5 to 15) to test extrapolation | |
| const x = 5 + (i * 10 / (numSamples-1)); | |
| let y; | |
| switch (type) { | |
| case 'quadratic': y = 0.5 * x**2 - x - 2; break; | |
| case 'sine': y = 3 * Math.sin(x); break; | |
| case 'exponential': y = Math.exp(0.5 * x); break; | |
| default: y = 2 * x + 1; | |
| } | |
| return { x, y: y + (Math.random() - 0.5) * 2.5 }; // Add some noise | |
| }); | |
| validationDataset = data; | |
| updateChart('validation'); | |
| showStatus(`✓ Generated new ${type} validation dataset`, 'success', 'validation'); | |
| } | |
| // --- UI & UTILITY FUNCTIONS --- | |
| function updateChart(mode) { | |
| const targetChart = mode === 'training' ? chart : validationChart; | |
| const targetDataset = mode === 'training' ? dataset : validationDataset; | |
| if (!targetChart || !targetDataset) return; | |
| targetChart.data.datasets[0].data = targetDataset; | |
| targetChart.data.datasets[1].data = []; | |
| targetChart.update(); | |
| } | |
| function plotPredictions(predictions, loss, mode) { | |
| const targetChart = mode === 'training' ? chart : validationChart; | |
| const targetDataset = mode === 'training' ? dataset : validationDataset; | |
| const sortedData = [...targetDataset].sort((a, b) => a.x - b.x); | |
| const predPoints = sortedData.map((point) => ({ | |
| x: point.x, | |
| y: predictions[targetDataset.findIndex(d => d.x === point.x)] | |
| })); | |
| targetChart.data.datasets[1].data = predPoints; | |
| targetChart.update(); | |
| const actuals = targetDataset.map(d => d.y); | |
| const r2 = calculateR2(actuals, predictions); | |
| if (mode === 'training') { | |
| document.getElementById('lossValue').textContent = loss.toFixed(5); | |
| document.getElementById('r2Value').textContent = r2.toFixed(4); | |
| document.getElementById('metricsContainer').style.display = 'grid'; | |
| } else { | |
| document.getElementById('validationLossValue').textContent = loss.toFixed(5); | |
| document.getElementById('validationR2Value').textContent = r2.toFixed(4); | |
| document.getElementById('validationMetricsContainer').style.display = 'grid'; | |
| } | |
| } | |
| function calculateR2(actual, predicted) { | |
| const actualMean = actual.reduce((a, b) => a + b, 0) / actual.length; | |
| const totalSumSquares = actual.reduce((sum, val) => sum + (val - actualMean) ** 2, 0); | |
| const residualSumSquares = actual.reduce((sum, val, i) => sum + (val - predicted[i]) ** 2, 0); | |
| return 1 - (residualSumSquares / totalSumSquares); | |
| } | |
| function showStatus(message, type, context) { | |
| const statusEl = context === 'validation' ? document.getElementById('validationStatus') : document.getElementById('dataStatus'); | |
| statusEl.textContent = message; | |
| statusEl.className = `status ${type}`; | |
| statusEl.style.display = 'block'; | |
| if (type !== 'error') setTimeout(() => statusEl.style.display = 'none', 3000); | |
| } | |
| function checkTrainingReady() { | |
| document.getElementById('trainBtn').disabled = !(layers.some(l => l.type === 'input') && layers.some(l => l.type === 'output') && dataset && layers.length >= 2); | |
| } | |
| function switchInputMethod(method, context) { | |
| if (context === 'training') { | |
| document.getElementById('manualInput').style.display = method === 'manual' ? 'block' : 'none'; | |
| document.getElementById('functionInput').style.display = method === 'function' ? 'block' : 'none'; | |
| document.getElementById('manualBtn').classList.toggle('active', method === 'manual'); | |
| document.getElementById('functionBtn').classList.toggle('active', method === 'function'); | |
| } else { | |
| document.getElementById('valManualInput').style.display = method === 'manual' ? 'block' : 'none'; | |
| document.getElementById('valFunctionInput').style.display = method === 'function' ? 'block' : 'none'; | |
| document.getElementById('valManualBtn').classList.toggle('active', method === 'manual'); | |
| document.getElementById('valFunctionBtn').classList.toggle('active', method === 'function'); | |
| } | |
| } | |
| // Event Listeners for Layer Configuration | |
| document.getElementById('layerUnits').addEventListener('input', (e) => { | |
| if (!selectedLayerId) return; | |
| const layer = layers.find(l => l.id === selectedLayerId); | |
| if (layer) { layer.units = parseInt(e.target.value); renderLayer(layer); updateConnections(); } | |
| }); | |
| document.getElementById('layerActivation').addEventListener('change', (e) => { | |
| if (!selectedLayerId) return; | |
| const layer = layers.find(l => l.id === selectedLayerId); | |
| if (layer) { layer.activation = e.target.value; renderLayer(layer); } | |
| }); | |
| // --- INITIALIZATION --- | |
| document.addEventListener('DOMContentLoaded', () => { | |
| const ctx = document.getElementById('chart').getContext('2d'); | |
| chart = new Chart(ctx, { | |
| type: 'scatter', | |
| data: { datasets: [{ label: 'Training Data', data: [], backgroundColor: 'rgba(106, 130, 251, 0.7)' }, { label: 'Model Prediction', data: [], borderColor: 'rgba(252, 92, 125, 1)', backgroundColor: 'transparent', type: 'line', fill: false, tension: 0.4, borderWidth: 2 }] }, | |
| options: { responsive: true, maintainAspectRatio: false, plugins: { title: { display: true, text: 'Training Results' } } } | |
| }); | |
| const valCtx = document.getElementById('validationChart').getContext('2d'); | |
| validationChart = new Chart(valCtx, { | |
| type: 'scatter', | |
| data: { datasets: [{ label: 'Validation Data', data: [], backgroundColor: 'rgba(33, 150, 243, 0.7)' }, { label: 'Model Prediction', data: [], borderColor: 'rgba(255, 152, 0, 1)', backgroundColor: 'transparent', type: 'line', fill: false, tension: 0.4, borderWidth: 2 }] }, | |
| options: { responsive: true, maintainAspectRatio: false, plugins: { title: { display: true, text: 'Validation Results' } } } | |
| }); | |
| generateFunctionData(); | |
| generateValidationData(); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |