Vishalpainjane commited on
Commit
8a01471
·
1 Parent(s): cd12c95

added files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +54 -0
  2. .gitignore +17 -22
  3. Dockerfile +39 -0
  4. LICENSE +21 -0
  5. backend/.dockerignore +32 -0
  6. backend/Dockerfile +40 -0
  7. backend/README.md +185 -0
  8. backend/app/__init__.py +1 -0
  9. backend/app/database.py +163 -0
  10. backend/app/main.py +1082 -0
  11. backend/app/model_analyzer.py +714 -0
  12. backend/requirements.txt +10 -0
  13. backend/start.bat +23 -0
  14. backend/start.sh +23 -0
  15. docker-compose.yml +60 -0
  16. docker-start.bat +43 -0
  17. docker-start.sh +38 -0
  18. exporters/python/README.md +138 -0
  19. exporters/python/nn3d_exporter/__init__.py +28 -0
  20. exporters/python/nn3d_exporter/onnx_exporter.py +371 -0
  21. exporters/python/nn3d_exporter/pytorch_exporter.py +434 -0
  22. exporters/python/nn3d_exporter/schema.py +316 -0
  23. exporters/python/pyproject.toml +61 -0
  24. files_to_commit.txt +0 -0
  25. index.html +33 -0
  26. nginx.conf +47 -0
  27. package-lock.json +0 -0
  28. package.json +48 -34
  29. public/favicon.ico +0 -0
  30. public/favicon.svg +48 -0
  31. public/index.html +0 -43
  32. public/logo192.png +0 -0
  33. public/logo512.png +0 -0
  34. public/manifest.json +0 -25
  35. public/robots.txt +0 -3
  36. samples/cnn_resnet.nn3d +327 -0
  37. samples/simple_mlp.nn3d +137 -0
  38. samples/transformer_encoder.nn3d +220 -0
  39. src/App.css +179 -22
  40. src/App.js +0 -25
  41. src/App.test.js +0 -8
  42. src/App.tsx +124 -0
  43. src/components/Scene.tsx +145 -0
  44. src/components/controls/CameraControls.tsx +117 -0
  45. src/components/controls/Interaction.tsx +138 -0
  46. src/components/controls/index.ts +2 -0
  47. src/components/edges/EdgeConnections.tsx +89 -0
  48. src/components/edges/EdgeGeometry.tsx +206 -0
  49. src/components/edges/NeuralConnection.tsx +359 -0
  50. src/components/edges/index.ts +2 -0
.dockerignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dependencies
2
+ node_modules/
3
+ backend/venv/
4
+ backend/__pycache__/
5
+ backend/app/__pycache__/
6
+
7
+ # Build outputs
8
+ dist/
9
+ build/
10
+ *.egg-info/
11
+
12
+ # IDE
13
+ .vscode/
14
+ .idea/
15
+ *.swp
16
+ *.swo
17
+
18
+ # Git
19
+ .git/
20
+ .gitignore
21
+
22
+ # Environment
23
+ .env
24
+ .env.local
25
+ .env.*.local
26
+ *.env
27
+
28
+ # Logs
29
+ *.log
30
+ npm-debug.log*
31
+
32
+ # OS
33
+ .DS_Store
34
+ Thumbs.db
35
+
36
+ # Test files
37
+ coverage/
38
+ .nyc_output/
39
+
40
+ # Docker
41
+ Dockerfile*
42
+ docker-compose*.yml
43
+ .dockerignore
44
+
45
+ # Database (will be created in container)
46
+ backend/models.db
47
+ *.db
48
+
49
+ # Documentation
50
+ *.md
51
+ LICENSE
52
+
53
+ # Samples (mounted separately)
54
+ # samples/
.gitignore CHANGED
@@ -1,23 +1,18 @@
1
- # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2
-
3
- # dependencies
4
- /node_modules
5
- /.pnp
6
- .pnp.js
7
-
8
- # testing
9
- /coverage
10
-
11
- # production
12
- /build
13
-
14
- # misc
15
  .DS_Store
16
- .env.local
17
- .env.development.local
18
- .env.test.local
19
- .env.production.local
20
-
21
- npm-debug.log*
22
- yarn-debug.log*
23
- yarn-error.log*
 
 
 
 
 
 
1
+ # Ignore Python virtual environments
2
+ backend/venv/
3
+ node_modules/
4
+ dist/
 
 
 
 
 
 
 
 
 
 
5
  .DS_Store
6
+ *.local
7
+ *.log
8
+ .env
9
+ .env.*
10
+ !.env.example
11
+ coverage/
12
+ .nyc_output/
13
+ *.egg-info/
14
+ __pycache__/
15
+ *.pyc
16
+ .pytest_cache/
17
+ build/
18
+ *.nn3d.bak
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Frontend Dockerfile - React/Vite NN3D Visualizer
2
+ FROM node:20-alpine AS builder
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy package files
8
+ COPY package.json package-lock.json ./
9
+
10
+ # Install dependencies
11
+ RUN npm ci
12
+
13
+ # Copy source code
14
+ COPY . .
15
+
16
+ # Set environment variable for API URL (nginx proxy)
17
+ ENV VITE_API_URL=/api
18
+
19
+ # Build the application
20
+ RUN npm run build
21
+
22
+ # Production stage - serve with nginx
23
+ FROM nginx:alpine
24
+
25
+ # Copy custom nginx config
26
+ COPY nginx.conf /etc/nginx/conf.d/default.conf
27
+
28
+ # Copy built assets from builder
29
+ COPY --from=builder /app/dist /usr/share/nginx/html
30
+
31
+ # Expose port
32
+ EXPOSE 80
33
+
34
+ # Health check
35
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
36
+ CMD wget --no-verbose --tries=1 --spider http://localhost:80/ || exit 1
37
+
38
+ # Start nginx
39
+ CMD ["nginx", "-g", "daemon off;"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 3D Deep Learning Model Visualizer
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
backend/.dockerignore ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual environment
2
+ venv/
3
+ .venv/
4
+ env/
5
+
6
+ # Python cache
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ *.so
11
+
12
+ # Database (will be created fresh)
13
+ models.db
14
+ *.db
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+
20
+ # Logs
21
+ *.log
22
+
23
+ # Environment files
24
+ .env
25
+ .env.*
26
+
27
+ # Documentation
28
+ *.md
29
+
30
+ # Scripts (not needed in container)
31
+ start.bat
32
+ start.sh
backend/Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend Dockerfile - FastAPI Model Analyzer
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies for PyTorch and ONNX
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential \
10
+ libgomp1 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first for better caching
14
+ COPY requirements.txt .
15
+
16
+ # Install Python dependencies
17
+ # Using CPU-only PyTorch for smaller image size
18
+ RUN pip install --no-cache-dir --upgrade pip && \
19
+ pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
20
+ pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy application code
23
+ COPY app/ ./app/
24
+
25
+ # Create directory for database
26
+ RUN mkdir -p /app/data
27
+
28
+ # Environment variables
29
+ ENV PYTHONUNBUFFERED=1
30
+ ENV PYTHONDONTWRITEBYTECODE=1
31
+
32
+ # Expose port
33
+ EXPOSE 8000
34
+
35
+ # Health check
36
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
37
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
38
+
39
+ # Run the application
40
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
backend/README.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NN3D Visualizer Backend
2
+
3
+ Python microservice for analyzing neural network model architectures.
4
+
5
+ ## Features
6
+
7
+ - **PyTorch Model Analysis**: Extracts architecture information from `.pt`, `.pth`, `.ckpt`, `.bin` files
8
+ - **ONNX Model Analysis**: Parses ONNX model graphs
9
+ - **Shape Inference**: Traces model execution to capture input/output shapes
10
+ - **Layer Type Detection**: Identifies layer types from weight names and shapes
11
+
12
+ ## Requirements
13
+
14
+ - Python 3.9+
15
+ - PyTorch 2.0+
16
+
17
+ ## Quick Start
18
+
19
+ ### Windows
20
+
21
+ ```batch
22
+ cd backend
23
+ start.bat
24
+ ```
25
+
26
+ ### Linux/Mac
27
+
28
+ ```bash
29
+ cd backend
30
+ chmod +x start.sh
31
+ ./start.sh
32
+ ```
33
+
34
+ ### Manual Setup
35
+
36
+ ```bash
37
+ cd backend
38
+ python -m venv venv
39
+ source venv/bin/activate # On Windows: venv\Scripts\activate
40
+ pip install -r requirements.txt
41
+ python -m uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
42
+ ```
43
+
44
+ ## API Endpoints
45
+
46
+ ### Health Check
47
+
48
+ ```
49
+ GET /health
50
+ ```
51
+
52
+ Returns server status and PyTorch version.
53
+
54
+ ### Analyze PyTorch Model
55
+
56
+ ```
57
+ POST /analyze
58
+ Content-Type: multipart/form-data
59
+
60
+ file: <model_file>
61
+ input_shape: (optional) comma-separated integers, e.g., "1,3,224,224"
62
+ ```
63
+
64
+ ### Analyze ONNX Model
65
+
66
+ ```
67
+ POST /analyze/onnx
68
+ Content-Type: multipart/form-data
69
+
70
+ file: <onnx_file>
71
+ ```
72
+
73
+ ## Response Format
74
+
75
+ ```json
76
+ {
77
+ "success": true,
78
+ "model_type": "full_model|state_dict|torchscript|checkpoint",
79
+ "architecture": {
80
+ "name": "model_name",
81
+ "framework": "pytorch",
82
+ "totalParameters": 1000000,
83
+ "trainableParameters": 1000000,
84
+ "inputShape": [1, 3, 224, 224],
85
+ "outputShape": [1, 1000],
86
+ "layers": [
87
+ {
88
+ "id": "layer_0",
89
+ "name": "conv1",
90
+ "type": "Conv2d",
91
+ "category": "convolution",
92
+ "inputShape": [1, 3, 224, 224],
93
+ "outputShape": [1, 64, 112, 112],
94
+ "params": {
95
+ "in_channels": 3,
96
+ "out_channels": 64,
97
+ "kernel_size": [7, 7],
98
+ "stride": [2, 2],
99
+ "padding": [3, 3]
100
+ },
101
+ "numParameters": 9408,
102
+ "trainable": true
103
+ }
104
+ ],
105
+ "connections": [
106
+ {
107
+ "source": "layer_0",
108
+ "target": "layer_1",
109
+ "tensorShape": [1, 64, 112, 112]
110
+ }
111
+ ]
112
+ },
113
+ "message": "Successfully analyzed model"
114
+ }
115
+ ```
116
+
117
+ ## Integration with Frontend
118
+
119
+ The frontend automatically detects if the backend is available:
120
+
121
+ 1. Start the backend server (port 8000)
122
+ 2. Start the frontend dev server (port 3000)
123
+ 3. Drop a PyTorch model file - it will use the backend for analysis
124
+
125
+ If the backend is unavailable, the frontend falls back to JavaScript-based parsing.
126
+
127
+ ## Supported Layer Types
128
+
129
+ ### Convolution
130
+
131
+ - Conv1d, Conv2d, Conv3d
132
+ - ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
133
+
134
+ ### Pooling
135
+
136
+ - MaxPool1d/2d/3d, AvgPool1d/2d/3d
137
+ - AdaptiveAvgPool1d/2d/3d, AdaptiveMaxPool1d/2d/3d
138
+
139
+ ### Linear
140
+
141
+ - Linear, LazyLinear, Bilinear
142
+
143
+ ### Normalization
144
+
145
+ - BatchNorm1d/2d/3d
146
+ - LayerNorm, GroupNorm, InstanceNorm
147
+
148
+ ### Activation
149
+
150
+ - ReLU, LeakyReLU, PReLU, ELU, SELU
151
+ - GELU, Sigmoid, Tanh, Softmax, SiLU, Mish
152
+
153
+ ### Recurrent
154
+
155
+ - RNN, LSTM, GRU
156
+
157
+ ### Attention
158
+
159
+ - MultiheadAttention, Transformer, TransformerEncoder/Decoder
160
+
161
+ ### Embedding
162
+
163
+ - Embedding, EmbeddingBag
164
+
165
+ ### Regularization
166
+
167
+ - Dropout, Dropout2d/3d, AlphaDropout
168
+
169
+ ## Architecture
170
+
171
+ ```
172
+ backend/
173
+ ├── app/
174
+ │ ├── __init__.py
175
+ │ ├── main.py # FastAPI application
176
+ │ └── model_analyzer.py # PyTorch model analysis
177
+ ├── requirements.txt
178
+ ├── start.bat # Windows startup script
179
+ ├── start.sh # Linux/Mac startup script
180
+ └── README.md
181
+ ```
182
+
183
+ ## Development
184
+
185
+ API documentation is available at `http://localhost:8000/docs` when the server is running.
backend/app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # NN3D Visualizer Backend
backend/app/database.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database module for storing processed model architectures.
3
+ Uses SQLite for simple, file-based persistence.
4
+ """
5
+
6
+ import json
7
+ import os
8
+ import sqlite3
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import List, Optional
12
+ from contextlib import contextmanager
13
+
14
+ # Database file path - use environment variable for Docker, fallback to local
15
+ DB_PATH = Path(os.environ.get("DATABASE_PATH", Path(__file__).parent.parent / "models.db"))
16
+
17
+
18
+ def get_connection():
19
+ """Get a database connection."""
20
+ conn = sqlite3.connect(str(DB_PATH))
21
+ conn.row_factory = sqlite3.Row
22
+ return conn
23
+
24
+
25
+ @contextmanager
26
+ def get_db():
27
+ """Context manager for database connections."""
28
+ conn = get_connection()
29
+ try:
30
+ yield conn
31
+ conn.commit()
32
+ except Exception:
33
+ conn.rollback()
34
+ raise
35
+ finally:
36
+ conn.close()
37
+
38
+
39
+ def init_db():
40
+ """Initialize the database tables."""
41
+ with get_db() as conn:
42
+ conn.execute("""
43
+ CREATE TABLE IF NOT EXISTS saved_models (
44
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
45
+ name TEXT NOT NULL,
46
+ framework TEXT NOT NULL,
47
+ total_parameters INTEGER DEFAULT 0,
48
+ layer_count INTEGER DEFAULT 0,
49
+ architecture_json TEXT NOT NULL,
50
+ thumbnail TEXT,
51
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
52
+ file_hash TEXT UNIQUE
53
+ )
54
+ """)
55
+ conn.execute("""
56
+ CREATE INDEX IF NOT EXISTS idx_name ON saved_models(name)
57
+ """)
58
+ conn.execute("""
59
+ CREATE INDEX IF NOT EXISTS idx_created_at ON saved_models(created_at)
60
+ """)
61
+
62
+
63
+ def save_model(
64
+ name: str,
65
+ framework: str,
66
+ total_parameters: int,
67
+ layer_count: int,
68
+ architecture: dict,
69
+ file_hash: Optional[str] = None,
70
+ thumbnail: Optional[str] = None
71
+ ) -> int:
72
+ """
73
+ Save a model architecture to the database.
74
+ Returns the saved model ID.
75
+ """
76
+ architecture_json = json.dumps(architecture)
77
+
78
+ with get_db() as conn:
79
+ # Check if model with same hash already exists
80
+ if file_hash:
81
+ existing = conn.execute(
82
+ "SELECT id FROM saved_models WHERE file_hash = ?",
83
+ (file_hash,)
84
+ ).fetchone()
85
+
86
+ if existing:
87
+ # Update existing entry
88
+ conn.execute("""
89
+ UPDATE saved_models
90
+ SET name = ?, framework = ?, total_parameters = ?,
91
+ layer_count = ?, architecture_json = ?,
92
+ thumbnail = ?, created_at = CURRENT_TIMESTAMP
93
+ WHERE file_hash = ?
94
+ """, (name, framework, total_parameters, layer_count,
95
+ architecture_json, thumbnail, file_hash))
96
+ return existing['id']
97
+
98
+ # Insert new entry
99
+ cursor = conn.execute("""
100
+ INSERT INTO saved_models
101
+ (name, framework, total_parameters, layer_count, architecture_json, file_hash, thumbnail)
102
+ VALUES (?, ?, ?, ?, ?, ?, ?)
103
+ """, (name, framework, total_parameters, layer_count,
104
+ architecture_json, file_hash, thumbnail))
105
+
106
+ return cursor.lastrowid
107
+
108
+
109
+ def get_saved_models() -> List[dict]:
110
+ """Get all saved models (metadata only, not full architecture)."""
111
+ with get_db() as conn:
112
+ rows = conn.execute("""
113
+ SELECT id, name, framework, total_parameters, layer_count,
114
+ thumbnail, created_at
115
+ FROM saved_models
116
+ ORDER BY created_at DESC
117
+ """).fetchall()
118
+
119
+ return [dict(row) for row in rows]
120
+
121
+
122
+ def get_model_by_id(model_id: int) -> Optional[dict]:
123
+ """Get a specific model with full architecture."""
124
+ with get_db() as conn:
125
+ row = conn.execute("""
126
+ SELECT id, name, framework, total_parameters, layer_count,
127
+ architecture_json, thumbnail, created_at
128
+ FROM saved_models
129
+ WHERE id = ?
130
+ """, (model_id,)).fetchone()
131
+
132
+ if row:
133
+ result = dict(row)
134
+ result['architecture'] = json.loads(result['architecture_json'])
135
+ del result['architecture_json']
136
+ return result
137
+
138
+ return None
139
+
140
+
141
+ def delete_model(model_id: int) -> bool:
142
+ """Delete a model by ID. Returns True if deleted."""
143
+ with get_db() as conn:
144
+ cursor = conn.execute(
145
+ "DELETE FROM saved_models WHERE id = ?",
146
+ (model_id,)
147
+ )
148
+ return cursor.rowcount > 0
149
+
150
+
151
+ def model_exists_by_hash(file_hash: str) -> Optional[int]:
152
+ """Check if a model with the given hash exists. Returns ID if exists."""
153
+ with get_db() as conn:
154
+ row = conn.execute(
155
+ "SELECT id FROM saved_models WHERE file_hash = ?",
156
+ (file_hash,)
157
+ ).fetchone()
158
+
159
+ return row['id'] if row else None
160
+
161
+
162
+ # Initialize database on module load
163
+ init_db()
backend/app/main.py ADDED
@@ -0,0 +1,1082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NN3D Visualizer Backend API
3
+ FastAPI server for analyzing neural network models.
4
+ """
5
+
6
+ import os
7
+ import hashlib
8
+ import tempfile
9
+ import traceback
10
+ from typing import Optional, List
11
+ from pathlib import Path
12
+
13
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Query
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import JSONResponse
16
+ from pydantic import BaseModel
17
+
18
+ from .model_analyzer import (
19
+ load_pytorch_model,
20
+ analyze_model_structure,
21
+ analyze_state_dict,
22
+ trace_model_shapes,
23
+ architecture_to_dict
24
+ )
25
+
26
+ from . import database as db
27
+
28
+ import torch
29
+
30
+
31
+ app = FastAPI(
32
+ title="NN3D Model Analyzer",
33
+ description="Backend service for analyzing neural network model architectures",
34
+ version="1.0.0"
35
+ )
36
+
37
+ # Enable CORS for frontend
38
+ app.add_middleware(
39
+ CORSMiddleware,
40
+ allow_origins=["*"], # In production, specify exact origins
41
+ allow_credentials=True,
42
+ allow_methods=["*"],
43
+ allow_headers=["*"],
44
+ )
45
+
46
+
47
+ class AnalysisRequest(BaseModel):
48
+ """Request model for analysis with sample input shape."""
49
+ input_shape: Optional[List[int]] = None
50
+
51
+
52
+ class AnalysisResponse(BaseModel):
53
+ """Response model for analysis results."""
54
+ success: bool
55
+ model_type: str
56
+ architecture: dict
57
+ message: Optional[str] = None
58
+
59
+
60
+ class HealthResponse(BaseModel):
61
+ """Health check response."""
62
+ status: str
63
+ pytorch_version: str
64
+ cuda_available: bool
65
+
66
+
67
+ @app.get("/health", response_model=HealthResponse)
68
+ async def health_check():
69
+ """Check server health and PyTorch availability."""
70
+ return HealthResponse(
71
+ status="healthy",
72
+ pytorch_version=torch.__version__,
73
+ cuda_available=torch.cuda.is_available()
74
+ )
75
+
76
+
77
+ @app.post("/analyze", response_model=AnalysisResponse)
78
+ async def analyze_model(
79
+ file: UploadFile = File(...),
80
+ input_shape: Optional[str] = Query(None, description="Input shape as comma-separated ints, e.g., '1,3,224,224'")
81
+ ):
82
+ """
83
+ Analyze a PyTorch model file and extract architecture information.
84
+
85
+ Supports:
86
+ - Full model files (.pt, .pth)
87
+ - State dict checkpoints
88
+ - TorchScript models
89
+ - Training checkpoints with model_state_dict
90
+ """
91
+ # Validate file extension
92
+ allowed_extensions = {'.pt', '.pth', '.ckpt', '.bin', '.model'}
93
+ file_ext = Path(file.filename).suffix.lower()
94
+
95
+ if file_ext not in allowed_extensions:
96
+ raise HTTPException(
97
+ status_code=400,
98
+ detail=f"Unsupported file format. Allowed: {', '.join(allowed_extensions)}"
99
+ )
100
+
101
+ # Parse input shape if provided
102
+ sample_shape = None
103
+ if input_shape:
104
+ try:
105
+ sample_shape = [int(x.strip()) for x in input_shape.split(',')]
106
+ except ValueError:
107
+ raise HTTPException(
108
+ status_code=400,
109
+ detail="Invalid input_shape format. Use comma-separated integers, e.g., '1,3,224,224'"
110
+ )
111
+
112
+ # Save uploaded file temporarily
113
+ temp_path = None
114
+ try:
115
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
116
+ content = await file.read()
117
+ temp_file.write(content)
118
+ temp_path = temp_file.name
119
+
120
+ # Load and analyze model
121
+ model, state_dict, model_type = load_pytorch_model(temp_path)
122
+
123
+ if model is not None:
124
+ # Full model available - analyze structure
125
+ model_name = Path(file.filename).stem
126
+ architecture = analyze_model_structure(model, model_name)
127
+
128
+ # Try to trace shapes if input shape provided
129
+ if sample_shape and model_type != 'torchscript':
130
+ try:
131
+ sample_input = torch.randn(*sample_shape)
132
+ architecture = trace_model_shapes(model, sample_input, architecture)
133
+ except Exception as e:
134
+ print(f"Shape tracing failed: {e}")
135
+
136
+ return AnalysisResponse(
137
+ success=True,
138
+ model_type=model_type,
139
+ architecture=architecture_to_dict(architecture),
140
+ message=f"Successfully analyzed {model_type} model"
141
+ )
142
+
143
+ elif state_dict is not None:
144
+ # Only state dict available - infer from weights
145
+ model_name = Path(file.filename).stem
146
+ architecture = analyze_state_dict(state_dict, model_name)
147
+
148
+ return AnalysisResponse(
149
+ success=True,
150
+ model_type='state_dict',
151
+ architecture=architecture_to_dict(architecture),
152
+ message="Analyzed from state dict. Layer types inferred from weight names/shapes."
153
+ )
154
+
155
+ else:
156
+ raise HTTPException(
157
+ status_code=400,
158
+ detail="Could not parse model file. Unknown format."
159
+ )
160
+
161
+ except HTTPException:
162
+ raise
163
+ except Exception as e:
164
+ traceback.print_exc()
165
+ raise HTTPException(
166
+ status_code=500,
167
+ detail=f"Analysis failed: {str(e)}"
168
+ )
169
+ finally:
170
+ # Cleanup temp file
171
+ if temp_path and os.path.exists(temp_path):
172
+ try:
173
+ os.unlink(temp_path)
174
+ except Exception:
175
+ pass
176
+
177
+
178
+ @app.post("/analyze/onnx")
179
+ async def analyze_onnx_model(file: UploadFile = File(...)):
180
+ """
181
+ Analyze an ONNX model file.
182
+ """
183
+ if not file.filename.lower().endswith('.onnx'):
184
+ raise HTTPException(status_code=400, detail="File must be an ONNX model (.onnx)")
185
+
186
+ temp_path = None
187
+ try:
188
+ import onnx
189
+
190
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.onnx') as temp_file:
191
+ content = await file.read()
192
+ temp_file.write(content)
193
+ temp_path = temp_file.name
194
+
195
+ model = onnx.load(temp_path)
196
+ graph = model.graph
197
+
198
+ layers = []
199
+ connections = []
200
+ layer_map = {}
201
+
202
+ # Process nodes
203
+ for i, node in enumerate(graph.node):
204
+ layer_id = f"layer_{i}"
205
+ layer_map[node.name if node.name else f"node_{i}"] = layer_id
206
+
207
+ # Map output names to layer ids
208
+ for output in node.output:
209
+ layer_map[output] = layer_id
210
+
211
+ # Extract attributes
212
+ params = {}
213
+ for attr in node.attribute:
214
+ if attr.type == onnx.AttributeProto.INT:
215
+ params[attr.name] = attr.i
216
+ elif attr.type == onnx.AttributeProto.INTS:
217
+ params[attr.name] = list(attr.ints)
218
+ elif attr.type == onnx.AttributeProto.FLOAT:
219
+ params[attr.name] = attr.f
220
+ elif attr.type == onnx.AttributeProto.STRING:
221
+ params[attr.name] = attr.s.decode('utf-8')
222
+
223
+ layers.append({
224
+ 'id': layer_id,
225
+ 'name': node.name if node.name else node.op_type,
226
+ 'type': node.op_type,
227
+ 'category': infer_onnx_category(node.op_type),
228
+ 'inputShape': None,
229
+ 'outputShape': None,
230
+ 'params': params,
231
+ 'numParameters': 0,
232
+ 'trainable': True
233
+ })
234
+
235
+ # Create connections from inputs
236
+ for input_name in node.input:
237
+ if input_name in layer_map:
238
+ source_id = layer_map[input_name]
239
+ if source_id != layer_id: # Avoid self-loops
240
+ connections.append({
241
+ 'source': source_id,
242
+ 'target': layer_id,
243
+ 'tensorShape': None
244
+ })
245
+
246
+ # Get input/output shapes from graph
247
+ input_shape = None
248
+ output_shape = None
249
+
250
+ if graph.input:
251
+ for inp in graph.input:
252
+ shape = []
253
+ if inp.type.tensor_type.shape.dim:
254
+ for dim in inp.type.tensor_type.shape.dim:
255
+ shape.append(dim.dim_value if dim.dim_value else -1)
256
+ if shape:
257
+ input_shape = shape
258
+ break
259
+
260
+ if graph.output:
261
+ for out in graph.output:
262
+ shape = []
263
+ if out.type.tensor_type.shape.dim:
264
+ for dim in out.type.tensor_type.shape.dim:
265
+ shape.append(dim.dim_value if dim.dim_value else -1)
266
+ if shape:
267
+ output_shape = shape
268
+ break
269
+
270
+ architecture = {
271
+ 'name': Path(file.filename).stem,
272
+ 'framework': 'onnx',
273
+ 'totalParameters': 0,
274
+ 'trainableParameters': 0,
275
+ 'inputShape': input_shape,
276
+ 'outputShape': output_shape,
277
+ 'layers': layers,
278
+ 'connections': connections
279
+ }
280
+
281
+ return AnalysisResponse(
282
+ success=True,
283
+ model_type='onnx',
284
+ architecture=architecture,
285
+ message="Successfully analyzed ONNX model"
286
+ )
287
+
288
+ except Exception as e:
289
+ traceback.print_exc()
290
+ raise HTTPException(status_code=500, detail=f"ONNX analysis failed: {str(e)}")
291
+ finally:
292
+ if temp_path and os.path.exists(temp_path):
293
+ try:
294
+ os.unlink(temp_path)
295
+ except Exception:
296
+ pass
297
+
298
+
299
+ def infer_onnx_category(op_type: str) -> str:
300
+ """Infer category from ONNX operator type."""
301
+ op_lower = op_type.lower()
302
+
303
+ if 'conv' in op_lower:
304
+ return 'convolution'
305
+ if 'pool' in op_lower:
306
+ return 'pooling'
307
+ if 'norm' in op_lower or 'batch' in op_lower:
308
+ return 'normalization'
309
+ if 'relu' in op_lower or 'sigmoid' in op_lower or 'tanh' in op_lower or 'softmax' in op_lower:
310
+ return 'activation'
311
+ if 'gemm' in op_lower or 'matmul' in op_lower or 'linear' in op_lower:
312
+ return 'linear'
313
+ if 'lstm' in op_lower or 'gru' in op_lower or 'rnn' in op_lower:
314
+ return 'recurrent'
315
+ if 'attention' in op_lower:
316
+ return 'attention'
317
+ if 'dropout' in op_lower:
318
+ return 'regularization'
319
+ if 'reshape' in op_lower or 'flatten' in op_lower or 'squeeze' in op_lower:
320
+ return 'reshape'
321
+ if 'add' in op_lower or 'mul' in op_lower or 'sub' in op_lower:
322
+ return 'arithmetic'
323
+ if 'concat' in op_lower or 'split' in op_lower:
324
+ return 'merge'
325
+
326
+ return 'other'
327
+
328
+
329
+ # Mapping of file extensions to supported frameworks
330
+ SUPPORTED_FORMATS = {
331
+ # PyTorch formats
332
+ '.pt': 'pytorch',
333
+ '.pth': 'pytorch',
334
+ '.ckpt': 'pytorch',
335
+ '.bin': 'pytorch',
336
+ '.model': 'pytorch',
337
+ # ONNX format
338
+ '.onnx': 'onnx',
339
+ # TensorFlow/Keras formats
340
+ '.h5': 'keras',
341
+ '.hdf5': 'keras',
342
+ '.keras': 'keras',
343
+ '.pb': 'tensorflow',
344
+ # SafeTensors format
345
+ '.safetensors': 'safetensors',
346
+ }
347
+
348
+
349
+ @app.post("/analyze/universal")
350
+ async def analyze_universal(
351
+ file: UploadFile = File(...),
352
+ input_shape: Optional[str] = Query(None, description="Input shape as comma-separated ints, e.g., '1,3,224,224'")
353
+ ):
354
+ """
355
+ Universal model analyzer - accepts any supported model format.
356
+
357
+ Supported formats:
358
+ - PyTorch: .pt, .pth, .ckpt, .bin, .model
359
+ - ONNX: .onnx
360
+ - Keras/TensorFlow: .h5, .hdf5, .keras, .pb
361
+ - SafeTensors: .safetensors
362
+
363
+ Returns detailed architecture information including:
364
+ - Layer types and names
365
+ - Input/output shapes for each layer
366
+ - Parameter counts
367
+ - Layer connections
368
+ - Model metadata
369
+ """
370
+ filename = file.filename or "unknown"
371
+ file_ext = Path(filename).suffix.lower()
372
+
373
+ # Check if format is supported
374
+ if file_ext not in SUPPORTED_FORMATS:
375
+ supported = ', '.join(sorted(SUPPORTED_FORMATS.keys()))
376
+ raise HTTPException(
377
+ status_code=400,
378
+ detail=f"Unsupported file format '{file_ext}'. Supported formats: {supported}"
379
+ )
380
+
381
+ framework = SUPPORTED_FORMATS[file_ext]
382
+
383
+ # Parse input shape if provided
384
+ sample_shape = None
385
+ if input_shape:
386
+ try:
387
+ sample_shape = [int(x.strip()) for x in input_shape.split(',')]
388
+ except ValueError:
389
+ raise HTTPException(
390
+ status_code=400,
391
+ detail="Invalid input_shape format. Use comma-separated integers, e.g., '1,3,224,224'"
392
+ )
393
+
394
+ temp_path = None
395
+ try:
396
+ # Save uploaded file temporarily
397
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
398
+ content = await file.read()
399
+ temp_file.write(content)
400
+ temp_path = temp_file.name
401
+
402
+ # Route to appropriate analyzer based on framework
403
+ if framework == 'pytorch':
404
+ return await _analyze_pytorch(temp_path, filename, sample_shape)
405
+ elif framework == 'onnx':
406
+ return await _analyze_onnx(temp_path, filename)
407
+ elif framework == 'keras':
408
+ return await _analyze_keras(temp_path, filename)
409
+ elif framework == 'tensorflow':
410
+ return await _analyze_tensorflow(temp_path, filename)
411
+ elif framework == 'safetensors':
412
+ return await _analyze_safetensors(temp_path, filename)
413
+ else:
414
+ raise HTTPException(status_code=400, detail=f"Framework '{framework}' not yet implemented")
415
+
416
+ except HTTPException:
417
+ raise
418
+ except Exception as e:
419
+ traceback.print_exc()
420
+ raise HTTPException(
421
+ status_code=500,
422
+ detail=f"Analysis failed: {str(e)}"
423
+ )
424
+ finally:
425
+ if temp_path and os.path.exists(temp_path):
426
+ try:
427
+ os.unlink(temp_path)
428
+ except Exception:
429
+ pass
430
+
431
+
432
+ async def _analyze_pytorch(file_path: str, filename: str, sample_shape: Optional[List[int]] = None) -> AnalysisResponse:
433
+ """Analyze PyTorch model."""
434
+ model, state_dict, model_type = load_pytorch_model(file_path)
435
+ model_name = Path(filename).stem
436
+
437
+ if model is not None:
438
+ architecture = analyze_model_structure(model, model_name)
439
+
440
+ # Try to trace shapes if input shape provided
441
+ if sample_shape and model_type != 'torchscript':
442
+ try:
443
+ sample_input = torch.randn(*sample_shape)
444
+ architecture = trace_model_shapes(model, sample_input, architecture)
445
+ except Exception as e:
446
+ print(f"Shape tracing failed: {e}")
447
+
448
+ return AnalysisResponse(
449
+ success=True,
450
+ model_type=model_type,
451
+ architecture=architecture_to_dict(architecture),
452
+ message=f"Successfully analyzed PyTorch {model_type}"
453
+ )
454
+
455
+ elif state_dict is not None:
456
+ architecture = analyze_state_dict(state_dict, model_name)
457
+ return AnalysisResponse(
458
+ success=True,
459
+ model_type='state_dict',
460
+ architecture=architecture_to_dict(architecture),
461
+ message="Analyzed from state dict. Layer types inferred from weight names/shapes."
462
+ )
463
+
464
+ raise HTTPException(status_code=400, detail="Could not parse PyTorch model file")
465
+
466
+
467
+ async def _analyze_onnx(file_path: str, filename: str) -> AnalysisResponse:
468
+ """Analyze ONNX model."""
469
+ try:
470
+ import onnx
471
+ except ImportError:
472
+ raise HTTPException(status_code=500, detail="ONNX library not installed. Install with: pip install onnx")
473
+
474
+ model = onnx.load(file_path)
475
+ graph = model.graph
476
+
477
+ layers = []
478
+ connections = []
479
+ layer_map = {}
480
+ total_params = 0
481
+
482
+ # Process initializers (weights) for parameter counts
483
+ weight_shapes = {}
484
+ for init in graph.initializer:
485
+ dims = list(init.dims)
486
+ weight_shapes[init.name] = dims
487
+ total_params += int(torch.prod(torch.tensor(dims)).item()) if dims else 0
488
+
489
+ # Process nodes
490
+ for i, node in enumerate(graph.node):
491
+ layer_id = f"layer_{i}"
492
+ layer_map[node.name if node.name else f"node_{i}"] = layer_id
493
+
494
+ for output in node.output:
495
+ layer_map[output] = layer_id
496
+
497
+ # Extract attributes
498
+ params = {}
499
+ for attr in node.attribute:
500
+ if attr.type == onnx.AttributeProto.INT:
501
+ params[attr.name] = attr.i
502
+ elif attr.type == onnx.AttributeProto.INTS:
503
+ params[attr.name] = list(attr.ints)
504
+ elif attr.type == onnx.AttributeProto.FLOAT:
505
+ params[attr.name] = round(attr.f, 6)
506
+ elif attr.type == onnx.AttributeProto.STRING:
507
+ params[attr.name] = attr.s.decode('utf-8')
508
+
509
+ # Count parameters for this layer
510
+ layer_params = 0
511
+ input_shapes = []
512
+ for inp_name in node.input:
513
+ if inp_name in weight_shapes:
514
+ layer_params += int(torch.prod(torch.tensor(weight_shapes[inp_name])).item())
515
+ input_shapes.append(weight_shapes[inp_name])
516
+
517
+ # Infer input/output shapes from value_info
518
+ input_shape = None
519
+ output_shape = None
520
+
521
+ layers.append({
522
+ 'id': layer_id,
523
+ 'name': node.name if node.name else f"{node.op_type}_{i}",
524
+ 'type': node.op_type,
525
+ 'category': infer_onnx_category(node.op_type),
526
+ 'inputShape': input_shape,
527
+ 'outputShape': output_shape,
528
+ 'params': params,
529
+ 'numParameters': layer_params,
530
+ 'trainable': layer_params > 0
531
+ })
532
+
533
+ # Create connections
534
+ for input_name in node.input:
535
+ if input_name in layer_map:
536
+ source_id = layer_map[input_name]
537
+ if source_id != layer_id:
538
+ connections.append({
539
+ 'source': source_id,
540
+ 'target': layer_id,
541
+ 'tensorShape': weight_shapes.get(input_name)
542
+ })
543
+
544
+ # Get model input/output shapes
545
+ input_shape = None
546
+ output_shape = None
547
+
548
+ if graph.input:
549
+ for inp in graph.input:
550
+ if inp.name not in weight_shapes: # Skip weight inputs
551
+ shape = []
552
+ if inp.type.tensor_type.shape.dim:
553
+ for dim in inp.type.tensor_type.shape.dim:
554
+ shape.append(dim.dim_value if dim.dim_value else -1)
555
+ if shape:
556
+ input_shape = shape
557
+ break
558
+
559
+ if graph.output:
560
+ for out in graph.output:
561
+ shape = []
562
+ if out.type.tensor_type.shape.dim:
563
+ for dim in out.type.tensor_type.shape.dim:
564
+ shape.append(dim.dim_value if dim.dim_value else -1)
565
+ if shape:
566
+ output_shape = shape
567
+ break
568
+
569
+ architecture = {
570
+ 'name': Path(filename).stem,
571
+ 'framework': 'onnx',
572
+ 'totalParameters': total_params,
573
+ 'trainableParameters': total_params,
574
+ 'inputShape': input_shape,
575
+ 'outputShape': output_shape,
576
+ 'layers': layers,
577
+ 'connections': connections
578
+ }
579
+
580
+ return AnalysisResponse(
581
+ success=True,
582
+ model_type='onnx',
583
+ architecture=architecture,
584
+ message=f"Successfully analyzed ONNX model with {len(layers)} layers"
585
+ )
586
+
587
+
588
+ async def _analyze_keras(file_path: str, filename: str) -> AnalysisResponse:
589
+ """Analyze Keras/HDF5 model."""
590
+ try:
591
+ import h5py
592
+ except ImportError:
593
+ raise HTTPException(status_code=500, detail="h5py not installed. Install with: pip install h5py")
594
+
595
+ layers = []
596
+ connections = []
597
+ total_params = 0
598
+
599
+ with h5py.File(file_path, 'r') as f:
600
+ # Check for Keras model structure
601
+ if 'model_config' in f.attrs:
602
+ import json
603
+ config = json.loads(f.attrs['model_config'])
604
+ model_name = config.get('config', {}).get('name', Path(filename).stem)
605
+
606
+ # Parse layers from config
607
+ layer_configs = config.get('config', {}).get('layers', [])
608
+
609
+ for i, layer_cfg in enumerate(layer_configs):
610
+ layer_id = f"layer_{i}"
611
+ layer_class = layer_cfg.get('class_name', 'Unknown')
612
+ layer_config = layer_cfg.get('config', {})
613
+
614
+ # Extract parameters
615
+ params = {}
616
+ param_keys = ['units', 'filters', 'kernel_size', 'strides', 'padding',
617
+ 'activation', 'use_bias', 'dropout', 'rate', 'axis',
618
+ 'epsilon', 'momentum', 'input_dim', 'output_dim']
619
+ for key in param_keys:
620
+ if key in layer_config:
621
+ params[key] = layer_config[key]
622
+
623
+ # Infer shapes from config
624
+ input_shape = None
625
+ output_shape = None
626
+ if 'batch_input_shape' in layer_config:
627
+ input_shape = list(layer_config['batch_input_shape'])
628
+
629
+ layers.append({
630
+ 'id': layer_id,
631
+ 'name': layer_config.get('name', f"{layer_class}_{i}"),
632
+ 'type': layer_class,
633
+ 'category': _infer_keras_category(layer_class),
634
+ 'inputShape': input_shape,
635
+ 'outputShape': output_shape,
636
+ 'params': params,
637
+ 'numParameters': 0,
638
+ 'trainable': layer_config.get('trainable', True)
639
+ })
640
+
641
+ # Create sequential connections
642
+ if i > 0:
643
+ connections.append({
644
+ 'source': f"layer_{i-1}",
645
+ 'target': layer_id,
646
+ 'tensorShape': None
647
+ })
648
+
649
+ # Count parameters from model_weights
650
+ if 'model_weights' in f:
651
+ def count_h5_params(group):
652
+ count = 0
653
+ for key in group.keys():
654
+ item = group[key]
655
+ if isinstance(item, h5py.Dataset):
656
+ count += item.size
657
+ elif isinstance(item, h5py.Group):
658
+ count += count_h5_params(item)
659
+ return count
660
+ total_params = count_h5_params(f['model_weights'])
661
+
662
+ architecture = {
663
+ 'name': Path(filename).stem,
664
+ 'framework': 'keras',
665
+ 'totalParameters': total_params,
666
+ 'trainableParameters': total_params,
667
+ 'inputShape': layers[0].get('inputShape') if layers else None,
668
+ 'outputShape': None,
669
+ 'layers': layers,
670
+ 'connections': connections
671
+ }
672
+
673
+ return AnalysisResponse(
674
+ success=True,
675
+ model_type='keras',
676
+ architecture=architecture,
677
+ message=f"Successfully analyzed Keras model with {len(layers)} layers"
678
+ )
679
+
680
+
681
+ async def _analyze_tensorflow(file_path: str, filename: str) -> AnalysisResponse:
682
+ """Analyze TensorFlow SavedModel or frozen graph."""
683
+ try:
684
+ import tensorflow as tf
685
+ except ImportError:
686
+ # Fallback: parse .pb file manually
687
+ return await _analyze_pb_file(file_path, filename)
688
+
689
+ layers = []
690
+ connections = []
691
+
692
+ # Try loading as SavedModel or GraphDef
693
+ try:
694
+ graph_def = tf.compat.v1.GraphDef()
695
+ with open(file_path, 'rb') as f:
696
+ graph_def.ParseFromString(f.read())
697
+
698
+ node_map = {}
699
+ for i, node in enumerate(graph_def.node):
700
+ layer_id = f"layer_{i}"
701
+ node_map[node.name] = layer_id
702
+
703
+ # Extract attributes
704
+ params = {}
705
+ for key, attr in node.attr.items():
706
+ if attr.HasField('i'):
707
+ params[key] = attr.i
708
+ elif attr.HasField('f'):
709
+ params[key] = round(attr.f, 6)
710
+ elif attr.HasField('s'):
711
+ params[key] = attr.s.decode('utf-8')
712
+ elif attr.HasField('shape'):
713
+ dims = [d.size for d in attr.shape.dim]
714
+ params[key] = dims
715
+
716
+ layers.append({
717
+ 'id': layer_id,
718
+ 'name': node.name,
719
+ 'type': node.op,
720
+ 'category': _infer_tf_category(node.op),
721
+ 'inputShape': None,
722
+ 'outputShape': None,
723
+ 'params': params,
724
+ 'numParameters': 0,
725
+ 'trainable': True
726
+ })
727
+
728
+ # Create connections from inputs
729
+ for inp in node.input:
730
+ inp_name = inp.lstrip('^').split(':')[0]
731
+ if inp_name in node_map:
732
+ connections.append({
733
+ 'source': node_map[inp_name],
734
+ 'target': layer_id,
735
+ 'tensorShape': None
736
+ })
737
+
738
+ architecture = {
739
+ 'name': Path(filename).stem,
740
+ 'framework': 'tensorflow',
741
+ 'totalParameters': 0,
742
+ 'trainableParameters': 0,
743
+ 'inputShape': None,
744
+ 'outputShape': None,
745
+ 'layers': layers,
746
+ 'connections': connections
747
+ }
748
+
749
+ return AnalysisResponse(
750
+ success=True,
751
+ model_type='tensorflow_pb',
752
+ architecture=architecture,
753
+ message=f"Successfully analyzed TensorFlow graph with {len(layers)} nodes"
754
+ )
755
+ except Exception as e:
756
+ raise HTTPException(status_code=400, detail=f"Failed to parse TensorFlow model: {str(e)}")
757
+
758
+
759
+ async def _analyze_pb_file(file_path: str, filename: str) -> AnalysisResponse:
760
+ """Fallback .pb file analyzer without TensorFlow."""
761
+ raise HTTPException(
762
+ status_code=501,
763
+ detail="TensorFlow .pb analysis requires TensorFlow. Install with: pip install tensorflow"
764
+ )
765
+
766
+
767
+ async def _analyze_safetensors(file_path: str, filename: str) -> AnalysisResponse:
768
+ """Analyze SafeTensors file."""
769
+ try:
770
+ from safetensors import safe_open
771
+ except ImportError:
772
+ raise HTTPException(
773
+ status_code=500,
774
+ detail="safetensors not installed. Install with: pip install safetensors"
775
+ )
776
+
777
+ layers = []
778
+ connections = []
779
+ total_params = 0
780
+ layer_groups = {}
781
+
782
+ with safe_open(file_path, framework="pt") as f:
783
+ tensor_names = list(f.keys())
784
+
785
+ # Group tensors by layer
786
+ for name in tensor_names:
787
+ tensor = f.get_tensor(name)
788
+ shape = list(tensor.shape)
789
+ num_params = int(tensor.numel())
790
+ total_params += num_params
791
+
792
+ # Extract layer name from tensor name (e.g., "encoder.layer.0.attention.weight")
793
+ parts = name.rsplit('.', 1)
794
+ layer_name = parts[0] if len(parts) > 1 else name
795
+ tensor_type = parts[1] if len(parts) > 1 else 'weight'
796
+
797
+ if layer_name not in layer_groups:
798
+ layer_groups[layer_name] = {
799
+ 'tensors': {},
800
+ 'params': {},
801
+ 'total_params': 0
802
+ }
803
+
804
+ layer_groups[layer_name]['tensors'][tensor_type] = shape
805
+ layer_groups[layer_name]['total_params'] += num_params
806
+
807
+ # Infer params from shapes
808
+ if tensor_type == 'weight' and len(shape) >= 2:
809
+ layer_groups[layer_name]['params']['out_features'] = shape[0]
810
+ layer_groups[layer_name]['params']['in_features'] = shape[1]
811
+
812
+ # Convert groups to layers
813
+ prev_layer_id = None
814
+ for i, (layer_name, group) in enumerate(layer_groups.items()):
815
+ layer_id = f"layer_{i}"
816
+
817
+ # Infer layer type from name and shapes
818
+ layer_type = _infer_layer_type_from_name(layer_name, group['tensors'])
819
+
820
+ # Infer shapes
821
+ input_shape = None
822
+ output_shape = None
823
+ if 'in_features' in group['params']:
824
+ input_shape = [-1, group['params']['in_features']]
825
+ if 'out_features' in group['params']:
826
+ output_shape = [-1, group['params']['out_features']]
827
+
828
+ layers.append({
829
+ 'id': layer_id,
830
+ 'name': layer_name,
831
+ 'type': layer_type,
832
+ 'category': _infer_category_from_type(layer_type),
833
+ 'inputShape': input_shape,
834
+ 'outputShape': output_shape,
835
+ 'params': group['params'],
836
+ 'numParameters': group['total_params'],
837
+ 'trainable': True
838
+ })
839
+
840
+ # Create sequential connections
841
+ if prev_layer_id:
842
+ connections.append({
843
+ 'source': prev_layer_id,
844
+ 'target': layer_id,
845
+ 'tensorShape': None
846
+ })
847
+ prev_layer_id = layer_id
848
+
849
+ architecture = {
850
+ 'name': Path(filename).stem,
851
+ 'framework': 'safetensors',
852
+ 'totalParameters': total_params,
853
+ 'trainableParameters': total_params,
854
+ 'inputShape': layers[0].get('inputShape') if layers else None,
855
+ 'outputShape': layers[-1].get('outputShape') if layers else None,
856
+ 'layers': layers,
857
+ 'connections': connections
858
+ }
859
+
860
+ return AnalysisResponse(
861
+ success=True,
862
+ model_type='safetensors',
863
+ architecture=architecture,
864
+ message=f"Successfully analyzed SafeTensors model with {len(layers)} layers, {total_params:,} parameters"
865
+ )
866
+
867
+
868
+ def _infer_keras_category(class_name: str) -> str:
869
+ """Infer category from Keras layer class name."""
870
+ name = class_name.lower()
871
+ if 'conv' in name:
872
+ return 'convolution'
873
+ if 'pool' in name:
874
+ return 'pooling'
875
+ if 'dense' in name or 'linear' in name:
876
+ return 'linear'
877
+ if 'norm' in name or 'batch' in name:
878
+ return 'normalization'
879
+ if 'dropout' in name:
880
+ return 'regularization'
881
+ if 'lstm' in name or 'gru' in name or 'rnn' in name:
882
+ return 'recurrent'
883
+ if 'attention' in name:
884
+ return 'attention'
885
+ if 'activation' in name or 'relu' in name or 'sigmoid' in name:
886
+ return 'activation'
887
+ if 'embed' in name:
888
+ return 'embedding'
889
+ if 'flatten' in name or 'reshape' in name:
890
+ return 'reshape'
891
+ if 'input' in name:
892
+ return 'input'
893
+ return 'other'
894
+
895
+
896
+ def _infer_tf_category(op_type: str) -> str:
897
+ """Infer category from TensorFlow op type."""
898
+ op = op_type.lower()
899
+ if 'conv' in op:
900
+ return 'convolution'
901
+ if 'pool' in op:
902
+ return 'pooling'
903
+ if 'matmul' in op or 'dense' in op:
904
+ return 'linear'
905
+ if 'norm' in op or 'batch' in op:
906
+ return 'normalization'
907
+ if 'relu' in op or 'sigmoid' in op or 'tanh' in op or 'softmax' in op:
908
+ return 'activation'
909
+ if 'placeholder' in op or 'input' in op:
910
+ return 'input'
911
+ if 'variable' in op or 'const' in op:
912
+ return 'parameter'
913
+ return 'other'
914
+
915
+
916
+ def _infer_layer_type_from_name(name: str, tensors: dict) -> str:
917
+ """Infer layer type from name and tensor shapes."""
918
+ name_lower = name.lower()
919
+
920
+ if 'attention' in name_lower or 'attn' in name_lower:
921
+ return 'MultiHeadAttention'
922
+ if 'linear' in name_lower or 'dense' in name_lower or 'fc' in name_lower:
923
+ return 'Linear'
924
+ if 'conv' in name_lower:
925
+ if 'weight' in tensors and len(tensors['weight']) == 4:
926
+ return 'Conv2d'
927
+ return 'Conv1d'
928
+ if 'norm' in name_lower:
929
+ if 'layer' in name_lower:
930
+ return 'LayerNorm'
931
+ return 'BatchNorm'
932
+ if 'embed' in name_lower:
933
+ return 'Embedding'
934
+ if 'lstm' in name_lower:
935
+ return 'LSTM'
936
+ if 'gru' in name_lower:
937
+ return 'GRU'
938
+ if 'query' in name_lower or 'key' in name_lower or 'value' in name_lower:
939
+ return 'Linear'
940
+
941
+ # Infer from tensor shapes
942
+ if 'weight' in tensors:
943
+ shape = tensors['weight']
944
+ if len(shape) == 2:
945
+ return 'Linear'
946
+ if len(shape) == 4:
947
+ return 'Conv2d'
948
+ if len(shape) == 1:
949
+ return 'LayerNorm'
950
+
951
+ return 'Unknown'
952
+
953
+
954
+ def _infer_category_from_type(layer_type: str) -> str:
955
+ """Infer category from layer type."""
956
+ type_lower = layer_type.lower()
957
+ if 'conv' in type_lower:
958
+ return 'convolution'
959
+ if 'linear' in type_lower:
960
+ return 'linear'
961
+ if 'norm' in type_lower:
962
+ return 'normalization'
963
+ if 'attention' in type_lower:
964
+ return 'attention'
965
+ if 'embed' in type_lower:
966
+ return 'embedding'
967
+ if 'lstm' in type_lower or 'gru' in type_lower or 'rnn' in type_lower:
968
+ return 'recurrent'
969
+ return 'other'
970
+
971
+
972
+ # =============================================================================
973
+ # Saved Models API Endpoints
974
+ # =============================================================================
975
+
976
+ class SaveModelRequest(BaseModel):
977
+ """Request to save a model."""
978
+ name: str
979
+ framework: str
980
+ totalParameters: int
981
+ layerCount: int
982
+ architecture: dict
983
+ fileHash: Optional[str] = None
984
+
985
+
986
+ class SavedModelSummary(BaseModel):
987
+ """Summary of a saved model."""
988
+ id: int
989
+ name: str
990
+ framework: str
991
+ total_parameters: int
992
+ layer_count: int
993
+ created_at: str
994
+
995
+
996
+ @app.get("/models/saved")
997
+ async def list_saved_models():
998
+ """
999
+ Get a list of all saved models.
1000
+ Returns metadata only (not full architecture).
1001
+ """
1002
+ try:
1003
+ models = db.get_saved_models()
1004
+ return {
1005
+ "success": True,
1006
+ "models": models
1007
+ }
1008
+ except Exception as e:
1009
+ traceback.print_exc()
1010
+ raise HTTPException(status_code=500, detail=str(e))
1011
+
1012
+
1013
+ @app.get("/models/saved/{model_id}")
1014
+ async def get_saved_model(model_id: int):
1015
+ """
1016
+ Get a saved model by ID with full architecture.
1017
+ """
1018
+ try:
1019
+ model = db.get_model_by_id(model_id)
1020
+ if not model:
1021
+ raise HTTPException(status_code=404, detail="Model not found")
1022
+
1023
+ return {
1024
+ "success": True,
1025
+ "model": model
1026
+ }
1027
+ except HTTPException:
1028
+ raise
1029
+ except Exception as e:
1030
+ traceback.print_exc()
1031
+ raise HTTPException(status_code=500, detail=str(e))
1032
+
1033
+
1034
+ @app.post("/models/save")
1035
+ async def save_model(request: SaveModelRequest):
1036
+ """
1037
+ Save a processed model to the database.
1038
+ """
1039
+ try:
1040
+ model_id = db.save_model(
1041
+ name=request.name,
1042
+ framework=request.framework,
1043
+ total_parameters=request.totalParameters,
1044
+ layer_count=request.layerCount,
1045
+ architecture=request.architecture,
1046
+ file_hash=request.fileHash
1047
+ )
1048
+
1049
+ return {
1050
+ "success": True,
1051
+ "id": model_id,
1052
+ "message": "Model saved successfully"
1053
+ }
1054
+ except Exception as e:
1055
+ traceback.print_exc()
1056
+ raise HTTPException(status_code=500, detail=str(e))
1057
+
1058
+
1059
+ @app.delete("/models/saved/{model_id}")
1060
+ async def delete_saved_model(model_id: int):
1061
+ """
1062
+ Delete a saved model by ID.
1063
+ """
1064
+ try:
1065
+ deleted = db.delete_model(model_id)
1066
+ if not deleted:
1067
+ raise HTTPException(status_code=404, detail="Model not found")
1068
+
1069
+ return {
1070
+ "success": True,
1071
+ "message": "Model deleted successfully"
1072
+ }
1073
+ except HTTPException:
1074
+ raise
1075
+ except Exception as e:
1076
+ traceback.print_exc()
1077
+ raise HTTPException(status_code=500, detail=str(e))
1078
+
1079
+
1080
+ if __name__ == "__main__":
1081
+ import uvicorn
1082
+ uvicorn.run(app, host="0.0.0.0", port=8000)
backend/app/model_analyzer.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Model Analyzer
3
+ Extracts architecture information from PyTorch models for 3D visualization.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Dict, List, Any, Optional, Tuple
9
+ from dataclasses import dataclass, asdict
10
+ from collections import OrderedDict
11
+ import json
12
+
13
+
14
+ @dataclass
15
+ class LayerInfo:
16
+ """Information about a single layer in the model."""
17
+ id: str
18
+ name: str
19
+ type: str
20
+ category: str
21
+ input_shape: Optional[List[int]]
22
+ output_shape: Optional[List[int]]
23
+ params: Dict[str, Any]
24
+ num_parameters: int
25
+ trainable: bool
26
+
27
+
28
+ @dataclass
29
+ class ConnectionInfo:
30
+ """Information about connections between layers."""
31
+ source: str
32
+ target: str
33
+ tensor_shape: Optional[List[int]]
34
+
35
+
36
+ @dataclass
37
+ class ModelArchitecture:
38
+ """Complete model architecture information."""
39
+ name: str
40
+ framework: str
41
+ total_parameters: int
42
+ trainable_parameters: int
43
+ layers: List[LayerInfo]
44
+ connections: List[ConnectionInfo]
45
+ input_shape: Optional[List[int]]
46
+ output_shape: Optional[List[int]]
47
+
48
+
49
+ # Layer category mapping
50
+ LAYER_CATEGORIES = {
51
+ # Convolution layers
52
+ 'Conv1d': 'convolution',
53
+ 'Conv2d': 'convolution',
54
+ 'Conv3d': 'convolution',
55
+ 'ConvTranspose1d': 'convolution',
56
+ 'ConvTranspose2d': 'convolution',
57
+ 'ConvTranspose3d': 'convolution',
58
+
59
+ # Pooling layers
60
+ 'MaxPool1d': 'pooling',
61
+ 'MaxPool2d': 'pooling',
62
+ 'MaxPool3d': 'pooling',
63
+ 'AvgPool1d': 'pooling',
64
+ 'AvgPool2d': 'pooling',
65
+ 'AvgPool3d': 'pooling',
66
+ 'AdaptiveAvgPool1d': 'pooling',
67
+ 'AdaptiveAvgPool2d': 'pooling',
68
+ 'AdaptiveAvgPool3d': 'pooling',
69
+ 'AdaptiveMaxPool1d': 'pooling',
70
+ 'AdaptiveMaxPool2d': 'pooling',
71
+ 'AdaptiveMaxPool3d': 'pooling',
72
+ 'GlobalAveragePooling2D': 'pooling',
73
+
74
+ # Linear/Dense layers
75
+ 'Linear': 'linear',
76
+ 'LazyLinear': 'linear',
77
+ 'Bilinear': 'linear',
78
+
79
+ # Normalization layers
80
+ 'BatchNorm1d': 'normalization',
81
+ 'BatchNorm2d': 'normalization',
82
+ 'BatchNorm3d': 'normalization',
83
+ 'LayerNorm': 'normalization',
84
+ 'GroupNorm': 'normalization',
85
+ 'InstanceNorm1d': 'normalization',
86
+ 'InstanceNorm2d': 'normalization',
87
+ 'InstanceNorm3d': 'normalization',
88
+
89
+ # Activation layers
90
+ 'ReLU': 'activation',
91
+ 'ReLU6': 'activation',
92
+ 'LeakyReLU': 'activation',
93
+ 'PReLU': 'activation',
94
+ 'ELU': 'activation',
95
+ 'SELU': 'activation',
96
+ 'GELU': 'activation',
97
+ 'Sigmoid': 'activation',
98
+ 'Tanh': 'activation',
99
+ 'Softmax': 'activation',
100
+ 'LogSoftmax': 'activation',
101
+ 'Softplus': 'activation',
102
+ 'Softsign': 'activation',
103
+ 'Hardswish': 'activation',
104
+ 'Hardsigmoid': 'activation',
105
+ 'SiLU': 'activation',
106
+ 'Mish': 'activation',
107
+
108
+ # Dropout layers
109
+ 'Dropout': 'regularization',
110
+ 'Dropout2d': 'regularization',
111
+ 'Dropout3d': 'regularization',
112
+ 'AlphaDropout': 'regularization',
113
+
114
+ # Recurrent layers
115
+ 'RNN': 'recurrent',
116
+ 'LSTM': 'recurrent',
117
+ 'GRU': 'recurrent',
118
+ 'RNNCell': 'recurrent',
119
+ 'LSTMCell': 'recurrent',
120
+ 'GRUCell': 'recurrent',
121
+
122
+ # Transformer layers
123
+ 'Transformer': 'attention',
124
+ 'TransformerEncoder': 'attention',
125
+ 'TransformerDecoder': 'attention',
126
+ 'TransformerEncoderLayer': 'attention',
127
+ 'TransformerDecoderLayer': 'attention',
128
+ 'MultiheadAttention': 'attention',
129
+
130
+ # Embedding layers
131
+ 'Embedding': 'embedding',
132
+ 'EmbeddingBag': 'embedding',
133
+
134
+ # Reshape/View layers
135
+ 'Flatten': 'reshape',
136
+ 'Unflatten': 'reshape',
137
+
138
+ # Container layers
139
+ 'Sequential': 'container',
140
+ 'ModuleList': 'container',
141
+ 'ModuleDict': 'container',
142
+ }
143
+
144
+
145
+ def get_layer_category(layer_type: str) -> str:
146
+ """Get the category for a layer type."""
147
+ return LAYER_CATEGORIES.get(layer_type, 'other')
148
+
149
+
150
+ def count_parameters(module: nn.Module) -> Tuple[int, int]:
151
+ """Count total and trainable parameters in a module."""
152
+ total = sum(p.numel() for p in module.parameters())
153
+ trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
154
+ return total, trainable
155
+
156
+
157
+ def extract_layer_params(module: nn.Module, layer_type: str) -> Dict[str, Any]:
158
+ """Extract relevant parameters from a layer."""
159
+ params = {}
160
+
161
+ try:
162
+ if hasattr(module, 'in_features'):
163
+ params['in_features'] = module.in_features
164
+ if hasattr(module, 'out_features'):
165
+ params['out_features'] = module.out_features
166
+ if hasattr(module, 'in_channels'):
167
+ params['in_channels'] = module.in_channels
168
+ if hasattr(module, 'out_channels'):
169
+ params['out_channels'] = module.out_channels
170
+ if hasattr(module, 'kernel_size'):
171
+ ks = module.kernel_size
172
+ params['kernel_size'] = list(ks) if isinstance(ks, tuple) else ks
173
+ if hasattr(module, 'stride'):
174
+ s = module.stride
175
+ params['stride'] = list(s) if isinstance(s, tuple) else s
176
+ if hasattr(module, 'padding'):
177
+ p = module.padding
178
+ params['padding'] = list(p) if isinstance(p, tuple) else p
179
+ if hasattr(module, 'dilation'):
180
+ d = module.dilation
181
+ params['dilation'] = list(d) if isinstance(d, tuple) else d
182
+ if hasattr(module, 'groups'):
183
+ params['groups'] = module.groups
184
+ if hasattr(module, 'bias') and module.bias is not None:
185
+ params['bias'] = True
186
+ if hasattr(module, 'num_features'):
187
+ params['num_features'] = module.num_features
188
+ if hasattr(module, 'eps'):
189
+ params['eps'] = module.eps
190
+ if hasattr(module, 'momentum') and module.momentum is not None:
191
+ params['momentum'] = module.momentum
192
+ if hasattr(module, 'normalized_shape'):
193
+ params['normalized_shape'] = list(module.normalized_shape)
194
+ if hasattr(module, 'hidden_size'):
195
+ params['hidden_size'] = module.hidden_size
196
+ if hasattr(module, 'num_layers'):
197
+ params['num_layers'] = module.num_layers
198
+ if hasattr(module, 'bidirectional'):
199
+ params['bidirectional'] = module.bidirectional
200
+ if hasattr(module, 'num_heads'):
201
+ params['num_heads'] = module.num_heads
202
+ if hasattr(module, 'embed_dim'):
203
+ params['embed_dim'] = module.embed_dim
204
+ if hasattr(module, 'num_embeddings'):
205
+ params['num_embeddings'] = module.num_embeddings
206
+ if hasattr(module, 'embedding_dim'):
207
+ params['embedding_dim'] = module.embedding_dim
208
+ if hasattr(module, 'p') and layer_type.startswith('Dropout'):
209
+ params['p'] = module.p
210
+ if hasattr(module, 'negative_slope'):
211
+ params['negative_slope'] = module.negative_slope
212
+ if hasattr(module, 'inplace'):
213
+ params['inplace'] = module.inplace
214
+ if hasattr(module, 'dim'):
215
+ params['dim'] = module.dim
216
+ except Exception:
217
+ pass
218
+
219
+ return params
220
+
221
+
222
+ def analyze_model_structure(model: nn.Module, model_name: str = "model") -> ModelArchitecture:
223
+ """
224
+ Analyze a PyTorch model and extract its architecture.
225
+
226
+ Args:
227
+ model: The PyTorch model to analyze
228
+ model_name: Name identifier for the model
229
+
230
+ Returns:
231
+ ModelArchitecture with complete layer and connection information
232
+ """
233
+ layers = []
234
+ connections = []
235
+ layer_index = 0
236
+ parent_stack = []
237
+
238
+ def process_module(name: str, module: nn.Module, parent_id: Optional[str] = None):
239
+ nonlocal layer_index
240
+
241
+ layer_type = module.__class__.__name__
242
+
243
+ # Skip container modules but process their children
244
+ if layer_type in ('Sequential', 'ModuleList', 'ModuleDict'):
245
+ for child_name, child in module.named_children():
246
+ full_name = f"{name}.{child_name}" if name else child_name
247
+ process_module(full_name, child, parent_id)
248
+ return
249
+
250
+ # Skip modules with no parameters and no meaningful operation
251
+ # But include activation, pooling, dropout, etc.
252
+ has_params = sum(1 for _ in module.parameters(recurse=False)) > 0
253
+ is_meaningful = layer_type in LAYER_CATEGORIES or has_params
254
+
255
+ if not is_meaningful and len(list(module.children())) > 0:
256
+ # Process children of non-meaningful containers
257
+ for child_name, child in module.named_children():
258
+ full_name = f"{name}.{child_name}" if name else child_name
259
+ process_module(full_name, child, parent_id)
260
+ return
261
+
262
+ layer_id = f"layer_{layer_index}"
263
+ layer_index += 1
264
+
265
+ total_params, trainable_params = count_parameters(module)
266
+ params = extract_layer_params(module, layer_type)
267
+
268
+ layer_info = LayerInfo(
269
+ id=layer_id,
270
+ name=name or layer_type,
271
+ type=layer_type,
272
+ category=get_layer_category(layer_type),
273
+ input_shape=None, # Will be populated during forward pass
274
+ output_shape=None,
275
+ params=params,
276
+ num_parameters=total_params,
277
+ trainable=trainable_params > 0
278
+ )
279
+
280
+ layers.append(layer_info)
281
+
282
+ # Create connection from parent
283
+ if parent_id is not None:
284
+ connections.append(ConnectionInfo(
285
+ source=parent_id,
286
+ target=layer_id,
287
+ tensor_shape=None
288
+ ))
289
+
290
+ # Process children
291
+ children = list(module.named_children())
292
+ if children:
293
+ for child_name, child in children:
294
+ full_name = f"{name}.{child_name}" if name else child_name
295
+ process_module(full_name, child, layer_id)
296
+
297
+ return layer_id
298
+
299
+ # Process the model
300
+ children = list(model.named_children())
301
+ if children:
302
+ prev_id = None
303
+ for name, child in children:
304
+ layer_id = process_module(name, child, prev_id)
305
+ if layer_id:
306
+ prev_id = layer_id
307
+ else:
308
+ # Single layer model
309
+ process_module("", model, None)
310
+
311
+ # If layers are sequential and no connections exist, create linear connections
312
+ if len(layers) > 1 and len(connections) == 0:
313
+ for i in range(len(layers) - 1):
314
+ connections.append(ConnectionInfo(
315
+ source=layers[i].id,
316
+ target=layers[i + 1].id,
317
+ tensor_shape=None
318
+ ))
319
+
320
+ total_params, trainable_params = count_parameters(model)
321
+
322
+ return ModelArchitecture(
323
+ name=model_name,
324
+ framework="pytorch",
325
+ total_parameters=total_params,
326
+ trainable_parameters=trainable_params,
327
+ layers=layers,
328
+ connections=connections,
329
+ input_shape=None,
330
+ output_shape=None
331
+ )
332
+
333
+
334
+ def trace_model_shapes(model: nn.Module, input_tensor: torch.Tensor, arch: ModelArchitecture) -> ModelArchitecture:
335
+ """
336
+ Trace model execution to capture input/output shapes for each layer.
337
+
338
+ Args:
339
+ model: The PyTorch model
340
+ input_tensor: Sample input tensor
341
+ arch: Existing architecture info to update
342
+
343
+ Returns:
344
+ Updated ModelArchitecture with shape information
345
+ """
346
+ shapes = {}
347
+ hooks = []
348
+
349
+ def make_hook(name):
350
+ def hook(module, input, output):
351
+ input_shape = None
352
+ output_shape = None
353
+
354
+ if isinstance(input, tuple) and len(input) > 0:
355
+ if isinstance(input[0], torch.Tensor):
356
+ input_shape = list(input[0].shape)
357
+ elif isinstance(input, torch.Tensor):
358
+ input_shape = list(input.shape)
359
+
360
+ if isinstance(output, torch.Tensor):
361
+ output_shape = list(output.shape)
362
+ elif isinstance(output, tuple) and len(output) > 0:
363
+ if isinstance(output[0], torch.Tensor):
364
+ output_shape = list(output[0].shape)
365
+
366
+ shapes[name] = {
367
+ 'input': input_shape,
368
+ 'output': output_shape
369
+ }
370
+ return hook
371
+
372
+ # Register hooks
373
+ for name, module in model.named_modules():
374
+ if name: # Skip root module
375
+ hooks.append(module.register_forward_hook(make_hook(name)))
376
+
377
+ # Run forward pass
378
+ try:
379
+ model.eval()
380
+ with torch.no_grad():
381
+ output = model(input_tensor)
382
+
383
+ # Update architecture with shapes
384
+ for layer in arch.layers:
385
+ if layer.name in shapes:
386
+ layer.input_shape = shapes[layer.name]['input']
387
+ layer.output_shape = shapes[layer.name]['output']
388
+
389
+ # Set model input/output shapes
390
+ arch.input_shape = list(input_tensor.shape)
391
+ if isinstance(output, torch.Tensor):
392
+ arch.output_shape = list(output.shape)
393
+
394
+ except Exception as e:
395
+ print(f"Warning: Could not trace shapes: {e}")
396
+ finally:
397
+ # Remove hooks
398
+ for hook in hooks:
399
+ hook.remove()
400
+
401
+ return arch
402
+
403
+
404
+ def load_pytorch_model(file_path: str) -> Tuple[Optional[nn.Module], Optional[Dict], str]:
405
+ """
406
+ Load a PyTorch model from file.
407
+
408
+ Returns:
409
+ Tuple of (model, state_dict, model_type)
410
+ model_type can be: 'full_model', 'state_dict', 'torchscript', 'checkpoint'
411
+ """
412
+ try:
413
+ # Try loading as TorchScript first
414
+ try:
415
+ model = torch.jit.load(file_path, map_location='cpu')
416
+ return model, None, 'torchscript'
417
+ except Exception:
418
+ pass
419
+
420
+ # Try loading as regular checkpoint
421
+ checkpoint = torch.load(file_path, map_location='cpu', weights_only=False)
422
+
423
+ if isinstance(checkpoint, nn.Module):
424
+ return checkpoint, None, 'full_model'
425
+
426
+ if isinstance(checkpoint, dict):
427
+ # Check for common checkpoint formats
428
+ if 'model' in checkpoint:
429
+ if isinstance(checkpoint['model'], nn.Module):
430
+ return checkpoint['model'], None, 'checkpoint'
431
+ elif isinstance(checkpoint['model'], dict):
432
+ return None, checkpoint['model'], 'state_dict'
433
+
434
+ if 'state_dict' in checkpoint:
435
+ return None, checkpoint['state_dict'], 'state_dict'
436
+
437
+ if 'model_state_dict' in checkpoint:
438
+ return None, checkpoint['model_state_dict'], 'state_dict'
439
+
440
+ # Check if it's directly a state dict (contains tensor values)
441
+ has_tensors = any(isinstance(v, torch.Tensor) for v in checkpoint.values())
442
+ if has_tensors:
443
+ return None, checkpoint, 'state_dict'
444
+
445
+ return None, None, 'unknown'
446
+
447
+ except Exception as e:
448
+ raise ValueError(f"Failed to load model: {str(e)}")
449
+
450
+
451
+ def analyze_state_dict(state_dict: Dict[str, torch.Tensor], model_name: str = "model") -> ModelArchitecture:
452
+ """
453
+ Analyze a state dict to infer model architecture.
454
+
455
+ This extracts layer information from weight tensor names and shapes.
456
+ """
457
+ layers = []
458
+ layer_map = OrderedDict()
459
+
460
+ # Group parameters by layer name
461
+ for key, tensor in state_dict.items():
462
+ if not isinstance(tensor, torch.Tensor):
463
+ continue
464
+
465
+ # Extract layer name from parameter name
466
+ parts = key.rsplit('.', 1)
467
+ if len(parts) == 2:
468
+ layer_name, param_type = parts
469
+ else:
470
+ layer_name = key
471
+ param_type = 'weight'
472
+
473
+ if layer_name not in layer_map:
474
+ layer_map[layer_name] = {
475
+ 'params': {},
476
+ 'shapes': {}
477
+ }
478
+
479
+ layer_map[layer_name]['params'][param_type] = True
480
+ layer_map[layer_name]['shapes'][param_type] = list(tensor.shape)
481
+
482
+ # Create layer info from grouped parameters
483
+ layer_index = 0
484
+ for layer_name, info in layer_map.items():
485
+ layer_type, category = infer_layer_type(layer_name, info['shapes'])
486
+
487
+ layer_id = f"layer_{layer_index}"
488
+ layer_index += 1
489
+
490
+ # Compute number of parameters
491
+ num_params = 0
492
+ for param_type, shape in info['shapes'].items():
493
+ param_size = 1
494
+ for dim in shape:
495
+ param_size *= dim
496
+ num_params += param_size
497
+
498
+ # Extract layer parameters from shapes
499
+ params = extract_params_from_shapes(layer_type, info['shapes'])
500
+
501
+ # Infer input/output shapes
502
+ input_shape, output_shape = infer_shapes(layer_type, info['shapes'], params)
503
+
504
+ layers.append(LayerInfo(
505
+ id=layer_id,
506
+ name=layer_name,
507
+ type=layer_type,
508
+ category=category,
509
+ input_shape=input_shape,
510
+ output_shape=output_shape,
511
+ params=params,
512
+ num_parameters=num_params,
513
+ trainable=True
514
+ ))
515
+
516
+ # Create sequential connections
517
+ connections = []
518
+ for i in range(len(layers) - 1):
519
+ connections.append(ConnectionInfo(
520
+ source=layers[i].id,
521
+ target=layers[i + 1].id,
522
+ tensor_shape=layers[i].output_shape
523
+ ))
524
+
525
+ total_params = sum(layer.num_parameters for layer in layers)
526
+
527
+ return ModelArchitecture(
528
+ name=model_name,
529
+ framework="pytorch",
530
+ total_parameters=total_params,
531
+ trainable_parameters=total_params,
532
+ layers=layers,
533
+ connections=connections,
534
+ input_shape=layers[0].input_shape if layers else None,
535
+ output_shape=layers[-1].output_shape if layers else None
536
+ )
537
+
538
+
539
+ def infer_layer_type(layer_name: str, shapes: Dict[str, List[int]]) -> Tuple[str, str]:
540
+ """Infer layer type from name and weight shapes."""
541
+ name_lower = layer_name.lower()
542
+
543
+ # Check for common layer type patterns in name
544
+ if 'conv' in name_lower:
545
+ weight_shape = shapes.get('weight', [])
546
+ if len(weight_shape) == 5:
547
+ return 'Conv3d', 'convolution'
548
+ elif len(weight_shape) == 4:
549
+ return 'Conv2d', 'convolution'
550
+ elif len(weight_shape) == 3:
551
+ return 'Conv1d', 'convolution'
552
+ return 'Conv2d', 'convolution'
553
+
554
+ if 'bn' in name_lower or 'batch' in name_lower or 'norm' in name_lower:
555
+ weight_shape = shapes.get('weight', shapes.get('running_mean', []))
556
+ if 'layer' in name_lower:
557
+ return 'LayerNorm', 'normalization'
558
+ return 'BatchNorm2d', 'normalization'
559
+
560
+ if 'fc' in name_lower or 'linear' in name_lower or 'dense' in name_lower or 'classifier' in name_lower:
561
+ return 'Linear', 'linear'
562
+
563
+ if 'lstm' in name_lower:
564
+ return 'LSTM', 'recurrent'
565
+
566
+ if 'gru' in name_lower:
567
+ return 'GRU', 'recurrent'
568
+
569
+ if 'rnn' in name_lower:
570
+ return 'RNN', 'recurrent'
571
+
572
+ if 'attention' in name_lower or 'attn' in name_lower:
573
+ return 'MultiheadAttention', 'attention'
574
+
575
+ if 'embed' in name_lower:
576
+ return 'Embedding', 'embedding'
577
+
578
+ if 'pool' in name_lower:
579
+ return 'AdaptiveAvgPool2d', 'pooling'
580
+
581
+ # Infer from weight shape
582
+ weight_shape = shapes.get('weight', [])
583
+ if len(weight_shape) == 2:
584
+ return 'Linear', 'linear'
585
+ elif len(weight_shape) == 4:
586
+ return 'Conv2d', 'convolution'
587
+ elif len(weight_shape) == 3:
588
+ return 'Conv1d', 'convolution'
589
+ elif len(weight_shape) == 1:
590
+ return 'BatchNorm2d', 'normalization'
591
+
592
+ return 'Unknown', 'other'
593
+
594
+
595
+ def extract_params_from_shapes(layer_type: str, shapes: Dict[str, List[int]]) -> Dict[str, Any]:
596
+ """Extract layer parameters from weight shapes."""
597
+ params = {}
598
+ weight_shape = shapes.get('weight', [])
599
+
600
+ if layer_type in ('Linear',):
601
+ if len(weight_shape) >= 2:
602
+ params['out_features'] = weight_shape[0]
603
+ params['in_features'] = weight_shape[1]
604
+ params['bias'] = 'bias' in shapes
605
+
606
+ elif layer_type in ('Conv1d', 'Conv2d', 'Conv3d'):
607
+ if len(weight_shape) >= 2:
608
+ params['out_channels'] = weight_shape[0]
609
+ params['in_channels'] = weight_shape[1]
610
+ if len(weight_shape) > 2:
611
+ params['kernel_size'] = weight_shape[2:]
612
+ params['bias'] = 'bias' in shapes
613
+
614
+ elif layer_type in ('BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'):
615
+ if len(weight_shape) >= 1:
616
+ params['num_features'] = weight_shape[0]
617
+
618
+ elif layer_type == 'LayerNorm':
619
+ if len(weight_shape) >= 1:
620
+ params['normalized_shape'] = weight_shape
621
+
622
+ elif layer_type == 'Embedding':
623
+ if len(weight_shape) >= 2:
624
+ params['num_embeddings'] = weight_shape[0]
625
+ params['embedding_dim'] = weight_shape[1]
626
+
627
+ elif layer_type in ('LSTM', 'GRU', 'RNN'):
628
+ # weight_ih_l0 shape gives hidden_size x input_size
629
+ if 'weight_ih_l0' in shapes:
630
+ ih_shape = shapes['weight_ih_l0']
631
+ if len(ih_shape) >= 2:
632
+ multiplier = 4 if layer_type == 'LSTM' else (3 if layer_type == 'GRU' else 1)
633
+ params['hidden_size'] = ih_shape[0] // multiplier
634
+ params['input_size'] = ih_shape[1]
635
+
636
+ return params
637
+
638
+
639
+ def infer_shapes(layer_type: str, shapes: Dict[str, List[int]], params: Dict[str, Any]) -> Tuple[Optional[List[int]], Optional[List[int]]]:
640
+ """Infer input/output shapes from layer parameters."""
641
+ input_shape = None
642
+ output_shape = None
643
+
644
+ if layer_type == 'Linear':
645
+ if 'in_features' in params:
646
+ input_shape = [-1, params['in_features']]
647
+ if 'out_features' in params:
648
+ output_shape = [-1, params['out_features']]
649
+
650
+ elif layer_type in ('Conv2d',):
651
+ if 'in_channels' in params:
652
+ input_shape = [-1, params['in_channels'], -1, -1]
653
+ if 'out_channels' in params:
654
+ output_shape = [-1, params['out_channels'], -1, -1]
655
+
656
+ elif layer_type in ('Conv1d',):
657
+ if 'in_channels' in params:
658
+ input_shape = [-1, params['in_channels'], -1]
659
+ if 'out_channels' in params:
660
+ output_shape = [-1, params['out_channels'], -1]
661
+
662
+ elif layer_type in ('BatchNorm2d',):
663
+ if 'num_features' in params:
664
+ input_shape = [-1, params['num_features'], -1, -1]
665
+ output_shape = [-1, params['num_features'], -1, -1]
666
+
667
+ elif layer_type == 'Embedding':
668
+ if 'embedding_dim' in params:
669
+ output_shape = [-1, -1, params['embedding_dim']]
670
+
671
+ elif layer_type in ('GRU', 'LSTM', 'RNN'):
672
+ # For recurrent layers: input is (batch, seq_len, input_size)
673
+ # output is (batch, seq_len, hidden_size * num_directions)
674
+ if 'input_size' in params:
675
+ input_shape = [-1, -1, params['input_size']]
676
+ if 'hidden_size' in params:
677
+ num_directions = 2 if params.get('bidirectional', False) else 1
678
+ output_shape = [-1, -1, params['hidden_size'] * num_directions]
679
+
680
+ return input_shape, output_shape
681
+
682
+
683
+ def architecture_to_dict(arch: ModelArchitecture) -> Dict[str, Any]:
684
+ """Convert ModelArchitecture to JSON-serializable dict."""
685
+ return {
686
+ 'name': arch.name,
687
+ 'framework': arch.framework,
688
+ 'totalParameters': arch.total_parameters,
689
+ 'trainableParameters': arch.trainable_parameters,
690
+ 'inputShape': arch.input_shape,
691
+ 'outputShape': arch.output_shape,
692
+ 'layers': [
693
+ {
694
+ 'id': layer.id,
695
+ 'name': layer.name,
696
+ 'type': layer.type,
697
+ 'category': layer.category,
698
+ 'inputShape': layer.input_shape,
699
+ 'outputShape': layer.output_shape,
700
+ 'params': layer.params,
701
+ 'numParameters': layer.num_parameters,
702
+ 'trainable': layer.trainable
703
+ }
704
+ for layer in arch.layers
705
+ ],
706
+ 'connections': [
707
+ {
708
+ 'source': conn.source,
709
+ 'target': conn.target,
710
+ 'tensorShape': conn.tensor_shape
711
+ }
712
+ for conn in arch.connections
713
+ ]
714
+ }
backend/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.109.0
2
+ uvicorn[standard]==0.27.0
3
+ python-multipart==0.0.6
4
+ torch>=2.0.0
5
+ torchvision>=0.15.0
6
+ onnx>=1.14.0
7
+ numpy>=1.24.0
8
+ pydantic>=2.0.0
9
+ h5py>=3.10.0
10
+ safetensors>=0.4.0
backend/start.bat ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo Starting NN3D Visualizer Backend...
3
+ echo.
4
+
5
+ REM Check if virtual environment exists
6
+ if not exist "venv" (
7
+ echo Creating virtual environment...
8
+ python -m venv venv
9
+ )
10
+
11
+ REM Activate virtual environment
12
+ call venv\Scripts\activate.bat
13
+
14
+ REM Install dependencies
15
+ echo Installing dependencies...
16
+ pip install -r requirements.txt --quiet
17
+
18
+ REM Start the server
19
+ echo.
20
+ echo Starting FastAPI server on http://localhost:8000
21
+ echo API docs available at http://localhost:8000/docs
22
+ echo.
23
+ python -m uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
backend/start.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ echo "Starting NN3D Visualizer Backend..."
3
+ echo
4
+
5
+ # Check if virtual environment exists
6
+ if [ ! -d "venv" ]; then
7
+ echo "Creating virtual environment..."
8
+ python3 -m venv venv
9
+ fi
10
+
11
+ # Activate virtual environment
12
+ source venv/bin/activate
13
+
14
+ # Install dependencies
15
+ echo "Installing dependencies..."
16
+ pip install -r requirements.txt --quiet
17
+
18
+ # Start the server
19
+ echo
20
+ echo "Starting FastAPI server on http://localhost:8000"
21
+ echo "API docs available at http://localhost:8000/docs"
22
+ echo
23
+ python -m uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
docker-compose.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Compose for NN3D Visualizer
2
+ # Run: docker-compose up --build
3
+
4
+ version: '3.8'
5
+
6
+ services:
7
+ # FastAPI Backend - Model Analysis
8
+ backend:
9
+ build:
10
+ context: ./backend
11
+ dockerfile: Dockerfile
12
+ container_name: nn3d-backend
13
+ restart: unless-stopped
14
+ volumes:
15
+ # Persist database
16
+ - nn3d-data:/app/data
17
+ # Mount models directory for local model files (optional)
18
+ - ./samples:/app/samples:ro
19
+ environment:
20
+ - DATABASE_PATH=/app/data/models.db
21
+ - PYTHONUNBUFFERED=1
22
+ ports:
23
+ - "8000:8000"
24
+ networks:
25
+ - nn3d-network
26
+ healthcheck:
27
+ test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
28
+ interval: 30s
29
+ timeout: 10s
30
+ retries: 3
31
+ start_period: 10s
32
+
33
+ # React Frontend - 3D Visualization
34
+ frontend:
35
+ build:
36
+ context: .
37
+ dockerfile: Dockerfile
38
+ container_name: nn3d-frontend
39
+ restart: unless-stopped
40
+ ports:
41
+ - "3000:80"
42
+ depends_on:
43
+ backend:
44
+ condition: service_healthy
45
+ networks:
46
+ - nn3d-network
47
+ healthcheck:
48
+ test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:80/"]
49
+ interval: 30s
50
+ timeout: 10s
51
+ retries: 3
52
+ start_period: 5s
53
+
54
+ networks:
55
+ nn3d-network:
56
+ driver: bridge
57
+
58
+ volumes:
59
+ nn3d-data:
60
+ driver: local
docker-start.bat ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ REM NN3D Visualizer - Docker Startup Script (Windows)
3
+
4
+ echo.
5
+ echo ========================================
6
+ echo NN3D VISUALIZER - DOCKER SETUP
7
+ echo ========================================
8
+ echo.
9
+
10
+ REM Check if Docker is running
11
+ docker info >nul 2>&1
12
+ if errorlevel 1 (
13
+ echo [ERROR] Docker is not running. Please start Docker Desktop.
14
+ pause
15
+ exit /b 1
16
+ )
17
+
18
+ echo [*] Building and starting containers...
19
+ echo.
20
+
21
+ docker-compose up --build -d
22
+
23
+ if errorlevel 1 (
24
+ echo [ERROR] Failed to start containers.
25
+ pause
26
+ exit /b 1
27
+ )
28
+
29
+ echo.
30
+ echo ========================================
31
+ echo CONTAINERS STARTED SUCCESSFULLY
32
+ echo ========================================
33
+ echo.
34
+ echo Frontend: http://localhost:3000
35
+ echo Backend: http://localhost:8000
36
+ echo API Docs: http://localhost:8000/docs
37
+ echo.
38
+ echo Run 'docker-compose logs -f' to view logs
39
+ echo Run 'docker-compose down' to stop
40
+ echo ========================================
41
+ echo.
42
+
43
+ pause
docker-start.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # NN3D Visualizer - Docker Startup Script (Linux/Mac)
3
+
4
+ echo ""
5
+ echo "========================================"
6
+ echo " NN3D VISUALIZER - DOCKER SETUP"
7
+ echo "========================================"
8
+ echo ""
9
+
10
+ # Check if Docker is running
11
+ if ! docker info > /dev/null 2>&1; then
12
+ echo "[ERROR] Docker is not running. Please start Docker."
13
+ exit 1
14
+ fi
15
+
16
+ echo "[*] Building and starting containers..."
17
+ echo ""
18
+
19
+ docker-compose up --build -d
20
+
21
+ if [ $? -ne 0 ]; then
22
+ echo "[ERROR] Failed to start containers."
23
+ exit 1
24
+ fi
25
+
26
+ echo ""
27
+ echo "========================================"
28
+ echo " CONTAINERS STARTED SUCCESSFULLY"
29
+ echo "========================================"
30
+ echo ""
31
+ echo " Frontend: http://localhost:3000"
32
+ echo " Backend: http://localhost:8000"
33
+ echo " API Docs: http://localhost:8000/docs"
34
+ echo ""
35
+ echo " Run 'docker-compose logs -f' to view logs"
36
+ echo " Run 'docker-compose down' to stop"
37
+ echo "========================================"
38
+ echo ""
exporters/python/README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NN3D Exporter
2
+
3
+ Python library to export neural network models to `.nn3d` format for 3D visualization.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ # Basic installation
9
+ pip install nn3d-exporter
10
+
11
+ # With PyTorch support
12
+ pip install nn3d-exporter[pytorch]
13
+
14
+ # With ONNX support
15
+ pip install nn3d-exporter[onnx]
16
+
17
+ # With all frameworks
18
+ pip install nn3d-exporter[all]
19
+ ```
20
+
21
+ ## Quick Start
22
+
23
+ ### Export PyTorch Model
24
+
25
+ ```python
26
+ import torch.nn as nn
27
+ from nn3d_exporter import export_pytorch_model
28
+
29
+ # Define your model
30
+ class SimpleCNN(nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
34
+ self.bn1 = nn.BatchNorm2d(64)
35
+ self.relu = nn.ReLU()
36
+ self.pool = nn.MaxPool2d(2)
37
+ self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
38
+ self.bn2 = nn.BatchNorm2d(128)
39
+ self.fc = nn.Linear(128 * 56 * 56, 10)
40
+
41
+ def forward(self, x):
42
+ x = self.pool(self.relu(self.bn1(self.conv1(x))))
43
+ x = self.pool(self.relu(self.bn2(self.conv2(x))))
44
+ x = x.view(x.size(0), -1)
45
+ return self.fc(x)
46
+
47
+ model = SimpleCNN()
48
+
49
+ # Export to .nn3d format
50
+ export_pytorch_model(
51
+ model,
52
+ output_path="simple_cnn.nn3d",
53
+ input_shape=(1, 3, 224, 224),
54
+ model_name="Simple CNN"
55
+ )
56
+ ```
57
+
58
+ ### Export ONNX Model
59
+
60
+ ```python
61
+ from nn3d_exporter import export_onnx_model
62
+
63
+ # Export an existing ONNX model
64
+ export_onnx_model(
65
+ model_path="resnet50.onnx",
66
+ output_path="resnet50.nn3d",
67
+ model_name="ResNet-50"
68
+ )
69
+ ```
70
+
71
+ ### Using the Exporter Classes
72
+
73
+ For more control, use the exporter classes directly:
74
+
75
+ ```python
76
+ from nn3d_exporter import PyTorchExporter
77
+
78
+ exporter = PyTorchExporter(
79
+ model=model,
80
+ input_shape=(1, 3, 224, 224),
81
+ model_name="My Model"
82
+ )
83
+
84
+ # Export to NN3DModel object
85
+ nn3d_model = exporter.export()
86
+
87
+ # Customize before saving
88
+ nn3d_model.visualization.theme = "blueprint"
89
+ nn3d_model.visualization.layout = "force"
90
+
91
+ # Save to file
92
+ nn3d_model.save("my_model.nn3d")
93
+
94
+ # Or get JSON string
95
+ json_str = nn3d_model.to_json()
96
+ ```
97
+
98
+ ## Supported Frameworks
99
+
100
+ ### PyTorch
101
+
102
+ - All standard `torch.nn` layers
103
+ - Custom modules (exported as "custom" type)
104
+ - Automatic shape inference via forward pass
105
+ - Parameter counting
106
+
107
+ ### ONNX
108
+
109
+ - Standard ONNX operators
110
+ - Shape extraction from model metadata
111
+ - Operator attributes preserved
112
+
113
+ ## API Reference
114
+
115
+ ### `export_pytorch_model(model, output_path, input_shape=None, model_name=None)`
116
+
117
+ Export a PyTorch model to .nn3d file.
118
+
119
+ | Parameter | Type | Description |
120
+ | ------------- | ----------- | ----------------------------- |
121
+ | `model` | `nn.Module` | PyTorch model to export |
122
+ | `output_path` | `str` | Path for output .nn3d file |
123
+ | `input_shape` | `tuple` | Input tensor shape (optional) |
124
+ | `model_name` | `str` | Model name (optional) |
125
+
126
+ ### `export_onnx_model(model_path, output_path, model_name=None)`
127
+
128
+ Export an ONNX model to .nn3d file.
129
+
130
+ | Parameter | Type | Description |
131
+ | ------------- | ----- | -------------------------- |
132
+ | `model_path` | `str` | Path to ONNX model file |
133
+ | `output_path` | `str` | Path for output .nn3d file |
134
+ | `model_name` | `str` | Model name (optional) |
135
+
136
+ ## License
137
+
138
+ MIT License
exporters/python/nn3d_exporter/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NN3D Exporter - Export neural network models to .nn3d format
3
+
4
+ This package provides utilities to export models from various
5
+ deep learning frameworks to the .nn3d visualization format.
6
+
7
+ Supported frameworks:
8
+ - PyTorch
9
+ - ONNX
10
+ - TensorFlow/Keras (planned)
11
+ """
12
+
13
+ from .pytorch_exporter import PyTorchExporter, export_pytorch_model
14
+ from .onnx_exporter import ONNXExporter, export_onnx_model
15
+ from .schema import NN3DModel, NN3DNode, NN3DEdge, NN3DGraph, NN3DMetadata
16
+
17
+ __version__ = "1.0.0"
18
+ __all__ = [
19
+ "PyTorchExporter",
20
+ "export_pytorch_model",
21
+ "ONNXExporter",
22
+ "export_onnx_model",
23
+ "NN3DModel",
24
+ "NN3DNode",
25
+ "NN3DEdge",
26
+ "NN3DGraph",
27
+ "NN3DMetadata",
28
+ ]
exporters/python/nn3d_exporter/onnx_exporter.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ONNX Model Exporter
3
+
4
+ Export ONNX models to .nn3d format for visualization.
5
+ """
6
+
7
+ from typing import Dict, List, Optional, Tuple, Any, Union
8
+ from datetime import datetime
9
+ import json
10
+
11
+ try:
12
+ import onnx
13
+ from onnx import numpy_helper
14
+ ONNX_AVAILABLE = True
15
+ except ImportError:
16
+ ONNX_AVAILABLE = False
17
+
18
+ from .schema import (
19
+ NN3DModel, NN3DGraph, NN3DNode, NN3DEdge, NN3DMetadata,
20
+ NN3DSubgraph, LayerParams, LayerType, VisualizationConfig
21
+ )
22
+
23
+
24
+ # Mapping from ONNX op types to NN3D layer types
25
+ ONNX_TO_NN3D_TYPE: Dict[str, str] = {
26
+ # Convolution
27
+ 'Conv': LayerType.CONV2D.value,
28
+ 'ConvTranspose': LayerType.CONV_TRANSPOSE_2D.value,
29
+
30
+ # Linear
31
+ 'Gemm': LayerType.LINEAR.value,
32
+ 'MatMul': LayerType.LINEAR.value,
33
+
34
+ # Normalization
35
+ 'BatchNormalization': LayerType.BATCH_NORM_2D.value,
36
+ 'LayerNormalization': LayerType.LAYER_NORM.value,
37
+ 'InstanceNormalization': LayerType.INSTANCE_NORM.value,
38
+ 'GroupNormalization': LayerType.GROUP_NORM.value,
39
+ 'Dropout': LayerType.DROPOUT.value,
40
+
41
+ # Activations
42
+ 'Relu': LayerType.RELU.value,
43
+ 'LeakyRelu': LayerType.LEAKY_RELU.value,
44
+ 'Sigmoid': LayerType.SIGMOID.value,
45
+ 'Tanh': LayerType.TANH.value,
46
+ 'Softmax': LayerType.SOFTMAX.value,
47
+ 'Gelu': LayerType.GELU.value,
48
+
49
+ # Pooling
50
+ 'MaxPool': LayerType.MAX_POOL_2D.value,
51
+ 'AveragePool': LayerType.AVG_POOL_2D.value,
52
+ 'GlobalAveragePool': LayerType.GLOBAL_AVG_POOL.value,
53
+ 'GlobalMaxPool': LayerType.MAX_POOL_2D.value,
54
+
55
+ # Shape operations
56
+ 'Flatten': LayerType.FLATTEN.value,
57
+ 'Reshape': LayerType.RESHAPE.value,
58
+ 'Transpose': LayerType.RESHAPE.value,
59
+ 'Squeeze': LayerType.RESHAPE.value,
60
+ 'Unsqueeze': LayerType.RESHAPE.value,
61
+
62
+ # Merge operations
63
+ 'Concat': LayerType.CONCAT.value,
64
+ 'Add': LayerType.ADD.value,
65
+ 'Mul': LayerType.MULTIPLY.value,
66
+ 'Split': LayerType.SPLIT.value,
67
+
68
+ # Attention
69
+ 'Attention': LayerType.ATTENTION.value,
70
+ 'MultiHeadAttention': LayerType.MULTI_HEAD_ATTENTION.value,
71
+
72
+ # Recurrent
73
+ 'LSTM': LayerType.LSTM.value,
74
+ 'GRU': LayerType.GRU.value,
75
+ 'RNN': LayerType.RNN.value,
76
+
77
+ # Resize/Upsample
78
+ 'Resize': LayerType.UPSAMPLE.value,
79
+ 'Upsample': LayerType.UPSAMPLE.value,
80
+
81
+ # Padding
82
+ 'Pad': LayerType.PAD.value,
83
+ }
84
+
85
+
86
+ class ONNXExporter:
87
+ """
88
+ Export ONNX models to NN3D format.
89
+
90
+ Usage:
91
+ exporter = ONNXExporter("model.onnx")
92
+ nn3d_model = exporter.export()
93
+ nn3d_model.save("model.nn3d")
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ model_path: str,
99
+ model_name: Optional[str] = None,
100
+ ):
101
+ """
102
+ Initialize the exporter.
103
+
104
+ Args:
105
+ model_path: Path to ONNX model file
106
+ model_name: Name for the model (defaults to filename)
107
+ """
108
+ if not ONNX_AVAILABLE:
109
+ raise ImportError("ONNX is required. Install with: pip install onnx")
110
+
111
+ self.model_path = model_path
112
+ self.model = onnx.load(model_path)
113
+ self.model_name = model_name or model_path.split('/')[-1].replace('.onnx', '')
114
+
115
+ # Initialize graph
116
+ onnx.checker.check_model(self.model)
117
+ self.graph = self.model.graph
118
+
119
+ self.nodes: List[NN3DNode] = []
120
+ self.edges: List[NN3DEdge] = []
121
+
122
+ self._tensor_shapes: Dict[str, List[int]] = {}
123
+ self._value_info: Dict[str, Any] = {}
124
+ self._node_id_map: Dict[str, str] = {}
125
+ self._output_to_node: Dict[str, str] = {}
126
+
127
+ def _extract_shapes(self) -> None:
128
+ """Extract tensor shapes from the model"""
129
+ # Input shapes
130
+ for input_info in self.graph.input:
131
+ shape = []
132
+ if input_info.type.tensor_type.HasField('shape'):
133
+ for dim in input_info.type.tensor_type.shape.dim:
134
+ if dim.HasField('dim_value'):
135
+ shape.append(dim.dim_value)
136
+ elif dim.HasField('dim_param'):
137
+ shape.append(dim.dim_param)
138
+ else:
139
+ shape.append(-1)
140
+ self._tensor_shapes[input_info.name] = shape
141
+ self._value_info[input_info.name] = input_info
142
+
143
+ # Output shapes
144
+ for output_info in self.graph.output:
145
+ shape = []
146
+ if output_info.type.tensor_type.HasField('shape'):
147
+ for dim in output_info.type.tensor_type.shape.dim:
148
+ if dim.HasField('dim_value'):
149
+ shape.append(dim.dim_value)
150
+ elif dim.HasField('dim_param'):
151
+ shape.append(dim.dim_param)
152
+ else:
153
+ shape.append(-1)
154
+ self._tensor_shapes[output_info.name] = shape
155
+ self._value_info[output_info.name] = output_info
156
+
157
+ # Value info (intermediate tensors)
158
+ for value_info in self.graph.value_info:
159
+ shape = []
160
+ if value_info.type.tensor_type.HasField('shape'):
161
+ for dim in value_info.type.tensor_type.shape.dim:
162
+ if dim.HasField('dim_value'):
163
+ shape.append(dim.dim_value)
164
+ elif dim.HasField('dim_param'):
165
+ shape.append(dim.dim_param)
166
+ else:
167
+ shape.append(-1)
168
+ self._tensor_shapes[value_info.name] = shape
169
+ self._value_info[value_info.name] = value_info
170
+
171
+ def _get_layer_type(self, op_type: str) -> str:
172
+ """Map ONNX op type to NN3D layer type"""
173
+ return ONNX_TO_NN3D_TYPE.get(op_type, LayerType.CUSTOM.value)
174
+
175
+ def _extract_attributes(self, node) -> Dict[str, Any]:
176
+ """Extract node attributes"""
177
+ attrs = {}
178
+ for attr in node.attribute:
179
+ if attr.type == onnx.AttributeProto.INT:
180
+ attrs[attr.name] = attr.i
181
+ elif attr.type == onnx.AttributeProto.INTS:
182
+ attrs[attr.name] = list(attr.ints)
183
+ elif attr.type == onnx.AttributeProto.FLOAT:
184
+ attrs[attr.name] = attr.f
185
+ elif attr.type == onnx.AttributeProto.FLOATS:
186
+ attrs[attr.name] = list(attr.floats)
187
+ elif attr.type == onnx.AttributeProto.STRING:
188
+ attrs[attr.name] = attr.s.decode('utf-8')
189
+ elif attr.type == onnx.AttributeProto.STRINGS:
190
+ attrs[attr.name] = [s.decode('utf-8') for s in attr.strings]
191
+ return attrs
192
+
193
+ def _extract_params(self, node, attrs: Dict[str, Any]) -> LayerParams:
194
+ """Extract layer parameters from ONNX node"""
195
+ params = LayerParams()
196
+
197
+ # Convolution parameters
198
+ if 'kernel_shape' in attrs:
199
+ params.kernel_size = attrs['kernel_shape']
200
+ if 'strides' in attrs:
201
+ params.stride = attrs['strides']
202
+ if 'pads' in attrs:
203
+ params.padding = attrs['pads']
204
+ if 'dilations' in attrs:
205
+ params.dilation = attrs['dilations']
206
+ if 'group' in attrs:
207
+ params.groups = attrs['group']
208
+
209
+ # Normalization parameters
210
+ if 'epsilon' in attrs:
211
+ params.eps = attrs['epsilon']
212
+ if 'momentum' in attrs:
213
+ params.momentum = attrs['momentum']
214
+
215
+ # Dropout
216
+ if 'ratio' in attrs:
217
+ params.dropout_rate = attrs['ratio']
218
+
219
+ return params
220
+
221
+ def export(self) -> NN3DModel:
222
+ """Export the ONNX model to NN3D format"""
223
+
224
+ # Extract tensor shapes
225
+ self._extract_shapes()
226
+
227
+ # Add input nodes
228
+ for idx, input_info in enumerate(self.graph.input):
229
+ # Skip initializers (weights)
230
+ if input_info.name in [init.name for init in self.graph.initializer]:
231
+ continue
232
+
233
+ node_id = f"input_{idx}"
234
+ shape = self._tensor_shapes.get(input_info.name, [])
235
+
236
+ input_node = NN3DNode(
237
+ id=node_id,
238
+ type=LayerType.INPUT.value,
239
+ name=input_info.name,
240
+ output_shape=shape if shape else None,
241
+ depth=0
242
+ )
243
+ self.nodes.append(input_node)
244
+ self._output_to_node[input_info.name] = node_id
245
+
246
+ # Process all operator nodes
247
+ for idx, node in enumerate(self.graph.node):
248
+ node_id = f"node_{idx}"
249
+ layer_type = self._get_layer_type(node.op_type)
250
+ attrs = self._extract_attributes(node)
251
+ params = self._extract_params(node, attrs)
252
+
253
+ # Get input/output shapes
254
+ input_shapes = [self._tensor_shapes.get(inp, []) for inp in node.input
255
+ if inp not in [init.name for init in self.graph.initializer]]
256
+ output_shapes = [self._tensor_shapes.get(out, []) for out in node.output]
257
+
258
+ input_shape = input_shapes[0] if input_shapes else None
259
+ output_shape = output_shapes[0] if output_shapes else None
260
+
261
+ # Create node
262
+ nn3d_node = NN3DNode(
263
+ id=node_id,
264
+ type=layer_type,
265
+ name=node.name or f"{node.op_type}_{idx}",
266
+ params=params if any(v is not None for v in [
267
+ params.kernel_size, params.stride, params.padding,
268
+ params.eps, params.dropout_rate
269
+ ]) else None,
270
+ input_shape=input_shape if input_shape else None,
271
+ output_shape=output_shape if output_shape else None,
272
+ depth=idx + 1,
273
+ attributes={'op_type': node.op_type, **attrs} if attrs else {'op_type': node.op_type}
274
+ )
275
+ self.nodes.append(nn3d_node)
276
+
277
+ # Map outputs to this node
278
+ for output in node.output:
279
+ self._output_to_node[output] = node_id
280
+
281
+ # Create edges from inputs
282
+ for inp in node.input:
283
+ # Skip initializers (weights)
284
+ if inp in [init.name for init in self.graph.initializer]:
285
+ continue
286
+
287
+ source_id = self._output_to_node.get(inp)
288
+ if source_id:
289
+ edge = NN3DEdge(
290
+ source=source_id,
291
+ target=node_id,
292
+ tensor_shape=self._tensor_shapes.get(inp, None),
293
+ label=inp
294
+ )
295
+ self.edges.append(edge)
296
+
297
+ # Add output nodes
298
+ for idx, output_info in enumerate(self.graph.output):
299
+ node_id = f"output_{idx}"
300
+ shape = self._tensor_shapes.get(output_info.name, [])
301
+
302
+ output_node = NN3DNode(
303
+ id=node_id,
304
+ type=LayerType.OUTPUT.value,
305
+ name=output_info.name,
306
+ input_shape=shape if shape else None,
307
+ depth=len(self.graph.node) + 1
308
+ )
309
+ self.nodes.append(output_node)
310
+
311
+ # Create edge from last node producing this output
312
+ source_id = self._output_to_node.get(output_info.name)
313
+ if source_id:
314
+ edge = NN3DEdge(
315
+ source=source_id,
316
+ target=node_id,
317
+ tensor_shape=shape if shape else None
318
+ )
319
+ self.edges.append(edge)
320
+
321
+ # Get input/output shapes for metadata
322
+ input_shapes = [self._tensor_shapes.get(inp.name) for inp in self.graph.input
323
+ if inp.name not in [init.name for init in self.graph.initializer]]
324
+ output_shapes = [self._tensor_shapes.get(out.name) for out in self.graph.output]
325
+
326
+ # Create metadata
327
+ metadata = NN3DMetadata(
328
+ name=self.model_name,
329
+ framework="onnx",
330
+ created=datetime.now().isoformat(),
331
+ input_shape=input_shapes[0] if input_shapes else None,
332
+ output_shape=output_shapes[0] if output_shapes else None,
333
+ description=f"Converted from ONNX model: {self.model_path}"
334
+ )
335
+
336
+ # Create graph
337
+ graph = NN3DGraph(
338
+ nodes=self.nodes,
339
+ edges=self.edges,
340
+ )
341
+
342
+ # Create visualization config
343
+ viz_config = VisualizationConfig()
344
+
345
+ return NN3DModel(
346
+ metadata=metadata,
347
+ graph=graph,
348
+ visualization=viz_config
349
+ )
350
+
351
+
352
+ def export_onnx_model(
353
+ model_path: str,
354
+ output_path: str,
355
+ model_name: Optional[str] = None,
356
+ ) -> NN3DModel:
357
+ """
358
+ Convenience function to export an ONNX model to .nn3d file.
359
+
360
+ Args:
361
+ model_path: Path to ONNX model file
362
+ output_path: Path to save the .nn3d file
363
+ model_name: Name for the model
364
+
365
+ Returns:
366
+ The exported NN3DModel
367
+ """
368
+ exporter = ONNXExporter(model_path, model_name)
369
+ nn3d_model = exporter.export()
370
+ nn3d_model.save(output_path)
371
+ return nn3d_model
exporters/python/nn3d_exporter/pytorch_exporter.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Model Exporter
3
+
4
+ Export PyTorch models to .nn3d format for visualization.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Dict, List, Optional, Tuple, Any, Union
10
+ from datetime import datetime
11
+ from collections import OrderedDict
12
+
13
+ from .schema import (
14
+ NN3DModel, NN3DGraph, NN3DNode, NN3DEdge, NN3DMetadata,
15
+ NN3DSubgraph, LayerParams, LayerType, VisualizationConfig
16
+ )
17
+
18
+
19
+ # Mapping from PyTorch module types to NN3D layer types
20
+ PYTORCH_TO_NN3D_TYPE: Dict[type, str] = {
21
+ # Convolution layers
22
+ nn.Conv1d: LayerType.CONV1D.value,
23
+ nn.Conv2d: LayerType.CONV2D.value,
24
+ nn.Conv3d: LayerType.CONV3D.value,
25
+ nn.ConvTranspose2d: LayerType.CONV_TRANSPOSE_2D.value,
26
+
27
+ # Linear layers
28
+ nn.Linear: LayerType.LINEAR.value,
29
+ nn.Embedding: LayerType.EMBEDDING.value,
30
+
31
+ # Normalization layers
32
+ nn.BatchNorm1d: LayerType.BATCH_NORM_1D.value,
33
+ nn.BatchNorm2d: LayerType.BATCH_NORM_2D.value,
34
+ nn.LayerNorm: LayerType.LAYER_NORM.value,
35
+ nn.GroupNorm: LayerType.GROUP_NORM.value,
36
+ nn.InstanceNorm2d: LayerType.INSTANCE_NORM.value,
37
+ nn.Dropout: LayerType.DROPOUT.value,
38
+ nn.Dropout2d: LayerType.DROPOUT.value,
39
+
40
+ # Activation layers
41
+ nn.ReLU: LayerType.RELU.value,
42
+ nn.LeakyReLU: LayerType.LEAKY_RELU.value,
43
+ nn.GELU: LayerType.GELU.value,
44
+ nn.SiLU: LayerType.SILU.value,
45
+ nn.Sigmoid: LayerType.SIGMOID.value,
46
+ nn.Tanh: LayerType.TANH.value,
47
+ nn.Softmax: LayerType.SOFTMAX.value,
48
+
49
+ # Pooling layers
50
+ nn.MaxPool1d: LayerType.MAX_POOL_1D.value,
51
+ nn.MaxPool2d: LayerType.MAX_POOL_2D.value,
52
+ nn.AvgPool2d: LayerType.AVG_POOL_2D.value,
53
+ nn.AdaptiveAvgPool2d: LayerType.ADAPTIVE_AVG_POOL.value,
54
+ nn.AdaptiveAvgPool1d: LayerType.ADAPTIVE_AVG_POOL.value,
55
+
56
+ # Recurrent layers
57
+ nn.LSTM: LayerType.LSTM.value,
58
+ nn.GRU: LayerType.GRU.value,
59
+ nn.RNN: LayerType.RNN.value,
60
+
61
+ # Attention layers
62
+ nn.MultiheadAttention: LayerType.MULTI_HEAD_ATTENTION.value,
63
+
64
+ # Transform layers
65
+ nn.Flatten: LayerType.FLATTEN.value,
66
+ nn.Upsample: LayerType.UPSAMPLE.value,
67
+ }
68
+
69
+
70
+ class PyTorchExporter:
71
+ """
72
+ Export PyTorch models to NN3D format.
73
+
74
+ Usage:
75
+ exporter = PyTorchExporter(model, input_shape=(1, 3, 224, 224))
76
+ nn3d_model = exporter.export()
77
+ nn3d_model.save("model.nn3d")
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ model: nn.Module,
83
+ input_shape: Optional[Tuple[int, ...]] = None,
84
+ model_name: Optional[str] = None,
85
+ include_activations: bool = False,
86
+ ):
87
+ """
88
+ Initialize the exporter.
89
+
90
+ Args:
91
+ model: PyTorch model to export
92
+ input_shape: Input tensor shape (batch, channels, height, width)
93
+ model_name: Name for the model (defaults to class name)
94
+ include_activations: Whether to trace activations (requires input_shape)
95
+ """
96
+ self.model = model
97
+ self.input_shape = input_shape
98
+ self.model_name = model_name or model.__class__.__name__
99
+ self.include_activations = include_activations
100
+
101
+ self.nodes: List[NN3DNode] = []
102
+ self.edges: List[NN3DEdge] = []
103
+ self.subgraphs: List[NN3DSubgraph] = []
104
+
105
+ self._node_id_counter = 0
106
+ self._module_to_id: Dict[nn.Module, str] = {}
107
+ self._shapes: Dict[str, Tuple[List[int], List[int]]] = {}
108
+
109
+ def _get_node_id(self) -> str:
110
+ """Generate unique node ID"""
111
+ node_id = f"node_{self._node_id_counter}"
112
+ self._node_id_counter += 1
113
+ return node_id
114
+
115
+ def _get_layer_type(self, module: nn.Module) -> str:
116
+ """Map PyTorch module to NN3D layer type"""
117
+ module_type = type(module)
118
+
119
+ if module_type in PYTORCH_TO_NN3D_TYPE:
120
+ return PYTORCH_TO_NN3D_TYPE[module_type]
121
+
122
+ # Check for common container patterns
123
+ class_name = module_type.__name__.lower()
124
+
125
+ if 'attention' in class_name:
126
+ return LayerType.ATTENTION.value
127
+ elif 'transformer' in class_name:
128
+ return LayerType.TRANSFORMER.value
129
+ elif 'encoder' in class_name:
130
+ return LayerType.ENCODER_BLOCK.value
131
+ elif 'decoder' in class_name:
132
+ return LayerType.DECODER_BLOCK.value
133
+ elif 'residual' in class_name or 'resblock' in class_name:
134
+ return LayerType.RESIDUAL_BLOCK.value
135
+
136
+ return LayerType.CUSTOM.value
137
+
138
+ def _extract_params(self, module: nn.Module) -> LayerParams:
139
+ """Extract layer parameters from PyTorch module"""
140
+ params = LayerParams()
141
+
142
+ # Convolution parameters
143
+ if hasattr(module, 'in_channels'):
144
+ params.in_channels = module.in_channels
145
+ if hasattr(module, 'out_channels'):
146
+ params.out_channels = module.out_channels
147
+ if hasattr(module, 'kernel_size'):
148
+ ks = module.kernel_size
149
+ params.kernel_size = list(ks) if isinstance(ks, tuple) else ks
150
+ if hasattr(module, 'stride'):
151
+ stride = module.stride
152
+ params.stride = list(stride) if isinstance(stride, tuple) else stride
153
+ if hasattr(module, 'padding'):
154
+ pad = module.padding
155
+ params.padding = list(pad) if isinstance(pad, tuple) else pad
156
+ if hasattr(module, 'dilation'):
157
+ dil = module.dilation
158
+ params.dilation = list(dil) if isinstance(dil, tuple) else dil
159
+ if hasattr(module, 'groups'):
160
+ params.groups = module.groups
161
+
162
+ # Linear parameters
163
+ if hasattr(module, 'in_features'):
164
+ params.in_features = module.in_features
165
+ if hasattr(module, 'out_features'):
166
+ params.out_features = module.out_features
167
+
168
+ # Attention parameters
169
+ if hasattr(module, 'num_heads'):
170
+ params.num_heads = module.num_heads
171
+ if hasattr(module, 'embed_dim'):
172
+ params.hidden_size = module.embed_dim
173
+
174
+ # Normalization parameters
175
+ if hasattr(module, 'eps'):
176
+ params.eps = module.eps
177
+ if hasattr(module, 'momentum') and module.momentum is not None:
178
+ params.momentum = module.momentum
179
+ if hasattr(module, 'affine'):
180
+ params.affine = module.affine
181
+
182
+ # Dropout parameters
183
+ if hasattr(module, 'p'):
184
+ params.dropout_rate = module.p
185
+
186
+ # Embedding parameters
187
+ if hasattr(module, 'num_embeddings'):
188
+ params.num_embeddings = module.num_embeddings
189
+ if hasattr(module, 'embedding_dim'):
190
+ params.embedding_dim = module.embedding_dim
191
+
192
+ # Bias
193
+ if hasattr(module, 'bias') and module.bias is not None:
194
+ params.bias = True
195
+ elif hasattr(module, 'bias'):
196
+ params.bias = False
197
+
198
+ return params
199
+
200
+ def _trace_shapes(self) -> None:
201
+ """Trace tensor shapes through the model"""
202
+ if self.input_shape is None:
203
+ return
204
+
205
+ hooks = []
206
+
207
+ def hook_fn(name: str):
208
+ def hook(module, input, output):
209
+ input_shape = None
210
+ output_shape = None
211
+
212
+ if isinstance(input, tuple) and len(input) > 0:
213
+ if isinstance(input[0], torch.Tensor):
214
+ input_shape = list(input[0].shape)
215
+ elif isinstance(input, torch.Tensor):
216
+ input_shape = list(input.shape)
217
+
218
+ if isinstance(output, torch.Tensor):
219
+ output_shape = list(output.shape)
220
+ elif isinstance(output, tuple) and len(output) > 0:
221
+ if isinstance(output[0], torch.Tensor):
222
+ output_shape = list(output[0].shape)
223
+
224
+ self._shapes[name] = (input_shape, output_shape)
225
+ return hook
226
+
227
+ # Register hooks
228
+ for name, module in self.model.named_modules():
229
+ if len(list(module.children())) == 0: # Leaf modules only
230
+ hooks.append(module.register_forward_hook(hook_fn(name)))
231
+
232
+ # Run forward pass
233
+ try:
234
+ self.model.eval()
235
+ with torch.no_grad():
236
+ dummy_input = torch.zeros(self.input_shape)
237
+ self.model(dummy_input)
238
+ except Exception as e:
239
+ print(f"Warning: Could not trace shapes: {e}")
240
+ finally:
241
+ # Remove hooks
242
+ for hook in hooks:
243
+ hook.remove()
244
+
245
+ def _process_module(
246
+ self,
247
+ module: nn.Module,
248
+ name: str,
249
+ parent_id: Optional[str] = None,
250
+ depth: int = 0
251
+ ) -> Optional[str]:
252
+ """Process a single module and its children"""
253
+
254
+ # Skip container modules without parameters
255
+ children = list(module.named_children())
256
+ is_leaf = len(children) == 0
257
+
258
+ # Create node for leaf modules or significant containers
259
+ layer_type = self._get_layer_type(module)
260
+
261
+ if is_leaf or layer_type != LayerType.CUSTOM.value:
262
+ node_id = self._get_node_id()
263
+ self._module_to_id[module] = node_id
264
+
265
+ # Get shapes if traced
266
+ input_shape, output_shape = self._shapes.get(name, (None, None))
267
+
268
+ # Extract parameters
269
+ params = self._extract_params(module)
270
+
271
+ # Count parameters
272
+ num_params = sum(p.numel() for p in module.parameters(recurse=False))
273
+
274
+ node = NN3DNode(
275
+ id=node_id,
276
+ type=layer_type,
277
+ name=name or module.__class__.__name__,
278
+ params=params if any(v is not None for v in [
279
+ params.in_channels, params.out_channels,
280
+ params.in_features, params.out_features,
281
+ params.kernel_size, params.num_heads
282
+ ]) else None,
283
+ input_shape=input_shape,
284
+ output_shape=output_shape,
285
+ depth=depth,
286
+ attributes={'num_params': num_params} if num_params > 0 else None
287
+ )
288
+
289
+ self.nodes.append(node)
290
+
291
+ # Create edge from parent
292
+ if parent_id:
293
+ edge = NN3DEdge(
294
+ source=parent_id,
295
+ target=node_id,
296
+ tensor_shape=input_shape,
297
+ )
298
+ self.edges.append(edge)
299
+
300
+ # Process children
301
+ if children:
302
+ subgraph_nodes = [node_id]
303
+ prev_id = node_id
304
+
305
+ for child_name, child in children:
306
+ full_name = f"{name}.{child_name}" if name else child_name
307
+ child_id = self._process_module(child, full_name, prev_id, depth + 1)
308
+ if child_id:
309
+ subgraph_nodes.append(child_id)
310
+ prev_id = child_id
311
+
312
+ # Create subgraph for container
313
+ if len(subgraph_nodes) > 1:
314
+ self.subgraphs.append(NN3DSubgraph(
315
+ id=f"subgraph_{node_id}",
316
+ name=name or module.__class__.__name__,
317
+ nodes=subgraph_nodes,
318
+ type='sequential'
319
+ ))
320
+
321
+ return prev_id
322
+
323
+ return node_id
324
+
325
+ # Process children of container
326
+ prev_id = parent_id
327
+ for child_name, child in children:
328
+ full_name = f"{name}.{child_name}" if name else child_name
329
+ child_id = self._process_module(child, full_name, prev_id, depth)
330
+ if child_id:
331
+ prev_id = child_id
332
+
333
+ return prev_id
334
+
335
+ def export(self) -> NN3DModel:
336
+ """Export the model to NN3D format"""
337
+
338
+ # Trace shapes first
339
+ self._trace_shapes()
340
+
341
+ # Add input node
342
+ input_id = self._get_node_id()
343
+ input_node = NN3DNode(
344
+ id=input_id,
345
+ type=LayerType.INPUT.value,
346
+ name="input",
347
+ output_shape=list(self.input_shape) if self.input_shape else None,
348
+ depth=0
349
+ )
350
+ self.nodes.append(input_node)
351
+
352
+ # Process all modules
353
+ last_id = self._process_module(self.model, "", input_id, 1)
354
+
355
+ # Add output node
356
+ output_id = self._get_node_id()
357
+ output_shape = None
358
+ if last_id and self.nodes:
359
+ # Get output shape from last processed node
360
+ for node in reversed(self.nodes):
361
+ if node.output_shape:
362
+ output_shape = node.output_shape
363
+ break
364
+
365
+ output_node = NN3DNode(
366
+ id=output_id,
367
+ type=LayerType.OUTPUT.value,
368
+ name="output",
369
+ input_shape=output_shape,
370
+ depth=len(self.nodes)
371
+ )
372
+ self.nodes.append(output_node)
373
+
374
+ if last_id:
375
+ self.edges.append(NN3DEdge(
376
+ source=last_id,
377
+ target=output_id,
378
+ tensor_shape=output_shape
379
+ ))
380
+
381
+ # Count parameters
382
+ total_params = sum(p.numel() for p in self.model.parameters())
383
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
384
+
385
+ # Create metadata
386
+ metadata = NN3DMetadata(
387
+ name=self.model_name,
388
+ framework="pytorch",
389
+ created=datetime.now().isoformat(),
390
+ input_shape=list(self.input_shape) if self.input_shape else None,
391
+ output_shape=output_shape,
392
+ total_params=total_params,
393
+ trainable_params=trainable_params,
394
+ )
395
+
396
+ # Create graph
397
+ graph = NN3DGraph(
398
+ nodes=self.nodes,
399
+ edges=self.edges,
400
+ subgraphs=self.subgraphs if self.subgraphs else None
401
+ )
402
+
403
+ # Create visualization config
404
+ viz_config = VisualizationConfig()
405
+
406
+ return NN3DModel(
407
+ metadata=metadata,
408
+ graph=graph,
409
+ visualization=viz_config
410
+ )
411
+
412
+
413
+ def export_pytorch_model(
414
+ model: nn.Module,
415
+ output_path: str,
416
+ input_shape: Optional[Tuple[int, ...]] = None,
417
+ model_name: Optional[str] = None,
418
+ ) -> NN3DModel:
419
+ """
420
+ Convenience function to export a PyTorch model to .nn3d file.
421
+
422
+ Args:
423
+ model: PyTorch model to export
424
+ output_path: Path to save the .nn3d file
425
+ input_shape: Input tensor shape (batch, channels, height, width)
426
+ model_name: Name for the model
427
+
428
+ Returns:
429
+ The exported NN3DModel
430
+ """
431
+ exporter = PyTorchExporter(model, input_shape, model_name)
432
+ nn3d_model = exporter.export()
433
+ nn3d_model.save(output_path)
434
+ return nn3d_model
exporters/python/nn3d_exporter/schema.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NN3D Schema definitions using Python dataclasses
3
+
4
+ These classes mirror the JSON schema and can be serialized to .nn3d format.
5
+ """
6
+
7
+ from dataclasses import dataclass, field, asdict
8
+ from typing import List, Optional, Dict, Any, Union
9
+ from enum import Enum
10
+ import json
11
+ from datetime import datetime
12
+
13
+
14
+ class LayerType(str, Enum):
15
+ """Supported layer types"""
16
+ INPUT = "input"
17
+ OUTPUT = "output"
18
+ CONV1D = "conv1d"
19
+ CONV2D = "conv2d"
20
+ CONV3D = "conv3d"
21
+ CONV_TRANSPOSE_2D = "convTranspose2d"
22
+ DEPTHWISE_CONV2D = "depthwiseConv2d"
23
+ SEPARABLE_CONV2D = "separableConv2d"
24
+ LINEAR = "linear"
25
+ DENSE = "dense"
26
+ EMBEDDING = "embedding"
27
+ BATCH_NORM_1D = "batchNorm1d"
28
+ BATCH_NORM_2D = "batchNorm2d"
29
+ LAYER_NORM = "layerNorm"
30
+ GROUP_NORM = "groupNorm"
31
+ INSTANCE_NORM = "instanceNorm"
32
+ DROPOUT = "dropout"
33
+ RELU = "relu"
34
+ LEAKY_RELU = "leakyRelu"
35
+ GELU = "gelu"
36
+ SILU = "silu"
37
+ SIGMOID = "sigmoid"
38
+ TANH = "tanh"
39
+ SOFTMAX = "softmax"
40
+ MAX_POOL_1D = "maxPool1d"
41
+ MAX_POOL_2D = "maxPool2d"
42
+ AVG_POOL_2D = "avgPool2d"
43
+ GLOBAL_AVG_POOL = "globalAvgPool"
44
+ ADAPTIVE_AVG_POOL = "adaptiveAvgPool"
45
+ FLATTEN = "flatten"
46
+ RESHAPE = "reshape"
47
+ CONCAT = "concat"
48
+ ADD = "add"
49
+ MULTIPLY = "multiply"
50
+ SPLIT = "split"
51
+ ATTENTION = "attention"
52
+ MULTI_HEAD_ATTENTION = "multiHeadAttention"
53
+ SELF_ATTENTION = "selfAttention"
54
+ CROSS_ATTENTION = "crossAttention"
55
+ LSTM = "lstm"
56
+ GRU = "gru"
57
+ RNN = "rnn"
58
+ TRANSFORMER = "transformer"
59
+ ENCODER_BLOCK = "encoderBlock"
60
+ DECODER_BLOCK = "decoderBlock"
61
+ RESIDUAL_BLOCK = "residualBlock"
62
+ UPSAMPLE = "upsample"
63
+ INTERPOLATE = "interpolate"
64
+ PAD = "pad"
65
+ CUSTOM = "custom"
66
+
67
+
68
+ @dataclass
69
+ class Position3D:
70
+ """3D position"""
71
+ x: float = 0.0
72
+ y: float = 0.0
73
+ z: float = 0.0
74
+
75
+
76
+ @dataclass
77
+ class LayerParams:
78
+ """Layer parameters"""
79
+ in_channels: Optional[int] = None
80
+ out_channels: Optional[int] = None
81
+ in_features: Optional[int] = None
82
+ out_features: Optional[int] = None
83
+ kernel_size: Optional[Union[int, List[int]]] = None
84
+ stride: Optional[Union[int, List[int]]] = None
85
+ padding: Optional[Union[int, str, List[int]]] = None
86
+ dilation: Optional[Union[int, List[int]]] = None
87
+ groups: Optional[int] = None
88
+ bias: Optional[bool] = None
89
+ num_heads: Optional[int] = None
90
+ hidden_size: Optional[int] = None
91
+ dropout_rate: Optional[float] = None
92
+ eps: Optional[float] = None
93
+ momentum: Optional[float] = None
94
+ affine: Optional[bool] = None
95
+ num_embeddings: Optional[int] = None
96
+ embedding_dim: Optional[int] = None
97
+
98
+ def to_dict(self) -> Dict[str, Any]:
99
+ """Convert to dict with camelCase keys, excluding None values"""
100
+ key_mapping = {
101
+ 'in_channels': 'inChannels',
102
+ 'out_channels': 'outChannels',
103
+ 'in_features': 'inFeatures',
104
+ 'out_features': 'outFeatures',
105
+ 'kernel_size': 'kernelSize',
106
+ 'num_heads': 'numHeads',
107
+ 'hidden_size': 'hiddenSize',
108
+ 'dropout_rate': 'dropoutRate',
109
+ 'num_embeddings': 'numEmbeddings',
110
+ 'embedding_dim': 'embeddingDim',
111
+ }
112
+ result = {}
113
+ for key, value in asdict(self).items():
114
+ if value is not None:
115
+ camel_key = key_mapping.get(key, key)
116
+ result[camel_key] = value
117
+ return result
118
+
119
+
120
+ @dataclass
121
+ class NN3DNode:
122
+ """Graph node representing a layer"""
123
+ id: str
124
+ type: str # LayerType value
125
+ name: str
126
+ params: Optional[LayerParams] = None
127
+ input_shape: Optional[List[Union[int, str]]] = None
128
+ output_shape: Optional[List[Union[int, str]]] = None
129
+ position: Optional[Position3D] = None
130
+ attributes: Optional[Dict[str, Any]] = None
131
+ group: Optional[str] = None
132
+ depth: Optional[int] = None
133
+
134
+ def to_dict(self) -> Dict[str, Any]:
135
+ result = {
136
+ 'id': self.id,
137
+ 'type': self.type,
138
+ 'name': self.name,
139
+ }
140
+ if self.params:
141
+ result['params'] = self.params.to_dict()
142
+ if self.input_shape:
143
+ result['inputShape'] = self.input_shape
144
+ if self.output_shape:
145
+ result['outputShape'] = self.output_shape
146
+ if self.position:
147
+ result['position'] = asdict(self.position)
148
+ if self.attributes:
149
+ result['attributes'] = self.attributes
150
+ if self.group:
151
+ result['group'] = self.group
152
+ if self.depth is not None:
153
+ result['depth'] = self.depth
154
+ return result
155
+
156
+
157
+ @dataclass
158
+ class NN3DEdge:
159
+ """Graph edge representing a connection"""
160
+ source: str
161
+ target: str
162
+ id: Optional[str] = None
163
+ source_port: Optional[int] = None
164
+ target_port: Optional[int] = None
165
+ tensor_shape: Optional[List[Union[int, str]]] = None
166
+ dtype: Optional[str] = None
167
+ label: Optional[str] = None
168
+
169
+ def to_dict(self) -> Dict[str, Any]:
170
+ result = {
171
+ 'source': self.source,
172
+ 'target': self.target,
173
+ }
174
+ if self.id:
175
+ result['id'] = self.id
176
+ if self.source_port is not None:
177
+ result['sourcePort'] = self.source_port
178
+ if self.target_port is not None:
179
+ result['targetPort'] = self.target_port
180
+ if self.tensor_shape:
181
+ result['tensorShape'] = self.tensor_shape
182
+ if self.dtype:
183
+ result['dtype'] = self.dtype
184
+ if self.label:
185
+ result['label'] = self.label
186
+ return result
187
+
188
+
189
+ @dataclass
190
+ class NN3DSubgraph:
191
+ """Subgraph for grouping layers"""
192
+ id: str
193
+ name: str
194
+ nodes: List[str] = field(default_factory=list)
195
+ type: Optional[str] = None
196
+ color: Optional[str] = None
197
+ collapsed: bool = False
198
+
199
+ def to_dict(self) -> Dict[str, Any]:
200
+ result = {
201
+ 'id': self.id,
202
+ 'name': self.name,
203
+ 'nodes': self.nodes,
204
+ }
205
+ if self.type:
206
+ result['type'] = self.type
207
+ if self.color:
208
+ result['color'] = self.color
209
+ if self.collapsed:
210
+ result['collapsed'] = self.collapsed
211
+ return result
212
+
213
+
214
+ @dataclass
215
+ class NN3DGraph:
216
+ """Graph containing nodes and edges"""
217
+ nodes: List[NN3DNode] = field(default_factory=list)
218
+ edges: List[NN3DEdge] = field(default_factory=list)
219
+ subgraphs: Optional[List[NN3DSubgraph]] = None
220
+
221
+ def to_dict(self) -> Dict[str, Any]:
222
+ result = {
223
+ 'nodes': [n.to_dict() for n in self.nodes],
224
+ 'edges': [e.to_dict() for e in self.edges],
225
+ }
226
+ if self.subgraphs:
227
+ result['subgraphs'] = [s.to_dict() for s in self.subgraphs]
228
+ return result
229
+
230
+
231
+ @dataclass
232
+ class NN3DMetadata:
233
+ """Model metadata"""
234
+ name: str
235
+ description: Optional[str] = None
236
+ framework: Optional[str] = None
237
+ author: Optional[str] = None
238
+ created: Optional[str] = None
239
+ tags: Optional[List[str]] = None
240
+ input_shape: Optional[List[Union[int, str]]] = None
241
+ output_shape: Optional[List[Union[int, str]]] = None
242
+ total_params: Optional[int] = None
243
+ trainable_params: Optional[int] = None
244
+
245
+ def to_dict(self) -> Dict[str, Any]:
246
+ result = {'name': self.name}
247
+ if self.description:
248
+ result['description'] = self.description
249
+ if self.framework:
250
+ result['framework'] = self.framework
251
+ if self.author:
252
+ result['author'] = self.author
253
+ if self.created:
254
+ result['created'] = self.created
255
+ if self.tags:
256
+ result['tags'] = self.tags
257
+ if self.input_shape:
258
+ result['inputShape'] = self.input_shape
259
+ if self.output_shape:
260
+ result['outputShape'] = self.output_shape
261
+ if self.total_params is not None:
262
+ result['totalParams'] = self.total_params
263
+ if self.trainable_params is not None:
264
+ result['trainableParams'] = self.trainable_params
265
+ return result
266
+
267
+
268
+ @dataclass
269
+ class VisualizationConfig:
270
+ """Visualization configuration"""
271
+ layout: str = "layered"
272
+ theme: str = "dark"
273
+ layer_spacing: float = 3.0
274
+ node_scale: float = 1.0
275
+ show_labels: bool = True
276
+ show_edges: bool = True
277
+ edge_style: str = "tube"
278
+
279
+ def to_dict(self) -> Dict[str, Any]:
280
+ return {
281
+ 'layout': self.layout,
282
+ 'theme': self.theme,
283
+ 'layerSpacing': self.layer_spacing,
284
+ 'nodeScale': self.node_scale,
285
+ 'showLabels': self.show_labels,
286
+ 'showEdges': self.show_edges,
287
+ 'edgeStyle': self.edge_style,
288
+ }
289
+
290
+
291
+ @dataclass
292
+ class NN3DModel:
293
+ """Complete NN3D model"""
294
+ metadata: NN3DMetadata
295
+ graph: NN3DGraph
296
+ version: str = "1.0.0"
297
+ visualization: Optional[VisualizationConfig] = None
298
+
299
+ def to_dict(self) -> Dict[str, Any]:
300
+ result = {
301
+ 'version': self.version,
302
+ 'metadata': self.metadata.to_dict(),
303
+ 'graph': self.graph.to_dict(),
304
+ }
305
+ if self.visualization:
306
+ result['visualization'] = self.visualization.to_dict()
307
+ return result
308
+
309
+ def to_json(self, indent: int = 2) -> str:
310
+ """Serialize to JSON string"""
311
+ return json.dumps(self.to_dict(), indent=indent)
312
+
313
+ def save(self, filepath: str) -> None:
314
+ """Save to .nn3d file"""
315
+ with open(filepath, 'w') as f:
316
+ f.write(self.to_json())
exporters/python/pyproject.toml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "nn3d-exporter"
7
+ version = "1.0.0"
8
+ description = "Export neural network models to NN3D format for 3D visualization"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "NN3D Team"}
14
+ ]
15
+ keywords = ["neural-network", "visualization", "deep-learning", "pytorch", "onnx", "3d"]
16
+ classifiers = [
17
+ "Development Status :: 4 - Beta",
18
+ "Intended Audience :: Developers",
19
+ "Intended Audience :: Science/Research",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Programming Language :: Python :: 3",
22
+ "Programming Language :: Python :: 3.8",
23
+ "Programming Language :: Python :: 3.9",
24
+ "Programming Language :: Python :: 3.10",
25
+ "Programming Language :: Python :: 3.11",
26
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
27
+ "Topic :: Scientific/Engineering :: Visualization",
28
+ ]
29
+
30
+ dependencies = []
31
+
32
+ [project.optional-dependencies]
33
+ pytorch = ["torch>=1.9.0"]
34
+ onnx = ["onnx>=1.10.0"]
35
+ all = ["torch>=1.9.0", "onnx>=1.10.0"]
36
+ dev = ["pytest>=7.0.0", "black", "isort", "mypy"]
37
+
38
+ [project.urls]
39
+ Homepage = "https://github.com/nn3d/visualizer"
40
+ Documentation = "https://nn3d.dev/docs"
41
+ Repository = "https://github.com/nn3d/visualizer"
42
+
43
+ [project.scripts]
44
+ nn3d-export = "nn3d_exporter.cli:main"
45
+
46
+ [tool.setuptools.packages.find]
47
+ where = ["."]
48
+
49
+ [tool.black]
50
+ line-length = 100
51
+ target-version = ["py38", "py39", "py310", "py311"]
52
+
53
+ [tool.isort]
54
+ profile = "black"
55
+ line_length = 100
56
+
57
+ [tool.mypy]
58
+ python_version = "3.8"
59
+ warn_return_any = true
60
+ warn_unused_configs = true
61
+ ignore_missing_imports = true
files_to_commit.txt ADDED
Binary file (5.42 kB). View file
 
index.html ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/favicon.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>NN3D Visualizer - 3D Deep Learning Model Viewer</title>
8
+ <meta
9
+ name="description"
10
+ content="Interactive 3D visualization of neural network architectures"
11
+ />
12
+ <style>
13
+ * {
14
+ margin: 0;
15
+ padding: 0;
16
+ box-sizing: border-box;
17
+ }
18
+ html,
19
+ body,
20
+ #root {
21
+ width: 100%;
22
+ height: 100%;
23
+ overflow: hidden;
24
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
25
+ Oxygen, Ubuntu, sans-serif;
26
+ }
27
+ </style>
28
+ </head>
29
+ <body>
30
+ <div id="root"></div>
31
+ <script type="module" src="/src/main.tsx"></script>
32
+ </body>
33
+ </html>
nginx.conf ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ server {
2
+ listen 80;
3
+ server_name localhost;
4
+ root /usr/share/nginx/html;
5
+ index index.html;
6
+
7
+ # Gzip compression
8
+ gzip on;
9
+ gzip_vary on;
10
+ gzip_min_length 1024;
11
+ gzip_proxied expired no-cache no-store private auth;
12
+ gzip_types text/plain text/css text/xml text/javascript application/x-javascript application/xml application/javascript application/json;
13
+
14
+ # Frontend routes - SPA support
15
+ location / {
16
+ try_files $uri $uri/ /index.html;
17
+ }
18
+
19
+ # Proxy API requests to backend
20
+ location /api/ {
21
+ proxy_pass http://backend:8000/;
22
+ proxy_http_version 1.1;
23
+ proxy_set_header Upgrade $http_upgrade;
24
+ proxy_set_header Connection 'upgrade';
25
+ proxy_set_header Host $host;
26
+ proxy_set_header X-Real-IP $remote_addr;
27
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
28
+ proxy_set_header X-Forwarded-Proto $scheme;
29
+ proxy_cache_bypass $http_upgrade;
30
+ proxy_read_timeout 300s;
31
+ proxy_connect_timeout 75s;
32
+
33
+ # Handle large file uploads
34
+ client_max_body_size 500M;
35
+ }
36
+
37
+ # Cache static assets
38
+ location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {
39
+ expires 1y;
40
+ add_header Cache-Control "public, immutable";
41
+ }
42
+
43
+ # Security headers
44
+ add_header X-Frame-Options "SAMEORIGIN" always;
45
+ add_header X-Content-Type-Options "nosniff" always;
46
+ add_header X-XSS-Protection "1; mode=block" always;
47
+ }
package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
package.json CHANGED
@@ -1,39 +1,53 @@
1
  {
2
- "name": "react-template",
3
- "version": "0.1.0",
4
- "private": true,
5
- "dependencies": {
6
- "@testing-library/dom": "^10.4.0",
7
- "@testing-library/jest-dom": "^6.6.3",
8
- "@testing-library/react": "^16.3.0",
9
- "@testing-library/user-event": "^13.5.0",
10
- "react": "^19.1.0",
11
- "react-dom": "^19.1.0",
12
- "react-scripts": "5.0.1",
13
- "web-vitals": "^2.1.4"
14
- },
15
  "scripts": {
16
- "start": "react-scripts start",
17
- "build": "react-scripts build",
18
- "test": "react-scripts test",
19
- "eject": "react-scripts eject"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  },
21
- "eslintConfig": {
22
- "extends": [
23
- "react-app",
24
- "react-app/jest"
25
- ]
 
 
 
 
 
 
 
 
26
  },
27
- "browserslist": {
28
- "production": [
29
- ">0.2%",
30
- "not dead",
31
- "not op_mini all"
32
- ],
33
- "development": [
34
- "last 1 chrome version",
35
- "last 1 firefox version",
36
- "last 1 safari version"
37
- ]
38
- }
39
  }
 
1
  {
2
+ "name": "nn3d-visualizer",
3
+ "version": "1.0.0",
4
+ "description": "3D Deep Learning Model Visualizer - Interactive WebGL visualization of neural network architectures",
5
+ "type": "module",
 
 
 
 
 
 
 
 
 
6
  "scripts": {
7
+ "dev": "vite",
8
+ "build": "tsc && vite build",
9
+ "preview": "vite preview",
10
+ "lint": "eslint src --ext .ts,.tsx",
11
+ "test": "vitest",
12
+ "validate-schema": "ajv validate -s src/schema/nn3d.schema.json -d"
13
+ },
14
+ "dependencies": {
15
+ "@react-three/drei": "^9.92.7",
16
+ "@react-three/fiber": "^8.15.12",
17
+ "ajv": "^8.12.0",
18
+ "d3-force-3d": "^3.0.5",
19
+ "h5wasm": "^0.7.5",
20
+ "leva": "^0.9.35",
21
+ "onnxruntime-web": "^1.16.3",
22
+ "pako": "^2.1.0",
23
+ "react": "^18.2.0",
24
+ "react-dom": "^18.2.0",
25
+ "three": "^0.160.0",
26
+ "zustand": "^4.4.7"
27
  },
28
+ "devDependencies": {
29
+ "@types/node": "^25.0.3",
30
+ "@types/pako": "^2.0.4",
31
+ "@types/react": "^18.2.45",
32
+ "@types/react-dom": "^18.2.18",
33
+ "@types/three": "^0.160.0",
34
+ "@typescript-eslint/eslint-plugin": "^6.16.0",
35
+ "@typescript-eslint/parser": "^6.16.0",
36
+ "@vitejs/plugin-react": "^4.2.1",
37
+ "eslint": "^8.56.0",
38
+ "typescript": "^5.3.3",
39
+ "vite": "^5.0.10",
40
+ "vitest": "^1.1.0"
41
  },
42
+ "keywords": [
43
+ "neural-network",
44
+ "deep-learning",
45
+ "visualization",
46
+ "3d",
47
+ "threejs",
48
+ "webgl",
49
+ "machine-learning"
50
+ ],
51
+ "author": "",
52
+ "license": "MIT"
 
53
  }
public/favicon.ico DELETED
Binary file (3.87 kB)
 
public/favicon.svg ADDED
public/index.html DELETED
@@ -1,43 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="utf-8" />
5
- <link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
6
- <meta name="viewport" content="width=device-width, initial-scale=1" />
7
- <meta name="theme-color" content="#000000" />
8
- <meta
9
- name="description"
10
- content="Web site created using create-react-app"
11
- />
12
- <link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
13
- <!--
14
- manifest.json provides metadata used when your web app is installed on a
15
- user's mobile device or desktop. See https://developers.google.com/web/fundamentals/web-app-manifest/
16
- -->
17
- <link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
18
- <!--
19
- Notice the use of %PUBLIC_URL% in the tags above.
20
- It will be replaced with the URL of the `public` folder during the build.
21
- Only files inside the `public` folder can be referenced from the HTML.
22
-
23
- Unlike "/favicon.ico" or "favicon.ico", "%PUBLIC_URL%/favicon.ico" will
24
- work correctly both with client-side routing and a non-root public URL.
25
- Learn how to configure a non-root public URL by running `npm run build`.
26
- -->
27
- <title>React App</title>
28
- </head>
29
- <body>
30
- <noscript>You need to enable JavaScript to run this app.</noscript>
31
- <div id="root"></div>
32
- <!--
33
- This HTML file is a template.
34
- If you open it directly in the browser, you will see an empty page.
35
-
36
- You can add webfonts, meta tags, or analytics to this file.
37
- The build step will place the bundled scripts into the <body> tag.
38
-
39
- To begin the development, run `npm start` or `yarn start`.
40
- To create a production bundle, use `npm run build` or `yarn build`.
41
- -->
42
- </body>
43
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
public/logo192.png DELETED
Binary file (5.35 kB)
 
public/logo512.png DELETED
Binary file (9.66 kB)
 
public/manifest.json DELETED
@@ -1,25 +0,0 @@
1
- {
2
- "short_name": "React App",
3
- "name": "Create React App Sample",
4
- "icons": [
5
- {
6
- "src": "favicon.ico",
7
- "sizes": "64x64 32x32 24x24 16x16",
8
- "type": "image/x-icon"
9
- },
10
- {
11
- "src": "logo192.png",
12
- "type": "image/png",
13
- "sizes": "192x192"
14
- },
15
- {
16
- "src": "logo512.png",
17
- "type": "image/png",
18
- "sizes": "512x512"
19
- }
20
- ],
21
- "start_url": ".",
22
- "display": "standalone",
23
- "theme_color": "#000000",
24
- "background_color": "#ffffff"
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
public/robots.txt DELETED
@@ -1,3 +0,0 @@
1
- # https://www.robotstxt.org/robotstxt.html
2
- User-agent: *
3
- Disallow:
 
 
 
 
samples/cnn_resnet.nn3d ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0.0",
3
+ "metadata": {
4
+ "name": "CNN Image Classifier",
5
+ "description": "Convolutional Neural Network for image classification",
6
+ "framework": "pytorch",
7
+ "created": "2024-12-17T10:00:00Z",
8
+ "tags": ["cnn", "classification", "imagenet"],
9
+ "inputShape": [1, 3, 224, 224],
10
+ "outputShape": [1, 1000],
11
+ "totalParams": 11689512,
12
+ "trainableParams": 11689512
13
+ },
14
+ "graph": {
15
+ "nodes": [
16
+ {
17
+ "id": "input",
18
+ "type": "input",
19
+ "name": "Input Image",
20
+ "outputShape": [1, 3, 224, 224],
21
+ "depth": 0
22
+ },
23
+ {
24
+ "id": "conv1",
25
+ "type": "conv2d",
26
+ "name": "conv1",
27
+ "params": {
28
+ "inChannels": 3,
29
+ "outChannels": 64,
30
+ "kernelSize": [7, 7],
31
+ "stride": [2, 2],
32
+ "padding": [3, 3],
33
+ "bias": false
34
+ },
35
+ "inputShape": [1, 3, 224, 224],
36
+ "outputShape": [1, 64, 112, 112],
37
+ "depth": 1,
38
+ "group": "stem"
39
+ },
40
+ {
41
+ "id": "bn1",
42
+ "type": "batchNorm2d",
43
+ "name": "bn1",
44
+ "params": {
45
+ "eps": 1e-5,
46
+ "momentum": 0.1
47
+ },
48
+ "inputShape": [1, 64, 112, 112],
49
+ "outputShape": [1, 64, 112, 112],
50
+ "depth": 2,
51
+ "group": "stem"
52
+ },
53
+ {
54
+ "id": "relu1",
55
+ "type": "relu",
56
+ "name": "ReLU",
57
+ "inputShape": [1, 64, 112, 112],
58
+ "outputShape": [1, 64, 112, 112],
59
+ "depth": 3,
60
+ "group": "stem"
61
+ },
62
+ {
63
+ "id": "maxpool",
64
+ "type": "maxPool2d",
65
+ "name": "MaxPool",
66
+ "params": {
67
+ "kernelSize": [3, 3],
68
+ "stride": [2, 2],
69
+ "padding": [1, 1]
70
+ },
71
+ "inputShape": [1, 64, 112, 112],
72
+ "outputShape": [1, 64, 56, 56],
73
+ "depth": 4,
74
+ "group": "stem"
75
+ },
76
+ {
77
+ "id": "layer1_conv1",
78
+ "type": "conv2d",
79
+ "name": "layer1.0.conv1",
80
+ "params": {
81
+ "inChannels": 64,
82
+ "outChannels": 64,
83
+ "kernelSize": [3, 3],
84
+ "stride": [1, 1],
85
+ "padding": [1, 1],
86
+ "bias": false
87
+ },
88
+ "inputShape": [1, 64, 56, 56],
89
+ "outputShape": [1, 64, 56, 56],
90
+ "depth": 5,
91
+ "group": "layer1"
92
+ },
93
+ {
94
+ "id": "layer1_bn1",
95
+ "type": "batchNorm2d",
96
+ "name": "layer1.0.bn1",
97
+ "inputShape": [1, 64, 56, 56],
98
+ "outputShape": [1, 64, 56, 56],
99
+ "depth": 6,
100
+ "group": "layer1"
101
+ },
102
+ {
103
+ "id": "layer1_relu1",
104
+ "type": "relu",
105
+ "name": "ReLU",
106
+ "inputShape": [1, 64, 56, 56],
107
+ "outputShape": [1, 64, 56, 56],
108
+ "depth": 7,
109
+ "group": "layer1"
110
+ },
111
+ {
112
+ "id": "layer1_conv2",
113
+ "type": "conv2d",
114
+ "name": "layer1.0.conv2",
115
+ "params": {
116
+ "inChannels": 64,
117
+ "outChannels": 64,
118
+ "kernelSize": [3, 3],
119
+ "stride": [1, 1],
120
+ "padding": [1, 1],
121
+ "bias": false
122
+ },
123
+ "inputShape": [1, 64, 56, 56],
124
+ "outputShape": [1, 64, 56, 56],
125
+ "depth": 8,
126
+ "group": "layer1"
127
+ },
128
+ {
129
+ "id": "layer1_bn2",
130
+ "type": "batchNorm2d",
131
+ "name": "layer1.0.bn2",
132
+ "inputShape": [1, 64, 56, 56],
133
+ "outputShape": [1, 64, 56, 56],
134
+ "depth": 9,
135
+ "group": "layer1"
136
+ },
137
+ {
138
+ "id": "layer1_add",
139
+ "type": "add",
140
+ "name": "Residual Add",
141
+ "inputShape": [1, 64, 56, 56],
142
+ "outputShape": [1, 64, 56, 56],
143
+ "depth": 10,
144
+ "group": "layer1"
145
+ },
146
+ {
147
+ "id": "layer2_conv1",
148
+ "type": "conv2d",
149
+ "name": "layer2.0.conv1",
150
+ "params": {
151
+ "inChannels": 64,
152
+ "outChannels": 128,
153
+ "kernelSize": [3, 3],
154
+ "stride": [2, 2],
155
+ "padding": [1, 1],
156
+ "bias": false
157
+ },
158
+ "inputShape": [1, 64, 56, 56],
159
+ "outputShape": [1, 128, 28, 28],
160
+ "depth": 11,
161
+ "group": "layer2"
162
+ },
163
+ {
164
+ "id": "layer2_bn1",
165
+ "type": "batchNorm2d",
166
+ "name": "layer2.0.bn1",
167
+ "inputShape": [1, 128, 28, 28],
168
+ "outputShape": [1, 128, 28, 28],
169
+ "depth": 12,
170
+ "group": "layer2"
171
+ },
172
+ {
173
+ "id": "layer2_relu1",
174
+ "type": "relu",
175
+ "name": "ReLU",
176
+ "inputShape": [1, 128, 28, 28],
177
+ "outputShape": [1, 128, 28, 28],
178
+ "depth": 13,
179
+ "group": "layer2"
180
+ },
181
+ {
182
+ "id": "layer2_conv2",
183
+ "type": "conv2d",
184
+ "name": "layer2.0.conv2",
185
+ "params": {
186
+ "inChannels": 128,
187
+ "outChannels": 128,
188
+ "kernelSize": [3, 3],
189
+ "stride": [1, 1],
190
+ "padding": [1, 1],
191
+ "bias": false
192
+ },
193
+ "inputShape": [1, 128, 28, 28],
194
+ "outputShape": [1, 128, 28, 28],
195
+ "depth": 14,
196
+ "group": "layer2"
197
+ },
198
+ {
199
+ "id": "layer2_bn2",
200
+ "type": "batchNorm2d",
201
+ "name": "layer2.0.bn2",
202
+ "inputShape": [1, 128, 28, 28],
203
+ "outputShape": [1, 128, 28, 28],
204
+ "depth": 15,
205
+ "group": "layer2"
206
+ },
207
+ {
208
+ "id": "layer2_downsample",
209
+ "type": "conv2d",
210
+ "name": "layer2.0.downsample",
211
+ "params": {
212
+ "inChannels": 64,
213
+ "outChannels": 128,
214
+ "kernelSize": [1, 1],
215
+ "stride": [2, 2],
216
+ "bias": false
217
+ },
218
+ "inputShape": [1, 64, 56, 56],
219
+ "outputShape": [1, 128, 28, 28],
220
+ "depth": 16,
221
+ "group": "layer2"
222
+ },
223
+ {
224
+ "id": "layer2_add",
225
+ "type": "add",
226
+ "name": "Residual Add",
227
+ "inputShape": [1, 128, 28, 28],
228
+ "outputShape": [1, 128, 28, 28],
229
+ "depth": 17,
230
+ "group": "layer2"
231
+ },
232
+ {
233
+ "id": "avgpool",
234
+ "type": "globalAvgPool",
235
+ "name": "Global Average Pool",
236
+ "inputShape": [1, 128, 28, 28],
237
+ "outputShape": [1, 128, 1, 1],
238
+ "depth": 18
239
+ },
240
+ {
241
+ "id": "flatten",
242
+ "type": "flatten",
243
+ "name": "Flatten",
244
+ "inputShape": [1, 128, 1, 1],
245
+ "outputShape": [1, 128],
246
+ "depth": 19
247
+ },
248
+ {
249
+ "id": "fc",
250
+ "type": "linear",
251
+ "name": "fc",
252
+ "params": {
253
+ "inFeatures": 128,
254
+ "outFeatures": 1000,
255
+ "bias": true
256
+ },
257
+ "inputShape": [1, 128],
258
+ "outputShape": [1, 1000],
259
+ "depth": 20
260
+ },
261
+ {
262
+ "id": "output",
263
+ "type": "output",
264
+ "name": "Output",
265
+ "inputShape": [1, 1000],
266
+ "depth": 21
267
+ }
268
+ ],
269
+ "edges": [
270
+ { "source": "input", "target": "conv1" },
271
+ { "source": "conv1", "target": "bn1" },
272
+ { "source": "bn1", "target": "relu1" },
273
+ { "source": "relu1", "target": "maxpool" },
274
+ { "source": "maxpool", "target": "layer1_conv1" },
275
+ { "source": "layer1_conv1", "target": "layer1_bn1" },
276
+ { "source": "layer1_bn1", "target": "layer1_relu1" },
277
+ { "source": "layer1_relu1", "target": "layer1_conv2" },
278
+ { "source": "layer1_conv2", "target": "layer1_bn2" },
279
+ { "source": "layer1_bn2", "target": "layer1_add" },
280
+ { "source": "maxpool", "target": "layer1_add" },
281
+ { "source": "layer1_add", "target": "layer2_conv1" },
282
+ { "source": "layer2_conv1", "target": "layer2_bn1" },
283
+ { "source": "layer2_bn1", "target": "layer2_relu1" },
284
+ { "source": "layer2_relu1", "target": "layer2_conv2" },
285
+ { "source": "layer2_conv2", "target": "layer2_bn2" },
286
+ { "source": "layer1_add", "target": "layer2_downsample" },
287
+ { "source": "layer2_bn2", "target": "layer2_add" },
288
+ { "source": "layer2_downsample", "target": "layer2_add" },
289
+ { "source": "layer2_add", "target": "avgpool" },
290
+ { "source": "avgpool", "target": "flatten" },
291
+ { "source": "flatten", "target": "fc" },
292
+ { "source": "fc", "target": "output" }
293
+ ],
294
+ "subgraphs": [
295
+ {
296
+ "id": "stem_group",
297
+ "name": "Stem",
298
+ "type": "sequential",
299
+ "nodes": ["conv1", "bn1", "relu1", "maxpool"],
300
+ "color": "#4CAF50"
301
+ },
302
+ {
303
+ "id": "layer1_group",
304
+ "name": "Layer 1 (ResBlock)",
305
+ "type": "residual",
306
+ "nodes": ["layer1_conv1", "layer1_bn1", "layer1_relu1", "layer1_conv2", "layer1_bn2", "layer1_add"],
307
+ "color": "#2196F3"
308
+ },
309
+ {
310
+ "id": "layer2_group",
311
+ "name": "Layer 2 (ResBlock)",
312
+ "type": "residual",
313
+ "nodes": ["layer2_conv1", "layer2_bn1", "layer2_relu1", "layer2_conv2", "layer2_bn2", "layer2_downsample", "layer2_add"],
314
+ "color": "#9C27B0"
315
+ }
316
+ ]
317
+ },
318
+ "visualization": {
319
+ "layout": "layered",
320
+ "theme": "dark",
321
+ "layerSpacing": 2.5,
322
+ "nodeScale": 1.0,
323
+ "showLabels": true,
324
+ "showEdges": true,
325
+ "edgeStyle": "tube"
326
+ }
327
+ }
samples/simple_mlp.nn3d ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0.0",
3
+ "metadata": {
4
+ "name": "Simple MLP",
5
+ "description": "A simple Multi-Layer Perceptron for MNIST classification",
6
+ "framework": "pytorch",
7
+ "created": "2024-12-17T10:00:00Z",
8
+ "tags": ["mlp", "classification", "mnist"],
9
+ "inputShape": [1, 784],
10
+ "outputShape": [1, 10],
11
+ "totalParams": 669706,
12
+ "trainableParams": 669706
13
+ },
14
+ "graph": {
15
+ "nodes": [
16
+ {
17
+ "id": "input",
18
+ "type": "input",
19
+ "name": "Input",
20
+ "outputShape": [1, 784],
21
+ "depth": 0
22
+ },
23
+ {
24
+ "id": "fc1",
25
+ "type": "linear",
26
+ "name": "fc1",
27
+ "params": {
28
+ "inFeatures": 784,
29
+ "outFeatures": 512,
30
+ "bias": true
31
+ },
32
+ "inputShape": [1, 784],
33
+ "outputShape": [1, 512],
34
+ "depth": 1
35
+ },
36
+ {
37
+ "id": "relu1",
38
+ "type": "relu",
39
+ "name": "ReLU",
40
+ "inputShape": [1, 512],
41
+ "outputShape": [1, 512],
42
+ "depth": 2
43
+ },
44
+ {
45
+ "id": "dropout1",
46
+ "type": "dropout",
47
+ "name": "Dropout",
48
+ "params": {
49
+ "dropoutRate": 0.2
50
+ },
51
+ "inputShape": [1, 512],
52
+ "outputShape": [1, 512],
53
+ "depth": 3
54
+ },
55
+ {
56
+ "id": "fc2",
57
+ "type": "linear",
58
+ "name": "fc2",
59
+ "params": {
60
+ "inFeatures": 512,
61
+ "outFeatures": 256,
62
+ "bias": true
63
+ },
64
+ "inputShape": [1, 512],
65
+ "outputShape": [1, 256],
66
+ "depth": 4
67
+ },
68
+ {
69
+ "id": "relu2",
70
+ "type": "relu",
71
+ "name": "ReLU",
72
+ "inputShape": [1, 256],
73
+ "outputShape": [1, 256],
74
+ "depth": 5
75
+ },
76
+ {
77
+ "id": "dropout2",
78
+ "type": "dropout",
79
+ "name": "Dropout",
80
+ "params": {
81
+ "dropoutRate": 0.2
82
+ },
83
+ "inputShape": [1, 256],
84
+ "outputShape": [1, 256],
85
+ "depth": 6
86
+ },
87
+ {
88
+ "id": "fc3",
89
+ "type": "linear",
90
+ "name": "fc3",
91
+ "params": {
92
+ "inFeatures": 256,
93
+ "outFeatures": 10,
94
+ "bias": true
95
+ },
96
+ "inputShape": [1, 256],
97
+ "outputShape": [1, 10],
98
+ "depth": 7
99
+ },
100
+ {
101
+ "id": "softmax",
102
+ "type": "softmax",
103
+ "name": "Softmax",
104
+ "inputShape": [1, 10],
105
+ "outputShape": [1, 10],
106
+ "depth": 8
107
+ },
108
+ {
109
+ "id": "output",
110
+ "type": "output",
111
+ "name": "Output",
112
+ "inputShape": [1, 10],
113
+ "depth": 9
114
+ }
115
+ ],
116
+ "edges": [
117
+ { "source": "input", "target": "fc1", "tensorShape": [1, 784] },
118
+ { "source": "fc1", "target": "relu1", "tensorShape": [1, 512] },
119
+ { "source": "relu1", "target": "dropout1", "tensorShape": [1, 512] },
120
+ { "source": "dropout1", "target": "fc2", "tensorShape": [1, 512] },
121
+ { "source": "fc2", "target": "relu2", "tensorShape": [1, 256] },
122
+ { "source": "relu2", "target": "dropout2", "tensorShape": [1, 256] },
123
+ { "source": "dropout2", "target": "fc3", "tensorShape": [1, 256] },
124
+ { "source": "fc3", "target": "softmax", "tensorShape": [1, 10] },
125
+ { "source": "softmax", "target": "output", "tensorShape": [1, 10] }
126
+ ]
127
+ },
128
+ "visualization": {
129
+ "layout": "layered",
130
+ "theme": "dark",
131
+ "layerSpacing": 3.0,
132
+ "nodeScale": 1.0,
133
+ "showLabels": true,
134
+ "showEdges": true,
135
+ "edgeStyle": "tube"
136
+ }
137
+ }
samples/transformer_encoder.nn3d ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0.0",
3
+ "metadata": {
4
+ "name": "Transformer Encoder",
5
+ "description": "A Transformer encoder block for sequence modeling",
6
+ "framework": "pytorch",
7
+ "created": "2024-12-17T10:00:00Z",
8
+ "tags": ["transformer", "attention", "nlp", "encoder"],
9
+ "inputShape": [1, 512, 768],
10
+ "outputShape": [1, 512, 768],
11
+ "totalParams": 7087872,
12
+ "trainableParams": 7087872
13
+ },
14
+ "graph": {
15
+ "nodes": [
16
+ {
17
+ "id": "input",
18
+ "type": "input",
19
+ "name": "Input Embeddings",
20
+ "outputShape": [1, 512, 768],
21
+ "depth": 0
22
+ },
23
+ {
24
+ "id": "pos_embed",
25
+ "type": "embedding",
26
+ "name": "Positional Embedding",
27
+ "params": {
28
+ "numEmbeddings": 512,
29
+ "embeddingDim": 768
30
+ },
31
+ "inputShape": [1, 512],
32
+ "outputShape": [1, 512, 768],
33
+ "depth": 1
34
+ },
35
+ {
36
+ "id": "embed_add",
37
+ "type": "add",
38
+ "name": "Add Position",
39
+ "inputShape": [1, 512, 768],
40
+ "outputShape": [1, 512, 768],
41
+ "depth": 2
42
+ },
43
+ {
44
+ "id": "layer1_ln1",
45
+ "type": "layerNorm",
46
+ "name": "LayerNorm 1",
47
+ "params": {
48
+ "eps": 1e-6
49
+ },
50
+ "inputShape": [1, 512, 768],
51
+ "outputShape": [1, 512, 768],
52
+ "depth": 3,
53
+ "group": "encoder_layer_1"
54
+ },
55
+ {
56
+ "id": "layer1_mha",
57
+ "type": "multiHeadAttention",
58
+ "name": "Multi-Head Attention",
59
+ "params": {
60
+ "numHeads": 12,
61
+ "hiddenSize": 768,
62
+ "dropoutRate": 0.1
63
+ },
64
+ "inputShape": [1, 512, 768],
65
+ "outputShape": [1, 512, 768],
66
+ "depth": 4,
67
+ "group": "encoder_layer_1"
68
+ },
69
+ {
70
+ "id": "layer1_dropout1",
71
+ "type": "dropout",
72
+ "name": "Dropout",
73
+ "params": {
74
+ "dropoutRate": 0.1
75
+ },
76
+ "inputShape": [1, 512, 768],
77
+ "outputShape": [1, 512, 768],
78
+ "depth": 5,
79
+ "group": "encoder_layer_1"
80
+ },
81
+ {
82
+ "id": "layer1_add1",
83
+ "type": "add",
84
+ "name": "Residual Add 1",
85
+ "inputShape": [1, 512, 768],
86
+ "outputShape": [1, 512, 768],
87
+ "depth": 6,
88
+ "group": "encoder_layer_1"
89
+ },
90
+ {
91
+ "id": "layer1_ln2",
92
+ "type": "layerNorm",
93
+ "name": "LayerNorm 2",
94
+ "params": {
95
+ "eps": 1e-6
96
+ },
97
+ "inputShape": [1, 512, 768],
98
+ "outputShape": [1, 512, 768],
99
+ "depth": 7,
100
+ "group": "encoder_layer_1"
101
+ },
102
+ {
103
+ "id": "layer1_ff1",
104
+ "type": "linear",
105
+ "name": "Feed Forward 1",
106
+ "params": {
107
+ "inFeatures": 768,
108
+ "outFeatures": 3072,
109
+ "bias": true
110
+ },
111
+ "inputShape": [1, 512, 768],
112
+ "outputShape": [1, 512, 3072],
113
+ "depth": 8,
114
+ "group": "encoder_layer_1"
115
+ },
116
+ {
117
+ "id": "layer1_gelu",
118
+ "type": "gelu",
119
+ "name": "GELU",
120
+ "inputShape": [1, 512, 3072],
121
+ "outputShape": [1, 512, 3072],
122
+ "depth": 9,
123
+ "group": "encoder_layer_1"
124
+ },
125
+ {
126
+ "id": "layer1_ff2",
127
+ "type": "linear",
128
+ "name": "Feed Forward 2",
129
+ "params": {
130
+ "inFeatures": 3072,
131
+ "outFeatures": 768,
132
+ "bias": true
133
+ },
134
+ "inputShape": [1, 512, 3072],
135
+ "outputShape": [1, 512, 768],
136
+ "depth": 10,
137
+ "group": "encoder_layer_1"
138
+ },
139
+ {
140
+ "id": "layer1_dropout2",
141
+ "type": "dropout",
142
+ "name": "Dropout",
143
+ "params": {
144
+ "dropoutRate": 0.1
145
+ },
146
+ "inputShape": [1, 512, 768],
147
+ "outputShape": [1, 512, 768],
148
+ "depth": 11,
149
+ "group": "encoder_layer_1"
150
+ },
151
+ {
152
+ "id": "layer1_add2",
153
+ "type": "add",
154
+ "name": "Residual Add 2",
155
+ "inputShape": [1, 512, 768],
156
+ "outputShape": [1, 512, 768],
157
+ "depth": 12,
158
+ "group": "encoder_layer_1"
159
+ },
160
+ {
161
+ "id": "final_ln",
162
+ "type": "layerNorm",
163
+ "name": "Final LayerNorm",
164
+ "params": {
165
+ "eps": 1e-6
166
+ },
167
+ "inputShape": [1, 512, 768],
168
+ "outputShape": [1, 512, 768],
169
+ "depth": 13
170
+ },
171
+ {
172
+ "id": "output",
173
+ "type": "output",
174
+ "name": "Output",
175
+ "inputShape": [1, 512, 768],
176
+ "depth": 14
177
+ }
178
+ ],
179
+ "edges": [
180
+ { "source": "input", "target": "embed_add" },
181
+ { "source": "pos_embed", "target": "embed_add" },
182
+ { "source": "embed_add", "target": "layer1_ln1" },
183
+ { "source": "layer1_ln1", "target": "layer1_mha" },
184
+ { "source": "layer1_mha", "target": "layer1_dropout1" },
185
+ { "source": "layer1_dropout1", "target": "layer1_add1" },
186
+ { "source": "embed_add", "target": "layer1_add1" },
187
+ { "source": "layer1_add1", "target": "layer1_ln2" },
188
+ { "source": "layer1_ln2", "target": "layer1_ff1" },
189
+ { "source": "layer1_ff1", "target": "layer1_gelu" },
190
+ { "source": "layer1_gelu", "target": "layer1_ff2" },
191
+ { "source": "layer1_ff2", "target": "layer1_dropout2" },
192
+ { "source": "layer1_dropout2", "target": "layer1_add2" },
193
+ { "source": "layer1_add1", "target": "layer1_add2" },
194
+ { "source": "layer1_add2", "target": "final_ln" },
195
+ { "source": "final_ln", "target": "output" }
196
+ ],
197
+ "subgraphs": [
198
+ {
199
+ "id": "encoder_layer_1_group",
200
+ "name": "Encoder Layer 1",
201
+ "type": "attention",
202
+ "nodes": [
203
+ "layer1_ln1", "layer1_mha", "layer1_dropout1", "layer1_add1",
204
+ "layer1_ln2", "layer1_ff1", "layer1_gelu", "layer1_ff2",
205
+ "layer1_dropout2", "layer1_add2"
206
+ ],
207
+ "color": "#E91E63"
208
+ }
209
+ ]
210
+ },
211
+ "visualization": {
212
+ "layout": "layered",
213
+ "theme": "dark",
214
+ "layerSpacing": 2.0,
215
+ "nodeScale": 1.0,
216
+ "showLabels": true,
217
+ "showEdges": true,
218
+ "edgeStyle": "bezier"
219
+ }
220
+ }
src/App.css CHANGED
@@ -1,38 +1,195 @@
1
- .App {
2
- text-align: center;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  }
4
 
5
- .App-logo {
6
- height: 40vmin;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  pointer-events: none;
 
8
  }
9
 
10
- @media (prefers-reduced-motion: no-preference) {
11
- .App-logo {
12
- animation: App-logo-spin infinite 20s linear;
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  }
15
 
16
- .App-header {
17
- background-color: #282c34;
18
- min-height: 100vh;
19
- display: flex;
20
- flex-direction: column;
21
- align-items: center;
22
- justify-content: center;
23
- font-size: calc(10px + 2vmin);
24
- color: white;
25
  }
26
 
27
- .App-link {
28
- color: #61dafb;
 
 
 
 
 
 
 
 
 
29
  }
30
 
31
- @keyframes App-logo-spin {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  from {
33
- transform: rotate(0deg);
 
34
  }
35
  to {
36
- transform: rotate(360deg);
 
37
  }
38
  }
 
 
 
 
 
1
+ /*
2
+ * Oldskool Wireframe Design System
3
+ * Inspired by ctxdc.com and awwwards-winning dark interfaces
4
+ */
5
+
6
+ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;500;600;700&display=swap');
7
+
8
+ :root {
9
+ /* Core Colors */
10
+ --bg-primary: #0a0a0a;
11
+ --bg-secondary: #0f0f0f;
12
+ --bg-tertiary: #141414;
13
+ --bg-card: rgba(15, 15, 15, 0.8);
14
+
15
+ /* Accent Colors */
16
+ --accent-primary: #b4ff39;
17
+ --accent-secondary: #7fff00;
18
+ --accent-dim: rgba(180, 255, 57, 0.15);
19
+ --accent-glow: rgba(180, 255, 57, 0.3);
20
+
21
+ /* Text Colors */
22
+ --text-primary: #ffffff;
23
+ --text-secondary: #a0a0a0;
24
+ --text-tertiary: #606060;
25
+ --text-muted: #404040;
26
+
27
+ /* Border Colors */
28
+ --border-primary: rgba(180, 255, 57, 0.3);
29
+ --border-secondary: rgba(255, 255, 255, 0.1);
30
+ --border-dim: rgba(255, 255, 255, 0.05);
31
+
32
+ /* Status Colors */
33
+ --success: #b4ff39;
34
+ --error: #ff4444;
35
+ --warning: #ffaa00;
36
+
37
+ /* Typography */
38
+ --font-mono: 'JetBrains Mono', 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
39
+
40
+ /* Spacing */
41
+ --space-xs: 4px;
42
+ --space-sm: 8px;
43
+ --space-md: 16px;
44
+ --space-lg: 24px;
45
+ --space-xl: 32px;
46
+
47
+ /* Border Radius */
48
+ --radius-sm: 2px;
49
+ --radius-md: 4px;
50
+ --radius-lg: 8px;
51
  }
52
 
53
+ * {
54
+ margin: 0;
55
+ padding: 0;
56
+ box-sizing: border-box;
57
+ }
58
+
59
+ html, body, #root {
60
+ width: 100%;
61
+ height: 100%;
62
+ overflow: hidden;
63
+ font-family: var(--font-mono);
64
+ font-size: 13px;
65
+ -webkit-font-smoothing: antialiased;
66
+ -moz-osx-font-smoothing: grayscale;
67
+ background: var(--bg-primary);
68
+ color: var(--text-primary);
69
+ }
70
+
71
+ .app {
72
+ width: 100%;
73
+ height: 100%;
74
+ position: relative;
75
+ background: var(--bg-primary);
76
+ }
77
+
78
+ /* Grid Background Pattern */
79
+ .app::before {
80
+ content: '';
81
+ position: absolute;
82
+ top: 0;
83
+ left: 0;
84
+ right: 0;
85
+ bottom: 0;
86
+ background-image:
87
+ linear-gradient(rgba(180, 255, 57, 0.03) 1px, transparent 1px),
88
+ linear-gradient(90deg, rgba(180, 255, 57, 0.03) 1px, transparent 1px);
89
+ background-size: 50px 50px;
90
  pointer-events: none;
91
+ z-index: 0;
92
  }
93
 
94
+ /* Scrollbar styling */
95
+ ::-webkit-scrollbar {
96
+ width: 6px;
97
+ height: 6px;
98
+ }
99
+
100
+ ::-webkit-scrollbar-track {
101
+ background: var(--bg-secondary);
102
+ }
103
+
104
+ ::-webkit-scrollbar-thumb {
105
+ background: var(--border-primary);
106
+ border-radius: var(--radius-sm);
107
+ }
108
+
109
+ ::-webkit-scrollbar-thumb:hover {
110
+ background: var(--accent-primary);
111
+ }
112
+
113
+ /* Selection styling */
114
+ ::selection {
115
+ background: var(--accent-dim);
116
+ color: var(--accent-primary);
117
  }
118
 
119
+ /* Focus styling */
120
+ :focus-visible {
121
+ outline: 1px solid var(--accent-primary);
122
+ outline-offset: 2px;
 
 
 
 
 
123
  }
124
 
125
+ /* Custom checkbox */
126
+ input[type="checkbox"] {
127
+ appearance: none;
128
+ width: 14px;
129
+ height: 14px;
130
+ border: 1px solid var(--border-primary);
131
+ border-radius: var(--radius-sm);
132
+ background: transparent;
133
+ cursor: pointer;
134
+ position: relative;
135
+ transition: all 0.15s ease;
136
  }
137
 
138
+ input[type="checkbox"]:checked {
139
+ background: var(--accent-primary);
140
+ border-color: var(--accent-primary);
141
+ }
142
+
143
+ input[type="checkbox"]:checked::after {
144
+ content: '';
145
+ position: absolute;
146
+ top: 2px;
147
+ left: 4px;
148
+ width: 4px;
149
+ height: 7px;
150
+ border: solid var(--bg-primary);
151
+ border-width: 0 2px 2px 0;
152
+ transform: rotate(45deg);
153
+ }
154
+
155
+ input[type="checkbox"]:hover {
156
+ border-color: var(--accent-primary);
157
+ }
158
+
159
+ /* Terminal cursor blink */
160
+ @keyframes blink {
161
+ 0%, 50% { opacity: 1; }
162
+ 51%, 100% { opacity: 0; }
163
+ }
164
+
165
+ /* Scan line effect */
166
+ @keyframes scanline {
167
+ 0% { transform: translateY(-100%); }
168
+ 100% { transform: translateY(100vh); }
169
+ }
170
+
171
+ /* Pulse glow */
172
+ @keyframes pulse-glow {
173
+ 0%, 100% {
174
+ box-shadow: 0 0 5px var(--accent-dim);
175
+ }
176
+ 50% {
177
+ box-shadow: 0 0 20px var(--accent-glow), 0 0 30px var(--accent-dim);
178
+ }
179
+ }
180
+
181
+ /* Tooltip animations */
182
+ @keyframes fadeIn {
183
  from {
184
+ opacity: 0;
185
+ transform: translateY(5px);
186
  }
187
  to {
188
+ opacity: 1;
189
+ transform: translateY(0);
190
  }
191
  }
192
+
193
+ .tooltip-enter {
194
+ animation: fadeIn 0.2s ease-out;
195
+ }
src/App.js DELETED
@@ -1,25 +0,0 @@
1
- import logo from './logo.svg';
2
- import './App.css';
3
-
4
- function App() {
5
- return (
6
- <div className="App">
7
- <header className="App-header">
8
- <img src={logo} className="App-logo" alt="logo" />
9
- <p>
10
- Edit <code>src/App.js</code> and save to reload.
11
- </p>
12
- <a
13
- className="App-link"
14
- href="https://reactjs.org"
15
- target="_blank"
16
- rel="noopener noreferrer"
17
- >
18
- Learn React
19
- </a>
20
- </header>
21
- </div>
22
- );
23
- }
24
-
25
- export default App;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/App.test.js DELETED
@@ -1,8 +0,0 @@
1
- import { render, screen } from '@testing-library/react';
2
- import App from './App';
3
-
4
- test('renders learn react link', () => {
5
- render(<App />);
6
- const linkElement = screen.getByText(/learn react/i);
7
- expect(linkElement).toBeInTheDocument();
8
- });
 
 
 
 
 
 
 
 
 
src/App.tsx ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useCallback } from 'react';
2
+ import { DropZone, NeuralVisualizer } from './components';
3
+ import { useVisualizerStore } from './core/store';
4
+ import { LAYER_CATEGORIES, type LayerType } from '@/schema/types';
5
+ import type { NN3DModel } from '@/schema/types';
6
+ import './App.css';
7
+
8
+ /**
9
+ * Main Application Component
10
+ *
11
+ * Integrates the new 3D Neural Network Visualization System
12
+ */
13
+ function App() {
14
+ const model = useVisualizerStore(state => state.model);
15
+ const isLoading = useVisualizerStore(state => state.isLoading);
16
+ const error = useVisualizerStore(state => state.error);
17
+ const selectNode = useVisualizerStore(state => state.selectNode);
18
+ const clearModel = useVisualizerStore(state => state.clearModel);
19
+ const loadModel = useVisualizerStore(state => state.loadModel);
20
+
21
+ // Build architecture data from store model for the visualizer
22
+ const architecture = model ? {
23
+ name: model.metadata?.name || 'Model',
24
+ framework: model.metadata?.framework || 'Unknown',
25
+ totalParameters: model.metadata?.totalParams || 0,
26
+ trainableParameters: model.metadata?.trainableParams,
27
+ inputShape: model.metadata?.inputShape as number[] | null || null,
28
+ outputShape: model.metadata?.outputShape as number[] | null || null,
29
+ layers: model.graph.nodes.map(node => {
30
+ // Get category from attributes (set by backend) or infer from layer type
31
+ const backendCategory = node.attributes?.category as string | undefined;
32
+ const layerType = node.type as LayerType;
33
+ const category = backendCategory || LAYER_CATEGORIES[layerType] || 'other';
34
+
35
+ // Extract num parameters from attributes or params
36
+ const numParameters =
37
+ (node.attributes?.parameters as number) ||
38
+ (node.params?.totalParams ? parseInt(String(node.params.totalParams).replace(/,/g, '')) : 0) ||
39
+ 0;
40
+
41
+ return {
42
+ id: node.id,
43
+ name: node.name,
44
+ type: node.type,
45
+ category,
46
+ inputShape: (node.inputShape as number[] | null) || null,
47
+ outputShape: (node.outputShape as number[] | null) || null,
48
+ params: node.params || {},
49
+ numParameters,
50
+ trainable: true,
51
+ };
52
+ }),
53
+ connections: model.graph.edges.map(edge => ({
54
+ source: edge.source,
55
+ target: edge.target,
56
+ tensorShape: (edge.tensorShape as number[] | null) || null,
57
+ })),
58
+ } : null;
59
+
60
+ // Handle layer selection
61
+ const handleLayerSelect = useCallback((layerId: string | null) => {
62
+ selectNode(layerId);
63
+ }, [selectNode]);
64
+
65
+ // Handle uploading a new model (clear current)
66
+ const handleUploadNew = useCallback(() => {
67
+ clearModel();
68
+ }, [clearModel]);
69
+
70
+ // Handle loading a saved model from database
71
+ const handleLoadSavedModel = useCallback((arch: any) => {
72
+ // Convert architecture back to NN3DModel format
73
+ const savedModel: NN3DModel = {
74
+ version: '1.0',
75
+ metadata: {
76
+ name: arch.name,
77
+ framework: arch.framework,
78
+ totalParams: arch.totalParameters,
79
+ trainableParams: arch.trainableParameters,
80
+ inputShape: arch.inputShape,
81
+ outputShape: arch.outputShape,
82
+ },
83
+ graph: {
84
+ nodes: arch.layers.map((layer: any) => ({
85
+ id: layer.id,
86
+ name: layer.name,
87
+ type: layer.type,
88
+ inputShape: layer.inputShape,
89
+ outputShape: layer.outputShape,
90
+ params: layer.params,
91
+ attributes: {
92
+ category: layer.category,
93
+ parameters: layer.numParameters,
94
+ },
95
+ })),
96
+ edges: arch.connections.map((conn: any, idx: number) => ({
97
+ id: `edge-${idx}`,
98
+ source: conn.source,
99
+ target: conn.target,
100
+ tensorShape: conn.tensorShape,
101
+ })),
102
+ },
103
+ };
104
+
105
+ loadModel(savedModel);
106
+ }, [loadModel]);
107
+
108
+ return (
109
+ <div className="app">
110
+ <DropZone>
111
+ <NeuralVisualizer
112
+ architecture={architecture}
113
+ isLoading={isLoading}
114
+ error={error}
115
+ onLayerSelect={handleLayerSelect}
116
+ onUploadNew={handleUploadNew}
117
+ onLoadSavedModel={handleLoadSavedModel}
118
+ />
119
+ </DropZone>
120
+ </div>
121
+ );
122
+ }
123
+
124
+ export default App;
src/components/Scene.tsx ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Suspense } from 'react';
2
+ import { Canvas } from '@react-three/fiber';
3
+ import { Stars, GizmoHelper, GizmoViewport } from '@react-three/drei';
4
+ import { LayerNodes } from './layers';
5
+ import { EdgeConnections } from './edges';
6
+ import { CameraControls, useKeyboardShortcuts } from './controls';
7
+ import { useVisualizerStore } from '@/core/store';
8
+
9
+ /**
10
+ * Loading fallback component
11
+ */
12
+ function LoadingFallback() {
13
+ return (
14
+ <mesh>
15
+ <sphereGeometry args={[0.5, 32, 32]} />
16
+ <meshStandardMaterial color="#4fc3f7" wireframe />
17
+ </mesh>
18
+ );
19
+ }
20
+
21
+ /**
22
+ * Scene lighting setup
23
+ */
24
+ function SceneLighting() {
25
+ const config = useVisualizerStore(state => state.config);
26
+ const isDark = config.theme !== 'light';
27
+
28
+ return (
29
+ <>
30
+ <ambientLight intensity={isDark ? 0.4 : 0.6} />
31
+ <directionalLight
32
+ position={[10, 20, 10]}
33
+ intensity={isDark ? 0.8 : 1}
34
+ castShadow
35
+ shadow-mapSize={[2048, 2048]}
36
+ />
37
+ <directionalLight position={[-10, 10, -10]} intensity={0.3} />
38
+ <pointLight position={[0, 10, 0]} intensity={0.5} color="#4fc3f7" />
39
+ </>
40
+ );
41
+ }
42
+
43
+ /**
44
+ * Scene background
45
+ */
46
+ function SceneBackground() {
47
+ const config = useVisualizerStore(state => state.config);
48
+
49
+ if (config.theme === 'light') {
50
+ return <color attach="background" args={['#f0f0f0']} />;
51
+ }
52
+
53
+ if (config.theme === 'blueprint') {
54
+ return (
55
+ <>
56
+ <color attach="background" args={['#0a1929']} />
57
+ <gridHelper args={[100, 100, '#1e3a5f', '#0d2137']} position={[0, -10, 0]} />
58
+ </>
59
+ );
60
+ }
61
+
62
+ // Dark theme (default)
63
+ return (
64
+ <>
65
+ <color attach="background" args={['#0f0f1a']} />
66
+ <Stars radius={100} depth={50} count={2000} factor={4} fade speed={0.5} />
67
+ </>
68
+ );
69
+ }
70
+
71
+ /**
72
+ * Keyboard shortcuts handler component
73
+ */
74
+ function KeyboardHandler() {
75
+ useKeyboardShortcuts();
76
+ return null;
77
+ }
78
+
79
+ /**
80
+ * Grid and helper elements
81
+ */
82
+ function SceneHelpers() {
83
+ const model = useVisualizerStore(state => state.model);
84
+
85
+ if (!model) return null;
86
+
87
+ return (
88
+ <>
89
+ <gridHelper
90
+ args={[50, 50, '#333', '#222']}
91
+ position={[0, -15, 0]}
92
+ rotation={[0, 0, 0]}
93
+ />
94
+ <GizmoHelper alignment="bottom-left" margin={[80, 80]}>
95
+ <GizmoViewport axisColors={['#f44336', '#4caf50', '#2196f3']} labelColor="white" />
96
+ </GizmoHelper>
97
+ </>
98
+ );
99
+ }
100
+
101
+ /**
102
+ * Main 3D network scene
103
+ */
104
+ function NetworkScene() {
105
+ return (
106
+ <>
107
+ <SceneLighting />
108
+ <SceneBackground />
109
+ <SceneHelpers />
110
+
111
+ <Suspense fallback={<LoadingFallback />}>
112
+ <EdgeConnections />
113
+ <LayerNodes />
114
+ </Suspense>
115
+ </>
116
+ );
117
+ }
118
+
119
+ /**
120
+ * Main 3D Canvas component
121
+ */
122
+ export function Scene() {
123
+ return (
124
+ <Canvas
125
+ camera={{
126
+ position: [0, 5, 20],
127
+ fov: 60,
128
+ near: 0.1,
129
+ far: 1000,
130
+ }}
131
+ dpr={[1, 2]}
132
+ gl={{
133
+ antialias: true,
134
+ alpha: false,
135
+ powerPreference: 'high-performance',
136
+ }}
137
+ >
138
+ <CameraControls />
139
+ <KeyboardHandler />
140
+ <NetworkScene />
141
+ </Canvas>
142
+ );
143
+ }
144
+
145
+ export default Scene;
src/components/controls/CameraControls.tsx ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useRef, useEffect } from 'react';
2
+ import { useThree, useFrame } from '@react-three/fiber';
3
+ import { OrbitControls as DreiOrbitControls } from '@react-three/drei';
4
+ import * as THREE from 'three';
5
+ import { useVisualizerStore } from '@/core/store';
6
+
7
+ /**
8
+ * Camera controls component with orbit, zoom, and pan
9
+ */
10
+ export function CameraControls() {
11
+ const controlsRef = useRef<any>(null);
12
+ const model = useVisualizerStore(state => state.model);
13
+ const setCameraPosition = useVisualizerStore(state => state.setCameraPosition);
14
+ const setCameraTarget = useVisualizerStore(state => state.setCameraTarget);
15
+
16
+ const { camera: threeCamera } = useThree();
17
+
18
+ // Update store when camera moves
19
+ useFrame(() => {
20
+ if (controlsRef.current) {
21
+ const pos = threeCamera.position;
22
+ const target = controlsRef.current.target;
23
+
24
+ // Debounce updates
25
+ setCameraPosition({ x: pos.x, y: pos.y, z: pos.z });
26
+ setCameraTarget({ x: target.x, y: target.y, z: target.z });
27
+ }
28
+ });
29
+
30
+ // Reset camera when model changes
31
+ useEffect(() => {
32
+ if (model && controlsRef.current) {
33
+ // Compute camera position to frame the model
34
+ const nodeCount = model.graph.nodes.length;
35
+ const distance = Math.max(nodeCount * 1.5, 15);
36
+
37
+ threeCamera.position.set(0, distance * 0.3, distance);
38
+ controlsRef.current.target.set(0, -nodeCount * 0.5, 0);
39
+ controlsRef.current.update();
40
+ }
41
+ }, [model, threeCamera]);
42
+
43
+ return (
44
+ <DreiOrbitControls
45
+ ref={controlsRef}
46
+ enablePan={true}
47
+ enableZoom={true}
48
+ enableRotate={true}
49
+ minDistance={2}
50
+ maxDistance={100}
51
+ minPolarAngle={0}
52
+ maxPolarAngle={Math.PI}
53
+ dampingFactor={0.1}
54
+ rotateSpeed={0.5}
55
+ panSpeed={0.5}
56
+ zoomSpeed={0.8}
57
+ />
58
+ );
59
+ }
60
+
61
+ /**
62
+ * Camera animation for transitions
63
+ */
64
+ export function useCameraAnimation() {
65
+ const { camera } = useThree();
66
+ const targetRef = useRef<THREE.Vector3 | null>(null);
67
+ const lookAtRef = useRef<THREE.Vector3 | null>(null);
68
+ const progressRef = useRef(0);
69
+
70
+ useFrame((_, delta) => {
71
+ if (targetRef.current && progressRef.current < 1) {
72
+ progressRef.current += delta * 2;
73
+ const t = Math.min(progressRef.current, 1);
74
+ const eased = 1 - Math.pow(1 - t, 3); // Ease out cubic
75
+
76
+ camera.position.lerp(targetRef.current, eased);
77
+
78
+ if (lookAtRef.current) {
79
+ camera.lookAt(lookAtRef.current);
80
+ }
81
+
82
+ if (t >= 1) {
83
+ targetRef.current = null;
84
+ lookAtRef.current = null;
85
+ }
86
+ }
87
+ });
88
+
89
+ const animateTo = (position: THREE.Vector3, lookAt?: THREE.Vector3) => {
90
+ targetRef.current = position;
91
+ lookAtRef.current = lookAt || null;
92
+ progressRef.current = 0;
93
+ };
94
+
95
+ return { animateTo };
96
+ }
97
+
98
+ /**
99
+ * Focus camera on a specific node
100
+ */
101
+ export function useFocusNode() {
102
+ const computedNodes = useVisualizerStore(state => state.computedNodes);
103
+ const { animateTo } = useCameraAnimation();
104
+
105
+ const focusNode = (nodeId: string) => {
106
+ const node = computedNodes.get(nodeId);
107
+ if (!node) return;
108
+
109
+ const { x, y, z } = node.computedPosition;
110
+ const targetPos = new THREE.Vector3(x, y + 5, z + 10);
111
+ const lookAtPos = new THREE.Vector3(x, y, z);
112
+
113
+ animateTo(targetPos, lookAtPos);
114
+ };
115
+
116
+ return { focusNode };
117
+ }
src/components/controls/Interaction.tsx ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useCallback, useEffect } from 'react';
2
+ import { useThree } from '@react-three/fiber';
3
+ import * as THREE from 'three';
4
+ import { useVisualizerStore } from '@/core/store';
5
+
6
+ /**
7
+ * Hook for raycasting and object picking
8
+ */
9
+ export function useRaycast() {
10
+ const { camera, scene, gl } = useThree();
11
+ const raycaster = new THREE.Raycaster();
12
+ const pointer = new THREE.Vector2();
13
+
14
+ const getIntersections = useCallback((event: MouseEvent | PointerEvent) => {
15
+ const rect = gl.domElement.getBoundingClientRect();
16
+ pointer.x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
17
+ pointer.y = -((event.clientY - rect.top) / rect.height) * 2 + 1;
18
+
19
+ raycaster.setFromCamera(pointer, camera);
20
+ return raycaster.intersectObjects(scene.children, true);
21
+ }, [camera, scene, gl]);
22
+
23
+ return { getIntersections };
24
+ }
25
+
26
+ /**
27
+ * Keyboard shortcuts handler
28
+ */
29
+ export function useKeyboardShortcuts() {
30
+ const resetCamera = useVisualizerStore(state => state.resetCamera);
31
+ const selectNode = useVisualizerStore(state => state.selectNode);
32
+ const selection = useVisualizerStore(state => state.selection);
33
+ const computedNodes = useVisualizerStore(state => state.computedNodes);
34
+ const updateConfig = useVisualizerStore(state => state.updateConfig);
35
+ const config = useVisualizerStore(state => state.config);
36
+
37
+ useEffect(() => {
38
+ const handleKeyDown = (event: KeyboardEvent) => {
39
+ switch (event.key) {
40
+ case 'Escape':
41
+ // Deselect current selection
42
+ selectNode(null);
43
+ break;
44
+
45
+ case 'r':
46
+ case 'R':
47
+ // Reset camera
48
+ if (!event.ctrlKey && !event.metaKey) {
49
+ resetCamera();
50
+ }
51
+ break;
52
+
53
+ case 'l':
54
+ case 'L':
55
+ // Toggle labels
56
+ updateConfig({ showLabels: !config.showLabels });
57
+ break;
58
+
59
+ case 'e':
60
+ case 'E':
61
+ // Toggle edges
62
+ updateConfig({ showEdges: !config.showEdges });
63
+ break;
64
+
65
+ case 'ArrowUp':
66
+ case 'ArrowDown': {
67
+ // Navigate between nodes
68
+ if (selection.selectedNodeId) {
69
+ const nodeIds = Array.from(computedNodes.keys());
70
+ const currentIndex = nodeIds.indexOf(selection.selectedNodeId);
71
+ const nextIndex = event.key === 'ArrowDown'
72
+ ? Math.min(currentIndex + 1, nodeIds.length - 1)
73
+ : Math.max(currentIndex - 1, 0);
74
+ selectNode(nodeIds[nextIndex]);
75
+ }
76
+ break;
77
+ }
78
+
79
+ default:
80
+ break;
81
+ }
82
+ };
83
+
84
+ window.addEventListener('keydown', handleKeyDown);
85
+ return () => window.removeEventListener('keydown', handleKeyDown);
86
+ }, [resetCamera, selectNode, selection, computedNodes, updateConfig, config]);
87
+ }
88
+
89
+ /**
90
+ * Touch gesture handler for mobile
91
+ */
92
+ export function useTouchGestures() {
93
+ // Placeholder for touch gesture handling
94
+ // Can be expanded for pinch-to-zoom, two-finger rotate, etc.
95
+ }
96
+
97
+ /**
98
+ * LOD (Level of Detail) manager based on camera distance
99
+ */
100
+ export function useLODManager() {
101
+ const { camera } = useThree();
102
+ const computedNodes = useVisualizerStore(state => state.computedNodes);
103
+ const updateNodeLOD = useVisualizerStore(state => state.updateNodeLOD);
104
+
105
+ // LOD thresholds
106
+ const LOD_DISTANCES = {
107
+ HIGH: 20, // LOD 0 (full detail) when closer than this
108
+ MEDIUM: 40, // LOD 1 (medium detail)
109
+ LOW: 80, // LOD 2 (low detail)
110
+ };
111
+
112
+ const updateLOD = useCallback(() => {
113
+ const lodMap = new Map<string, number>();
114
+ const cameraPos = camera.position;
115
+
116
+ computedNodes.forEach((node, id) => {
117
+ const nodePos = new THREE.Vector3(
118
+ node.computedPosition.x,
119
+ node.computedPosition.y,
120
+ node.computedPosition.z
121
+ );
122
+ const distance = cameraPos.distanceTo(nodePos);
123
+
124
+ let lod = 0;
125
+ if (distance > LOD_DISTANCES.LOW) {
126
+ lod = 2;
127
+ } else if (distance > LOD_DISTANCES.MEDIUM) {
128
+ lod = 1;
129
+ }
130
+
131
+ lodMap.set(id, lod);
132
+ });
133
+
134
+ updateNodeLOD(lodMap);
135
+ }, [camera, computedNodes, updateNodeLOD]);
136
+
137
+ return { updateLOD };
138
+ }
src/components/controls/index.ts ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ export * from './CameraControls';
2
+ export * from './Interaction';
src/components/edges/EdgeConnections.tsx ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useMemo } from 'react';
2
+ import { useVisualizerStore } from '@/core/store';
3
+ import { getEdgeComponent, EdgeStyle } from './EdgeGeometry';
4
+ import { SmartConnection } from './NeuralConnection';
5
+
6
+ /**
7
+ * Renders all edges/connections in the network
8
+ */
9
+ export function EdgeConnections() {
10
+ const computedEdges = useVisualizerStore(state => state.computedEdges);
11
+ const computedNodes = useVisualizerStore(state => state.computedNodes);
12
+ const config = useVisualizerStore(state => state.config);
13
+
14
+ // Check if we should use enhanced neural visualization
15
+ const useEnhancedEdges = useMemo(() => {
16
+ // Check if any node has enhanced attributes (from backend analysis)
17
+ for (const node of computedNodes.values()) {
18
+ if (node.params && (
19
+ node.params.out_features !== undefined ||
20
+ node.params.outFeatures !== undefined ||
21
+ node.params.out_channels !== undefined ||
22
+ node.params.outChannels !== undefined
23
+ )) {
24
+ return true;
25
+ }
26
+ }
27
+ return false;
28
+ }, [computedNodes]);
29
+
30
+ const EdgeComponent = useMemo(
31
+ () => getEdgeComponent(config.edgeStyle as EdgeStyle || 'tube'),
32
+ [config.edgeStyle]
33
+ );
34
+
35
+ if (!config.showEdges) return null;
36
+
37
+ // Enhanced neural connections
38
+ if (useEnhancedEdges) {
39
+ return (
40
+ <group name="edge-connections">
41
+ {computedEdges.map((edge, index) => {
42
+ // Get source and target node info for neuron counts
43
+ const sourceNode = computedNodes.get(edge.source);
44
+ const targetNode = computedNodes.get(edge.target);
45
+
46
+ const sourceNeurons = sourceNode?.params?.out_features ||
47
+ sourceNode?.params?.outFeatures ||
48
+ sourceNode?.params?.out_channels ||
49
+ sourceNode?.params?.outChannels || 16;
50
+
51
+ const targetNeurons = targetNode?.params?.in_features ||
52
+ targetNode?.params?.inFeatures ||
53
+ targetNode?.params?.in_channels ||
54
+ targetNode?.params?.inChannels || 16;
55
+
56
+ return (
57
+ <SmartConnection
58
+ key={edge.id || `edge-${index}`}
59
+ sourcePosition={edge.sourcePosition}
60
+ targetPosition={edge.targetPosition}
61
+ sourceNeurons={typeof sourceNeurons === 'number' ? sourceNeurons : 16}
62
+ targetNeurons={typeof targetNeurons === 'number' ? targetNeurons : 16}
63
+ color={edge.color}
64
+ highlighted={edge.highlighted}
65
+ animated={true}
66
+ style="bundle"
67
+ />
68
+ );
69
+ })}
70
+ </group>
71
+ );
72
+ }
73
+
74
+ // Original edge rendering
75
+ return (
76
+ <group name="edge-connections">
77
+ {computedEdges.map((edge, index) => (
78
+ <EdgeComponent
79
+ key={edge.id || `edge-${index}`}
80
+ edge={edge}
81
+ style={config.edgeStyle as EdgeStyle}
82
+ animated={false}
83
+ />
84
+ ))}
85
+ </group>
86
+ );
87
+ }
88
+
89
+ export { EdgeConnections as Edges };
src/components/edges/EdgeGeometry.tsx ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useMemo, useRef } from 'react';
2
+ import * as THREE from 'three';
3
+ import { useFrame } from '@react-three/fiber';
4
+ import { Line, QuadraticBezierLine } from '@react-three/drei';
5
+ import type { ComputedEdge } from '@/core/store';
6
+ import type { Position3D } from '@/schema/types';
7
+
8
+ /**
9
+ * Edge style type
10
+ */
11
+ export type EdgeStyle = 'line' | 'tube' | 'arrow' | 'bezier';
12
+
13
+ /**
14
+ * Props for edge components
15
+ */
16
+ export interface EdgeProps {
17
+ edge: ComputedEdge;
18
+ style?: EdgeStyle;
19
+ animated?: boolean;
20
+ flowSpeed?: number;
21
+ }
22
+
23
+ /**
24
+ * Calculate control points for bezier curves
25
+ */
26
+ function getBezierControlPoints(
27
+ start: Position3D,
28
+ end: Position3D
29
+ ): { mid1: Position3D; mid2: Position3D } {
30
+ const midY = (start.y + end.y) / 2;
31
+ const offsetZ = Math.abs(end.y - start.y) * 0.3;
32
+
33
+ return {
34
+ mid1: { x: start.x, y: midY, z: start.z + offsetZ },
35
+ mid2: { x: end.x, y: midY, z: end.z + offsetZ },
36
+ };
37
+ }
38
+
39
+ /**
40
+ * Simple line edge
41
+ */
42
+ export function LineEdge({ edge }: EdgeProps) {
43
+ const points = useMemo(() => [
44
+ new THREE.Vector3(edge.sourcePosition.x, edge.sourcePosition.y, edge.sourcePosition.z),
45
+ new THREE.Vector3(edge.targetPosition.x, edge.targetPosition.y, edge.targetPosition.z),
46
+ ], [edge.sourcePosition, edge.targetPosition]);
47
+
48
+ const color = edge.highlighted ? '#ffffff' : edge.color;
49
+ const lineWidth = edge.highlighted ? 3 : 1.5;
50
+
51
+ return (
52
+ <Line
53
+ points={points}
54
+ color={color}
55
+ lineWidth={lineWidth}
56
+ opacity={edge.visible ? 1 : 0.2}
57
+ transparent
58
+ />
59
+ );
60
+ }
61
+
62
+ /**
63
+ * Bezier curve edge for smoother connections
64
+ */
65
+ export function BezierEdge({ edge }: EdgeProps) {
66
+ const start = useMemo(() =>
67
+ new THREE.Vector3(edge.sourcePosition.x, edge.sourcePosition.y, edge.sourcePosition.z),
68
+ [edge.sourcePosition]
69
+ );
70
+
71
+ const end = useMemo(() =>
72
+ new THREE.Vector3(edge.targetPosition.x, edge.targetPosition.y, edge.targetPosition.z),
73
+ [edge.targetPosition]
74
+ );
75
+
76
+ const control = useMemo(() => {
77
+ const { mid1 } = getBezierControlPoints(edge.sourcePosition, edge.targetPosition);
78
+ return new THREE.Vector3(mid1.x, mid1.y, mid1.z);
79
+ }, [edge.sourcePosition, edge.targetPosition]);
80
+
81
+ const color = edge.highlighted ? '#ffffff' : edge.color;
82
+ const lineWidth = edge.highlighted ? 3 : 1.5;
83
+
84
+ return (
85
+ <QuadraticBezierLine
86
+ start={start}
87
+ end={end}
88
+ mid={control}
89
+ color={color}
90
+ lineWidth={lineWidth}
91
+ opacity={edge.visible ? 1 : 0.2}
92
+ transparent
93
+ />
94
+ );
95
+ }
96
+
97
+ /**
98
+ * Tube edge for 3D pipe-like connections
99
+ */
100
+ export function TubeEdge({ edge, animated = false, flowSpeed = 1 }: EdgeProps) {
101
+ const tubeRef = useRef<THREE.Mesh>(null);
102
+
103
+ // Create curve for tube
104
+ const curve = useMemo(() => {
105
+ const start = new THREE.Vector3(
106
+ edge.sourcePosition.x,
107
+ edge.sourcePosition.y,
108
+ edge.sourcePosition.z
109
+ );
110
+ const end = new THREE.Vector3(
111
+ edge.targetPosition.x,
112
+ edge.targetPosition.y,
113
+ edge.targetPosition.z
114
+ );
115
+
116
+ const { mid1, mid2 } = getBezierControlPoints(edge.sourcePosition, edge.targetPosition);
117
+ const control1 = new THREE.Vector3(mid1.x, mid1.y, mid1.z);
118
+ const control2 = new THREE.Vector3(mid2.x, mid2.y, mid2.z);
119
+
120
+ return new THREE.CubicBezierCurve3(start, control1, control2, end);
121
+ }, [edge.sourcePosition, edge.targetPosition]);
122
+
123
+ // Tube geometry
124
+ const geometry = useMemo(() => {
125
+ const radius = edge.highlighted ? 0.06 : 0.03;
126
+ return new THREE.TubeGeometry(curve, 32, radius, 8, false);
127
+ }, [curve, edge.highlighted]);
128
+
129
+ // Animate flow effect
130
+ useFrame(({ clock }) => {
131
+ if (animated && tubeRef.current) {
132
+ const material = tubeRef.current.material as THREE.MeshStandardMaterial;
133
+ if (material.map) {
134
+ material.map.offset.y = clock.getElapsedTime() * flowSpeed;
135
+ }
136
+ }
137
+ });
138
+
139
+ const color = edge.highlighted ? '#ffffff' : edge.color;
140
+
141
+ return (
142
+ <mesh ref={tubeRef} geometry={geometry}>
143
+ <meshStandardMaterial
144
+ color={color}
145
+ opacity={edge.visible ? 0.8 : 0.2}
146
+ transparent
147
+ metalness={0.3}
148
+ roughness={0.6}
149
+ />
150
+ </mesh>
151
+ );
152
+ }
153
+
154
+ /**
155
+ * Arrow edge with direction indicator
156
+ */
157
+ export function ArrowEdge({ edge }: EdgeProps) {
158
+ const start = new THREE.Vector3(
159
+ edge.sourcePosition.x,
160
+ edge.sourcePosition.y,
161
+ edge.sourcePosition.z
162
+ );
163
+ const end = new THREE.Vector3(
164
+ edge.targetPosition.x,
165
+ edge.targetPosition.y,
166
+ edge.targetPosition.z
167
+ );
168
+
169
+ const direction = end.clone().sub(start).normalize();
170
+ const arrowPosition = end.clone().sub(direction.clone().multiplyScalar(0.3));
171
+
172
+ const color = edge.highlighted ? '#ffffff' : edge.color;
173
+
174
+ // Calculate arrow rotation
175
+ const quaternion = new THREE.Quaternion();
176
+ quaternion.setFromUnitVectors(new THREE.Vector3(0, 1, 0), direction);
177
+
178
+ return (
179
+ <group>
180
+ <LineEdge edge={edge} />
181
+ {/* Arrow head */}
182
+ <mesh position={arrowPosition} quaternion={quaternion}>
183
+ <coneGeometry args={[0.1, 0.2, 8]} />
184
+ <meshStandardMaterial color={color} />
185
+ </mesh>
186
+ </group>
187
+ );
188
+ }
189
+
190
+ /**
191
+ * Factory function to get edge component by style
192
+ */
193
+ export function getEdgeComponent(style: EdgeStyle): React.ComponentType<EdgeProps> {
194
+ switch (style) {
195
+ case 'line':
196
+ return LineEdge;
197
+ case 'bezier':
198
+ return BezierEdge;
199
+ case 'tube':
200
+ return TubeEdge;
201
+ case 'arrow':
202
+ return ArrowEdge;
203
+ default:
204
+ return TubeEdge;
205
+ }
206
+ }
src/components/edges/NeuralConnection.tsx ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Neural Network Connection Visualization
3
+ * Shows dense connections between layers like actual neural networks
4
+ */
5
+
6
+ import { useMemo, useRef } from 'react';
7
+ import * as THREE from 'three';
8
+ import { useFrame } from '@react-three/fiber';
9
+ import type { Position3D } from '@/schema/types';
10
+
11
+ export interface NeuralConnectionProps {
12
+ sourcePosition: Position3D;
13
+ targetPosition: Position3D;
14
+ sourceNeurons?: number;
15
+ targetNeurons?: number;
16
+ color?: string;
17
+ highlighted?: boolean;
18
+ animated?: boolean;
19
+ connectionDensity?: number; // 0-1, how many connections to show
20
+ style?: 'single' | 'dense' | 'bundle';
21
+ }
22
+
23
+ /**
24
+ * Single connection line between layers
25
+ */
26
+ export function SingleConnection({
27
+ sourcePosition,
28
+ targetPosition,
29
+ color = '#ffffff',
30
+ highlighted = false,
31
+ }: NeuralConnectionProps) {
32
+ const geometry = useMemo(() => {
33
+ const points = [
34
+ new THREE.Vector3(sourcePosition.x, sourcePosition.y, sourcePosition.z),
35
+ new THREE.Vector3(targetPosition.x, targetPosition.y, targetPosition.z),
36
+ ];
37
+ return new THREE.BufferGeometry().setFromPoints(points);
38
+ }, [sourcePosition, targetPosition]);
39
+
40
+ return (
41
+ <primitive object={new THREE.Line(geometry, new THREE.LineBasicMaterial({
42
+ color: highlighted ? '#ffffff' : color,
43
+ transparent: true,
44
+ opacity: highlighted ? 1 : 0.5,
45
+ }))} />
46
+ );
47
+ }
48
+
49
+ /**
50
+ * Dense connection bundle showing multiple lines
51
+ */
52
+ export function DenseConnection({
53
+ sourcePosition,
54
+ targetPosition,
55
+ sourceNeurons = 16,
56
+ targetNeurons = 16,
57
+ color = '#4488ff',
58
+ highlighted = false,
59
+ animated = true,
60
+ connectionDensity = 0.3,
61
+ }: NeuralConnectionProps) {
62
+ const groupRef = useRef<THREE.Group>(null);
63
+ const materialRef = useRef<THREE.LineBasicMaterial>(null);
64
+
65
+ // Limit connections for performance
66
+ const maxConnections = 100;
67
+ const numConnections = Math.min(
68
+ Math.floor(sourceNeurons * targetNeurons * connectionDensity),
69
+ maxConnections
70
+ );
71
+
72
+ // Generate connection lines
73
+ const lines = useMemo(() => {
74
+ const connections: { start: THREE.Vector3; end: THREE.Vector3 }[] = [];
75
+
76
+ // Calculate source and target layer bounds
77
+ const sourceSpread = Math.min(Math.sqrt(sourceNeurons) * 0.1, 0.4);
78
+ const targetSpread = Math.min(Math.sqrt(targetNeurons) * 0.1, 0.4);
79
+
80
+ for (let i = 0; i < numConnections; i++) {
81
+ // Random positions within layer bounds
82
+ const srcOffset = {
83
+ y: (Math.random() - 0.5) * sourceSpread,
84
+ z: (Math.random() - 0.5) * sourceSpread * 0.5,
85
+ };
86
+ const tgtOffset = {
87
+ y: (Math.random() - 0.5) * targetSpread,
88
+ z: (Math.random() - 0.5) * targetSpread * 0.5,
89
+ };
90
+
91
+ connections.push({
92
+ start: new THREE.Vector3(
93
+ sourcePosition.x + 0.3, // Offset from layer center
94
+ sourcePosition.y + srcOffset.y,
95
+ sourcePosition.z + srcOffset.z
96
+ ),
97
+ end: new THREE.Vector3(
98
+ targetPosition.x - 0.3, // Offset from layer center
99
+ targetPosition.y + tgtOffset.y,
100
+ targetPosition.z + tgtOffset.z
101
+ ),
102
+ });
103
+ }
104
+
105
+ return connections;
106
+ }, [sourcePosition, targetPosition, sourceNeurons, targetNeurons, numConnections]);
107
+
108
+ // Animate opacity
109
+ useFrame((state) => {
110
+ if (animated && materialRef.current) {
111
+ const pulse = Math.sin(state.clock.elapsedTime * 2) * 0.1 + 0.3;
112
+ materialRef.current.opacity = highlighted ? 0.8 : pulse;
113
+ }
114
+ });
115
+
116
+ // Create buffer geometry for all lines
117
+ const geometry = useMemo(() => {
118
+ const positions: number[] = [];
119
+
120
+ lines.forEach(line => {
121
+ positions.push(line.start.x, line.start.y, line.start.z);
122
+ positions.push(line.end.x, line.end.y, line.end.z);
123
+ });
124
+
125
+ const geo = new THREE.BufferGeometry();
126
+ geo.setAttribute('position', new THREE.Float32BufferAttribute(positions, 3));
127
+ return geo;
128
+ }, [lines]);
129
+
130
+ return (
131
+ <group ref={groupRef}>
132
+ <lineSegments geometry={geometry}>
133
+ <lineBasicMaterial
134
+ ref={materialRef}
135
+ color={highlighted ? '#ffffff' : color}
136
+ transparent
137
+ opacity={0.3}
138
+ depthWrite={false}
139
+ />
140
+ </lineSegments>
141
+ </group>
142
+ );
143
+ }
144
+
145
+ /**
146
+ * Bundled connection - shows as a tube/pipe
147
+ */
148
+ export function BundledConnection({
149
+ sourcePosition,
150
+ targetPosition,
151
+ sourceNeurons = 16,
152
+ targetNeurons = 16,
153
+ color = '#4488ff',
154
+ highlighted = false,
155
+ animated = true,
156
+ }: NeuralConnectionProps) {
157
+ const meshRef = useRef<THREE.Mesh>(null);
158
+
159
+ // Calculate bundle thickness based on connection count
160
+ const connectionStrength = Math.log2(Math.min(sourceNeurons, targetNeurons) + 1) * 0.02;
161
+ const thickness = Math.max(0.02, Math.min(connectionStrength, 0.1));
162
+
163
+ // Create tube path
164
+ const curve = useMemo(() => {
165
+ const start = new THREE.Vector3(sourcePosition.x, sourcePosition.y, sourcePosition.z);
166
+ const end = new THREE.Vector3(targetPosition.x, targetPosition.y, targetPosition.z);
167
+
168
+ // Bezier control points for smooth curve
169
+ const midX = (start.x + end.x) / 2;
170
+ const control1 = new THREE.Vector3(midX, start.y, start.z);
171
+ const control2 = new THREE.Vector3(midX, end.y, end.z);
172
+
173
+ return new THREE.CubicBezierCurve3(start, control1, control2, end);
174
+ }, [sourcePosition, targetPosition]);
175
+
176
+ const geometry = useMemo(() => {
177
+ return new THREE.TubeGeometry(curve, 20, thickness, 8, false);
178
+ }, [curve, thickness]);
179
+
180
+ // Animate flow effect
181
+ useFrame((state) => {
182
+ if (animated && meshRef.current) {
183
+ const material = meshRef.current.material as THREE.MeshStandardMaterial;
184
+ if (material.emissiveIntensity !== undefined) {
185
+ material.emissiveIntensity = Math.sin(state.clock.elapsedTime * 3) * 0.2 + 0.3;
186
+ }
187
+ }
188
+ });
189
+
190
+ const baseColor = new THREE.Color(color);
191
+
192
+ return (
193
+ <mesh ref={meshRef} geometry={geometry}>
194
+ <meshStandardMaterial
195
+ color={highlighted ? '#ffffff' : baseColor}
196
+ emissive={baseColor}
197
+ emissiveIntensity={0.3}
198
+ transparent
199
+ opacity={highlighted ? 0.9 : 0.6}
200
+ metalness={0.3}
201
+ roughness={0.7}
202
+ />
203
+ </mesh>
204
+ );
205
+ }
206
+
207
+ /**
208
+ * Flow particles along connection
209
+ */
210
+ export function FlowParticles({
211
+ sourcePosition,
212
+ targetPosition,
213
+ color = '#ffffff',
214
+ particleCount = 5,
215
+ speed = 1,
216
+ }: {
217
+ sourcePosition: Position3D;
218
+ targetPosition: Position3D;
219
+ color?: string;
220
+ particleCount?: number;
221
+ speed?: number;
222
+ }) {
223
+ const particlesRef = useRef<THREE.Points>(null);
224
+
225
+ // Create particle positions
226
+ const { positions, offsets } = useMemo(() => {
227
+ const pos = new Float32Array(particleCount * 3);
228
+ const off = new Float32Array(particleCount);
229
+
230
+ for (let i = 0; i < particleCount; i++) {
231
+ off[i] = i / particleCount; // Spread particles along path
232
+
233
+ // Initial positions will be updated in useFrame
234
+ pos[i * 3] = sourcePosition.x;
235
+ pos[i * 3 + 1] = sourcePosition.y;
236
+ pos[i * 3 + 2] = sourcePosition.z;
237
+ }
238
+
239
+ return { positions: pos, offsets: off };
240
+ }, [sourcePosition, particleCount]);
241
+
242
+ const geometry = useMemo(() => {
243
+ const geo = new THREE.BufferGeometry();
244
+ geo.setAttribute('position', new THREE.BufferAttribute(positions, 3));
245
+ return geo;
246
+ }, [positions]);
247
+
248
+ // Animate particles
249
+ useFrame((state) => {
250
+ if (particlesRef.current) {
251
+ const posAttr = particlesRef.current.geometry.getAttribute('position') as THREE.BufferAttribute;
252
+
253
+ for (let i = 0; i < particleCount; i++) {
254
+ // Calculate t along path (0 to 1)
255
+ const t = ((state.clock.elapsedTime * speed + offsets[i]) % 1);
256
+
257
+ // Lerp position
258
+ posAttr.setXYZ(
259
+ i,
260
+ sourcePosition.x + (targetPosition.x - sourcePosition.x) * t,
261
+ sourcePosition.y + (targetPosition.y - sourcePosition.y) * t,
262
+ sourcePosition.z + (targetPosition.z - sourcePosition.z) * t
263
+ );
264
+ }
265
+
266
+ posAttr.needsUpdate = true;
267
+ }
268
+ });
269
+
270
+ return (
271
+ <points ref={particlesRef} geometry={geometry}>
272
+ <pointsMaterial
273
+ color={color}
274
+ size={0.05}
275
+ transparent
276
+ opacity={0.8}
277
+ sizeAttenuation
278
+ />
279
+ </points>
280
+ );
281
+ }
282
+
283
+ /**
284
+ * Smart connection that chooses visualization based on layer sizes
285
+ */
286
+ export function SmartConnection({
287
+ sourcePosition,
288
+ targetPosition,
289
+ sourceNeurons = 16,
290
+ targetNeurons = 16,
291
+ color = '#4488ff',
292
+ highlighted = false,
293
+ animated = true,
294
+ style = 'bundle',
295
+ }: NeuralConnectionProps) {
296
+ const totalConnections = sourceNeurons * targetNeurons;
297
+
298
+ // Choose visualization based on connection count
299
+ if (style === 'single' || totalConnections < 50) {
300
+ return (
301
+ <SingleConnection
302
+ sourcePosition={sourcePosition}
303
+ targetPosition={targetPosition}
304
+ color={color}
305
+ highlighted={highlighted}
306
+ />
307
+ );
308
+ }
309
+
310
+ if (style === 'dense' || totalConnections < 500) {
311
+ return (
312
+ <>
313
+ <DenseConnection
314
+ sourcePosition={sourcePosition}
315
+ targetPosition={targetPosition}
316
+ sourceNeurons={sourceNeurons}
317
+ targetNeurons={targetNeurons}
318
+ color={color}
319
+ highlighted={highlighted}
320
+ animated={animated}
321
+ connectionDensity={0.2}
322
+ />
323
+ {animated && (
324
+ <FlowParticles
325
+ sourcePosition={sourcePosition}
326
+ targetPosition={targetPosition}
327
+ color={color}
328
+ particleCount={3}
329
+ speed={0.5}
330
+ />
331
+ )}
332
+ </>
333
+ );
334
+ }
335
+
336
+ // For very dense connections, use bundled representation
337
+ return (
338
+ <>
339
+ <BundledConnection
340
+ sourcePosition={sourcePosition}
341
+ targetPosition={targetPosition}
342
+ sourceNeurons={sourceNeurons}
343
+ targetNeurons={targetNeurons}
344
+ color={color}
345
+ highlighted={highlighted}
346
+ animated={animated}
347
+ />
348
+ {animated && (
349
+ <FlowParticles
350
+ sourcePosition={sourcePosition}
351
+ targetPosition={targetPosition}
352
+ color="#ffffff"
353
+ particleCount={5}
354
+ speed={0.8}
355
+ />
356
+ )}
357
+ </>
358
+ );
359
+ }
src/components/edges/index.ts ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ export * from './EdgeGeometry';
2
+ export * from './EdgeConnections';