Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision import datasets, transforms | |
| import numpy as np | |
| import threading | |
| app = FastAPI(title="3D CNN Visualizer + MNIST", version="0.2.0") | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL = None | |
| TRAINING_DONE = False | |
| TRAINING_ERROR = None | |
| class SimpleCNN(nn.Module): | |
| """ | |
| 1x28x28 | |
| -> conv1 (4, 5x5) -> 4x24x24 | |
| -> maxpool 2x2 -> 4x12x12 | |
| -> conv2 (8, 5x5) -> 8x8x8 | |
| -> maxpool 2x2 -> 8x4x4 | |
| -> flatten -> 128 | |
| -> fc -> 10 | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(1, 4, kernel_size=5) # 28 -> 24 | |
| self.conv2 = nn.Conv2d(4, 8, kernel_size=5) # 12 -> 8 | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.fc = nn.Linear(8 * 4 * 4, 10) | |
| def forward(self, x, return_activations: bool = False): | |
| acts = {} | |
| x = self.conv1(x) # (N,4,24,24) | |
| x = F.relu(x) | |
| acts["conv1"] = x | |
| x = self.pool(x) # (N,4,12,12) | |
| acts["pool1"] = x | |
| x = self.conv2(x) # (N,8,8,8) | |
| x = F.relu(x) | |
| acts["conv2"] = x | |
| x = self.pool(x) # (N,8,4,4) | |
| acts["pool2"] = x | |
| x = x.view(x.size(0), -1) # (N,128) | |
| acts["flat"] = x | |
| x = self.fc(x) # (N,10) | |
| if return_activations: | |
| return x, acts | |
| return x | |
| def train_model(): | |
| global MODEL, TRAINING_DONE, TRAINING_ERROR | |
| try: | |
| transform = transforms.ToTensor() | |
| train_dataset = datasets.MNIST( | |
| root="./data", train=True, download=True, transform=transform | |
| ) | |
| subset_size = min(10000, len(train_dataset)) | |
| train_subset = torch.utils.data.Subset(train_dataset, list(range(subset_size))) | |
| loader = DataLoader(train_subset, batch_size=128, shuffle=True) | |
| model = SimpleCNN().to(DEVICE) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| criterion = nn.CrossEntropyLoss() | |
| model.train() | |
| epochs = 1 | |
| for _ in range(epochs): | |
| for images, labels in loader: | |
| images, labels = images.to(DEVICE), labels.to(DEVICE) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| model.eval() | |
| MODEL = model | |
| TRAINING_DONE = True | |
| except Exception as e: | |
| TRAINING_ERROR = str(e) | |
| TRAINING_DONE = False | |
| # νμ΅μ λ°±κ·ΈλΌμ΄λμμ μμ | |
| threading.Thread(target=train_model, daemon=True).start() | |
| class PredictRequest(BaseModel): | |
| pixels: list[float] # 28*28 = 784 | |
| async def index(): | |
| return HTML_PAGE | |
| async def status(): | |
| return { | |
| "training_done": TRAINING_DONE, | |
| "training_error": TRAINING_ERROR, | |
| "device": str(DEVICE), | |
| } | |
| async def predict(req: PredictRequest): | |
| if TRAINING_ERROR: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Training failed", "detail": TRAINING_ERROR}, | |
| ) | |
| if not TRAINING_DONE or MODEL is None: | |
| return JSONResponse( | |
| status_code=503, | |
| content={"error": "Model not ready yet. Still training."}, | |
| ) | |
| if len(req.pixels) != 28 * 28: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "pixels must have length 784 (28x28)"}, | |
| ) | |
| arr = np.array(req.pixels, dtype=np.float32).reshape(1, 1, 28, 28) | |
| x = torch.from_numpy(arr).to(DEVICE) | |
| with torch.no_grad(): | |
| logits, acts = MODEL(x, return_activations=True) | |
| probs = torch.softmax(logits, dim=1).cpu().numpy()[0] | |
| conv1 = acts["conv1"].cpu().numpy()[0].tolist() # [4,24,24] | |
| pool1 = acts["pool1"].cpu().numpy()[0].tolist() # [4,12,12] | |
| conv2 = acts["conv2"].cpu().numpy()[0].tolist() # [8,8,8] | |
| pool2 = acts["pool2"].cpu().numpy()[0].tolist() # [8,4,4] | |
| flat = acts["flat"].cpu().numpy()[0].tolist() # [128] | |
| predicted_class = int(probs.argmax()) | |
| return { | |
| "predicted_class": predicted_class, | |
| "probabilities": probs.tolist(), | |
| "activations": { | |
| "conv1": conv1, | |
| "pool1": pool1, | |
| "conv2": conv2, | |
| "pool2": pool2, | |
| "flat": flat, | |
| }, | |
| } | |
| HTML_PAGE = r""" | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>3D CNN Visualizer</title> | |
| <!-- Tailwind CSS --> | |
| <script src="https://cdn.tailwindcss.com"></script> | |
| <script> | |
| tailwind.config = { | |
| theme: { | |
| extend: { | |
| colors: { | |
| 'neon-green': '#00ff00', | |
| 'dark-green': '#002200', | |
| }, | |
| fontFamily: { | |
| mono: ['ui-monospace', 'SFMono-Regular', 'Menlo', 'Monaco', 'Consolas', 'monospace'], | |
| } | |
| } | |
| } | |
| } | |
| </script> | |
| <!-- Import Map for React & Three.js ecosystem --> | |
| <script type="importmap"> | |
| { | |
| "imports": { | |
| "react": "https://esm.sh/react@18.2.0", | |
| "react-dom/client": "https://esm.sh/react-dom@18.2.0/client", | |
| "three": "https://esm.sh/three@0.160.0", | |
| "@react-three/fiber": "https://esm.sh/@react-three/fiber@8.15.12?external=react,react-dom,three", | |
| "@react-three/drei": "https://esm.sh/@react-three/drei@9.96.1?external=react,react-dom,three,@react-three/fiber", | |
| "lucide-react": "https://esm.sh/lucide-react@0.303.0?external=react" | |
| } | |
| } | |
| </script> | |
| <!-- Babel for JSX --> | |
| <script src="https://unpkg.com/@babel/standalone/babel.min.js"></script> | |
| <style> | |
| body { margin: 0; background-color: #000; overflow: hidden; color: white; } | |
| canvas { touch-action: none; } | |
| .hud-panel { | |
| background: rgba(0, 10, 0, 0.85); | |
| border: 1px solid #004400; | |
| box-shadow: 0 0 20px rgba(0, 255, 0, 0.1); | |
| backdrop-filter: blur(10px); | |
| } | |
| .btn-holo { | |
| background: linear-gradient(180deg, rgba(0,40,0,0.8) 0%, rgba(0,20,0,0.9) 100%); | |
| border: 1px solid #00ff00; | |
| color: #00ff00; | |
| text-shadow: 0 0 5px rgba(0,255,0,0.5); | |
| } | |
| .btn-holo:hover:not(:disabled) { | |
| background: #00ff00; | |
| color: #000; | |
| box-shadow: 0 0 15px #00ff00; | |
| } | |
| .btn-holo:disabled { | |
| border-color: #004400; | |
| color: #004400; | |
| cursor: not-allowed; | |
| } | |
| .neon-text { | |
| text-shadow: 0 0 10px rgba(0, 255, 0, 0.6); | |
| } | |
| .scanlines { | |
| position: fixed; | |
| top: 0; left: 0; width: 100%; height: 100%; | |
| background: linear-gradient(to bottom, rgba(255,255,255,0), rgba(255,255,255,0) 50%, rgba(0,0,0,0.1) 50%, rgba(0,0,0,0.1)); | |
| background-size: 100% 4px; | |
| pointer-events: none; | |
| z-index: 50; | |
| opacity: 0.3; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div id="root"></div> | |
| <div class="scanlines"></div> | |
| <script type="text/babel" data-type="module"> | |
| import React, { useState, useEffect, useRef, useMemo, useLayoutEffect } from 'react'; | |
| import { createRoot } from 'react-dom/client'; | |
| import { Canvas, useFrame } from '@react-three/fiber'; | |
| import { Text, Stars, Environment, Grid } from '@react-three/drei'; | |
| import * as THREE from 'three'; | |
| import { Play, RotateCcw, Activity, Layers, Cpu, Scan, Zap } from 'lucide-react'; | |
| // --- 1. CONFIG --- | |
| const ARCHITECTURE = [ | |
| { id: 'input', type: 'input', width: 28, height: 28, depth: 1, z: 0, label: 'INPUT (28x28)' }, | |
| { id: 'conv1', type: 'conv', width: 24, height: 24, depth: 4, z: -15, label: 'CONV LAYER 1 (4x24x24)' }, | |
| { id: 'pool1', type: 'pool', width: 12, height: 12, depth: 4, z: -25, label: 'POOLING 1 (4x12x12)' }, | |
| { id: 'conv2', type: 'conv', width: 8, height: 8, depth: 8, z: -35, label: 'CONV LAYER 2 (8x8x8)' }, | |
| { id: 'pool2', type: 'pool', width: 4, height: 4, depth: 8, z: -45, label: 'POOLING 2 (8x4x4)' }, | |
| { id: 'flat', type: 'flatten', width: 16, height: 8, depth: 1, z: -52, label: 'FLATTEN (128 β 16x8)' }, | |
| { id: 'fc', type: 'fc', width: 1, height: 10, depth: 1, z: -62, label: 'CLASSIFICATION (10)' } | |
| ]; | |
| const createGrid = (w, h, val = 0) => Array.from({ length: h }, () => Array(w).fill(val)); | |
| // --- 2. AUDIO --- | |
| class AudioEngine { | |
| constructor() { | |
| this.ctx = null; | |
| this.master = null; | |
| } | |
| init() { | |
| if(this.ctx) return; | |
| const Ctx = window.AudioContext || window.webkitAudioContext; | |
| this.ctx = new Ctx(); | |
| this.master = this.ctx.createGain(); | |
| this.master.gain.value = 0.08; | |
| this.master.connect(this.ctx.destination); | |
| } | |
| playTone(freq, type, dur) { | |
| if(!this.ctx) return; | |
| if(this.ctx.state === 'suspended') this.ctx.resume(); | |
| const osc = this.ctx.createOscillator(); | |
| const gain = this.ctx.createGain(); | |
| osc.type = type; | |
| osc.frequency.setValueAtTime(freq, this.ctx.currentTime); | |
| gain.gain.setValueAtTime(0, this.ctx.currentTime); | |
| gain.gain.linearRampToValueAtTime(1, this.ctx.currentTime + 0.01); | |
| gain.gain.exponentialRampToValueAtTime(0.01, this.ctx.currentTime + dur); | |
| osc.connect(gain); | |
| gain.connect(this.master); | |
| osc.start(); | |
| osc.stop(this.ctx.currentTime + dur); | |
| } | |
| playStep(step) { | |
| if(step===0) this.playTone(400, 'sine', 0.1); | |
| if(step===1) this.playTone(150, 'sawtooth', 0.1); | |
| if(step===2) this.playTone(200, 'square', 0.1); | |
| if(step===3) this.playTone(300, 'sawtooth', 0.1); | |
| if(step===4) this.playTone(600, 'square', 0.1); | |
| if(step===5) this.playTone(800, 'triangle', 0.1); | |
| if(step===6) { this.playTone(440, 'sine', 0.2); setTimeout(()=>this.playTone(660,'sine',0.2), 80); } | |
| } | |
| } | |
| const audio = new AudioEngine(); | |
| // --- 3. 3D LAYERS --- | |
| const VoxelLayer = ({ config, data, active }) => { | |
| const meshRef = useRef(); | |
| const dummy = useMemo(() => new THREE.Object3D(), []); | |
| const color = useMemo(() => new THREE.Color(), []); | |
| const count = config.type === 'fc' | |
| ? 10 | |
| : (config.type === 'flatten' ? 128 : config.width * config.height * config.depth); | |
| useLayoutEffect(() => { | |
| if(!meshRef.current) return; | |
| let idx = 0; | |
| if(config.type === 'fc') { | |
| for(let i=0; i<10; i++) { | |
| const val = data[0]?.[i]?.[0] || 0; | |
| dummy.position.set(0, (4.5 - i) * 0.7, 0); | |
| dummy.scale.set(val > 0.01 ? 0.5 + val * 4 : 0.5, 0.5, 0.5); | |
| dummy.updateMatrix(); | |
| meshRef.current.setMatrixAt(idx, dummy.matrix); | |
| if(active && val > 0.01) color.setHSL(0.33, 1, 0.5); | |
| else color.setHex(0x002200); | |
| meshRef.current.setColorAt(idx, color); | |
| idx++; | |
| } | |
| } else if (config.type === 'flatten') { | |
| const flatData = []; | |
| if(data.length) data.forEach(slice => slice.forEach(row => row.forEach(v => flatData.push(v)))); | |
| for(let y=0; y<8; y++) { | |
| for(let x=0; x<16; x++) { | |
| const val = flatData[idx] || 0; | |
| dummy.position.set((x-8)*0.25, (y-4)*0.25, 0); | |
| if(val > 0.05) { | |
| dummy.scale.set(0.9, 0.9, 0.9); | |
| color.setHSL(0.35, 1, 0.2 + val*0.8); | |
| } else { | |
| dummy.scale.set(0,0,0); | |
| } | |
| dummy.updateMatrix(); | |
| meshRef.current.setMatrixAt(idx, dummy.matrix); | |
| meshRef.current.setColorAt(idx, color); | |
| idx++; | |
| } | |
| } | |
| } else if (config.type === 'input') { | |
| const gap = 0.2; | |
| for(let y=0; y<28; y++) { | |
| for(let x=0; x<28; x++) { | |
| const val = data[0]?.[y]?.[x] || 0; | |
| dummy.position.set((x - 14) * gap, (13 - y) * gap, 0); | |
| if(val > 0.05) { | |
| dummy.scale.set(0.9,0.9,0.9); | |
| color.setHSL(0.35, 1, 0.2 + val*0.8); | |
| } else { | |
| dummy.scale.set(0,0,0); | |
| } | |
| dummy.updateMatrix(); | |
| meshRef.current.setMatrixAt(idx, dummy.matrix); | |
| meshRef.current.setColorAt(idx, color); | |
| idx++; | |
| } | |
| } | |
| } else { | |
| const gap = 0.3; | |
| const layerW = config.width * 0.2; | |
| const totalW = (layerW + gap) * config.depth; | |
| const startX = -totalW / 2; | |
| for(let z=0; z<config.depth; z++) { | |
| const zOff = startX + z * (layerW + gap); | |
| for(let y=0; y<config.height; y++) { | |
| for(let x=0; x<config.width; x++) { | |
| const val = data[z]?.[config.height-1-y]?.[x] || 0; | |
| dummy.position.set(zOff + x*0.2, (y - config.height/2)*0.2, 0); | |
| if(val > 0.05) { | |
| dummy.scale.set(0.9,0.9,0.9); | |
| color.setHSL(0.35, 1, 0.2 + val*0.8); | |
| } else { | |
| dummy.scale.set(0,0,0); | |
| } | |
| dummy.updateMatrix(); | |
| meshRef.current.setMatrixAt(idx, dummy.matrix); | |
| meshRef.current.setColorAt(idx, color); | |
| idx++; | |
| } | |
| } | |
| } | |
| } | |
| meshRef.current.instanceMatrix.needsUpdate = true; | |
| if(meshRef.current.instanceColor) meshRef.current.instanceColor.needsUpdate = true; | |
| }, [config, data, active, dummy, color]); | |
| return ( | |
| <group position={[0, 0, config.z]}> | |
| <Text position={[0, config.height * 0.12 + 2, 0]} fontSize={0.6} color={active?"#fff":"#004400"}> | |
| {config.label} | |
| </Text> | |
| <instancedMesh ref={meshRef} args={[undefined, undefined, count]}> | |
| <boxGeometry args={[0.2, 0.2, 0.2]} /> | |
| <meshStandardMaterial color="#0f0" transparent opacity={0.9} blending={THREE.AdditiveBlending} toneMapped={false} /> | |
| </instancedMesh> | |
| {config.type === 'fc' && Array.from({length:10}).map((_, i) => ( | |
| <group key={i} position={[0, (4.5 - i) * 0.7, 0]}> | |
| <Text position={[-1.5, 0, 0]} fontSize={0.4} color="#0f0">{i}</Text> | |
| <Text position={[4, 0, 0]} fontSize={0.4} color="#fff"> | |
| {((data[0]?.[i]?.[0] || 0) * 100).toFixed(1)}% | |
| </Text> | |
| </group> | |
| ))} | |
| </group> | |
| ); | |
| }; | |
| const CameraController = ({ step }) => { | |
| useFrame((state) => { | |
| const targetPos = new THREE.Vector3(); | |
| const targetLook = new THREE.Vector3(); | |
| if(step === -1) { | |
| targetPos.set(25, 10, 5); | |
| targetLook.set(0, 0, -35); | |
| } else { | |
| const zMap = {0:0, 1:-15, 2:-25, 3:-35, 4:-45, 5:-52, 6:-62}; | |
| const z = zMap[step] || 0; | |
| targetPos.set(18, 5, z + 8); | |
| targetLook.set(0, 0, z); | |
| } | |
| state.camera.position.lerp(targetPos, 0.05); | |
| const look = new THREE.Vector3(0,0,-1).applyQuaternion(state.camera.quaternion).add(state.camera.position); | |
| look.lerp(targetLook, 0.05); | |
| state.camera.lookAt(look); | |
| }); | |
| return null; | |
| }; | |
| // --- 4. DRAWING PAD --- | |
| const DrawingPad = ({ data, onChange, disabled }) => { | |
| const canvasRef = useRef(null); | |
| const [isDrawing, setIsDrawing] = useState(false); | |
| const redrawGrid = () => { | |
| const ctx = canvasRef.current?.getContext('2d'); | |
| if(!ctx) return; | |
| ctx.fillStyle = 'black'; | |
| ctx.fillRect(0,0,280,280); | |
| ctx.strokeStyle = '#002200'; | |
| ctx.lineWidth = 1; | |
| ctx.beginPath(); | |
| for(let i=0; i<=280; i+=28) { | |
| ctx.moveTo(i,0); ctx.lineTo(i,280); | |
| ctx.moveTo(0,i); ctx.lineTo(280,i); | |
| } | |
| ctx.stroke(); | |
| }; | |
| useEffect(() => { | |
| redrawGrid(); | |
| }, []); | |
| useEffect(() => { | |
| const allZero = data.every(row => row.every(v => v === 0)); | |
| if(allZero) { | |
| redrawGrid(); | |
| } | |
| }, [data]); | |
| const getPos = (e) => { | |
| const r = canvasRef.current.getBoundingClientRect(); | |
| const x = (e.touches?e.touches[0].clientX:e.clientX) - r.left; | |
| const y = (e.touches?e.touches[0].clientY:e.clientY) - r.top; | |
| const scaleX = canvasRef.current.width / r.width; | |
| const scaleY = canvasRef.current.height / r.height; | |
| return { x: x*scaleX, y: y*scaleY }; | |
| }; | |
| const draw = (e) => { | |
| if(disabled || !isDrawing) return; | |
| const ctx = canvasRef.current.getContext('2d'); | |
| const {x,y} = getPos(e); | |
| ctx.strokeStyle = '#0f0'; | |
| ctx.lineWidth = 25; | |
| ctx.lineCap = 'round'; | |
| ctx.shadowBlur = 10; | |
| ctx.shadowColor = '#0f0'; | |
| ctx.lineTo(x,y); | |
| ctx.stroke(); | |
| ctx.beginPath(); | |
| ctx.moveTo(x,y); | |
| }; | |
| const start = (e) => { | |
| if(disabled) return; | |
| setIsDrawing(true); | |
| const {x,y} = getPos(e); | |
| const ctx = canvasRef.current.getContext('2d'); | |
| ctx.beginPath(); | |
| ctx.moveTo(x,y); | |
| }; | |
| const end = () => { | |
| if(!isDrawing) return; | |
| setIsDrawing(false); | |
| const ctx = canvasRef.current.getContext('2d'); | |
| const temp = document.createElement('canvas'); | |
| temp.width=28; | |
| temp.height=28; | |
| temp.getContext('2d').drawImage(canvasRef.current,0,0,28,28); | |
| const img = temp.getContext('2d').getImageData(0,0,28,28).data; | |
| const grid = createGrid(28,28); | |
| for(let i=0; i<28*28; i++) grid[Math.floor(i/28)][i%28] = img[i*4+1]/255; | |
| onChange(grid); | |
| }; | |
| return ( | |
| <canvas ref={canvasRef} width={280} height={280} | |
| className={`w-[220px] h-[220px] rounded border border-green-800 bg-black cursor-crosshair ${disabled ? 'opacity-50 pointer-events-none' : ''}`} | |
| onMouseDown={start} onMouseMove={draw} onMouseUp={end} onMouseLeave={()=>setIsDrawing(false)} | |
| onTouchStart={start} onTouchMove={draw} onTouchEnd={end} | |
| /> | |
| ); | |
| }; | |
| // --- 5. MAIN APP --- | |
| const App = () => { | |
| const [activations, setActivations] = useState({}); | |
| const [step, setStep] = useState(-1); | |
| const [processing, setProcessing] = useState(false); | |
| const [inputData, setInputData] = useState(createGrid(28,28)); | |
| const [statusText, setStatusText] = useState("Checking model status..."); | |
| const [lastPrediction, setLastPrediction] = useState(null); | |
| const [padKey, setPadKey] = useState(0); | |
| useEffect(() => { | |
| audio.init(); | |
| reset(); | |
| const interval = setInterval(async () => { | |
| try { | |
| const res = await fetch("/status"); | |
| const json = await res.json(); | |
| if(json.training_error) { | |
| setStatusText("Training error: " + json.training_error); | |
| } else if(!json.training_done) { | |
| setStatusText("Model training in progress..."); | |
| } else { | |
| setStatusText("Model ready on " + json.device); | |
| clearInterval(interval); | |
| } | |
| } catch(e) { | |
| setStatusText("Status check failed"); | |
| } | |
| }, 3000); | |
| return () => clearInterval(interval); | |
| }, []); | |
| const reset = () => { | |
| setStep(-1); | |
| setActivations(ARCHITECTURE.reduce((acc,l) => ({...acc, [l.id]: []}), {})); | |
| setInputData(createGrid(28,28)); | |
| setLastPrediction(null); | |
| setPadKey(k => k + 1); | |
| }; | |
| const delay = ms => new Promise(r => setTimeout(r, ms)); | |
| const run = async () => { | |
| if(processing) return; | |
| setProcessing(true); | |
| // 1) μ λ ₯ νλνΌ | |
| const flat = []; | |
| for(let y=0; y<28; y++) for(let x=0; x<28; x++) flat.push(inputData[y][x]); | |
| let probs = null; | |
| let predClass = null; | |
| let acts = null; | |
| try { | |
| const res = await fetch("/predict", { | |
| method: "POST", | |
| headers: { "Content-Type": "application/json" }, | |
| body: JSON.stringify({ pixels: flat }), | |
| }); | |
| const json = await res.json(); | |
| if(res.ok) { | |
| probs = json.probabilities; | |
| predClass = json.predicted_class; | |
| acts = json.activations; | |
| setLastPrediction({ | |
| cls: predClass, | |
| conf: probs[predClass] | |
| }); | |
| } else { | |
| alert("Error: " + (json.error || "Unknown")); | |
| setProcessing(false); | |
| return; | |
| } | |
| } catch (e) { | |
| console.error(e); | |
| alert("Predict request failed."); | |
| setProcessing(false); | |
| return; | |
| } | |
| // 2) PyTorch activations β μκ°νμ© ν¬λ§· λ³ν | |
| // conv1 / pool1 / conv2 / pool2 λ κ·Έλλ‘ μ¬μ© (depth x h x w) | |
| const conv1 = acts.conv1 || []; | |
| const pool1 = acts.pool1 || []; | |
| const conv2 = acts.conv2 || []; | |
| const pool2 = acts.pool2 || []; | |
| const flatVec = acts.flat || []; // length 128 | |
| // flat: 1 x 8 x 16μΌλ‘ reshape | |
| const flatGrid = []; | |
| for(let y=0; y<8; y++) { | |
| const row = []; | |
| for(let x=0; x<16; x++) { | |
| const idx = y * 16 + x; | |
| row.push(flatVec[idx] || 0); | |
| } | |
| flatGrid.push(row); | |
| } | |
| const flatData = [flatGrid]; // depth=1 | |
| // fc: [probability] ννλ‘ κ°μΈμ μκ°ν | |
| const fcData = [probs.map(p => [p])]; | |
| // 3) λ¨κ³λ³λ‘ activations μν μ λ°μ΄νΈ (μ λλ©μ΄μ ) | |
| setActivations(prev => ({...prev, input: [inputData]})); | |
| setStep(0); audio.playStep(0); await delay(400); | |
| setActivations(prev => ({...prev, conv1})); | |
| setStep(1); audio.playStep(1); await delay(400); | |
| setActivations(prev => ({...prev, pool1})); | |
| setStep(2); audio.playStep(2); await delay(400); | |
| setActivations(prev => ({...prev, conv2})); | |
| setStep(3); audio.playStep(3); await delay(400); | |
| setActivations(prev => ({...prev, pool2})); | |
| setStep(4); audio.playStep(4); await delay(400); | |
| setActivations(prev => ({...prev, flat: flatData})); | |
| setStep(5); audio.playStep(5); await delay(400); | |
| setActivations(prev => ({...prev, fc: fcData})); | |
| setStep(6); audio.playStep(6); | |
| await delay(1500); | |
| setProcessing(false); | |
| setStep(-1); | |
| }; | |
| return ( | |
| <div className="w-full h-screen relative bg-black font-mono"> | |
| <Canvas shadows camera={{ position: [25, 10, 5], fov: 45 }}> | |
| <CameraController step={step} /> | |
| <color attach="background" args={['#000200']} /> | |
| <fog attach="fog" args={['#000200', 20, 90]} /> | |
| <ambientLight intensity={0.2} /> | |
| <pointLight position={[10, 20, 10]} intensity={1.5} color="#00ff00" distance={50} /> | |
| <group> | |
| {ARCHITECTURE.map((cfg, i) => ( | |
| <VoxelLayer key={cfg.id} config={cfg} data={activations[cfg.id] || []} active={step===i} /> | |
| ))} | |
| <Grid args={[200, 200]} cellSize={1} cellThickness={1} sectionSize={5} sectionThickness={1.5} fadeDistance={60} sectionColor="#004400" cellColor="#001100" position={[0, -5, -30]} /> | |
| </group> | |
| <Stars radius={100} depth={50} count={3000} factor={4} saturation={0} fade speed={1} /> | |
| <Environment preset="city" /> | |
| </Canvas> | |
| <div className="absolute inset-0 pointer-events-none flex flex-col justify-between p-4 z-10"> | |
| <div className="flex justify-between items-start"> | |
| <div className="hud-panel p-4 rounded-br-2xl border-l-4 border-l-green-500"> | |
| <h1 className="text-2xl md:text-4xl font-black tracking-tighter neon-text flex items-center gap-3"> | |
| <Cpu className="text-neon-green animate-pulse" /> DEEP <span className="text-neon-green">CNN</span> | |
| </h1> | |
| <div className="text-xs text-green-400 mt-2 flex items-center gap-2"> | |
| <Activity size={12} /> {processing ? "PROCESSING TENSORS..." : "ONLINE"} | |
| </div> | |
| <div className="text-[10px] text-green-500 mt-1">{statusText}</div> | |
| {lastPrediction && ( | |
| <div className="text-xs text-green-300 mt-1"> | |
| PREDICTED: | |
| <span className="font-bold ml-1 text-neon-green text-sm"> | |
| {lastPrediction.cls} | |
| </span> | |
| <span className="ml-1"> | |
| ({(lastPrediction.conf * 100).toFixed(1)}%) | |
| </span> | |
| </div> | |
| )} | |
| </div> | |
| </div> | |
| <div className="flex flex-col md:flex-row items-end gap-6 pointer-events-auto"> | |
| <div className="hud-panel p-4 rounded-tr-2xl backdrop-blur-xl max-w-sm"> | |
| <div className="flex justify-between items-center mb-2 text-green-400"> | |
| <div className="text-xs font-bold tracking-widest flex items-center gap-2"> | |
| <Scan size={14} /> INPUT SENSOR | |
| </div> | |
| <button onClick={reset} disabled={processing} className="hover:text-white transition-colors"> | |
| <RotateCcw size={16} /> | |
| </button> | |
| </div> | |
| <DrawingPad key={padKey} data={inputData} onChange={setInputData} disabled={processing} /> | |
| <button onClick={run} disabled={processing} className="w-full mt-4 py-3 rounded btn-holo flex justify-center items-center gap-2 font-bold transition-all"> | |
| {processing ? <Activity className="animate-spin" size={18} /> : <Play size={18} fill="currentColor" />} | |
| {processing ? 'CALCULATING...' : 'RUN INFERENCE'} | |
| </button> | |
| </div> | |
| <div className="hud-panel p-5 hidden md:block rounded-t-xl min-w-[260px] border-b-0"> | |
| <div className="text-xs text-green-500 font-bold mb-3 flex items-center gap-2"> | |
| <Layers size={14} /> PIPELINE STATUS | |
| </div> | |
| <div className="space-y-2"> | |
| {ARCHITECTURE.map((l, i) => ( | |
| <div key={l.id} className={`flex items-center gap-3 text-xs transition-all duration-300 ${step===i ? 'text-white translate-x-2' : 'text-green-900'}`}> | |
| <div className={`w-2 h-2 rounded-sm ${step===i ? 'bg-neon-green shadow-[0_0_8px_#0f0]' : 'bg-green-900'}`} /> | |
| <span className={step===i ? 'font-bold' : ''}>{l.label}</span> | |
| {step===i && <Zap size={10} className="ml-auto text-yellow-400 animate-pulse" />} | |
| </div> | |
| ))} | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| ); | |
| }; | |
| const root = createRoot(document.getElementById('root')); | |
| root.render(<App />); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |