popboat1 commited on
Commit
656ae3a
·
1 Parent(s): c7da975

Fix Tensor shape parsing bug, implement SVG icons, and migrate video streaming to robust HTTP polling

Browse files
frontend/3d-visualizer/src/app/page.js CHANGED
@@ -7,6 +7,11 @@ import { InputCube } from '@/components/InputCube';
7
  import { LayerCube } from '@/components/LayerCube';
8
  import { OutputNode } from '@/components/OutputNode';
9
 
 
 
 
 
 
10
  export default function NetworkVisualizer() {
11
  const [activeTab, setActiveTab] = useState({ type: 'prediction', layerIndex: null });
12
  const [layerData, setLayerData] = useState([]);
@@ -20,42 +25,12 @@ export default function NetworkVisualizer() {
20
 
21
  const videoRef = useRef(null);
22
  const canvasRef = useRef(null);
23
- const wsRef = useRef(null);
24
  const isAwaitingResponse = useRef(false);
25
  const animationFrameId = useRef(null);
26
  const lastFrameTime = useRef(0);
27
- const isVideoStateRef = useRef(false);
28
-
29
- const connectWebSocket = () => {
30
- if (wsRef.current && (wsRef.current.readyState === WebSocket.OPEN || wsRef.current.readyState === WebSocket.CONNECTING)) {
31
- return;
32
- }
33
-
34
- const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
35
- const wsUrl = `${protocol}//${window.location.host}/ws/predict-video`;
36
-
37
- wsRef.current = new WebSocket(wsUrl);
38
-
39
- wsRef.current.onmessage = (event) => {
40
- if (!isVideoStateRef.current) return;
41
- const data = JSON.parse(event.data);
42
- setLayerData(data.layers);
43
- setPrediction(data.prediction);
44
- isAwaitingResponse.current = false;
45
- };
46
-
47
- wsRef.current.onclose = () => {
48
- console.log("WebSocket closed by server.");
49
- isAwaitingResponse.current = false;
50
- };
51
- };
52
 
53
  useEffect(() => {
54
- connectWebSocket();
55
- return () => {
56
- if (wsRef.current) wsRef.current.close();
57
- cancelAnimationFrame(animationFrameId.current);
58
- };
59
  }, []);
60
 
61
  const processVideoFrame = (timestamp) => {
@@ -64,18 +39,39 @@ export default function NetworkVisualizer() {
64
  return;
65
  }
66
 
67
- if (!wsRef.current || wsRef.current.readyState === WebSocket.CLOSED) {
68
- connectWebSocket();
69
- }
70
-
71
- if (timestamp - lastFrameTime.current >= 80) {
72
- if (!isAwaitingResponse.current && wsRef.current?.readyState === WebSocket.OPEN) {
73
  const canvas = canvasRef.current;
74
  const ctx = canvas.getContext('2d', { willReadFrequently: true });
75
  ctx.drawImage(videoRef.current, 0, 0, 227, 227);
76
- const base64Frame = canvas.toDataURL('image/jpeg', 0.4);
77
- isAwaitingResponse.current = true;
78
- wsRef.current.send(base64Frame);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  lastFrameTime.current = timestamp;
80
  }
81
  }
@@ -98,9 +94,14 @@ export default function NetworkVisualizer() {
98
  if (previewImage) currentX += 0.1 + (blockGap * 0.5);
99
 
100
  const mappedLayers = activeData.map((layer) => {
101
- const sizeY = Math.max(0.8, layer.shape[0] * 0.07);
102
- const sizeZ = Math.max(0.8, layer.shape[1] * 0.07);
103
- const sizeX = Math.max(0.8, layer.shape[2] * 0.01);
 
 
 
 
 
104
  const xPos = currentX + sizeX / 2;
105
  currentX = xPos + sizeX / 2 + blockGap;
106
  return { ...layer, size: [sizeX, sizeY, sizeZ], xPos };
@@ -145,7 +146,6 @@ export default function NetworkVisualizer() {
145
  const isVid = file.type.startsWith('video/');
146
 
147
  setIsVideo(isVid);
148
- isVideoStateRef.current = isVid;
149
  setPreviewImage(fileUrl);
150
  setActiveTab({ type: 'prediction', layerIndex: null });
151
  setZoomedFeature(null);
@@ -182,7 +182,6 @@ export default function NetworkVisualizer() {
182
  className="hidden absolute top-0 left-0 w-0 h-0"
183
  autoPlay muted loop playsInline
184
  onPlay={() => {
185
- // Start the loop if it isn't already running
186
  if (!animationFrameId.current) {
187
  animationFrameId.current = requestAnimationFrame(processVideoFrame);
188
  }
@@ -273,9 +272,9 @@ export default function NetworkVisualizer() {
273
  {!isPanelOpen && (
274
  <button
275
  onClick={() => setIsPanelOpen(true)}
276
- className="absolute top-6 right-6 z-30 bg-[#111] border border-white/10 text-[#00ffcc] font-mono font-bold px-4 py-2 rounded-lg shadow-2xl hover:bg-[#222] transition-colors text-xs md:text-sm"
277
  >
278
- DATA PANEL
279
  </button>
280
  )}
281
 
@@ -284,18 +283,18 @@ export default function NetworkVisualizer() {
284
  >
285
  <div className="p-4 md:p-5 border-b border-white/10 bg-black/40 flex justify-between items-center">
286
  <div className="flex items-center justify-between bg-[#111] border border-gray-700 rounded-lg p-1 w-48 shadow-inner">
287
- <button onClick={handlePrevTab} className="px-3 py-1 text-gray-400 hover:text-white hover:bg-gray-700 rounded transition-colors text-sm"></button>
288
  <span className="font-mono font-bold text-sm text-[#00ffcc] tracking-widest">
289
  {activeTab.type === 'prediction' ? 'OUTPUT' : `LAYER ${activeTab.layerIndex}`}
290
  </span>
291
- <button onClick={handleNextTab} className="px-3 py-1 text-gray-400 hover:text-white hover:bg-gray-700 rounded transition-colors text-sm"></button>
292
  </div>
293
  <button
294
  onClick={() => setIsPanelOpen(false)}
295
- className="text-gray-500 hover:text-white font-bold p-2 rounded-full hover:bg-white/10 transition-colors"
296
  title="Close Panel"
297
  >
298
-
299
  </button>
300
  </div>
301
 
@@ -305,7 +304,7 @@ export default function NetworkVisualizer() {
305
  <div className="mb-6">
306
  <p className="text-[10px] md:text-xs text-gray-500 uppercase tracking-wider mb-1">Tensor Shape</p>
307
  <p className="font-mono text-xs md:text-sm bg-black/50 px-3 py-2 rounded-lg border border-white/10 tracking-widest inline-block text-gray-200">
308
- {activeLayerPanelData.shape?.join(' × ')}
309
  </p>
310
  </div>
311
  <div>
 
7
  import { LayerCube } from '@/components/LayerCube';
8
  import { OutputNode } from '@/components/OutputNode';
9
 
10
+ // Clean SVG Icons
11
+ const ChevronLeft = () => <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2.5" strokeLinecap="round" strokeLinejoin="round"><polyline points="15 18 9 12 15 6"></polyline></svg>;
12
+ const ChevronRight = () => <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2.5" strokeLinecap="round" strokeLinejoin="round"><polyline points="9 18 15 12 9 6"></polyline></svg>;
13
+ const CloseIcon = () => <svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2.5" strokeLinecap="round" strokeLinejoin="round"><line x1="18" y1="6" x2="6" y2="18"></line><line x1="6" y1="6" x2="18" y2="18"></line></svg>;
14
+
15
  export default function NetworkVisualizer() {
16
  const [activeTab, setActiveTab] = useState({ type: 'prediction', layerIndex: null });
17
  const [layerData, setLayerData] = useState([]);
 
25
 
26
  const videoRef = useRef(null);
27
  const canvasRef = useRef(null);
 
28
  const isAwaitingResponse = useRef(false);
29
  const animationFrameId = useRef(null);
30
  const lastFrameTime = useRef(0);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  useEffect(() => {
33
+ return () => cancelAnimationFrame(animationFrameId.current);
 
 
 
 
34
  }, []);
35
 
36
  const processVideoFrame = (timestamp) => {
 
39
  return;
40
  }
41
 
42
+ if (timestamp - lastFrameTime.current >= 150) {
43
+ if (!isAwaitingResponse.current) {
44
+ isAwaitingResponse.current = true;
 
 
 
45
  const canvas = canvasRef.current;
46
  const ctx = canvas.getContext('2d', { willReadFrequently: true });
47
  ctx.drawImage(videoRef.current, 0, 0, 227, 227);
48
+
49
+ canvas.toBlob(async (blob) => {
50
+ if (!blob) {
51
+ isAwaitingResponse.current = false;
52
+ return;
53
+ }
54
+
55
+ const formData = new FormData();
56
+ formData.append("file", blob, "frame.jpg");
57
+
58
+ try {
59
+ const response = await fetch("/predict", {
60
+ method: "POST",
61
+ body: formData,
62
+ });
63
+ const data = await response.json();
64
+ if (data && data.layers) {
65
+ setLayerData(data.layers);
66
+ setPrediction(data.prediction);
67
+ }
68
+ } catch (error) {
69
+ console.error("Frame dropped:", error);
70
+ } finally {
71
+ isAwaitingResponse.current = false;
72
+ }
73
+ }, 'image/jpeg', 0.6);
74
+
75
  lastFrameTime.current = timestamp;
76
  }
77
  }
 
94
  if (previewImage) currentX += 0.1 + (blockGap * 0.5);
95
 
96
  const mappedLayers = activeData.map((layer) => {
97
+ const s0 = layer.shape && layer.shape[0] ? layer.shape[0] : 10;
98
+ const s1 = layer.shape && layer.shape[1] ? layer.shape[1] : 10;
99
+ const s2 = layer.shape && layer.shape[2] ? layer.shape[2] : 10;
100
+
101
+ const sizeY = Math.max(0.8, s0 * 0.07);
102
+ const sizeZ = Math.max(0.8, s1 * 0.07);
103
+ const sizeX = Math.max(0.8, s2 * 0.01);
104
+
105
  const xPos = currentX + sizeX / 2;
106
  currentX = xPos + sizeX / 2 + blockGap;
107
  return { ...layer, size: [sizeX, sizeY, sizeZ], xPos };
 
146
  const isVid = file.type.startsWith('video/');
147
 
148
  setIsVideo(isVid);
 
149
  setPreviewImage(fileUrl);
150
  setActiveTab({ type: 'prediction', layerIndex: null });
151
  setZoomedFeature(null);
 
182
  className="hidden absolute top-0 left-0 w-0 h-0"
183
  autoPlay muted loop playsInline
184
  onPlay={() => {
 
185
  if (!animationFrameId.current) {
186
  animationFrameId.current = requestAnimationFrame(processVideoFrame);
187
  }
 
272
  {!isPanelOpen && (
273
  <button
274
  onClick={() => setIsPanelOpen(true)}
275
+ className="absolute top-6 right-6 z-30 bg-[#111] border border-white/10 text-[#00ffcc] font-mono font-bold px-4 py-2 rounded-lg shadow-2xl hover:bg-[#222] transition-colors text-xs md:text-sm flex items-center gap-2"
276
  >
277
+ <ChevronLeft /> DATA PANEL
278
  </button>
279
  )}
280
 
 
283
  >
284
  <div className="p-4 md:p-5 border-b border-white/10 bg-black/40 flex justify-between items-center">
285
  <div className="flex items-center justify-between bg-[#111] border border-gray-700 rounded-lg p-1 w-48 shadow-inner">
286
+ <button onClick={handlePrevTab} className="p-1.5 text-gray-400 hover:text-white hover:bg-gray-700 rounded transition-colors"><ChevronLeft /></button>
287
  <span className="font-mono font-bold text-sm text-[#00ffcc] tracking-widest">
288
  {activeTab.type === 'prediction' ? 'OUTPUT' : `LAYER ${activeTab.layerIndex}`}
289
  </span>
290
+ <button onClick={handleNextTab} className="p-1.5 text-gray-400 hover:text-white hover:bg-gray-700 rounded transition-colors"><ChevronRight /></button>
291
  </div>
292
  <button
293
  onClick={() => setIsPanelOpen(false)}
294
+ className="text-gray-500 hover:text-white p-2 rounded-full hover:bg-white/10 transition-colors"
295
  title="Close Panel"
296
  >
297
+ <CloseIcon />
298
  </button>
299
  </div>
300
 
 
304
  <div className="mb-6">
305
  <p className="text-[10px] md:text-xs text-gray-500 uppercase tracking-wider mb-1">Tensor Shape</p>
306
  <p className="font-mono text-xs md:text-sm bg-black/50 px-3 py-2 rounded-lg border border-white/10 tracking-widest inline-block text-gray-200">
307
+ {activeLayerPanelData.shape ? activeLayerPanelData.shape.join(' × ') : 'Loading...'}
308
  </p>
309
  </div>
310
  <div>
src/api/api.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, UploadFile, File, WebSocket, WebSocketDisconnect
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.staticfiles import StaticFiles
4
  import tensorflow as tf
@@ -7,7 +7,6 @@ import cv2
7
  import base64
8
  import math
9
  import os
10
- import asyncio
11
 
12
  gpus = tf.config.list_physical_devices('GPU')
13
  if gpus:
@@ -56,7 +55,6 @@ feature_extractor = tf.keras.Model(inputs=model.inputs, outputs=[layer.output fo
56
  CIFAR10_CLASSES = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer',
57
  'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
58
 
59
- # YOUR ORIGINAL PNG LOGIC RESTORED
60
  def generate_feature_grid(feature_map, max_features=64):
61
  if len(feature_map.shape) == 4:
62
  feature_map = feature_map[0]
@@ -84,10 +82,8 @@ def generate_feature_grid(feature_map, max_features=64):
84
  grid_image = np.uint8(grid_image)
85
 
86
  colored_grid = cv2.applyColorMap(grid_image, cv2.COLORMAP_VIRIDIS)
87
-
88
  b_channel, g_channel, r_channel = cv2.split(colored_grid)
89
  alpha_channel = grid_image
90
-
91
  transparent_grid = cv2.merge((b_channel, g_channel, r_channel, alpha_channel))
92
 
93
  _, buffer = cv2.imencode('.png', transparent_grid)
@@ -103,7 +99,6 @@ async def predict_image(file: UploadFile = File(...)):
103
  img_normalized = img_resized.astype(np.float32) / 255.0
104
  img_batch = np.expand_dims(img_normalized, axis=0)
105
 
106
- # DIRECT CALLS FOR IMAGE AS WELL
107
  activations = feature_extractor(img_batch, training=False)
108
  predictions = model(img_batch, training=False)
109
  class_idx = np.argmax(predictions[0].numpy())
@@ -112,9 +107,11 @@ async def predict_image(file: UploadFile = File(...)):
112
  for i, activation in enumerate(activations):
113
  b64_image = generate_feature_grid(activation.numpy())
114
 
 
 
115
  layer_data.append({
116
  "layer_index": i + 1,
117
- "shape": activation.shape[1:],
118
  "texture_b64": f"data:image/png;base64,{b64_image}"
119
  })
120
 
@@ -123,49 +120,4 @@ async def predict_image(file: UploadFile = File(...)):
123
  "layers": layer_data
124
  }
125
 
126
- @app.websocket("/ws/predict-video")
127
- async def predict_video_stream(websocket: WebSocket):
128
- await websocket.accept()
129
- print("WebSocket Connected for Video Stream")
130
-
131
- try:
132
- while True:
133
- data = await websocket.receive_text()
134
-
135
- encoded_data = data.split(',')[1]
136
- nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8)
137
- img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
138
-
139
- if img is None:
140
- continue
141
-
142
- img_resized = cv2.resize(img, (227, 227))
143
- img_normalized = img_resized.astype(np.float32) / 255.0
144
- img_batch = np.expand_dims(img_normalized, axis=0)
145
-
146
- activations = feature_extractor(img_batch, training=False)
147
- predictions = model(img_batch, training=False)
148
- class_idx = np.argmax(predictions[0].numpy())
149
-
150
- layer_data = []
151
- for i, activation in enumerate(activations):
152
- b64_image = generate_feature_grid(activation.numpy())
153
- layer_data.append({
154
- "layer_index": i + 1,
155
- "shape": activation.shape[1:],
156
- "texture_b64": f"data:image/png;base64,{b64_image}"
157
- })
158
-
159
- await websocket.send_json({
160
- "prediction": CIFAR10_CLASSES[class_idx],
161
- "layers": layer_data
162
- })
163
-
164
- await asyncio.sleep(0.01)
165
-
166
- except WebSocketDisconnect:
167
- print("WebSocket Disconnected")
168
- except Exception as e:
169
- print(f"WebSocket Error: {e}")
170
-
171
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
1
+ from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.staticfiles import StaticFiles
4
  import tensorflow as tf
 
7
  import base64
8
  import math
9
  import os
 
10
 
11
  gpus = tf.config.list_physical_devices('GPU')
12
  if gpus:
 
55
  CIFAR10_CLASSES = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer',
56
  'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
57
 
 
58
  def generate_feature_grid(feature_map, max_features=64):
59
  if len(feature_map.shape) == 4:
60
  feature_map = feature_map[0]
 
82
  grid_image = np.uint8(grid_image)
83
 
84
  colored_grid = cv2.applyColorMap(grid_image, cv2.COLORMAP_VIRIDIS)
 
85
  b_channel, g_channel, r_channel = cv2.split(colored_grid)
86
  alpha_channel = grid_image
 
87
  transparent_grid = cv2.merge((b_channel, g_channel, r_channel, alpha_channel))
88
 
89
  _, buffer = cv2.imencode('.png', transparent_grid)
 
99
  img_normalized = img_resized.astype(np.float32) / 255.0
100
  img_batch = np.expand_dims(img_normalized, axis=0)
101
 
 
102
  activations = feature_extractor(img_batch, training=False)
103
  predictions = model(img_batch, training=False)
104
  class_idx = np.argmax(predictions[0].numpy())
 
107
  for i, activation in enumerate(activations):
108
  b64_image = generate_feature_grid(activation.numpy())
109
 
110
+ clean_shape = [int(dim) for dim in activation.shape[1:]]
111
+
112
  layer_data.append({
113
  "layer_index": i + 1,
114
+ "shape": clean_shape,
115
  "texture_b64": f"data:image/png;base64,{b64_image}"
116
  })
117
 
 
120
  "layers": layer_data
121
  }
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  app.mount("/", StaticFiles(directory="static", html=True), name="static")