Spaces:
Build error
Build error
Commit ·
8a01471
1
Parent(s): cd12c95
added files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +54 -0
- .gitignore +17 -22
- Dockerfile +39 -0
- LICENSE +21 -0
- backend/.dockerignore +32 -0
- backend/Dockerfile +40 -0
- backend/README.md +185 -0
- backend/app/__init__.py +1 -0
- backend/app/database.py +163 -0
- backend/app/main.py +1082 -0
- backend/app/model_analyzer.py +714 -0
- backend/requirements.txt +10 -0
- backend/start.bat +23 -0
- backend/start.sh +23 -0
- docker-compose.yml +60 -0
- docker-start.bat +43 -0
- docker-start.sh +38 -0
- exporters/python/README.md +138 -0
- exporters/python/nn3d_exporter/__init__.py +28 -0
- exporters/python/nn3d_exporter/onnx_exporter.py +371 -0
- exporters/python/nn3d_exporter/pytorch_exporter.py +434 -0
- exporters/python/nn3d_exporter/schema.py +316 -0
- exporters/python/pyproject.toml +61 -0
- files_to_commit.txt +0 -0
- index.html +33 -0
- nginx.conf +47 -0
- package-lock.json +0 -0
- package.json +48 -34
- public/favicon.ico +0 -0
- public/favicon.svg +48 -0
- public/index.html +0 -43
- public/logo192.png +0 -0
- public/logo512.png +0 -0
- public/manifest.json +0 -25
- public/robots.txt +0 -3
- samples/cnn_resnet.nn3d +327 -0
- samples/simple_mlp.nn3d +137 -0
- samples/transformer_encoder.nn3d +220 -0
- src/App.css +179 -22
- src/App.js +0 -25
- src/App.test.js +0 -8
- src/App.tsx +124 -0
- src/components/Scene.tsx +145 -0
- src/components/controls/CameraControls.tsx +117 -0
- src/components/controls/Interaction.tsx +138 -0
- src/components/controls/index.ts +2 -0
- src/components/edges/EdgeConnections.tsx +89 -0
- src/components/edges/EdgeGeometry.tsx +206 -0
- src/components/edges/NeuralConnection.tsx +359 -0
- src/components/edges/index.ts +2 -0
.dockerignore
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dependencies
|
| 2 |
+
node_modules/
|
| 3 |
+
backend/venv/
|
| 4 |
+
backend/__pycache__/
|
| 5 |
+
backend/app/__pycache__/
|
| 6 |
+
|
| 7 |
+
# Build outputs
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
*.egg-info/
|
| 11 |
+
|
| 12 |
+
# IDE
|
| 13 |
+
.vscode/
|
| 14 |
+
.idea/
|
| 15 |
+
*.swp
|
| 16 |
+
*.swo
|
| 17 |
+
|
| 18 |
+
# Git
|
| 19 |
+
.git/
|
| 20 |
+
.gitignore
|
| 21 |
+
|
| 22 |
+
# Environment
|
| 23 |
+
.env
|
| 24 |
+
.env.local
|
| 25 |
+
.env.*.local
|
| 26 |
+
*.env
|
| 27 |
+
|
| 28 |
+
# Logs
|
| 29 |
+
*.log
|
| 30 |
+
npm-debug.log*
|
| 31 |
+
|
| 32 |
+
# OS
|
| 33 |
+
.DS_Store
|
| 34 |
+
Thumbs.db
|
| 35 |
+
|
| 36 |
+
# Test files
|
| 37 |
+
coverage/
|
| 38 |
+
.nyc_output/
|
| 39 |
+
|
| 40 |
+
# Docker
|
| 41 |
+
Dockerfile*
|
| 42 |
+
docker-compose*.yml
|
| 43 |
+
.dockerignore
|
| 44 |
+
|
| 45 |
+
# Database (will be created in container)
|
| 46 |
+
backend/models.db
|
| 47 |
+
*.db
|
| 48 |
+
|
| 49 |
+
# Documentation
|
| 50 |
+
*.md
|
| 51 |
+
LICENSE
|
| 52 |
+
|
| 53 |
+
# Samples (mounted separately)
|
| 54 |
+
# samples/
|
.gitignore
CHANGED
|
@@ -1,23 +1,18 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
/
|
| 5 |
-
/.pnp
|
| 6 |
-
.pnp.js
|
| 7 |
-
|
| 8 |
-
# testing
|
| 9 |
-
/coverage
|
| 10 |
-
|
| 11 |
-
# production
|
| 12 |
-
/build
|
| 13 |
-
|
| 14 |
-
# misc
|
| 15 |
.DS_Store
|
| 16 |
-
.
|
| 17 |
-
.
|
| 18 |
-
.env
|
| 19 |
-
.env.
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore Python virtual environments
|
| 2 |
+
backend/venv/
|
| 3 |
+
node_modules/
|
| 4 |
+
dist/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
.DS_Store
|
| 6 |
+
*.local
|
| 7 |
+
*.log
|
| 8 |
+
.env
|
| 9 |
+
.env.*
|
| 10 |
+
!.env.example
|
| 11 |
+
coverage/
|
| 12 |
+
.nyc_output/
|
| 13 |
+
*.egg-info/
|
| 14 |
+
__pycache__/
|
| 15 |
+
*.pyc
|
| 16 |
+
.pytest_cache/
|
| 17 |
+
build/
|
| 18 |
+
*.nn3d.bak
|
Dockerfile
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Frontend Dockerfile - React/Vite NN3D Visualizer
|
| 2 |
+
FROM node:20-alpine AS builder
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy package files
|
| 8 |
+
COPY package.json package-lock.json ./
|
| 9 |
+
|
| 10 |
+
# Install dependencies
|
| 11 |
+
RUN npm ci
|
| 12 |
+
|
| 13 |
+
# Copy source code
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
# Set environment variable for API URL (nginx proxy)
|
| 17 |
+
ENV VITE_API_URL=/api
|
| 18 |
+
|
| 19 |
+
# Build the application
|
| 20 |
+
RUN npm run build
|
| 21 |
+
|
| 22 |
+
# Production stage - serve with nginx
|
| 23 |
+
FROM nginx:alpine
|
| 24 |
+
|
| 25 |
+
# Copy custom nginx config
|
| 26 |
+
COPY nginx.conf /etc/nginx/conf.d/default.conf
|
| 27 |
+
|
| 28 |
+
# Copy built assets from builder
|
| 29 |
+
COPY --from=builder /app/dist /usr/share/nginx/html
|
| 30 |
+
|
| 31 |
+
# Expose port
|
| 32 |
+
EXPOSE 80
|
| 33 |
+
|
| 34 |
+
# Health check
|
| 35 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 36 |
+
CMD wget --no-verbose --tries=1 --spider http://localhost:80/ || exit 1
|
| 37 |
+
|
| 38 |
+
# Start nginx
|
| 39 |
+
CMD ["nginx", "-g", "daemon off;"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 3D Deep Learning Model Visualizer
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
backend/.dockerignore
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Virtual environment
|
| 2 |
+
venv/
|
| 3 |
+
.venv/
|
| 4 |
+
env/
|
| 5 |
+
|
| 6 |
+
# Python cache
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.py[cod]
|
| 9 |
+
*$py.class
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Database (will be created fresh)
|
| 13 |
+
models.db
|
| 14 |
+
*.db
|
| 15 |
+
|
| 16 |
+
# IDE
|
| 17 |
+
.vscode/
|
| 18 |
+
.idea/
|
| 19 |
+
|
| 20 |
+
# Logs
|
| 21 |
+
*.log
|
| 22 |
+
|
| 23 |
+
# Environment files
|
| 24 |
+
.env
|
| 25 |
+
.env.*
|
| 26 |
+
|
| 27 |
+
# Documentation
|
| 28 |
+
*.md
|
| 29 |
+
|
| 30 |
+
# Scripts (not needed in container)
|
| 31 |
+
start.bat
|
| 32 |
+
start.sh
|
backend/Dockerfile
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend Dockerfile - FastAPI Model Analyzer
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies for PyTorch and ONNX
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
build-essential \
|
| 10 |
+
libgomp1 \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Copy requirements first for better caching
|
| 14 |
+
COPY requirements.txt .
|
| 15 |
+
|
| 16 |
+
# Install Python dependencies
|
| 17 |
+
# Using CPU-only PyTorch for smaller image size
|
| 18 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 19 |
+
pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
|
| 20 |
+
pip install --no-cache-dir -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Copy application code
|
| 23 |
+
COPY app/ ./app/
|
| 24 |
+
|
| 25 |
+
# Create directory for database
|
| 26 |
+
RUN mkdir -p /app/data
|
| 27 |
+
|
| 28 |
+
# Environment variables
|
| 29 |
+
ENV PYTHONUNBUFFERED=1
|
| 30 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 31 |
+
|
| 32 |
+
# Expose port
|
| 33 |
+
EXPOSE 8000
|
| 34 |
+
|
| 35 |
+
# Health check
|
| 36 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 37 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 38 |
+
|
| 39 |
+
# Run the application
|
| 40 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
backend/README.md
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NN3D Visualizer Backend
|
| 2 |
+
|
| 3 |
+
Python microservice for analyzing neural network model architectures.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **PyTorch Model Analysis**: Extracts architecture information from `.pt`, `.pth`, `.ckpt`, `.bin` files
|
| 8 |
+
- **ONNX Model Analysis**: Parses ONNX model graphs
|
| 9 |
+
- **Shape Inference**: Traces model execution to capture input/output shapes
|
| 10 |
+
- **Layer Type Detection**: Identifies layer types from weight names and shapes
|
| 11 |
+
|
| 12 |
+
## Requirements
|
| 13 |
+
|
| 14 |
+
- Python 3.9+
|
| 15 |
+
- PyTorch 2.0+
|
| 16 |
+
|
| 17 |
+
## Quick Start
|
| 18 |
+
|
| 19 |
+
### Windows
|
| 20 |
+
|
| 21 |
+
```batch
|
| 22 |
+
cd backend
|
| 23 |
+
start.bat
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### Linux/Mac
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
cd backend
|
| 30 |
+
chmod +x start.sh
|
| 31 |
+
./start.sh
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### Manual Setup
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
cd backend
|
| 38 |
+
python -m venv venv
|
| 39 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
python -m uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## API Endpoints
|
| 45 |
+
|
| 46 |
+
### Health Check
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
GET /health
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
Returns server status and PyTorch version.
|
| 53 |
+
|
| 54 |
+
### Analyze PyTorch Model
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
POST /analyze
|
| 58 |
+
Content-Type: multipart/form-data
|
| 59 |
+
|
| 60 |
+
file: <model_file>
|
| 61 |
+
input_shape: (optional) comma-separated integers, e.g., "1,3,224,224"
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Analyze ONNX Model
|
| 65 |
+
|
| 66 |
+
```
|
| 67 |
+
POST /analyze/onnx
|
| 68 |
+
Content-Type: multipart/form-data
|
| 69 |
+
|
| 70 |
+
file: <onnx_file>
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Response Format
|
| 74 |
+
|
| 75 |
+
```json
|
| 76 |
+
{
|
| 77 |
+
"success": true,
|
| 78 |
+
"model_type": "full_model|state_dict|torchscript|checkpoint",
|
| 79 |
+
"architecture": {
|
| 80 |
+
"name": "model_name",
|
| 81 |
+
"framework": "pytorch",
|
| 82 |
+
"totalParameters": 1000000,
|
| 83 |
+
"trainableParameters": 1000000,
|
| 84 |
+
"inputShape": [1, 3, 224, 224],
|
| 85 |
+
"outputShape": [1, 1000],
|
| 86 |
+
"layers": [
|
| 87 |
+
{
|
| 88 |
+
"id": "layer_0",
|
| 89 |
+
"name": "conv1",
|
| 90 |
+
"type": "Conv2d",
|
| 91 |
+
"category": "convolution",
|
| 92 |
+
"inputShape": [1, 3, 224, 224],
|
| 93 |
+
"outputShape": [1, 64, 112, 112],
|
| 94 |
+
"params": {
|
| 95 |
+
"in_channels": 3,
|
| 96 |
+
"out_channels": 64,
|
| 97 |
+
"kernel_size": [7, 7],
|
| 98 |
+
"stride": [2, 2],
|
| 99 |
+
"padding": [3, 3]
|
| 100 |
+
},
|
| 101 |
+
"numParameters": 9408,
|
| 102 |
+
"trainable": true
|
| 103 |
+
}
|
| 104 |
+
],
|
| 105 |
+
"connections": [
|
| 106 |
+
{
|
| 107 |
+
"source": "layer_0",
|
| 108 |
+
"target": "layer_1",
|
| 109 |
+
"tensorShape": [1, 64, 112, 112]
|
| 110 |
+
}
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
"message": "Successfully analyzed model"
|
| 114 |
+
}
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
## Integration with Frontend
|
| 118 |
+
|
| 119 |
+
The frontend automatically detects if the backend is available:
|
| 120 |
+
|
| 121 |
+
1. Start the backend server (port 8000)
|
| 122 |
+
2. Start the frontend dev server (port 3000)
|
| 123 |
+
3. Drop a PyTorch model file - it will use the backend for analysis
|
| 124 |
+
|
| 125 |
+
If the backend is unavailable, the frontend falls back to JavaScript-based parsing.
|
| 126 |
+
|
| 127 |
+
## Supported Layer Types
|
| 128 |
+
|
| 129 |
+
### Convolution
|
| 130 |
+
|
| 131 |
+
- Conv1d, Conv2d, Conv3d
|
| 132 |
+
- ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
| 133 |
+
|
| 134 |
+
### Pooling
|
| 135 |
+
|
| 136 |
+
- MaxPool1d/2d/3d, AvgPool1d/2d/3d
|
| 137 |
+
- AdaptiveAvgPool1d/2d/3d, AdaptiveMaxPool1d/2d/3d
|
| 138 |
+
|
| 139 |
+
### Linear
|
| 140 |
+
|
| 141 |
+
- Linear, LazyLinear, Bilinear
|
| 142 |
+
|
| 143 |
+
### Normalization
|
| 144 |
+
|
| 145 |
+
- BatchNorm1d/2d/3d
|
| 146 |
+
- LayerNorm, GroupNorm, InstanceNorm
|
| 147 |
+
|
| 148 |
+
### Activation
|
| 149 |
+
|
| 150 |
+
- ReLU, LeakyReLU, PReLU, ELU, SELU
|
| 151 |
+
- GELU, Sigmoid, Tanh, Softmax, SiLU, Mish
|
| 152 |
+
|
| 153 |
+
### Recurrent
|
| 154 |
+
|
| 155 |
+
- RNN, LSTM, GRU
|
| 156 |
+
|
| 157 |
+
### Attention
|
| 158 |
+
|
| 159 |
+
- MultiheadAttention, Transformer, TransformerEncoder/Decoder
|
| 160 |
+
|
| 161 |
+
### Embedding
|
| 162 |
+
|
| 163 |
+
- Embedding, EmbeddingBag
|
| 164 |
+
|
| 165 |
+
### Regularization
|
| 166 |
+
|
| 167 |
+
- Dropout, Dropout2d/3d, AlphaDropout
|
| 168 |
+
|
| 169 |
+
## Architecture
|
| 170 |
+
|
| 171 |
+
```
|
| 172 |
+
backend/
|
| 173 |
+
├── app/
|
| 174 |
+
│ ├── __init__.py
|
| 175 |
+
│ ├── main.py # FastAPI application
|
| 176 |
+
│ └── model_analyzer.py # PyTorch model analysis
|
| 177 |
+
├── requirements.txt
|
| 178 |
+
├── start.bat # Windows startup script
|
| 179 |
+
├── start.sh # Linux/Mac startup script
|
| 180 |
+
└── README.md
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## Development
|
| 184 |
+
|
| 185 |
+
API documentation is available at `http://localhost:8000/docs` when the server is running.
|
backend/app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# NN3D Visualizer Backend
|
backend/app/database.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database module for storing processed model architectures.
|
| 3 |
+
Uses SQLite for simple, file-based persistence.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import sqlite3
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Optional
|
| 12 |
+
from contextlib import contextmanager
|
| 13 |
+
|
| 14 |
+
# Database file path - use environment variable for Docker, fallback to local
|
| 15 |
+
DB_PATH = Path(os.environ.get("DATABASE_PATH", Path(__file__).parent.parent / "models.db"))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_connection():
|
| 19 |
+
"""Get a database connection."""
|
| 20 |
+
conn = sqlite3.connect(str(DB_PATH))
|
| 21 |
+
conn.row_factory = sqlite3.Row
|
| 22 |
+
return conn
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@contextmanager
|
| 26 |
+
def get_db():
|
| 27 |
+
"""Context manager for database connections."""
|
| 28 |
+
conn = get_connection()
|
| 29 |
+
try:
|
| 30 |
+
yield conn
|
| 31 |
+
conn.commit()
|
| 32 |
+
except Exception:
|
| 33 |
+
conn.rollback()
|
| 34 |
+
raise
|
| 35 |
+
finally:
|
| 36 |
+
conn.close()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def init_db():
|
| 40 |
+
"""Initialize the database tables."""
|
| 41 |
+
with get_db() as conn:
|
| 42 |
+
conn.execute("""
|
| 43 |
+
CREATE TABLE IF NOT EXISTS saved_models (
|
| 44 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 45 |
+
name TEXT NOT NULL,
|
| 46 |
+
framework TEXT NOT NULL,
|
| 47 |
+
total_parameters INTEGER DEFAULT 0,
|
| 48 |
+
layer_count INTEGER DEFAULT 0,
|
| 49 |
+
architecture_json TEXT NOT NULL,
|
| 50 |
+
thumbnail TEXT,
|
| 51 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 52 |
+
file_hash TEXT UNIQUE
|
| 53 |
+
)
|
| 54 |
+
""")
|
| 55 |
+
conn.execute("""
|
| 56 |
+
CREATE INDEX IF NOT EXISTS idx_name ON saved_models(name)
|
| 57 |
+
""")
|
| 58 |
+
conn.execute("""
|
| 59 |
+
CREATE INDEX IF NOT EXISTS idx_created_at ON saved_models(created_at)
|
| 60 |
+
""")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def save_model(
|
| 64 |
+
name: str,
|
| 65 |
+
framework: str,
|
| 66 |
+
total_parameters: int,
|
| 67 |
+
layer_count: int,
|
| 68 |
+
architecture: dict,
|
| 69 |
+
file_hash: Optional[str] = None,
|
| 70 |
+
thumbnail: Optional[str] = None
|
| 71 |
+
) -> int:
|
| 72 |
+
"""
|
| 73 |
+
Save a model architecture to the database.
|
| 74 |
+
Returns the saved model ID.
|
| 75 |
+
"""
|
| 76 |
+
architecture_json = json.dumps(architecture)
|
| 77 |
+
|
| 78 |
+
with get_db() as conn:
|
| 79 |
+
# Check if model with same hash already exists
|
| 80 |
+
if file_hash:
|
| 81 |
+
existing = conn.execute(
|
| 82 |
+
"SELECT id FROM saved_models WHERE file_hash = ?",
|
| 83 |
+
(file_hash,)
|
| 84 |
+
).fetchone()
|
| 85 |
+
|
| 86 |
+
if existing:
|
| 87 |
+
# Update existing entry
|
| 88 |
+
conn.execute("""
|
| 89 |
+
UPDATE saved_models
|
| 90 |
+
SET name = ?, framework = ?, total_parameters = ?,
|
| 91 |
+
layer_count = ?, architecture_json = ?,
|
| 92 |
+
thumbnail = ?, created_at = CURRENT_TIMESTAMP
|
| 93 |
+
WHERE file_hash = ?
|
| 94 |
+
""", (name, framework, total_parameters, layer_count,
|
| 95 |
+
architecture_json, thumbnail, file_hash))
|
| 96 |
+
return existing['id']
|
| 97 |
+
|
| 98 |
+
# Insert new entry
|
| 99 |
+
cursor = conn.execute("""
|
| 100 |
+
INSERT INTO saved_models
|
| 101 |
+
(name, framework, total_parameters, layer_count, architecture_json, file_hash, thumbnail)
|
| 102 |
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
| 103 |
+
""", (name, framework, total_parameters, layer_count,
|
| 104 |
+
architecture_json, file_hash, thumbnail))
|
| 105 |
+
|
| 106 |
+
return cursor.lastrowid
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_saved_models() -> List[dict]:
|
| 110 |
+
"""Get all saved models (metadata only, not full architecture)."""
|
| 111 |
+
with get_db() as conn:
|
| 112 |
+
rows = conn.execute("""
|
| 113 |
+
SELECT id, name, framework, total_parameters, layer_count,
|
| 114 |
+
thumbnail, created_at
|
| 115 |
+
FROM saved_models
|
| 116 |
+
ORDER BY created_at DESC
|
| 117 |
+
""").fetchall()
|
| 118 |
+
|
| 119 |
+
return [dict(row) for row in rows]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_model_by_id(model_id: int) -> Optional[dict]:
|
| 123 |
+
"""Get a specific model with full architecture."""
|
| 124 |
+
with get_db() as conn:
|
| 125 |
+
row = conn.execute("""
|
| 126 |
+
SELECT id, name, framework, total_parameters, layer_count,
|
| 127 |
+
architecture_json, thumbnail, created_at
|
| 128 |
+
FROM saved_models
|
| 129 |
+
WHERE id = ?
|
| 130 |
+
""", (model_id,)).fetchone()
|
| 131 |
+
|
| 132 |
+
if row:
|
| 133 |
+
result = dict(row)
|
| 134 |
+
result['architecture'] = json.loads(result['architecture_json'])
|
| 135 |
+
del result['architecture_json']
|
| 136 |
+
return result
|
| 137 |
+
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def delete_model(model_id: int) -> bool:
|
| 142 |
+
"""Delete a model by ID. Returns True if deleted."""
|
| 143 |
+
with get_db() as conn:
|
| 144 |
+
cursor = conn.execute(
|
| 145 |
+
"DELETE FROM saved_models WHERE id = ?",
|
| 146 |
+
(model_id,)
|
| 147 |
+
)
|
| 148 |
+
return cursor.rowcount > 0
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def model_exists_by_hash(file_hash: str) -> Optional[int]:
|
| 152 |
+
"""Check if a model with the given hash exists. Returns ID if exists."""
|
| 153 |
+
with get_db() as conn:
|
| 154 |
+
row = conn.execute(
|
| 155 |
+
"SELECT id FROM saved_models WHERE file_hash = ?",
|
| 156 |
+
(file_hash,)
|
| 157 |
+
).fetchone()
|
| 158 |
+
|
| 159 |
+
return row['id'] if row else None
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Initialize database on module load
|
| 163 |
+
init_db()
|
backend/app/main.py
ADDED
|
@@ -0,0 +1,1082 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NN3D Visualizer Backend API
|
| 3 |
+
FastAPI server for analyzing neural network models.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import hashlib
|
| 8 |
+
import tempfile
|
| 9 |
+
import traceback
|
| 10 |
+
from typing import Optional, List
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
| 14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
from fastapi.responses import JSONResponse
|
| 16 |
+
from pydantic import BaseModel
|
| 17 |
+
|
| 18 |
+
from .model_analyzer import (
|
| 19 |
+
load_pytorch_model,
|
| 20 |
+
analyze_model_structure,
|
| 21 |
+
analyze_state_dict,
|
| 22 |
+
trace_model_shapes,
|
| 23 |
+
architecture_to_dict
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from . import database as db
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
app = FastAPI(
|
| 32 |
+
title="NN3D Model Analyzer",
|
| 33 |
+
description="Backend service for analyzing neural network model architectures",
|
| 34 |
+
version="1.0.0"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Enable CORS for frontend
|
| 38 |
+
app.add_middleware(
|
| 39 |
+
CORSMiddleware,
|
| 40 |
+
allow_origins=["*"], # In production, specify exact origins
|
| 41 |
+
allow_credentials=True,
|
| 42 |
+
allow_methods=["*"],
|
| 43 |
+
allow_headers=["*"],
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AnalysisRequest(BaseModel):
|
| 48 |
+
"""Request model for analysis with sample input shape."""
|
| 49 |
+
input_shape: Optional[List[int]] = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class AnalysisResponse(BaseModel):
|
| 53 |
+
"""Response model for analysis results."""
|
| 54 |
+
success: bool
|
| 55 |
+
model_type: str
|
| 56 |
+
architecture: dict
|
| 57 |
+
message: Optional[str] = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class HealthResponse(BaseModel):
|
| 61 |
+
"""Health check response."""
|
| 62 |
+
status: str
|
| 63 |
+
pytorch_version: str
|
| 64 |
+
cuda_available: bool
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.get("/health", response_model=HealthResponse)
|
| 68 |
+
async def health_check():
|
| 69 |
+
"""Check server health and PyTorch availability."""
|
| 70 |
+
return HealthResponse(
|
| 71 |
+
status="healthy",
|
| 72 |
+
pytorch_version=torch.__version__,
|
| 73 |
+
cuda_available=torch.cuda.is_available()
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@app.post("/analyze", response_model=AnalysisResponse)
|
| 78 |
+
async def analyze_model(
|
| 79 |
+
file: UploadFile = File(...),
|
| 80 |
+
input_shape: Optional[str] = Query(None, description="Input shape as comma-separated ints, e.g., '1,3,224,224'")
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
Analyze a PyTorch model file and extract architecture information.
|
| 84 |
+
|
| 85 |
+
Supports:
|
| 86 |
+
- Full model files (.pt, .pth)
|
| 87 |
+
- State dict checkpoints
|
| 88 |
+
- TorchScript models
|
| 89 |
+
- Training checkpoints with model_state_dict
|
| 90 |
+
"""
|
| 91 |
+
# Validate file extension
|
| 92 |
+
allowed_extensions = {'.pt', '.pth', '.ckpt', '.bin', '.model'}
|
| 93 |
+
file_ext = Path(file.filename).suffix.lower()
|
| 94 |
+
|
| 95 |
+
if file_ext not in allowed_extensions:
|
| 96 |
+
raise HTTPException(
|
| 97 |
+
status_code=400,
|
| 98 |
+
detail=f"Unsupported file format. Allowed: {', '.join(allowed_extensions)}"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Parse input shape if provided
|
| 102 |
+
sample_shape = None
|
| 103 |
+
if input_shape:
|
| 104 |
+
try:
|
| 105 |
+
sample_shape = [int(x.strip()) for x in input_shape.split(',')]
|
| 106 |
+
except ValueError:
|
| 107 |
+
raise HTTPException(
|
| 108 |
+
status_code=400,
|
| 109 |
+
detail="Invalid input_shape format. Use comma-separated integers, e.g., '1,3,224,224'"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Save uploaded file temporarily
|
| 113 |
+
temp_path = None
|
| 114 |
+
try:
|
| 115 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
|
| 116 |
+
content = await file.read()
|
| 117 |
+
temp_file.write(content)
|
| 118 |
+
temp_path = temp_file.name
|
| 119 |
+
|
| 120 |
+
# Load and analyze model
|
| 121 |
+
model, state_dict, model_type = load_pytorch_model(temp_path)
|
| 122 |
+
|
| 123 |
+
if model is not None:
|
| 124 |
+
# Full model available - analyze structure
|
| 125 |
+
model_name = Path(file.filename).stem
|
| 126 |
+
architecture = analyze_model_structure(model, model_name)
|
| 127 |
+
|
| 128 |
+
# Try to trace shapes if input shape provided
|
| 129 |
+
if sample_shape and model_type != 'torchscript':
|
| 130 |
+
try:
|
| 131 |
+
sample_input = torch.randn(*sample_shape)
|
| 132 |
+
architecture = trace_model_shapes(model, sample_input, architecture)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"Shape tracing failed: {e}")
|
| 135 |
+
|
| 136 |
+
return AnalysisResponse(
|
| 137 |
+
success=True,
|
| 138 |
+
model_type=model_type,
|
| 139 |
+
architecture=architecture_to_dict(architecture),
|
| 140 |
+
message=f"Successfully analyzed {model_type} model"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
elif state_dict is not None:
|
| 144 |
+
# Only state dict available - infer from weights
|
| 145 |
+
model_name = Path(file.filename).stem
|
| 146 |
+
architecture = analyze_state_dict(state_dict, model_name)
|
| 147 |
+
|
| 148 |
+
return AnalysisResponse(
|
| 149 |
+
success=True,
|
| 150 |
+
model_type='state_dict',
|
| 151 |
+
architecture=architecture_to_dict(architecture),
|
| 152 |
+
message="Analyzed from state dict. Layer types inferred from weight names/shapes."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
else:
|
| 156 |
+
raise HTTPException(
|
| 157 |
+
status_code=400,
|
| 158 |
+
detail="Could not parse model file. Unknown format."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
except HTTPException:
|
| 162 |
+
raise
|
| 163 |
+
except Exception as e:
|
| 164 |
+
traceback.print_exc()
|
| 165 |
+
raise HTTPException(
|
| 166 |
+
status_code=500,
|
| 167 |
+
detail=f"Analysis failed: {str(e)}"
|
| 168 |
+
)
|
| 169 |
+
finally:
|
| 170 |
+
# Cleanup temp file
|
| 171 |
+
if temp_path and os.path.exists(temp_path):
|
| 172 |
+
try:
|
| 173 |
+
os.unlink(temp_path)
|
| 174 |
+
except Exception:
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@app.post("/analyze/onnx")
|
| 179 |
+
async def analyze_onnx_model(file: UploadFile = File(...)):
|
| 180 |
+
"""
|
| 181 |
+
Analyze an ONNX model file.
|
| 182 |
+
"""
|
| 183 |
+
if not file.filename.lower().endswith('.onnx'):
|
| 184 |
+
raise HTTPException(status_code=400, detail="File must be an ONNX model (.onnx)")
|
| 185 |
+
|
| 186 |
+
temp_path = None
|
| 187 |
+
try:
|
| 188 |
+
import onnx
|
| 189 |
+
|
| 190 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.onnx') as temp_file:
|
| 191 |
+
content = await file.read()
|
| 192 |
+
temp_file.write(content)
|
| 193 |
+
temp_path = temp_file.name
|
| 194 |
+
|
| 195 |
+
model = onnx.load(temp_path)
|
| 196 |
+
graph = model.graph
|
| 197 |
+
|
| 198 |
+
layers = []
|
| 199 |
+
connections = []
|
| 200 |
+
layer_map = {}
|
| 201 |
+
|
| 202 |
+
# Process nodes
|
| 203 |
+
for i, node in enumerate(graph.node):
|
| 204 |
+
layer_id = f"layer_{i}"
|
| 205 |
+
layer_map[node.name if node.name else f"node_{i}"] = layer_id
|
| 206 |
+
|
| 207 |
+
# Map output names to layer ids
|
| 208 |
+
for output in node.output:
|
| 209 |
+
layer_map[output] = layer_id
|
| 210 |
+
|
| 211 |
+
# Extract attributes
|
| 212 |
+
params = {}
|
| 213 |
+
for attr in node.attribute:
|
| 214 |
+
if attr.type == onnx.AttributeProto.INT:
|
| 215 |
+
params[attr.name] = attr.i
|
| 216 |
+
elif attr.type == onnx.AttributeProto.INTS:
|
| 217 |
+
params[attr.name] = list(attr.ints)
|
| 218 |
+
elif attr.type == onnx.AttributeProto.FLOAT:
|
| 219 |
+
params[attr.name] = attr.f
|
| 220 |
+
elif attr.type == onnx.AttributeProto.STRING:
|
| 221 |
+
params[attr.name] = attr.s.decode('utf-8')
|
| 222 |
+
|
| 223 |
+
layers.append({
|
| 224 |
+
'id': layer_id,
|
| 225 |
+
'name': node.name if node.name else node.op_type,
|
| 226 |
+
'type': node.op_type,
|
| 227 |
+
'category': infer_onnx_category(node.op_type),
|
| 228 |
+
'inputShape': None,
|
| 229 |
+
'outputShape': None,
|
| 230 |
+
'params': params,
|
| 231 |
+
'numParameters': 0,
|
| 232 |
+
'trainable': True
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
# Create connections from inputs
|
| 236 |
+
for input_name in node.input:
|
| 237 |
+
if input_name in layer_map:
|
| 238 |
+
source_id = layer_map[input_name]
|
| 239 |
+
if source_id != layer_id: # Avoid self-loops
|
| 240 |
+
connections.append({
|
| 241 |
+
'source': source_id,
|
| 242 |
+
'target': layer_id,
|
| 243 |
+
'tensorShape': None
|
| 244 |
+
})
|
| 245 |
+
|
| 246 |
+
# Get input/output shapes from graph
|
| 247 |
+
input_shape = None
|
| 248 |
+
output_shape = None
|
| 249 |
+
|
| 250 |
+
if graph.input:
|
| 251 |
+
for inp in graph.input:
|
| 252 |
+
shape = []
|
| 253 |
+
if inp.type.tensor_type.shape.dim:
|
| 254 |
+
for dim in inp.type.tensor_type.shape.dim:
|
| 255 |
+
shape.append(dim.dim_value if dim.dim_value else -1)
|
| 256 |
+
if shape:
|
| 257 |
+
input_shape = shape
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
if graph.output:
|
| 261 |
+
for out in graph.output:
|
| 262 |
+
shape = []
|
| 263 |
+
if out.type.tensor_type.shape.dim:
|
| 264 |
+
for dim in out.type.tensor_type.shape.dim:
|
| 265 |
+
shape.append(dim.dim_value if dim.dim_value else -1)
|
| 266 |
+
if shape:
|
| 267 |
+
output_shape = shape
|
| 268 |
+
break
|
| 269 |
+
|
| 270 |
+
architecture = {
|
| 271 |
+
'name': Path(file.filename).stem,
|
| 272 |
+
'framework': 'onnx',
|
| 273 |
+
'totalParameters': 0,
|
| 274 |
+
'trainableParameters': 0,
|
| 275 |
+
'inputShape': input_shape,
|
| 276 |
+
'outputShape': output_shape,
|
| 277 |
+
'layers': layers,
|
| 278 |
+
'connections': connections
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
return AnalysisResponse(
|
| 282 |
+
success=True,
|
| 283 |
+
model_type='onnx',
|
| 284 |
+
architecture=architecture,
|
| 285 |
+
message="Successfully analyzed ONNX model"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
except Exception as e:
|
| 289 |
+
traceback.print_exc()
|
| 290 |
+
raise HTTPException(status_code=500, detail=f"ONNX analysis failed: {str(e)}")
|
| 291 |
+
finally:
|
| 292 |
+
if temp_path and os.path.exists(temp_path):
|
| 293 |
+
try:
|
| 294 |
+
os.unlink(temp_path)
|
| 295 |
+
except Exception:
|
| 296 |
+
pass
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def infer_onnx_category(op_type: str) -> str:
|
| 300 |
+
"""Infer category from ONNX operator type."""
|
| 301 |
+
op_lower = op_type.lower()
|
| 302 |
+
|
| 303 |
+
if 'conv' in op_lower:
|
| 304 |
+
return 'convolution'
|
| 305 |
+
if 'pool' in op_lower:
|
| 306 |
+
return 'pooling'
|
| 307 |
+
if 'norm' in op_lower or 'batch' in op_lower:
|
| 308 |
+
return 'normalization'
|
| 309 |
+
if 'relu' in op_lower or 'sigmoid' in op_lower or 'tanh' in op_lower or 'softmax' in op_lower:
|
| 310 |
+
return 'activation'
|
| 311 |
+
if 'gemm' in op_lower or 'matmul' in op_lower or 'linear' in op_lower:
|
| 312 |
+
return 'linear'
|
| 313 |
+
if 'lstm' in op_lower or 'gru' in op_lower or 'rnn' in op_lower:
|
| 314 |
+
return 'recurrent'
|
| 315 |
+
if 'attention' in op_lower:
|
| 316 |
+
return 'attention'
|
| 317 |
+
if 'dropout' in op_lower:
|
| 318 |
+
return 'regularization'
|
| 319 |
+
if 'reshape' in op_lower or 'flatten' in op_lower or 'squeeze' in op_lower:
|
| 320 |
+
return 'reshape'
|
| 321 |
+
if 'add' in op_lower or 'mul' in op_lower or 'sub' in op_lower:
|
| 322 |
+
return 'arithmetic'
|
| 323 |
+
if 'concat' in op_lower or 'split' in op_lower:
|
| 324 |
+
return 'merge'
|
| 325 |
+
|
| 326 |
+
return 'other'
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# Mapping of file extensions to supported frameworks
|
| 330 |
+
SUPPORTED_FORMATS = {
|
| 331 |
+
# PyTorch formats
|
| 332 |
+
'.pt': 'pytorch',
|
| 333 |
+
'.pth': 'pytorch',
|
| 334 |
+
'.ckpt': 'pytorch',
|
| 335 |
+
'.bin': 'pytorch',
|
| 336 |
+
'.model': 'pytorch',
|
| 337 |
+
# ONNX format
|
| 338 |
+
'.onnx': 'onnx',
|
| 339 |
+
# TensorFlow/Keras formats
|
| 340 |
+
'.h5': 'keras',
|
| 341 |
+
'.hdf5': 'keras',
|
| 342 |
+
'.keras': 'keras',
|
| 343 |
+
'.pb': 'tensorflow',
|
| 344 |
+
# SafeTensors format
|
| 345 |
+
'.safetensors': 'safetensors',
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@app.post("/analyze/universal")
|
| 350 |
+
async def analyze_universal(
|
| 351 |
+
file: UploadFile = File(...),
|
| 352 |
+
input_shape: Optional[str] = Query(None, description="Input shape as comma-separated ints, e.g., '1,3,224,224'")
|
| 353 |
+
):
|
| 354 |
+
"""
|
| 355 |
+
Universal model analyzer - accepts any supported model format.
|
| 356 |
+
|
| 357 |
+
Supported formats:
|
| 358 |
+
- PyTorch: .pt, .pth, .ckpt, .bin, .model
|
| 359 |
+
- ONNX: .onnx
|
| 360 |
+
- Keras/TensorFlow: .h5, .hdf5, .keras, .pb
|
| 361 |
+
- SafeTensors: .safetensors
|
| 362 |
+
|
| 363 |
+
Returns detailed architecture information including:
|
| 364 |
+
- Layer types and names
|
| 365 |
+
- Input/output shapes for each layer
|
| 366 |
+
- Parameter counts
|
| 367 |
+
- Layer connections
|
| 368 |
+
- Model metadata
|
| 369 |
+
"""
|
| 370 |
+
filename = file.filename or "unknown"
|
| 371 |
+
file_ext = Path(filename).suffix.lower()
|
| 372 |
+
|
| 373 |
+
# Check if format is supported
|
| 374 |
+
if file_ext not in SUPPORTED_FORMATS:
|
| 375 |
+
supported = ', '.join(sorted(SUPPORTED_FORMATS.keys()))
|
| 376 |
+
raise HTTPException(
|
| 377 |
+
status_code=400,
|
| 378 |
+
detail=f"Unsupported file format '{file_ext}'. Supported formats: {supported}"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
framework = SUPPORTED_FORMATS[file_ext]
|
| 382 |
+
|
| 383 |
+
# Parse input shape if provided
|
| 384 |
+
sample_shape = None
|
| 385 |
+
if input_shape:
|
| 386 |
+
try:
|
| 387 |
+
sample_shape = [int(x.strip()) for x in input_shape.split(',')]
|
| 388 |
+
except ValueError:
|
| 389 |
+
raise HTTPException(
|
| 390 |
+
status_code=400,
|
| 391 |
+
detail="Invalid input_shape format. Use comma-separated integers, e.g., '1,3,224,224'"
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
temp_path = None
|
| 395 |
+
try:
|
| 396 |
+
# Save uploaded file temporarily
|
| 397 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
|
| 398 |
+
content = await file.read()
|
| 399 |
+
temp_file.write(content)
|
| 400 |
+
temp_path = temp_file.name
|
| 401 |
+
|
| 402 |
+
# Route to appropriate analyzer based on framework
|
| 403 |
+
if framework == 'pytorch':
|
| 404 |
+
return await _analyze_pytorch(temp_path, filename, sample_shape)
|
| 405 |
+
elif framework == 'onnx':
|
| 406 |
+
return await _analyze_onnx(temp_path, filename)
|
| 407 |
+
elif framework == 'keras':
|
| 408 |
+
return await _analyze_keras(temp_path, filename)
|
| 409 |
+
elif framework == 'tensorflow':
|
| 410 |
+
return await _analyze_tensorflow(temp_path, filename)
|
| 411 |
+
elif framework == 'safetensors':
|
| 412 |
+
return await _analyze_safetensors(temp_path, filename)
|
| 413 |
+
else:
|
| 414 |
+
raise HTTPException(status_code=400, detail=f"Framework '{framework}' not yet implemented")
|
| 415 |
+
|
| 416 |
+
except HTTPException:
|
| 417 |
+
raise
|
| 418 |
+
except Exception as e:
|
| 419 |
+
traceback.print_exc()
|
| 420 |
+
raise HTTPException(
|
| 421 |
+
status_code=500,
|
| 422 |
+
detail=f"Analysis failed: {str(e)}"
|
| 423 |
+
)
|
| 424 |
+
finally:
|
| 425 |
+
if temp_path and os.path.exists(temp_path):
|
| 426 |
+
try:
|
| 427 |
+
os.unlink(temp_path)
|
| 428 |
+
except Exception:
|
| 429 |
+
pass
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
async def _analyze_pytorch(file_path: str, filename: str, sample_shape: Optional[List[int]] = None) -> AnalysisResponse:
|
| 433 |
+
"""Analyze PyTorch model."""
|
| 434 |
+
model, state_dict, model_type = load_pytorch_model(file_path)
|
| 435 |
+
model_name = Path(filename).stem
|
| 436 |
+
|
| 437 |
+
if model is not None:
|
| 438 |
+
architecture = analyze_model_structure(model, model_name)
|
| 439 |
+
|
| 440 |
+
# Try to trace shapes if input shape provided
|
| 441 |
+
if sample_shape and model_type != 'torchscript':
|
| 442 |
+
try:
|
| 443 |
+
sample_input = torch.randn(*sample_shape)
|
| 444 |
+
architecture = trace_model_shapes(model, sample_input, architecture)
|
| 445 |
+
except Exception as e:
|
| 446 |
+
print(f"Shape tracing failed: {e}")
|
| 447 |
+
|
| 448 |
+
return AnalysisResponse(
|
| 449 |
+
success=True,
|
| 450 |
+
model_type=model_type,
|
| 451 |
+
architecture=architecture_to_dict(architecture),
|
| 452 |
+
message=f"Successfully analyzed PyTorch {model_type}"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
elif state_dict is not None:
|
| 456 |
+
architecture = analyze_state_dict(state_dict, model_name)
|
| 457 |
+
return AnalysisResponse(
|
| 458 |
+
success=True,
|
| 459 |
+
model_type='state_dict',
|
| 460 |
+
architecture=architecture_to_dict(architecture),
|
| 461 |
+
message="Analyzed from state dict. Layer types inferred from weight names/shapes."
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
raise HTTPException(status_code=400, detail="Could not parse PyTorch model file")
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
async def _analyze_onnx(file_path: str, filename: str) -> AnalysisResponse:
|
| 468 |
+
"""Analyze ONNX model."""
|
| 469 |
+
try:
|
| 470 |
+
import onnx
|
| 471 |
+
except ImportError:
|
| 472 |
+
raise HTTPException(status_code=500, detail="ONNX library not installed. Install with: pip install onnx")
|
| 473 |
+
|
| 474 |
+
model = onnx.load(file_path)
|
| 475 |
+
graph = model.graph
|
| 476 |
+
|
| 477 |
+
layers = []
|
| 478 |
+
connections = []
|
| 479 |
+
layer_map = {}
|
| 480 |
+
total_params = 0
|
| 481 |
+
|
| 482 |
+
# Process initializers (weights) for parameter counts
|
| 483 |
+
weight_shapes = {}
|
| 484 |
+
for init in graph.initializer:
|
| 485 |
+
dims = list(init.dims)
|
| 486 |
+
weight_shapes[init.name] = dims
|
| 487 |
+
total_params += int(torch.prod(torch.tensor(dims)).item()) if dims else 0
|
| 488 |
+
|
| 489 |
+
# Process nodes
|
| 490 |
+
for i, node in enumerate(graph.node):
|
| 491 |
+
layer_id = f"layer_{i}"
|
| 492 |
+
layer_map[node.name if node.name else f"node_{i}"] = layer_id
|
| 493 |
+
|
| 494 |
+
for output in node.output:
|
| 495 |
+
layer_map[output] = layer_id
|
| 496 |
+
|
| 497 |
+
# Extract attributes
|
| 498 |
+
params = {}
|
| 499 |
+
for attr in node.attribute:
|
| 500 |
+
if attr.type == onnx.AttributeProto.INT:
|
| 501 |
+
params[attr.name] = attr.i
|
| 502 |
+
elif attr.type == onnx.AttributeProto.INTS:
|
| 503 |
+
params[attr.name] = list(attr.ints)
|
| 504 |
+
elif attr.type == onnx.AttributeProto.FLOAT:
|
| 505 |
+
params[attr.name] = round(attr.f, 6)
|
| 506 |
+
elif attr.type == onnx.AttributeProto.STRING:
|
| 507 |
+
params[attr.name] = attr.s.decode('utf-8')
|
| 508 |
+
|
| 509 |
+
# Count parameters for this layer
|
| 510 |
+
layer_params = 0
|
| 511 |
+
input_shapes = []
|
| 512 |
+
for inp_name in node.input:
|
| 513 |
+
if inp_name in weight_shapes:
|
| 514 |
+
layer_params += int(torch.prod(torch.tensor(weight_shapes[inp_name])).item())
|
| 515 |
+
input_shapes.append(weight_shapes[inp_name])
|
| 516 |
+
|
| 517 |
+
# Infer input/output shapes from value_info
|
| 518 |
+
input_shape = None
|
| 519 |
+
output_shape = None
|
| 520 |
+
|
| 521 |
+
layers.append({
|
| 522 |
+
'id': layer_id,
|
| 523 |
+
'name': node.name if node.name else f"{node.op_type}_{i}",
|
| 524 |
+
'type': node.op_type,
|
| 525 |
+
'category': infer_onnx_category(node.op_type),
|
| 526 |
+
'inputShape': input_shape,
|
| 527 |
+
'outputShape': output_shape,
|
| 528 |
+
'params': params,
|
| 529 |
+
'numParameters': layer_params,
|
| 530 |
+
'trainable': layer_params > 0
|
| 531 |
+
})
|
| 532 |
+
|
| 533 |
+
# Create connections
|
| 534 |
+
for input_name in node.input:
|
| 535 |
+
if input_name in layer_map:
|
| 536 |
+
source_id = layer_map[input_name]
|
| 537 |
+
if source_id != layer_id:
|
| 538 |
+
connections.append({
|
| 539 |
+
'source': source_id,
|
| 540 |
+
'target': layer_id,
|
| 541 |
+
'tensorShape': weight_shapes.get(input_name)
|
| 542 |
+
})
|
| 543 |
+
|
| 544 |
+
# Get model input/output shapes
|
| 545 |
+
input_shape = None
|
| 546 |
+
output_shape = None
|
| 547 |
+
|
| 548 |
+
if graph.input:
|
| 549 |
+
for inp in graph.input:
|
| 550 |
+
if inp.name not in weight_shapes: # Skip weight inputs
|
| 551 |
+
shape = []
|
| 552 |
+
if inp.type.tensor_type.shape.dim:
|
| 553 |
+
for dim in inp.type.tensor_type.shape.dim:
|
| 554 |
+
shape.append(dim.dim_value if dim.dim_value else -1)
|
| 555 |
+
if shape:
|
| 556 |
+
input_shape = shape
|
| 557 |
+
break
|
| 558 |
+
|
| 559 |
+
if graph.output:
|
| 560 |
+
for out in graph.output:
|
| 561 |
+
shape = []
|
| 562 |
+
if out.type.tensor_type.shape.dim:
|
| 563 |
+
for dim in out.type.tensor_type.shape.dim:
|
| 564 |
+
shape.append(dim.dim_value if dim.dim_value else -1)
|
| 565 |
+
if shape:
|
| 566 |
+
output_shape = shape
|
| 567 |
+
break
|
| 568 |
+
|
| 569 |
+
architecture = {
|
| 570 |
+
'name': Path(filename).stem,
|
| 571 |
+
'framework': 'onnx',
|
| 572 |
+
'totalParameters': total_params,
|
| 573 |
+
'trainableParameters': total_params,
|
| 574 |
+
'inputShape': input_shape,
|
| 575 |
+
'outputShape': output_shape,
|
| 576 |
+
'layers': layers,
|
| 577 |
+
'connections': connections
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
return AnalysisResponse(
|
| 581 |
+
success=True,
|
| 582 |
+
model_type='onnx',
|
| 583 |
+
architecture=architecture,
|
| 584 |
+
message=f"Successfully analyzed ONNX model with {len(layers)} layers"
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
async def _analyze_keras(file_path: str, filename: str) -> AnalysisResponse:
|
| 589 |
+
"""Analyze Keras/HDF5 model."""
|
| 590 |
+
try:
|
| 591 |
+
import h5py
|
| 592 |
+
except ImportError:
|
| 593 |
+
raise HTTPException(status_code=500, detail="h5py not installed. Install with: pip install h5py")
|
| 594 |
+
|
| 595 |
+
layers = []
|
| 596 |
+
connections = []
|
| 597 |
+
total_params = 0
|
| 598 |
+
|
| 599 |
+
with h5py.File(file_path, 'r') as f:
|
| 600 |
+
# Check for Keras model structure
|
| 601 |
+
if 'model_config' in f.attrs:
|
| 602 |
+
import json
|
| 603 |
+
config = json.loads(f.attrs['model_config'])
|
| 604 |
+
model_name = config.get('config', {}).get('name', Path(filename).stem)
|
| 605 |
+
|
| 606 |
+
# Parse layers from config
|
| 607 |
+
layer_configs = config.get('config', {}).get('layers', [])
|
| 608 |
+
|
| 609 |
+
for i, layer_cfg in enumerate(layer_configs):
|
| 610 |
+
layer_id = f"layer_{i}"
|
| 611 |
+
layer_class = layer_cfg.get('class_name', 'Unknown')
|
| 612 |
+
layer_config = layer_cfg.get('config', {})
|
| 613 |
+
|
| 614 |
+
# Extract parameters
|
| 615 |
+
params = {}
|
| 616 |
+
param_keys = ['units', 'filters', 'kernel_size', 'strides', 'padding',
|
| 617 |
+
'activation', 'use_bias', 'dropout', 'rate', 'axis',
|
| 618 |
+
'epsilon', 'momentum', 'input_dim', 'output_dim']
|
| 619 |
+
for key in param_keys:
|
| 620 |
+
if key in layer_config:
|
| 621 |
+
params[key] = layer_config[key]
|
| 622 |
+
|
| 623 |
+
# Infer shapes from config
|
| 624 |
+
input_shape = None
|
| 625 |
+
output_shape = None
|
| 626 |
+
if 'batch_input_shape' in layer_config:
|
| 627 |
+
input_shape = list(layer_config['batch_input_shape'])
|
| 628 |
+
|
| 629 |
+
layers.append({
|
| 630 |
+
'id': layer_id,
|
| 631 |
+
'name': layer_config.get('name', f"{layer_class}_{i}"),
|
| 632 |
+
'type': layer_class,
|
| 633 |
+
'category': _infer_keras_category(layer_class),
|
| 634 |
+
'inputShape': input_shape,
|
| 635 |
+
'outputShape': output_shape,
|
| 636 |
+
'params': params,
|
| 637 |
+
'numParameters': 0,
|
| 638 |
+
'trainable': layer_config.get('trainable', True)
|
| 639 |
+
})
|
| 640 |
+
|
| 641 |
+
# Create sequential connections
|
| 642 |
+
if i > 0:
|
| 643 |
+
connections.append({
|
| 644 |
+
'source': f"layer_{i-1}",
|
| 645 |
+
'target': layer_id,
|
| 646 |
+
'tensorShape': None
|
| 647 |
+
})
|
| 648 |
+
|
| 649 |
+
# Count parameters from model_weights
|
| 650 |
+
if 'model_weights' in f:
|
| 651 |
+
def count_h5_params(group):
|
| 652 |
+
count = 0
|
| 653 |
+
for key in group.keys():
|
| 654 |
+
item = group[key]
|
| 655 |
+
if isinstance(item, h5py.Dataset):
|
| 656 |
+
count += item.size
|
| 657 |
+
elif isinstance(item, h5py.Group):
|
| 658 |
+
count += count_h5_params(item)
|
| 659 |
+
return count
|
| 660 |
+
total_params = count_h5_params(f['model_weights'])
|
| 661 |
+
|
| 662 |
+
architecture = {
|
| 663 |
+
'name': Path(filename).stem,
|
| 664 |
+
'framework': 'keras',
|
| 665 |
+
'totalParameters': total_params,
|
| 666 |
+
'trainableParameters': total_params,
|
| 667 |
+
'inputShape': layers[0].get('inputShape') if layers else None,
|
| 668 |
+
'outputShape': None,
|
| 669 |
+
'layers': layers,
|
| 670 |
+
'connections': connections
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
return AnalysisResponse(
|
| 674 |
+
success=True,
|
| 675 |
+
model_type='keras',
|
| 676 |
+
architecture=architecture,
|
| 677 |
+
message=f"Successfully analyzed Keras model with {len(layers)} layers"
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
async def _analyze_tensorflow(file_path: str, filename: str) -> AnalysisResponse:
|
| 682 |
+
"""Analyze TensorFlow SavedModel or frozen graph."""
|
| 683 |
+
try:
|
| 684 |
+
import tensorflow as tf
|
| 685 |
+
except ImportError:
|
| 686 |
+
# Fallback: parse .pb file manually
|
| 687 |
+
return await _analyze_pb_file(file_path, filename)
|
| 688 |
+
|
| 689 |
+
layers = []
|
| 690 |
+
connections = []
|
| 691 |
+
|
| 692 |
+
# Try loading as SavedModel or GraphDef
|
| 693 |
+
try:
|
| 694 |
+
graph_def = tf.compat.v1.GraphDef()
|
| 695 |
+
with open(file_path, 'rb') as f:
|
| 696 |
+
graph_def.ParseFromString(f.read())
|
| 697 |
+
|
| 698 |
+
node_map = {}
|
| 699 |
+
for i, node in enumerate(graph_def.node):
|
| 700 |
+
layer_id = f"layer_{i}"
|
| 701 |
+
node_map[node.name] = layer_id
|
| 702 |
+
|
| 703 |
+
# Extract attributes
|
| 704 |
+
params = {}
|
| 705 |
+
for key, attr in node.attr.items():
|
| 706 |
+
if attr.HasField('i'):
|
| 707 |
+
params[key] = attr.i
|
| 708 |
+
elif attr.HasField('f'):
|
| 709 |
+
params[key] = round(attr.f, 6)
|
| 710 |
+
elif attr.HasField('s'):
|
| 711 |
+
params[key] = attr.s.decode('utf-8')
|
| 712 |
+
elif attr.HasField('shape'):
|
| 713 |
+
dims = [d.size for d in attr.shape.dim]
|
| 714 |
+
params[key] = dims
|
| 715 |
+
|
| 716 |
+
layers.append({
|
| 717 |
+
'id': layer_id,
|
| 718 |
+
'name': node.name,
|
| 719 |
+
'type': node.op,
|
| 720 |
+
'category': _infer_tf_category(node.op),
|
| 721 |
+
'inputShape': None,
|
| 722 |
+
'outputShape': None,
|
| 723 |
+
'params': params,
|
| 724 |
+
'numParameters': 0,
|
| 725 |
+
'trainable': True
|
| 726 |
+
})
|
| 727 |
+
|
| 728 |
+
# Create connections from inputs
|
| 729 |
+
for inp in node.input:
|
| 730 |
+
inp_name = inp.lstrip('^').split(':')[0]
|
| 731 |
+
if inp_name in node_map:
|
| 732 |
+
connections.append({
|
| 733 |
+
'source': node_map[inp_name],
|
| 734 |
+
'target': layer_id,
|
| 735 |
+
'tensorShape': None
|
| 736 |
+
})
|
| 737 |
+
|
| 738 |
+
architecture = {
|
| 739 |
+
'name': Path(filename).stem,
|
| 740 |
+
'framework': 'tensorflow',
|
| 741 |
+
'totalParameters': 0,
|
| 742 |
+
'trainableParameters': 0,
|
| 743 |
+
'inputShape': None,
|
| 744 |
+
'outputShape': None,
|
| 745 |
+
'layers': layers,
|
| 746 |
+
'connections': connections
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
return AnalysisResponse(
|
| 750 |
+
success=True,
|
| 751 |
+
model_type='tensorflow_pb',
|
| 752 |
+
architecture=architecture,
|
| 753 |
+
message=f"Successfully analyzed TensorFlow graph with {len(layers)} nodes"
|
| 754 |
+
)
|
| 755 |
+
except Exception as e:
|
| 756 |
+
raise HTTPException(status_code=400, detail=f"Failed to parse TensorFlow model: {str(e)}")
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
async def _analyze_pb_file(file_path: str, filename: str) -> AnalysisResponse:
|
| 760 |
+
"""Fallback .pb file analyzer without TensorFlow."""
|
| 761 |
+
raise HTTPException(
|
| 762 |
+
status_code=501,
|
| 763 |
+
detail="TensorFlow .pb analysis requires TensorFlow. Install with: pip install tensorflow"
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
async def _analyze_safetensors(file_path: str, filename: str) -> AnalysisResponse:
|
| 768 |
+
"""Analyze SafeTensors file."""
|
| 769 |
+
try:
|
| 770 |
+
from safetensors import safe_open
|
| 771 |
+
except ImportError:
|
| 772 |
+
raise HTTPException(
|
| 773 |
+
status_code=500,
|
| 774 |
+
detail="safetensors not installed. Install with: pip install safetensors"
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
layers = []
|
| 778 |
+
connections = []
|
| 779 |
+
total_params = 0
|
| 780 |
+
layer_groups = {}
|
| 781 |
+
|
| 782 |
+
with safe_open(file_path, framework="pt") as f:
|
| 783 |
+
tensor_names = list(f.keys())
|
| 784 |
+
|
| 785 |
+
# Group tensors by layer
|
| 786 |
+
for name in tensor_names:
|
| 787 |
+
tensor = f.get_tensor(name)
|
| 788 |
+
shape = list(tensor.shape)
|
| 789 |
+
num_params = int(tensor.numel())
|
| 790 |
+
total_params += num_params
|
| 791 |
+
|
| 792 |
+
# Extract layer name from tensor name (e.g., "encoder.layer.0.attention.weight")
|
| 793 |
+
parts = name.rsplit('.', 1)
|
| 794 |
+
layer_name = parts[0] if len(parts) > 1 else name
|
| 795 |
+
tensor_type = parts[1] if len(parts) > 1 else 'weight'
|
| 796 |
+
|
| 797 |
+
if layer_name not in layer_groups:
|
| 798 |
+
layer_groups[layer_name] = {
|
| 799 |
+
'tensors': {},
|
| 800 |
+
'params': {},
|
| 801 |
+
'total_params': 0
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
layer_groups[layer_name]['tensors'][tensor_type] = shape
|
| 805 |
+
layer_groups[layer_name]['total_params'] += num_params
|
| 806 |
+
|
| 807 |
+
# Infer params from shapes
|
| 808 |
+
if tensor_type == 'weight' and len(shape) >= 2:
|
| 809 |
+
layer_groups[layer_name]['params']['out_features'] = shape[0]
|
| 810 |
+
layer_groups[layer_name]['params']['in_features'] = shape[1]
|
| 811 |
+
|
| 812 |
+
# Convert groups to layers
|
| 813 |
+
prev_layer_id = None
|
| 814 |
+
for i, (layer_name, group) in enumerate(layer_groups.items()):
|
| 815 |
+
layer_id = f"layer_{i}"
|
| 816 |
+
|
| 817 |
+
# Infer layer type from name and shapes
|
| 818 |
+
layer_type = _infer_layer_type_from_name(layer_name, group['tensors'])
|
| 819 |
+
|
| 820 |
+
# Infer shapes
|
| 821 |
+
input_shape = None
|
| 822 |
+
output_shape = None
|
| 823 |
+
if 'in_features' in group['params']:
|
| 824 |
+
input_shape = [-1, group['params']['in_features']]
|
| 825 |
+
if 'out_features' in group['params']:
|
| 826 |
+
output_shape = [-1, group['params']['out_features']]
|
| 827 |
+
|
| 828 |
+
layers.append({
|
| 829 |
+
'id': layer_id,
|
| 830 |
+
'name': layer_name,
|
| 831 |
+
'type': layer_type,
|
| 832 |
+
'category': _infer_category_from_type(layer_type),
|
| 833 |
+
'inputShape': input_shape,
|
| 834 |
+
'outputShape': output_shape,
|
| 835 |
+
'params': group['params'],
|
| 836 |
+
'numParameters': group['total_params'],
|
| 837 |
+
'trainable': True
|
| 838 |
+
})
|
| 839 |
+
|
| 840 |
+
# Create sequential connections
|
| 841 |
+
if prev_layer_id:
|
| 842 |
+
connections.append({
|
| 843 |
+
'source': prev_layer_id,
|
| 844 |
+
'target': layer_id,
|
| 845 |
+
'tensorShape': None
|
| 846 |
+
})
|
| 847 |
+
prev_layer_id = layer_id
|
| 848 |
+
|
| 849 |
+
architecture = {
|
| 850 |
+
'name': Path(filename).stem,
|
| 851 |
+
'framework': 'safetensors',
|
| 852 |
+
'totalParameters': total_params,
|
| 853 |
+
'trainableParameters': total_params,
|
| 854 |
+
'inputShape': layers[0].get('inputShape') if layers else None,
|
| 855 |
+
'outputShape': layers[-1].get('outputShape') if layers else None,
|
| 856 |
+
'layers': layers,
|
| 857 |
+
'connections': connections
|
| 858 |
+
}
|
| 859 |
+
|
| 860 |
+
return AnalysisResponse(
|
| 861 |
+
success=True,
|
| 862 |
+
model_type='safetensors',
|
| 863 |
+
architecture=architecture,
|
| 864 |
+
message=f"Successfully analyzed SafeTensors model with {len(layers)} layers, {total_params:,} parameters"
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def _infer_keras_category(class_name: str) -> str:
|
| 869 |
+
"""Infer category from Keras layer class name."""
|
| 870 |
+
name = class_name.lower()
|
| 871 |
+
if 'conv' in name:
|
| 872 |
+
return 'convolution'
|
| 873 |
+
if 'pool' in name:
|
| 874 |
+
return 'pooling'
|
| 875 |
+
if 'dense' in name or 'linear' in name:
|
| 876 |
+
return 'linear'
|
| 877 |
+
if 'norm' in name or 'batch' in name:
|
| 878 |
+
return 'normalization'
|
| 879 |
+
if 'dropout' in name:
|
| 880 |
+
return 'regularization'
|
| 881 |
+
if 'lstm' in name or 'gru' in name or 'rnn' in name:
|
| 882 |
+
return 'recurrent'
|
| 883 |
+
if 'attention' in name:
|
| 884 |
+
return 'attention'
|
| 885 |
+
if 'activation' in name or 'relu' in name or 'sigmoid' in name:
|
| 886 |
+
return 'activation'
|
| 887 |
+
if 'embed' in name:
|
| 888 |
+
return 'embedding'
|
| 889 |
+
if 'flatten' in name or 'reshape' in name:
|
| 890 |
+
return 'reshape'
|
| 891 |
+
if 'input' in name:
|
| 892 |
+
return 'input'
|
| 893 |
+
return 'other'
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
def _infer_tf_category(op_type: str) -> str:
|
| 897 |
+
"""Infer category from TensorFlow op type."""
|
| 898 |
+
op = op_type.lower()
|
| 899 |
+
if 'conv' in op:
|
| 900 |
+
return 'convolution'
|
| 901 |
+
if 'pool' in op:
|
| 902 |
+
return 'pooling'
|
| 903 |
+
if 'matmul' in op or 'dense' in op:
|
| 904 |
+
return 'linear'
|
| 905 |
+
if 'norm' in op or 'batch' in op:
|
| 906 |
+
return 'normalization'
|
| 907 |
+
if 'relu' in op or 'sigmoid' in op or 'tanh' in op or 'softmax' in op:
|
| 908 |
+
return 'activation'
|
| 909 |
+
if 'placeholder' in op or 'input' in op:
|
| 910 |
+
return 'input'
|
| 911 |
+
if 'variable' in op or 'const' in op:
|
| 912 |
+
return 'parameter'
|
| 913 |
+
return 'other'
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
def _infer_layer_type_from_name(name: str, tensors: dict) -> str:
|
| 917 |
+
"""Infer layer type from name and tensor shapes."""
|
| 918 |
+
name_lower = name.lower()
|
| 919 |
+
|
| 920 |
+
if 'attention' in name_lower or 'attn' in name_lower:
|
| 921 |
+
return 'MultiHeadAttention'
|
| 922 |
+
if 'linear' in name_lower or 'dense' in name_lower or 'fc' in name_lower:
|
| 923 |
+
return 'Linear'
|
| 924 |
+
if 'conv' in name_lower:
|
| 925 |
+
if 'weight' in tensors and len(tensors['weight']) == 4:
|
| 926 |
+
return 'Conv2d'
|
| 927 |
+
return 'Conv1d'
|
| 928 |
+
if 'norm' in name_lower:
|
| 929 |
+
if 'layer' in name_lower:
|
| 930 |
+
return 'LayerNorm'
|
| 931 |
+
return 'BatchNorm'
|
| 932 |
+
if 'embed' in name_lower:
|
| 933 |
+
return 'Embedding'
|
| 934 |
+
if 'lstm' in name_lower:
|
| 935 |
+
return 'LSTM'
|
| 936 |
+
if 'gru' in name_lower:
|
| 937 |
+
return 'GRU'
|
| 938 |
+
if 'query' in name_lower or 'key' in name_lower or 'value' in name_lower:
|
| 939 |
+
return 'Linear'
|
| 940 |
+
|
| 941 |
+
# Infer from tensor shapes
|
| 942 |
+
if 'weight' in tensors:
|
| 943 |
+
shape = tensors['weight']
|
| 944 |
+
if len(shape) == 2:
|
| 945 |
+
return 'Linear'
|
| 946 |
+
if len(shape) == 4:
|
| 947 |
+
return 'Conv2d'
|
| 948 |
+
if len(shape) == 1:
|
| 949 |
+
return 'LayerNorm'
|
| 950 |
+
|
| 951 |
+
return 'Unknown'
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
def _infer_category_from_type(layer_type: str) -> str:
|
| 955 |
+
"""Infer category from layer type."""
|
| 956 |
+
type_lower = layer_type.lower()
|
| 957 |
+
if 'conv' in type_lower:
|
| 958 |
+
return 'convolution'
|
| 959 |
+
if 'linear' in type_lower:
|
| 960 |
+
return 'linear'
|
| 961 |
+
if 'norm' in type_lower:
|
| 962 |
+
return 'normalization'
|
| 963 |
+
if 'attention' in type_lower:
|
| 964 |
+
return 'attention'
|
| 965 |
+
if 'embed' in type_lower:
|
| 966 |
+
return 'embedding'
|
| 967 |
+
if 'lstm' in type_lower or 'gru' in type_lower or 'rnn' in type_lower:
|
| 968 |
+
return 'recurrent'
|
| 969 |
+
return 'other'
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
# =============================================================================
|
| 973 |
+
# Saved Models API Endpoints
|
| 974 |
+
# =============================================================================
|
| 975 |
+
|
| 976 |
+
class SaveModelRequest(BaseModel):
|
| 977 |
+
"""Request to save a model."""
|
| 978 |
+
name: str
|
| 979 |
+
framework: str
|
| 980 |
+
totalParameters: int
|
| 981 |
+
layerCount: int
|
| 982 |
+
architecture: dict
|
| 983 |
+
fileHash: Optional[str] = None
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
class SavedModelSummary(BaseModel):
|
| 987 |
+
"""Summary of a saved model."""
|
| 988 |
+
id: int
|
| 989 |
+
name: str
|
| 990 |
+
framework: str
|
| 991 |
+
total_parameters: int
|
| 992 |
+
layer_count: int
|
| 993 |
+
created_at: str
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
@app.get("/models/saved")
|
| 997 |
+
async def list_saved_models():
|
| 998 |
+
"""
|
| 999 |
+
Get a list of all saved models.
|
| 1000 |
+
Returns metadata only (not full architecture).
|
| 1001 |
+
"""
|
| 1002 |
+
try:
|
| 1003 |
+
models = db.get_saved_models()
|
| 1004 |
+
return {
|
| 1005 |
+
"success": True,
|
| 1006 |
+
"models": models
|
| 1007 |
+
}
|
| 1008 |
+
except Exception as e:
|
| 1009 |
+
traceback.print_exc()
|
| 1010 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
@app.get("/models/saved/{model_id}")
|
| 1014 |
+
async def get_saved_model(model_id: int):
|
| 1015 |
+
"""
|
| 1016 |
+
Get a saved model by ID with full architecture.
|
| 1017 |
+
"""
|
| 1018 |
+
try:
|
| 1019 |
+
model = db.get_model_by_id(model_id)
|
| 1020 |
+
if not model:
|
| 1021 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
| 1022 |
+
|
| 1023 |
+
return {
|
| 1024 |
+
"success": True,
|
| 1025 |
+
"model": model
|
| 1026 |
+
}
|
| 1027 |
+
except HTTPException:
|
| 1028 |
+
raise
|
| 1029 |
+
except Exception as e:
|
| 1030 |
+
traceback.print_exc()
|
| 1031 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
@app.post("/models/save")
|
| 1035 |
+
async def save_model(request: SaveModelRequest):
|
| 1036 |
+
"""
|
| 1037 |
+
Save a processed model to the database.
|
| 1038 |
+
"""
|
| 1039 |
+
try:
|
| 1040 |
+
model_id = db.save_model(
|
| 1041 |
+
name=request.name,
|
| 1042 |
+
framework=request.framework,
|
| 1043 |
+
total_parameters=request.totalParameters,
|
| 1044 |
+
layer_count=request.layerCount,
|
| 1045 |
+
architecture=request.architecture,
|
| 1046 |
+
file_hash=request.fileHash
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
return {
|
| 1050 |
+
"success": True,
|
| 1051 |
+
"id": model_id,
|
| 1052 |
+
"message": "Model saved successfully"
|
| 1053 |
+
}
|
| 1054 |
+
except Exception as e:
|
| 1055 |
+
traceback.print_exc()
|
| 1056 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
@app.delete("/models/saved/{model_id}")
|
| 1060 |
+
async def delete_saved_model(model_id: int):
|
| 1061 |
+
"""
|
| 1062 |
+
Delete a saved model by ID.
|
| 1063 |
+
"""
|
| 1064 |
+
try:
|
| 1065 |
+
deleted = db.delete_model(model_id)
|
| 1066 |
+
if not deleted:
|
| 1067 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
| 1068 |
+
|
| 1069 |
+
return {
|
| 1070 |
+
"success": True,
|
| 1071 |
+
"message": "Model deleted successfully"
|
| 1072 |
+
}
|
| 1073 |
+
except HTTPException:
|
| 1074 |
+
raise
|
| 1075 |
+
except Exception as e:
|
| 1076 |
+
traceback.print_exc()
|
| 1077 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1078 |
+
|
| 1079 |
+
|
| 1080 |
+
if __name__ == "__main__":
|
| 1081 |
+
import uvicorn
|
| 1082 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
backend/app/model_analyzer.py
ADDED
|
@@ -0,0 +1,714 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch Model Analyzer
|
| 3 |
+
Extracts architecture information from PyTorch models for 3D visualization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Dict, List, Any, Optional, Tuple
|
| 9 |
+
from dataclasses import dataclass, asdict
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class LayerInfo:
|
| 16 |
+
"""Information about a single layer in the model."""
|
| 17 |
+
id: str
|
| 18 |
+
name: str
|
| 19 |
+
type: str
|
| 20 |
+
category: str
|
| 21 |
+
input_shape: Optional[List[int]]
|
| 22 |
+
output_shape: Optional[List[int]]
|
| 23 |
+
params: Dict[str, Any]
|
| 24 |
+
num_parameters: int
|
| 25 |
+
trainable: bool
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class ConnectionInfo:
|
| 30 |
+
"""Information about connections between layers."""
|
| 31 |
+
source: str
|
| 32 |
+
target: str
|
| 33 |
+
tensor_shape: Optional[List[int]]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class ModelArchitecture:
|
| 38 |
+
"""Complete model architecture information."""
|
| 39 |
+
name: str
|
| 40 |
+
framework: str
|
| 41 |
+
total_parameters: int
|
| 42 |
+
trainable_parameters: int
|
| 43 |
+
layers: List[LayerInfo]
|
| 44 |
+
connections: List[ConnectionInfo]
|
| 45 |
+
input_shape: Optional[List[int]]
|
| 46 |
+
output_shape: Optional[List[int]]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Layer category mapping
|
| 50 |
+
LAYER_CATEGORIES = {
|
| 51 |
+
# Convolution layers
|
| 52 |
+
'Conv1d': 'convolution',
|
| 53 |
+
'Conv2d': 'convolution',
|
| 54 |
+
'Conv3d': 'convolution',
|
| 55 |
+
'ConvTranspose1d': 'convolution',
|
| 56 |
+
'ConvTranspose2d': 'convolution',
|
| 57 |
+
'ConvTranspose3d': 'convolution',
|
| 58 |
+
|
| 59 |
+
# Pooling layers
|
| 60 |
+
'MaxPool1d': 'pooling',
|
| 61 |
+
'MaxPool2d': 'pooling',
|
| 62 |
+
'MaxPool3d': 'pooling',
|
| 63 |
+
'AvgPool1d': 'pooling',
|
| 64 |
+
'AvgPool2d': 'pooling',
|
| 65 |
+
'AvgPool3d': 'pooling',
|
| 66 |
+
'AdaptiveAvgPool1d': 'pooling',
|
| 67 |
+
'AdaptiveAvgPool2d': 'pooling',
|
| 68 |
+
'AdaptiveAvgPool3d': 'pooling',
|
| 69 |
+
'AdaptiveMaxPool1d': 'pooling',
|
| 70 |
+
'AdaptiveMaxPool2d': 'pooling',
|
| 71 |
+
'AdaptiveMaxPool3d': 'pooling',
|
| 72 |
+
'GlobalAveragePooling2D': 'pooling',
|
| 73 |
+
|
| 74 |
+
# Linear/Dense layers
|
| 75 |
+
'Linear': 'linear',
|
| 76 |
+
'LazyLinear': 'linear',
|
| 77 |
+
'Bilinear': 'linear',
|
| 78 |
+
|
| 79 |
+
# Normalization layers
|
| 80 |
+
'BatchNorm1d': 'normalization',
|
| 81 |
+
'BatchNorm2d': 'normalization',
|
| 82 |
+
'BatchNorm3d': 'normalization',
|
| 83 |
+
'LayerNorm': 'normalization',
|
| 84 |
+
'GroupNorm': 'normalization',
|
| 85 |
+
'InstanceNorm1d': 'normalization',
|
| 86 |
+
'InstanceNorm2d': 'normalization',
|
| 87 |
+
'InstanceNorm3d': 'normalization',
|
| 88 |
+
|
| 89 |
+
# Activation layers
|
| 90 |
+
'ReLU': 'activation',
|
| 91 |
+
'ReLU6': 'activation',
|
| 92 |
+
'LeakyReLU': 'activation',
|
| 93 |
+
'PReLU': 'activation',
|
| 94 |
+
'ELU': 'activation',
|
| 95 |
+
'SELU': 'activation',
|
| 96 |
+
'GELU': 'activation',
|
| 97 |
+
'Sigmoid': 'activation',
|
| 98 |
+
'Tanh': 'activation',
|
| 99 |
+
'Softmax': 'activation',
|
| 100 |
+
'LogSoftmax': 'activation',
|
| 101 |
+
'Softplus': 'activation',
|
| 102 |
+
'Softsign': 'activation',
|
| 103 |
+
'Hardswish': 'activation',
|
| 104 |
+
'Hardsigmoid': 'activation',
|
| 105 |
+
'SiLU': 'activation',
|
| 106 |
+
'Mish': 'activation',
|
| 107 |
+
|
| 108 |
+
# Dropout layers
|
| 109 |
+
'Dropout': 'regularization',
|
| 110 |
+
'Dropout2d': 'regularization',
|
| 111 |
+
'Dropout3d': 'regularization',
|
| 112 |
+
'AlphaDropout': 'regularization',
|
| 113 |
+
|
| 114 |
+
# Recurrent layers
|
| 115 |
+
'RNN': 'recurrent',
|
| 116 |
+
'LSTM': 'recurrent',
|
| 117 |
+
'GRU': 'recurrent',
|
| 118 |
+
'RNNCell': 'recurrent',
|
| 119 |
+
'LSTMCell': 'recurrent',
|
| 120 |
+
'GRUCell': 'recurrent',
|
| 121 |
+
|
| 122 |
+
# Transformer layers
|
| 123 |
+
'Transformer': 'attention',
|
| 124 |
+
'TransformerEncoder': 'attention',
|
| 125 |
+
'TransformerDecoder': 'attention',
|
| 126 |
+
'TransformerEncoderLayer': 'attention',
|
| 127 |
+
'TransformerDecoderLayer': 'attention',
|
| 128 |
+
'MultiheadAttention': 'attention',
|
| 129 |
+
|
| 130 |
+
# Embedding layers
|
| 131 |
+
'Embedding': 'embedding',
|
| 132 |
+
'EmbeddingBag': 'embedding',
|
| 133 |
+
|
| 134 |
+
# Reshape/View layers
|
| 135 |
+
'Flatten': 'reshape',
|
| 136 |
+
'Unflatten': 'reshape',
|
| 137 |
+
|
| 138 |
+
# Container layers
|
| 139 |
+
'Sequential': 'container',
|
| 140 |
+
'ModuleList': 'container',
|
| 141 |
+
'ModuleDict': 'container',
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_layer_category(layer_type: str) -> str:
|
| 146 |
+
"""Get the category for a layer type."""
|
| 147 |
+
return LAYER_CATEGORIES.get(layer_type, 'other')
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def count_parameters(module: nn.Module) -> Tuple[int, int]:
|
| 151 |
+
"""Count total and trainable parameters in a module."""
|
| 152 |
+
total = sum(p.numel() for p in module.parameters())
|
| 153 |
+
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 154 |
+
return total, trainable
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def extract_layer_params(module: nn.Module, layer_type: str) -> Dict[str, Any]:
|
| 158 |
+
"""Extract relevant parameters from a layer."""
|
| 159 |
+
params = {}
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
if hasattr(module, 'in_features'):
|
| 163 |
+
params['in_features'] = module.in_features
|
| 164 |
+
if hasattr(module, 'out_features'):
|
| 165 |
+
params['out_features'] = module.out_features
|
| 166 |
+
if hasattr(module, 'in_channels'):
|
| 167 |
+
params['in_channels'] = module.in_channels
|
| 168 |
+
if hasattr(module, 'out_channels'):
|
| 169 |
+
params['out_channels'] = module.out_channels
|
| 170 |
+
if hasattr(module, 'kernel_size'):
|
| 171 |
+
ks = module.kernel_size
|
| 172 |
+
params['kernel_size'] = list(ks) if isinstance(ks, tuple) else ks
|
| 173 |
+
if hasattr(module, 'stride'):
|
| 174 |
+
s = module.stride
|
| 175 |
+
params['stride'] = list(s) if isinstance(s, tuple) else s
|
| 176 |
+
if hasattr(module, 'padding'):
|
| 177 |
+
p = module.padding
|
| 178 |
+
params['padding'] = list(p) if isinstance(p, tuple) else p
|
| 179 |
+
if hasattr(module, 'dilation'):
|
| 180 |
+
d = module.dilation
|
| 181 |
+
params['dilation'] = list(d) if isinstance(d, tuple) else d
|
| 182 |
+
if hasattr(module, 'groups'):
|
| 183 |
+
params['groups'] = module.groups
|
| 184 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 185 |
+
params['bias'] = True
|
| 186 |
+
if hasattr(module, 'num_features'):
|
| 187 |
+
params['num_features'] = module.num_features
|
| 188 |
+
if hasattr(module, 'eps'):
|
| 189 |
+
params['eps'] = module.eps
|
| 190 |
+
if hasattr(module, 'momentum') and module.momentum is not None:
|
| 191 |
+
params['momentum'] = module.momentum
|
| 192 |
+
if hasattr(module, 'normalized_shape'):
|
| 193 |
+
params['normalized_shape'] = list(module.normalized_shape)
|
| 194 |
+
if hasattr(module, 'hidden_size'):
|
| 195 |
+
params['hidden_size'] = module.hidden_size
|
| 196 |
+
if hasattr(module, 'num_layers'):
|
| 197 |
+
params['num_layers'] = module.num_layers
|
| 198 |
+
if hasattr(module, 'bidirectional'):
|
| 199 |
+
params['bidirectional'] = module.bidirectional
|
| 200 |
+
if hasattr(module, 'num_heads'):
|
| 201 |
+
params['num_heads'] = module.num_heads
|
| 202 |
+
if hasattr(module, 'embed_dim'):
|
| 203 |
+
params['embed_dim'] = module.embed_dim
|
| 204 |
+
if hasattr(module, 'num_embeddings'):
|
| 205 |
+
params['num_embeddings'] = module.num_embeddings
|
| 206 |
+
if hasattr(module, 'embedding_dim'):
|
| 207 |
+
params['embedding_dim'] = module.embedding_dim
|
| 208 |
+
if hasattr(module, 'p') and layer_type.startswith('Dropout'):
|
| 209 |
+
params['p'] = module.p
|
| 210 |
+
if hasattr(module, 'negative_slope'):
|
| 211 |
+
params['negative_slope'] = module.negative_slope
|
| 212 |
+
if hasattr(module, 'inplace'):
|
| 213 |
+
params['inplace'] = module.inplace
|
| 214 |
+
if hasattr(module, 'dim'):
|
| 215 |
+
params['dim'] = module.dim
|
| 216 |
+
except Exception:
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
return params
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def analyze_model_structure(model: nn.Module, model_name: str = "model") -> ModelArchitecture:
|
| 223 |
+
"""
|
| 224 |
+
Analyze a PyTorch model and extract its architecture.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
model: The PyTorch model to analyze
|
| 228 |
+
model_name: Name identifier for the model
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
ModelArchitecture with complete layer and connection information
|
| 232 |
+
"""
|
| 233 |
+
layers = []
|
| 234 |
+
connections = []
|
| 235 |
+
layer_index = 0
|
| 236 |
+
parent_stack = []
|
| 237 |
+
|
| 238 |
+
def process_module(name: str, module: nn.Module, parent_id: Optional[str] = None):
|
| 239 |
+
nonlocal layer_index
|
| 240 |
+
|
| 241 |
+
layer_type = module.__class__.__name__
|
| 242 |
+
|
| 243 |
+
# Skip container modules but process their children
|
| 244 |
+
if layer_type in ('Sequential', 'ModuleList', 'ModuleDict'):
|
| 245 |
+
for child_name, child in module.named_children():
|
| 246 |
+
full_name = f"{name}.{child_name}" if name else child_name
|
| 247 |
+
process_module(full_name, child, parent_id)
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
# Skip modules with no parameters and no meaningful operation
|
| 251 |
+
# But include activation, pooling, dropout, etc.
|
| 252 |
+
has_params = sum(1 for _ in module.parameters(recurse=False)) > 0
|
| 253 |
+
is_meaningful = layer_type in LAYER_CATEGORIES or has_params
|
| 254 |
+
|
| 255 |
+
if not is_meaningful and len(list(module.children())) > 0:
|
| 256 |
+
# Process children of non-meaningful containers
|
| 257 |
+
for child_name, child in module.named_children():
|
| 258 |
+
full_name = f"{name}.{child_name}" if name else child_name
|
| 259 |
+
process_module(full_name, child, parent_id)
|
| 260 |
+
return
|
| 261 |
+
|
| 262 |
+
layer_id = f"layer_{layer_index}"
|
| 263 |
+
layer_index += 1
|
| 264 |
+
|
| 265 |
+
total_params, trainable_params = count_parameters(module)
|
| 266 |
+
params = extract_layer_params(module, layer_type)
|
| 267 |
+
|
| 268 |
+
layer_info = LayerInfo(
|
| 269 |
+
id=layer_id,
|
| 270 |
+
name=name or layer_type,
|
| 271 |
+
type=layer_type,
|
| 272 |
+
category=get_layer_category(layer_type),
|
| 273 |
+
input_shape=None, # Will be populated during forward pass
|
| 274 |
+
output_shape=None,
|
| 275 |
+
params=params,
|
| 276 |
+
num_parameters=total_params,
|
| 277 |
+
trainable=trainable_params > 0
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
layers.append(layer_info)
|
| 281 |
+
|
| 282 |
+
# Create connection from parent
|
| 283 |
+
if parent_id is not None:
|
| 284 |
+
connections.append(ConnectionInfo(
|
| 285 |
+
source=parent_id,
|
| 286 |
+
target=layer_id,
|
| 287 |
+
tensor_shape=None
|
| 288 |
+
))
|
| 289 |
+
|
| 290 |
+
# Process children
|
| 291 |
+
children = list(module.named_children())
|
| 292 |
+
if children:
|
| 293 |
+
for child_name, child in children:
|
| 294 |
+
full_name = f"{name}.{child_name}" if name else child_name
|
| 295 |
+
process_module(full_name, child, layer_id)
|
| 296 |
+
|
| 297 |
+
return layer_id
|
| 298 |
+
|
| 299 |
+
# Process the model
|
| 300 |
+
children = list(model.named_children())
|
| 301 |
+
if children:
|
| 302 |
+
prev_id = None
|
| 303 |
+
for name, child in children:
|
| 304 |
+
layer_id = process_module(name, child, prev_id)
|
| 305 |
+
if layer_id:
|
| 306 |
+
prev_id = layer_id
|
| 307 |
+
else:
|
| 308 |
+
# Single layer model
|
| 309 |
+
process_module("", model, None)
|
| 310 |
+
|
| 311 |
+
# If layers are sequential and no connections exist, create linear connections
|
| 312 |
+
if len(layers) > 1 and len(connections) == 0:
|
| 313 |
+
for i in range(len(layers) - 1):
|
| 314 |
+
connections.append(ConnectionInfo(
|
| 315 |
+
source=layers[i].id,
|
| 316 |
+
target=layers[i + 1].id,
|
| 317 |
+
tensor_shape=None
|
| 318 |
+
))
|
| 319 |
+
|
| 320 |
+
total_params, trainable_params = count_parameters(model)
|
| 321 |
+
|
| 322 |
+
return ModelArchitecture(
|
| 323 |
+
name=model_name,
|
| 324 |
+
framework="pytorch",
|
| 325 |
+
total_parameters=total_params,
|
| 326 |
+
trainable_parameters=trainable_params,
|
| 327 |
+
layers=layers,
|
| 328 |
+
connections=connections,
|
| 329 |
+
input_shape=None,
|
| 330 |
+
output_shape=None
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def trace_model_shapes(model: nn.Module, input_tensor: torch.Tensor, arch: ModelArchitecture) -> ModelArchitecture:
|
| 335 |
+
"""
|
| 336 |
+
Trace model execution to capture input/output shapes for each layer.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
model: The PyTorch model
|
| 340 |
+
input_tensor: Sample input tensor
|
| 341 |
+
arch: Existing architecture info to update
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
Updated ModelArchitecture with shape information
|
| 345 |
+
"""
|
| 346 |
+
shapes = {}
|
| 347 |
+
hooks = []
|
| 348 |
+
|
| 349 |
+
def make_hook(name):
|
| 350 |
+
def hook(module, input, output):
|
| 351 |
+
input_shape = None
|
| 352 |
+
output_shape = None
|
| 353 |
+
|
| 354 |
+
if isinstance(input, tuple) and len(input) > 0:
|
| 355 |
+
if isinstance(input[0], torch.Tensor):
|
| 356 |
+
input_shape = list(input[0].shape)
|
| 357 |
+
elif isinstance(input, torch.Tensor):
|
| 358 |
+
input_shape = list(input.shape)
|
| 359 |
+
|
| 360 |
+
if isinstance(output, torch.Tensor):
|
| 361 |
+
output_shape = list(output.shape)
|
| 362 |
+
elif isinstance(output, tuple) and len(output) > 0:
|
| 363 |
+
if isinstance(output[0], torch.Tensor):
|
| 364 |
+
output_shape = list(output[0].shape)
|
| 365 |
+
|
| 366 |
+
shapes[name] = {
|
| 367 |
+
'input': input_shape,
|
| 368 |
+
'output': output_shape
|
| 369 |
+
}
|
| 370 |
+
return hook
|
| 371 |
+
|
| 372 |
+
# Register hooks
|
| 373 |
+
for name, module in model.named_modules():
|
| 374 |
+
if name: # Skip root module
|
| 375 |
+
hooks.append(module.register_forward_hook(make_hook(name)))
|
| 376 |
+
|
| 377 |
+
# Run forward pass
|
| 378 |
+
try:
|
| 379 |
+
model.eval()
|
| 380 |
+
with torch.no_grad():
|
| 381 |
+
output = model(input_tensor)
|
| 382 |
+
|
| 383 |
+
# Update architecture with shapes
|
| 384 |
+
for layer in arch.layers:
|
| 385 |
+
if layer.name in shapes:
|
| 386 |
+
layer.input_shape = shapes[layer.name]['input']
|
| 387 |
+
layer.output_shape = shapes[layer.name]['output']
|
| 388 |
+
|
| 389 |
+
# Set model input/output shapes
|
| 390 |
+
arch.input_shape = list(input_tensor.shape)
|
| 391 |
+
if isinstance(output, torch.Tensor):
|
| 392 |
+
arch.output_shape = list(output.shape)
|
| 393 |
+
|
| 394 |
+
except Exception as e:
|
| 395 |
+
print(f"Warning: Could not trace shapes: {e}")
|
| 396 |
+
finally:
|
| 397 |
+
# Remove hooks
|
| 398 |
+
for hook in hooks:
|
| 399 |
+
hook.remove()
|
| 400 |
+
|
| 401 |
+
return arch
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def load_pytorch_model(file_path: str) -> Tuple[Optional[nn.Module], Optional[Dict], str]:
|
| 405 |
+
"""
|
| 406 |
+
Load a PyTorch model from file.
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
Tuple of (model, state_dict, model_type)
|
| 410 |
+
model_type can be: 'full_model', 'state_dict', 'torchscript', 'checkpoint'
|
| 411 |
+
"""
|
| 412 |
+
try:
|
| 413 |
+
# Try loading as TorchScript first
|
| 414 |
+
try:
|
| 415 |
+
model = torch.jit.load(file_path, map_location='cpu')
|
| 416 |
+
return model, None, 'torchscript'
|
| 417 |
+
except Exception:
|
| 418 |
+
pass
|
| 419 |
+
|
| 420 |
+
# Try loading as regular checkpoint
|
| 421 |
+
checkpoint = torch.load(file_path, map_location='cpu', weights_only=False)
|
| 422 |
+
|
| 423 |
+
if isinstance(checkpoint, nn.Module):
|
| 424 |
+
return checkpoint, None, 'full_model'
|
| 425 |
+
|
| 426 |
+
if isinstance(checkpoint, dict):
|
| 427 |
+
# Check for common checkpoint formats
|
| 428 |
+
if 'model' in checkpoint:
|
| 429 |
+
if isinstance(checkpoint['model'], nn.Module):
|
| 430 |
+
return checkpoint['model'], None, 'checkpoint'
|
| 431 |
+
elif isinstance(checkpoint['model'], dict):
|
| 432 |
+
return None, checkpoint['model'], 'state_dict'
|
| 433 |
+
|
| 434 |
+
if 'state_dict' in checkpoint:
|
| 435 |
+
return None, checkpoint['state_dict'], 'state_dict'
|
| 436 |
+
|
| 437 |
+
if 'model_state_dict' in checkpoint:
|
| 438 |
+
return None, checkpoint['model_state_dict'], 'state_dict'
|
| 439 |
+
|
| 440 |
+
# Check if it's directly a state dict (contains tensor values)
|
| 441 |
+
has_tensors = any(isinstance(v, torch.Tensor) for v in checkpoint.values())
|
| 442 |
+
if has_tensors:
|
| 443 |
+
return None, checkpoint, 'state_dict'
|
| 444 |
+
|
| 445 |
+
return None, None, 'unknown'
|
| 446 |
+
|
| 447 |
+
except Exception as e:
|
| 448 |
+
raise ValueError(f"Failed to load model: {str(e)}")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def analyze_state_dict(state_dict: Dict[str, torch.Tensor], model_name: str = "model") -> ModelArchitecture:
|
| 452 |
+
"""
|
| 453 |
+
Analyze a state dict to infer model architecture.
|
| 454 |
+
|
| 455 |
+
This extracts layer information from weight tensor names and shapes.
|
| 456 |
+
"""
|
| 457 |
+
layers = []
|
| 458 |
+
layer_map = OrderedDict()
|
| 459 |
+
|
| 460 |
+
# Group parameters by layer name
|
| 461 |
+
for key, tensor in state_dict.items():
|
| 462 |
+
if not isinstance(tensor, torch.Tensor):
|
| 463 |
+
continue
|
| 464 |
+
|
| 465 |
+
# Extract layer name from parameter name
|
| 466 |
+
parts = key.rsplit('.', 1)
|
| 467 |
+
if len(parts) == 2:
|
| 468 |
+
layer_name, param_type = parts
|
| 469 |
+
else:
|
| 470 |
+
layer_name = key
|
| 471 |
+
param_type = 'weight'
|
| 472 |
+
|
| 473 |
+
if layer_name not in layer_map:
|
| 474 |
+
layer_map[layer_name] = {
|
| 475 |
+
'params': {},
|
| 476 |
+
'shapes': {}
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
layer_map[layer_name]['params'][param_type] = True
|
| 480 |
+
layer_map[layer_name]['shapes'][param_type] = list(tensor.shape)
|
| 481 |
+
|
| 482 |
+
# Create layer info from grouped parameters
|
| 483 |
+
layer_index = 0
|
| 484 |
+
for layer_name, info in layer_map.items():
|
| 485 |
+
layer_type, category = infer_layer_type(layer_name, info['shapes'])
|
| 486 |
+
|
| 487 |
+
layer_id = f"layer_{layer_index}"
|
| 488 |
+
layer_index += 1
|
| 489 |
+
|
| 490 |
+
# Compute number of parameters
|
| 491 |
+
num_params = 0
|
| 492 |
+
for param_type, shape in info['shapes'].items():
|
| 493 |
+
param_size = 1
|
| 494 |
+
for dim in shape:
|
| 495 |
+
param_size *= dim
|
| 496 |
+
num_params += param_size
|
| 497 |
+
|
| 498 |
+
# Extract layer parameters from shapes
|
| 499 |
+
params = extract_params_from_shapes(layer_type, info['shapes'])
|
| 500 |
+
|
| 501 |
+
# Infer input/output shapes
|
| 502 |
+
input_shape, output_shape = infer_shapes(layer_type, info['shapes'], params)
|
| 503 |
+
|
| 504 |
+
layers.append(LayerInfo(
|
| 505 |
+
id=layer_id,
|
| 506 |
+
name=layer_name,
|
| 507 |
+
type=layer_type,
|
| 508 |
+
category=category,
|
| 509 |
+
input_shape=input_shape,
|
| 510 |
+
output_shape=output_shape,
|
| 511 |
+
params=params,
|
| 512 |
+
num_parameters=num_params,
|
| 513 |
+
trainable=True
|
| 514 |
+
))
|
| 515 |
+
|
| 516 |
+
# Create sequential connections
|
| 517 |
+
connections = []
|
| 518 |
+
for i in range(len(layers) - 1):
|
| 519 |
+
connections.append(ConnectionInfo(
|
| 520 |
+
source=layers[i].id,
|
| 521 |
+
target=layers[i + 1].id,
|
| 522 |
+
tensor_shape=layers[i].output_shape
|
| 523 |
+
))
|
| 524 |
+
|
| 525 |
+
total_params = sum(layer.num_parameters for layer in layers)
|
| 526 |
+
|
| 527 |
+
return ModelArchitecture(
|
| 528 |
+
name=model_name,
|
| 529 |
+
framework="pytorch",
|
| 530 |
+
total_parameters=total_params,
|
| 531 |
+
trainable_parameters=total_params,
|
| 532 |
+
layers=layers,
|
| 533 |
+
connections=connections,
|
| 534 |
+
input_shape=layers[0].input_shape if layers else None,
|
| 535 |
+
output_shape=layers[-1].output_shape if layers else None
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def infer_layer_type(layer_name: str, shapes: Dict[str, List[int]]) -> Tuple[str, str]:
|
| 540 |
+
"""Infer layer type from name and weight shapes."""
|
| 541 |
+
name_lower = layer_name.lower()
|
| 542 |
+
|
| 543 |
+
# Check for common layer type patterns in name
|
| 544 |
+
if 'conv' in name_lower:
|
| 545 |
+
weight_shape = shapes.get('weight', [])
|
| 546 |
+
if len(weight_shape) == 5:
|
| 547 |
+
return 'Conv3d', 'convolution'
|
| 548 |
+
elif len(weight_shape) == 4:
|
| 549 |
+
return 'Conv2d', 'convolution'
|
| 550 |
+
elif len(weight_shape) == 3:
|
| 551 |
+
return 'Conv1d', 'convolution'
|
| 552 |
+
return 'Conv2d', 'convolution'
|
| 553 |
+
|
| 554 |
+
if 'bn' in name_lower or 'batch' in name_lower or 'norm' in name_lower:
|
| 555 |
+
weight_shape = shapes.get('weight', shapes.get('running_mean', []))
|
| 556 |
+
if 'layer' in name_lower:
|
| 557 |
+
return 'LayerNorm', 'normalization'
|
| 558 |
+
return 'BatchNorm2d', 'normalization'
|
| 559 |
+
|
| 560 |
+
if 'fc' in name_lower or 'linear' in name_lower or 'dense' in name_lower or 'classifier' in name_lower:
|
| 561 |
+
return 'Linear', 'linear'
|
| 562 |
+
|
| 563 |
+
if 'lstm' in name_lower:
|
| 564 |
+
return 'LSTM', 'recurrent'
|
| 565 |
+
|
| 566 |
+
if 'gru' in name_lower:
|
| 567 |
+
return 'GRU', 'recurrent'
|
| 568 |
+
|
| 569 |
+
if 'rnn' in name_lower:
|
| 570 |
+
return 'RNN', 'recurrent'
|
| 571 |
+
|
| 572 |
+
if 'attention' in name_lower or 'attn' in name_lower:
|
| 573 |
+
return 'MultiheadAttention', 'attention'
|
| 574 |
+
|
| 575 |
+
if 'embed' in name_lower:
|
| 576 |
+
return 'Embedding', 'embedding'
|
| 577 |
+
|
| 578 |
+
if 'pool' in name_lower:
|
| 579 |
+
return 'AdaptiveAvgPool2d', 'pooling'
|
| 580 |
+
|
| 581 |
+
# Infer from weight shape
|
| 582 |
+
weight_shape = shapes.get('weight', [])
|
| 583 |
+
if len(weight_shape) == 2:
|
| 584 |
+
return 'Linear', 'linear'
|
| 585 |
+
elif len(weight_shape) == 4:
|
| 586 |
+
return 'Conv2d', 'convolution'
|
| 587 |
+
elif len(weight_shape) == 3:
|
| 588 |
+
return 'Conv1d', 'convolution'
|
| 589 |
+
elif len(weight_shape) == 1:
|
| 590 |
+
return 'BatchNorm2d', 'normalization'
|
| 591 |
+
|
| 592 |
+
return 'Unknown', 'other'
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def extract_params_from_shapes(layer_type: str, shapes: Dict[str, List[int]]) -> Dict[str, Any]:
|
| 596 |
+
"""Extract layer parameters from weight shapes."""
|
| 597 |
+
params = {}
|
| 598 |
+
weight_shape = shapes.get('weight', [])
|
| 599 |
+
|
| 600 |
+
if layer_type in ('Linear',):
|
| 601 |
+
if len(weight_shape) >= 2:
|
| 602 |
+
params['out_features'] = weight_shape[0]
|
| 603 |
+
params['in_features'] = weight_shape[1]
|
| 604 |
+
params['bias'] = 'bias' in shapes
|
| 605 |
+
|
| 606 |
+
elif layer_type in ('Conv1d', 'Conv2d', 'Conv3d'):
|
| 607 |
+
if len(weight_shape) >= 2:
|
| 608 |
+
params['out_channels'] = weight_shape[0]
|
| 609 |
+
params['in_channels'] = weight_shape[1]
|
| 610 |
+
if len(weight_shape) > 2:
|
| 611 |
+
params['kernel_size'] = weight_shape[2:]
|
| 612 |
+
params['bias'] = 'bias' in shapes
|
| 613 |
+
|
| 614 |
+
elif layer_type in ('BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'):
|
| 615 |
+
if len(weight_shape) >= 1:
|
| 616 |
+
params['num_features'] = weight_shape[0]
|
| 617 |
+
|
| 618 |
+
elif layer_type == 'LayerNorm':
|
| 619 |
+
if len(weight_shape) >= 1:
|
| 620 |
+
params['normalized_shape'] = weight_shape
|
| 621 |
+
|
| 622 |
+
elif layer_type == 'Embedding':
|
| 623 |
+
if len(weight_shape) >= 2:
|
| 624 |
+
params['num_embeddings'] = weight_shape[0]
|
| 625 |
+
params['embedding_dim'] = weight_shape[1]
|
| 626 |
+
|
| 627 |
+
elif layer_type in ('LSTM', 'GRU', 'RNN'):
|
| 628 |
+
# weight_ih_l0 shape gives hidden_size x input_size
|
| 629 |
+
if 'weight_ih_l0' in shapes:
|
| 630 |
+
ih_shape = shapes['weight_ih_l0']
|
| 631 |
+
if len(ih_shape) >= 2:
|
| 632 |
+
multiplier = 4 if layer_type == 'LSTM' else (3 if layer_type == 'GRU' else 1)
|
| 633 |
+
params['hidden_size'] = ih_shape[0] // multiplier
|
| 634 |
+
params['input_size'] = ih_shape[1]
|
| 635 |
+
|
| 636 |
+
return params
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def infer_shapes(layer_type: str, shapes: Dict[str, List[int]], params: Dict[str, Any]) -> Tuple[Optional[List[int]], Optional[List[int]]]:
|
| 640 |
+
"""Infer input/output shapes from layer parameters."""
|
| 641 |
+
input_shape = None
|
| 642 |
+
output_shape = None
|
| 643 |
+
|
| 644 |
+
if layer_type == 'Linear':
|
| 645 |
+
if 'in_features' in params:
|
| 646 |
+
input_shape = [-1, params['in_features']]
|
| 647 |
+
if 'out_features' in params:
|
| 648 |
+
output_shape = [-1, params['out_features']]
|
| 649 |
+
|
| 650 |
+
elif layer_type in ('Conv2d',):
|
| 651 |
+
if 'in_channels' in params:
|
| 652 |
+
input_shape = [-1, params['in_channels'], -1, -1]
|
| 653 |
+
if 'out_channels' in params:
|
| 654 |
+
output_shape = [-1, params['out_channels'], -1, -1]
|
| 655 |
+
|
| 656 |
+
elif layer_type in ('Conv1d',):
|
| 657 |
+
if 'in_channels' in params:
|
| 658 |
+
input_shape = [-1, params['in_channels'], -1]
|
| 659 |
+
if 'out_channels' in params:
|
| 660 |
+
output_shape = [-1, params['out_channels'], -1]
|
| 661 |
+
|
| 662 |
+
elif layer_type in ('BatchNorm2d',):
|
| 663 |
+
if 'num_features' in params:
|
| 664 |
+
input_shape = [-1, params['num_features'], -1, -1]
|
| 665 |
+
output_shape = [-1, params['num_features'], -1, -1]
|
| 666 |
+
|
| 667 |
+
elif layer_type == 'Embedding':
|
| 668 |
+
if 'embedding_dim' in params:
|
| 669 |
+
output_shape = [-1, -1, params['embedding_dim']]
|
| 670 |
+
|
| 671 |
+
elif layer_type in ('GRU', 'LSTM', 'RNN'):
|
| 672 |
+
# For recurrent layers: input is (batch, seq_len, input_size)
|
| 673 |
+
# output is (batch, seq_len, hidden_size * num_directions)
|
| 674 |
+
if 'input_size' in params:
|
| 675 |
+
input_shape = [-1, -1, params['input_size']]
|
| 676 |
+
if 'hidden_size' in params:
|
| 677 |
+
num_directions = 2 if params.get('bidirectional', False) else 1
|
| 678 |
+
output_shape = [-1, -1, params['hidden_size'] * num_directions]
|
| 679 |
+
|
| 680 |
+
return input_shape, output_shape
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def architecture_to_dict(arch: ModelArchitecture) -> Dict[str, Any]:
|
| 684 |
+
"""Convert ModelArchitecture to JSON-serializable dict."""
|
| 685 |
+
return {
|
| 686 |
+
'name': arch.name,
|
| 687 |
+
'framework': arch.framework,
|
| 688 |
+
'totalParameters': arch.total_parameters,
|
| 689 |
+
'trainableParameters': arch.trainable_parameters,
|
| 690 |
+
'inputShape': arch.input_shape,
|
| 691 |
+
'outputShape': arch.output_shape,
|
| 692 |
+
'layers': [
|
| 693 |
+
{
|
| 694 |
+
'id': layer.id,
|
| 695 |
+
'name': layer.name,
|
| 696 |
+
'type': layer.type,
|
| 697 |
+
'category': layer.category,
|
| 698 |
+
'inputShape': layer.input_shape,
|
| 699 |
+
'outputShape': layer.output_shape,
|
| 700 |
+
'params': layer.params,
|
| 701 |
+
'numParameters': layer.num_parameters,
|
| 702 |
+
'trainable': layer.trainable
|
| 703 |
+
}
|
| 704 |
+
for layer in arch.layers
|
| 705 |
+
],
|
| 706 |
+
'connections': [
|
| 707 |
+
{
|
| 708 |
+
'source': conn.source,
|
| 709 |
+
'target': conn.target,
|
| 710 |
+
'tensorShape': conn.tensor_shape
|
| 711 |
+
}
|
| 712 |
+
for conn in arch.connections
|
| 713 |
+
]
|
| 714 |
+
}
|
backend/requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.109.0
|
| 2 |
+
uvicorn[standard]==0.27.0
|
| 3 |
+
python-multipart==0.0.6
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
torchvision>=0.15.0
|
| 6 |
+
onnx>=1.14.0
|
| 7 |
+
numpy>=1.24.0
|
| 8 |
+
pydantic>=2.0.0
|
| 9 |
+
h5py>=3.10.0
|
| 10 |
+
safetensors>=0.4.0
|
backend/start.bat
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
echo Starting NN3D Visualizer Backend...
|
| 3 |
+
echo.
|
| 4 |
+
|
| 5 |
+
REM Check if virtual environment exists
|
| 6 |
+
if not exist "venv" (
|
| 7 |
+
echo Creating virtual environment...
|
| 8 |
+
python -m venv venv
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
REM Activate virtual environment
|
| 12 |
+
call venv\Scripts\activate.bat
|
| 13 |
+
|
| 14 |
+
REM Install dependencies
|
| 15 |
+
echo Installing dependencies...
|
| 16 |
+
pip install -r requirements.txt --quiet
|
| 17 |
+
|
| 18 |
+
REM Start the server
|
| 19 |
+
echo.
|
| 20 |
+
echo Starting FastAPI server on http://localhost:8000
|
| 21 |
+
echo API docs available at http://localhost:8000/docs
|
| 22 |
+
echo.
|
| 23 |
+
python -m uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
backend/start.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
echo "Starting NN3D Visualizer Backend..."
|
| 3 |
+
echo
|
| 4 |
+
|
| 5 |
+
# Check if virtual environment exists
|
| 6 |
+
if [ ! -d "venv" ]; then
|
| 7 |
+
echo "Creating virtual environment..."
|
| 8 |
+
python3 -m venv venv
|
| 9 |
+
fi
|
| 10 |
+
|
| 11 |
+
# Activate virtual environment
|
| 12 |
+
source venv/bin/activate
|
| 13 |
+
|
| 14 |
+
# Install dependencies
|
| 15 |
+
echo "Installing dependencies..."
|
| 16 |
+
pip install -r requirements.txt --quiet
|
| 17 |
+
|
| 18 |
+
# Start the server
|
| 19 |
+
echo
|
| 20 |
+
echo "Starting FastAPI server on http://localhost:8000"
|
| 21 |
+
echo "API docs available at http://localhost:8000/docs"
|
| 22 |
+
echo
|
| 23 |
+
python -m uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker Compose for NN3D Visualizer
|
| 2 |
+
# Run: docker-compose up --build
|
| 3 |
+
|
| 4 |
+
version: '3.8'
|
| 5 |
+
|
| 6 |
+
services:
|
| 7 |
+
# FastAPI Backend - Model Analysis
|
| 8 |
+
backend:
|
| 9 |
+
build:
|
| 10 |
+
context: ./backend
|
| 11 |
+
dockerfile: Dockerfile
|
| 12 |
+
container_name: nn3d-backend
|
| 13 |
+
restart: unless-stopped
|
| 14 |
+
volumes:
|
| 15 |
+
# Persist database
|
| 16 |
+
- nn3d-data:/app/data
|
| 17 |
+
# Mount models directory for local model files (optional)
|
| 18 |
+
- ./samples:/app/samples:ro
|
| 19 |
+
environment:
|
| 20 |
+
- DATABASE_PATH=/app/data/models.db
|
| 21 |
+
- PYTHONUNBUFFERED=1
|
| 22 |
+
ports:
|
| 23 |
+
- "8000:8000"
|
| 24 |
+
networks:
|
| 25 |
+
- nn3d-network
|
| 26 |
+
healthcheck:
|
| 27 |
+
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
| 28 |
+
interval: 30s
|
| 29 |
+
timeout: 10s
|
| 30 |
+
retries: 3
|
| 31 |
+
start_period: 10s
|
| 32 |
+
|
| 33 |
+
# React Frontend - 3D Visualization
|
| 34 |
+
frontend:
|
| 35 |
+
build:
|
| 36 |
+
context: .
|
| 37 |
+
dockerfile: Dockerfile
|
| 38 |
+
container_name: nn3d-frontend
|
| 39 |
+
restart: unless-stopped
|
| 40 |
+
ports:
|
| 41 |
+
- "3000:80"
|
| 42 |
+
depends_on:
|
| 43 |
+
backend:
|
| 44 |
+
condition: service_healthy
|
| 45 |
+
networks:
|
| 46 |
+
- nn3d-network
|
| 47 |
+
healthcheck:
|
| 48 |
+
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:80/"]
|
| 49 |
+
interval: 30s
|
| 50 |
+
timeout: 10s
|
| 51 |
+
retries: 3
|
| 52 |
+
start_period: 5s
|
| 53 |
+
|
| 54 |
+
networks:
|
| 55 |
+
nn3d-network:
|
| 56 |
+
driver: bridge
|
| 57 |
+
|
| 58 |
+
volumes:
|
| 59 |
+
nn3d-data:
|
| 60 |
+
driver: local
|
docker-start.bat
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
REM NN3D Visualizer - Docker Startup Script (Windows)
|
| 3 |
+
|
| 4 |
+
echo.
|
| 5 |
+
echo ========================================
|
| 6 |
+
echo NN3D VISUALIZER - DOCKER SETUP
|
| 7 |
+
echo ========================================
|
| 8 |
+
echo.
|
| 9 |
+
|
| 10 |
+
REM Check if Docker is running
|
| 11 |
+
docker info >nul 2>&1
|
| 12 |
+
if errorlevel 1 (
|
| 13 |
+
echo [ERROR] Docker is not running. Please start Docker Desktop.
|
| 14 |
+
pause
|
| 15 |
+
exit /b 1
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
echo [*] Building and starting containers...
|
| 19 |
+
echo.
|
| 20 |
+
|
| 21 |
+
docker-compose up --build -d
|
| 22 |
+
|
| 23 |
+
if errorlevel 1 (
|
| 24 |
+
echo [ERROR] Failed to start containers.
|
| 25 |
+
pause
|
| 26 |
+
exit /b 1
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
echo.
|
| 30 |
+
echo ========================================
|
| 31 |
+
echo CONTAINERS STARTED SUCCESSFULLY
|
| 32 |
+
echo ========================================
|
| 33 |
+
echo.
|
| 34 |
+
echo Frontend: http://localhost:3000
|
| 35 |
+
echo Backend: http://localhost:8000
|
| 36 |
+
echo API Docs: http://localhost:8000/docs
|
| 37 |
+
echo.
|
| 38 |
+
echo Run 'docker-compose logs -f' to view logs
|
| 39 |
+
echo Run 'docker-compose down' to stop
|
| 40 |
+
echo ========================================
|
| 41 |
+
echo.
|
| 42 |
+
|
| 43 |
+
pause
|
docker-start.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# NN3D Visualizer - Docker Startup Script (Linux/Mac)
|
| 3 |
+
|
| 4 |
+
echo ""
|
| 5 |
+
echo "========================================"
|
| 6 |
+
echo " NN3D VISUALIZER - DOCKER SETUP"
|
| 7 |
+
echo "========================================"
|
| 8 |
+
echo ""
|
| 9 |
+
|
| 10 |
+
# Check if Docker is running
|
| 11 |
+
if ! docker info > /dev/null 2>&1; then
|
| 12 |
+
echo "[ERROR] Docker is not running. Please start Docker."
|
| 13 |
+
exit 1
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
echo "[*] Building and starting containers..."
|
| 17 |
+
echo ""
|
| 18 |
+
|
| 19 |
+
docker-compose up --build -d
|
| 20 |
+
|
| 21 |
+
if [ $? -ne 0 ]; then
|
| 22 |
+
echo "[ERROR] Failed to start containers."
|
| 23 |
+
exit 1
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
echo ""
|
| 27 |
+
echo "========================================"
|
| 28 |
+
echo " CONTAINERS STARTED SUCCESSFULLY"
|
| 29 |
+
echo "========================================"
|
| 30 |
+
echo ""
|
| 31 |
+
echo " Frontend: http://localhost:3000"
|
| 32 |
+
echo " Backend: http://localhost:8000"
|
| 33 |
+
echo " API Docs: http://localhost:8000/docs"
|
| 34 |
+
echo ""
|
| 35 |
+
echo " Run 'docker-compose logs -f' to view logs"
|
| 36 |
+
echo " Run 'docker-compose down' to stop"
|
| 37 |
+
echo "========================================"
|
| 38 |
+
echo ""
|
exporters/python/README.md
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NN3D Exporter
|
| 2 |
+
|
| 3 |
+
Python library to export neural network models to `.nn3d` format for 3D visualization.
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
# Basic installation
|
| 9 |
+
pip install nn3d-exporter
|
| 10 |
+
|
| 11 |
+
# With PyTorch support
|
| 12 |
+
pip install nn3d-exporter[pytorch]
|
| 13 |
+
|
| 14 |
+
# With ONNX support
|
| 15 |
+
pip install nn3d-exporter[onnx]
|
| 16 |
+
|
| 17 |
+
# With all frameworks
|
| 18 |
+
pip install nn3d-exporter[all]
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Quick Start
|
| 22 |
+
|
| 23 |
+
### Export PyTorch Model
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
from nn3d_exporter import export_pytorch_model
|
| 28 |
+
|
| 29 |
+
# Define your model
|
| 30 |
+
class SimpleCNN(nn.Module):
|
| 31 |
+
def __init__(self):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
|
| 34 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 35 |
+
self.relu = nn.ReLU()
|
| 36 |
+
self.pool = nn.MaxPool2d(2)
|
| 37 |
+
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
|
| 38 |
+
self.bn2 = nn.BatchNorm2d(128)
|
| 39 |
+
self.fc = nn.Linear(128 * 56 * 56, 10)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
x = self.pool(self.relu(self.bn1(self.conv1(x))))
|
| 43 |
+
x = self.pool(self.relu(self.bn2(self.conv2(x))))
|
| 44 |
+
x = x.view(x.size(0), -1)
|
| 45 |
+
return self.fc(x)
|
| 46 |
+
|
| 47 |
+
model = SimpleCNN()
|
| 48 |
+
|
| 49 |
+
# Export to .nn3d format
|
| 50 |
+
export_pytorch_model(
|
| 51 |
+
model,
|
| 52 |
+
output_path="simple_cnn.nn3d",
|
| 53 |
+
input_shape=(1, 3, 224, 224),
|
| 54 |
+
model_name="Simple CNN"
|
| 55 |
+
)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Export ONNX Model
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
from nn3d_exporter import export_onnx_model
|
| 62 |
+
|
| 63 |
+
# Export an existing ONNX model
|
| 64 |
+
export_onnx_model(
|
| 65 |
+
model_path="resnet50.onnx",
|
| 66 |
+
output_path="resnet50.nn3d",
|
| 67 |
+
model_name="ResNet-50"
|
| 68 |
+
)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### Using the Exporter Classes
|
| 72 |
+
|
| 73 |
+
For more control, use the exporter classes directly:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from nn3d_exporter import PyTorchExporter
|
| 77 |
+
|
| 78 |
+
exporter = PyTorchExporter(
|
| 79 |
+
model=model,
|
| 80 |
+
input_shape=(1, 3, 224, 224),
|
| 81 |
+
model_name="My Model"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Export to NN3DModel object
|
| 85 |
+
nn3d_model = exporter.export()
|
| 86 |
+
|
| 87 |
+
# Customize before saving
|
| 88 |
+
nn3d_model.visualization.theme = "blueprint"
|
| 89 |
+
nn3d_model.visualization.layout = "force"
|
| 90 |
+
|
| 91 |
+
# Save to file
|
| 92 |
+
nn3d_model.save("my_model.nn3d")
|
| 93 |
+
|
| 94 |
+
# Or get JSON string
|
| 95 |
+
json_str = nn3d_model.to_json()
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## Supported Frameworks
|
| 99 |
+
|
| 100 |
+
### PyTorch
|
| 101 |
+
|
| 102 |
+
- All standard `torch.nn` layers
|
| 103 |
+
- Custom modules (exported as "custom" type)
|
| 104 |
+
- Automatic shape inference via forward pass
|
| 105 |
+
- Parameter counting
|
| 106 |
+
|
| 107 |
+
### ONNX
|
| 108 |
+
|
| 109 |
+
- Standard ONNX operators
|
| 110 |
+
- Shape extraction from model metadata
|
| 111 |
+
- Operator attributes preserved
|
| 112 |
+
|
| 113 |
+
## API Reference
|
| 114 |
+
|
| 115 |
+
### `export_pytorch_model(model, output_path, input_shape=None, model_name=None)`
|
| 116 |
+
|
| 117 |
+
Export a PyTorch model to .nn3d file.
|
| 118 |
+
|
| 119 |
+
| Parameter | Type | Description |
|
| 120 |
+
| ------------- | ----------- | ----------------------------- |
|
| 121 |
+
| `model` | `nn.Module` | PyTorch model to export |
|
| 122 |
+
| `output_path` | `str` | Path for output .nn3d file |
|
| 123 |
+
| `input_shape` | `tuple` | Input tensor shape (optional) |
|
| 124 |
+
| `model_name` | `str` | Model name (optional) |
|
| 125 |
+
|
| 126 |
+
### `export_onnx_model(model_path, output_path, model_name=None)`
|
| 127 |
+
|
| 128 |
+
Export an ONNX model to .nn3d file.
|
| 129 |
+
|
| 130 |
+
| Parameter | Type | Description |
|
| 131 |
+
| ------------- | ----- | -------------------------- |
|
| 132 |
+
| `model_path` | `str` | Path to ONNX model file |
|
| 133 |
+
| `output_path` | `str` | Path for output .nn3d file |
|
| 134 |
+
| `model_name` | `str` | Model name (optional) |
|
| 135 |
+
|
| 136 |
+
## License
|
| 137 |
+
|
| 138 |
+
MIT License
|
exporters/python/nn3d_exporter/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NN3D Exporter - Export neural network models to .nn3d format
|
| 3 |
+
|
| 4 |
+
This package provides utilities to export models from various
|
| 5 |
+
deep learning frameworks to the .nn3d visualization format.
|
| 6 |
+
|
| 7 |
+
Supported frameworks:
|
| 8 |
+
- PyTorch
|
| 9 |
+
- ONNX
|
| 10 |
+
- TensorFlow/Keras (planned)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from .pytorch_exporter import PyTorchExporter, export_pytorch_model
|
| 14 |
+
from .onnx_exporter import ONNXExporter, export_onnx_model
|
| 15 |
+
from .schema import NN3DModel, NN3DNode, NN3DEdge, NN3DGraph, NN3DMetadata
|
| 16 |
+
|
| 17 |
+
__version__ = "1.0.0"
|
| 18 |
+
__all__ = [
|
| 19 |
+
"PyTorchExporter",
|
| 20 |
+
"export_pytorch_model",
|
| 21 |
+
"ONNXExporter",
|
| 22 |
+
"export_onnx_model",
|
| 23 |
+
"NN3DModel",
|
| 24 |
+
"NN3DNode",
|
| 25 |
+
"NN3DEdge",
|
| 26 |
+
"NN3DGraph",
|
| 27 |
+
"NN3DMetadata",
|
| 28 |
+
]
|
exporters/python/nn3d_exporter/onnx_exporter.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ONNX Model Exporter
|
| 3 |
+
|
| 4 |
+
Export ONNX models to .nn3d format for visualization.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, List, Optional, Tuple, Any, Union
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import onnx
|
| 13 |
+
from onnx import numpy_helper
|
| 14 |
+
ONNX_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
ONNX_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
from .schema import (
|
| 19 |
+
NN3DModel, NN3DGraph, NN3DNode, NN3DEdge, NN3DMetadata,
|
| 20 |
+
NN3DSubgraph, LayerParams, LayerType, VisualizationConfig
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Mapping from ONNX op types to NN3D layer types
|
| 25 |
+
ONNX_TO_NN3D_TYPE: Dict[str, str] = {
|
| 26 |
+
# Convolution
|
| 27 |
+
'Conv': LayerType.CONV2D.value,
|
| 28 |
+
'ConvTranspose': LayerType.CONV_TRANSPOSE_2D.value,
|
| 29 |
+
|
| 30 |
+
# Linear
|
| 31 |
+
'Gemm': LayerType.LINEAR.value,
|
| 32 |
+
'MatMul': LayerType.LINEAR.value,
|
| 33 |
+
|
| 34 |
+
# Normalization
|
| 35 |
+
'BatchNormalization': LayerType.BATCH_NORM_2D.value,
|
| 36 |
+
'LayerNormalization': LayerType.LAYER_NORM.value,
|
| 37 |
+
'InstanceNormalization': LayerType.INSTANCE_NORM.value,
|
| 38 |
+
'GroupNormalization': LayerType.GROUP_NORM.value,
|
| 39 |
+
'Dropout': LayerType.DROPOUT.value,
|
| 40 |
+
|
| 41 |
+
# Activations
|
| 42 |
+
'Relu': LayerType.RELU.value,
|
| 43 |
+
'LeakyRelu': LayerType.LEAKY_RELU.value,
|
| 44 |
+
'Sigmoid': LayerType.SIGMOID.value,
|
| 45 |
+
'Tanh': LayerType.TANH.value,
|
| 46 |
+
'Softmax': LayerType.SOFTMAX.value,
|
| 47 |
+
'Gelu': LayerType.GELU.value,
|
| 48 |
+
|
| 49 |
+
# Pooling
|
| 50 |
+
'MaxPool': LayerType.MAX_POOL_2D.value,
|
| 51 |
+
'AveragePool': LayerType.AVG_POOL_2D.value,
|
| 52 |
+
'GlobalAveragePool': LayerType.GLOBAL_AVG_POOL.value,
|
| 53 |
+
'GlobalMaxPool': LayerType.MAX_POOL_2D.value,
|
| 54 |
+
|
| 55 |
+
# Shape operations
|
| 56 |
+
'Flatten': LayerType.FLATTEN.value,
|
| 57 |
+
'Reshape': LayerType.RESHAPE.value,
|
| 58 |
+
'Transpose': LayerType.RESHAPE.value,
|
| 59 |
+
'Squeeze': LayerType.RESHAPE.value,
|
| 60 |
+
'Unsqueeze': LayerType.RESHAPE.value,
|
| 61 |
+
|
| 62 |
+
# Merge operations
|
| 63 |
+
'Concat': LayerType.CONCAT.value,
|
| 64 |
+
'Add': LayerType.ADD.value,
|
| 65 |
+
'Mul': LayerType.MULTIPLY.value,
|
| 66 |
+
'Split': LayerType.SPLIT.value,
|
| 67 |
+
|
| 68 |
+
# Attention
|
| 69 |
+
'Attention': LayerType.ATTENTION.value,
|
| 70 |
+
'MultiHeadAttention': LayerType.MULTI_HEAD_ATTENTION.value,
|
| 71 |
+
|
| 72 |
+
# Recurrent
|
| 73 |
+
'LSTM': LayerType.LSTM.value,
|
| 74 |
+
'GRU': LayerType.GRU.value,
|
| 75 |
+
'RNN': LayerType.RNN.value,
|
| 76 |
+
|
| 77 |
+
# Resize/Upsample
|
| 78 |
+
'Resize': LayerType.UPSAMPLE.value,
|
| 79 |
+
'Upsample': LayerType.UPSAMPLE.value,
|
| 80 |
+
|
| 81 |
+
# Padding
|
| 82 |
+
'Pad': LayerType.PAD.value,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ONNXExporter:
|
| 87 |
+
"""
|
| 88 |
+
Export ONNX models to NN3D format.
|
| 89 |
+
|
| 90 |
+
Usage:
|
| 91 |
+
exporter = ONNXExporter("model.onnx")
|
| 92 |
+
nn3d_model = exporter.export()
|
| 93 |
+
nn3d_model.save("model.nn3d")
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
model_path: str,
|
| 99 |
+
model_name: Optional[str] = None,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Initialize the exporter.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
model_path: Path to ONNX model file
|
| 106 |
+
model_name: Name for the model (defaults to filename)
|
| 107 |
+
"""
|
| 108 |
+
if not ONNX_AVAILABLE:
|
| 109 |
+
raise ImportError("ONNX is required. Install with: pip install onnx")
|
| 110 |
+
|
| 111 |
+
self.model_path = model_path
|
| 112 |
+
self.model = onnx.load(model_path)
|
| 113 |
+
self.model_name = model_name or model_path.split('/')[-1].replace('.onnx', '')
|
| 114 |
+
|
| 115 |
+
# Initialize graph
|
| 116 |
+
onnx.checker.check_model(self.model)
|
| 117 |
+
self.graph = self.model.graph
|
| 118 |
+
|
| 119 |
+
self.nodes: List[NN3DNode] = []
|
| 120 |
+
self.edges: List[NN3DEdge] = []
|
| 121 |
+
|
| 122 |
+
self._tensor_shapes: Dict[str, List[int]] = {}
|
| 123 |
+
self._value_info: Dict[str, Any] = {}
|
| 124 |
+
self._node_id_map: Dict[str, str] = {}
|
| 125 |
+
self._output_to_node: Dict[str, str] = {}
|
| 126 |
+
|
| 127 |
+
def _extract_shapes(self) -> None:
|
| 128 |
+
"""Extract tensor shapes from the model"""
|
| 129 |
+
# Input shapes
|
| 130 |
+
for input_info in self.graph.input:
|
| 131 |
+
shape = []
|
| 132 |
+
if input_info.type.tensor_type.HasField('shape'):
|
| 133 |
+
for dim in input_info.type.tensor_type.shape.dim:
|
| 134 |
+
if dim.HasField('dim_value'):
|
| 135 |
+
shape.append(dim.dim_value)
|
| 136 |
+
elif dim.HasField('dim_param'):
|
| 137 |
+
shape.append(dim.dim_param)
|
| 138 |
+
else:
|
| 139 |
+
shape.append(-1)
|
| 140 |
+
self._tensor_shapes[input_info.name] = shape
|
| 141 |
+
self._value_info[input_info.name] = input_info
|
| 142 |
+
|
| 143 |
+
# Output shapes
|
| 144 |
+
for output_info in self.graph.output:
|
| 145 |
+
shape = []
|
| 146 |
+
if output_info.type.tensor_type.HasField('shape'):
|
| 147 |
+
for dim in output_info.type.tensor_type.shape.dim:
|
| 148 |
+
if dim.HasField('dim_value'):
|
| 149 |
+
shape.append(dim.dim_value)
|
| 150 |
+
elif dim.HasField('dim_param'):
|
| 151 |
+
shape.append(dim.dim_param)
|
| 152 |
+
else:
|
| 153 |
+
shape.append(-1)
|
| 154 |
+
self._tensor_shapes[output_info.name] = shape
|
| 155 |
+
self._value_info[output_info.name] = output_info
|
| 156 |
+
|
| 157 |
+
# Value info (intermediate tensors)
|
| 158 |
+
for value_info in self.graph.value_info:
|
| 159 |
+
shape = []
|
| 160 |
+
if value_info.type.tensor_type.HasField('shape'):
|
| 161 |
+
for dim in value_info.type.tensor_type.shape.dim:
|
| 162 |
+
if dim.HasField('dim_value'):
|
| 163 |
+
shape.append(dim.dim_value)
|
| 164 |
+
elif dim.HasField('dim_param'):
|
| 165 |
+
shape.append(dim.dim_param)
|
| 166 |
+
else:
|
| 167 |
+
shape.append(-1)
|
| 168 |
+
self._tensor_shapes[value_info.name] = shape
|
| 169 |
+
self._value_info[value_info.name] = value_info
|
| 170 |
+
|
| 171 |
+
def _get_layer_type(self, op_type: str) -> str:
|
| 172 |
+
"""Map ONNX op type to NN3D layer type"""
|
| 173 |
+
return ONNX_TO_NN3D_TYPE.get(op_type, LayerType.CUSTOM.value)
|
| 174 |
+
|
| 175 |
+
def _extract_attributes(self, node) -> Dict[str, Any]:
|
| 176 |
+
"""Extract node attributes"""
|
| 177 |
+
attrs = {}
|
| 178 |
+
for attr in node.attribute:
|
| 179 |
+
if attr.type == onnx.AttributeProto.INT:
|
| 180 |
+
attrs[attr.name] = attr.i
|
| 181 |
+
elif attr.type == onnx.AttributeProto.INTS:
|
| 182 |
+
attrs[attr.name] = list(attr.ints)
|
| 183 |
+
elif attr.type == onnx.AttributeProto.FLOAT:
|
| 184 |
+
attrs[attr.name] = attr.f
|
| 185 |
+
elif attr.type == onnx.AttributeProto.FLOATS:
|
| 186 |
+
attrs[attr.name] = list(attr.floats)
|
| 187 |
+
elif attr.type == onnx.AttributeProto.STRING:
|
| 188 |
+
attrs[attr.name] = attr.s.decode('utf-8')
|
| 189 |
+
elif attr.type == onnx.AttributeProto.STRINGS:
|
| 190 |
+
attrs[attr.name] = [s.decode('utf-8') for s in attr.strings]
|
| 191 |
+
return attrs
|
| 192 |
+
|
| 193 |
+
def _extract_params(self, node, attrs: Dict[str, Any]) -> LayerParams:
|
| 194 |
+
"""Extract layer parameters from ONNX node"""
|
| 195 |
+
params = LayerParams()
|
| 196 |
+
|
| 197 |
+
# Convolution parameters
|
| 198 |
+
if 'kernel_shape' in attrs:
|
| 199 |
+
params.kernel_size = attrs['kernel_shape']
|
| 200 |
+
if 'strides' in attrs:
|
| 201 |
+
params.stride = attrs['strides']
|
| 202 |
+
if 'pads' in attrs:
|
| 203 |
+
params.padding = attrs['pads']
|
| 204 |
+
if 'dilations' in attrs:
|
| 205 |
+
params.dilation = attrs['dilations']
|
| 206 |
+
if 'group' in attrs:
|
| 207 |
+
params.groups = attrs['group']
|
| 208 |
+
|
| 209 |
+
# Normalization parameters
|
| 210 |
+
if 'epsilon' in attrs:
|
| 211 |
+
params.eps = attrs['epsilon']
|
| 212 |
+
if 'momentum' in attrs:
|
| 213 |
+
params.momentum = attrs['momentum']
|
| 214 |
+
|
| 215 |
+
# Dropout
|
| 216 |
+
if 'ratio' in attrs:
|
| 217 |
+
params.dropout_rate = attrs['ratio']
|
| 218 |
+
|
| 219 |
+
return params
|
| 220 |
+
|
| 221 |
+
def export(self) -> NN3DModel:
|
| 222 |
+
"""Export the ONNX model to NN3D format"""
|
| 223 |
+
|
| 224 |
+
# Extract tensor shapes
|
| 225 |
+
self._extract_shapes()
|
| 226 |
+
|
| 227 |
+
# Add input nodes
|
| 228 |
+
for idx, input_info in enumerate(self.graph.input):
|
| 229 |
+
# Skip initializers (weights)
|
| 230 |
+
if input_info.name in [init.name for init in self.graph.initializer]:
|
| 231 |
+
continue
|
| 232 |
+
|
| 233 |
+
node_id = f"input_{idx}"
|
| 234 |
+
shape = self._tensor_shapes.get(input_info.name, [])
|
| 235 |
+
|
| 236 |
+
input_node = NN3DNode(
|
| 237 |
+
id=node_id,
|
| 238 |
+
type=LayerType.INPUT.value,
|
| 239 |
+
name=input_info.name,
|
| 240 |
+
output_shape=shape if shape else None,
|
| 241 |
+
depth=0
|
| 242 |
+
)
|
| 243 |
+
self.nodes.append(input_node)
|
| 244 |
+
self._output_to_node[input_info.name] = node_id
|
| 245 |
+
|
| 246 |
+
# Process all operator nodes
|
| 247 |
+
for idx, node in enumerate(self.graph.node):
|
| 248 |
+
node_id = f"node_{idx}"
|
| 249 |
+
layer_type = self._get_layer_type(node.op_type)
|
| 250 |
+
attrs = self._extract_attributes(node)
|
| 251 |
+
params = self._extract_params(node, attrs)
|
| 252 |
+
|
| 253 |
+
# Get input/output shapes
|
| 254 |
+
input_shapes = [self._tensor_shapes.get(inp, []) for inp in node.input
|
| 255 |
+
if inp not in [init.name for init in self.graph.initializer]]
|
| 256 |
+
output_shapes = [self._tensor_shapes.get(out, []) for out in node.output]
|
| 257 |
+
|
| 258 |
+
input_shape = input_shapes[0] if input_shapes else None
|
| 259 |
+
output_shape = output_shapes[0] if output_shapes else None
|
| 260 |
+
|
| 261 |
+
# Create node
|
| 262 |
+
nn3d_node = NN3DNode(
|
| 263 |
+
id=node_id,
|
| 264 |
+
type=layer_type,
|
| 265 |
+
name=node.name or f"{node.op_type}_{idx}",
|
| 266 |
+
params=params if any(v is not None for v in [
|
| 267 |
+
params.kernel_size, params.stride, params.padding,
|
| 268 |
+
params.eps, params.dropout_rate
|
| 269 |
+
]) else None,
|
| 270 |
+
input_shape=input_shape if input_shape else None,
|
| 271 |
+
output_shape=output_shape if output_shape else None,
|
| 272 |
+
depth=idx + 1,
|
| 273 |
+
attributes={'op_type': node.op_type, **attrs} if attrs else {'op_type': node.op_type}
|
| 274 |
+
)
|
| 275 |
+
self.nodes.append(nn3d_node)
|
| 276 |
+
|
| 277 |
+
# Map outputs to this node
|
| 278 |
+
for output in node.output:
|
| 279 |
+
self._output_to_node[output] = node_id
|
| 280 |
+
|
| 281 |
+
# Create edges from inputs
|
| 282 |
+
for inp in node.input:
|
| 283 |
+
# Skip initializers (weights)
|
| 284 |
+
if inp in [init.name for init in self.graph.initializer]:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
source_id = self._output_to_node.get(inp)
|
| 288 |
+
if source_id:
|
| 289 |
+
edge = NN3DEdge(
|
| 290 |
+
source=source_id,
|
| 291 |
+
target=node_id,
|
| 292 |
+
tensor_shape=self._tensor_shapes.get(inp, None),
|
| 293 |
+
label=inp
|
| 294 |
+
)
|
| 295 |
+
self.edges.append(edge)
|
| 296 |
+
|
| 297 |
+
# Add output nodes
|
| 298 |
+
for idx, output_info in enumerate(self.graph.output):
|
| 299 |
+
node_id = f"output_{idx}"
|
| 300 |
+
shape = self._tensor_shapes.get(output_info.name, [])
|
| 301 |
+
|
| 302 |
+
output_node = NN3DNode(
|
| 303 |
+
id=node_id,
|
| 304 |
+
type=LayerType.OUTPUT.value,
|
| 305 |
+
name=output_info.name,
|
| 306 |
+
input_shape=shape if shape else None,
|
| 307 |
+
depth=len(self.graph.node) + 1
|
| 308 |
+
)
|
| 309 |
+
self.nodes.append(output_node)
|
| 310 |
+
|
| 311 |
+
# Create edge from last node producing this output
|
| 312 |
+
source_id = self._output_to_node.get(output_info.name)
|
| 313 |
+
if source_id:
|
| 314 |
+
edge = NN3DEdge(
|
| 315 |
+
source=source_id,
|
| 316 |
+
target=node_id,
|
| 317 |
+
tensor_shape=shape if shape else None
|
| 318 |
+
)
|
| 319 |
+
self.edges.append(edge)
|
| 320 |
+
|
| 321 |
+
# Get input/output shapes for metadata
|
| 322 |
+
input_shapes = [self._tensor_shapes.get(inp.name) for inp in self.graph.input
|
| 323 |
+
if inp.name not in [init.name for init in self.graph.initializer]]
|
| 324 |
+
output_shapes = [self._tensor_shapes.get(out.name) for out in self.graph.output]
|
| 325 |
+
|
| 326 |
+
# Create metadata
|
| 327 |
+
metadata = NN3DMetadata(
|
| 328 |
+
name=self.model_name,
|
| 329 |
+
framework="onnx",
|
| 330 |
+
created=datetime.now().isoformat(),
|
| 331 |
+
input_shape=input_shapes[0] if input_shapes else None,
|
| 332 |
+
output_shape=output_shapes[0] if output_shapes else None,
|
| 333 |
+
description=f"Converted from ONNX model: {self.model_path}"
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Create graph
|
| 337 |
+
graph = NN3DGraph(
|
| 338 |
+
nodes=self.nodes,
|
| 339 |
+
edges=self.edges,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Create visualization config
|
| 343 |
+
viz_config = VisualizationConfig()
|
| 344 |
+
|
| 345 |
+
return NN3DModel(
|
| 346 |
+
metadata=metadata,
|
| 347 |
+
graph=graph,
|
| 348 |
+
visualization=viz_config
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def export_onnx_model(
|
| 353 |
+
model_path: str,
|
| 354 |
+
output_path: str,
|
| 355 |
+
model_name: Optional[str] = None,
|
| 356 |
+
) -> NN3DModel:
|
| 357 |
+
"""
|
| 358 |
+
Convenience function to export an ONNX model to .nn3d file.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
model_path: Path to ONNX model file
|
| 362 |
+
output_path: Path to save the .nn3d file
|
| 363 |
+
model_name: Name for the model
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
The exported NN3DModel
|
| 367 |
+
"""
|
| 368 |
+
exporter = ONNXExporter(model_path, model_name)
|
| 369 |
+
nn3d_model = exporter.export()
|
| 370 |
+
nn3d_model.save(output_path)
|
| 371 |
+
return nn3d_model
|
exporters/python/nn3d_exporter/pytorch_exporter.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch Model Exporter
|
| 3 |
+
|
| 4 |
+
Export PyTorch models to .nn3d format for visualization.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Any, Union
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
|
| 13 |
+
from .schema import (
|
| 14 |
+
NN3DModel, NN3DGraph, NN3DNode, NN3DEdge, NN3DMetadata,
|
| 15 |
+
NN3DSubgraph, LayerParams, LayerType, VisualizationConfig
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Mapping from PyTorch module types to NN3D layer types
|
| 20 |
+
PYTORCH_TO_NN3D_TYPE: Dict[type, str] = {
|
| 21 |
+
# Convolution layers
|
| 22 |
+
nn.Conv1d: LayerType.CONV1D.value,
|
| 23 |
+
nn.Conv2d: LayerType.CONV2D.value,
|
| 24 |
+
nn.Conv3d: LayerType.CONV3D.value,
|
| 25 |
+
nn.ConvTranspose2d: LayerType.CONV_TRANSPOSE_2D.value,
|
| 26 |
+
|
| 27 |
+
# Linear layers
|
| 28 |
+
nn.Linear: LayerType.LINEAR.value,
|
| 29 |
+
nn.Embedding: LayerType.EMBEDDING.value,
|
| 30 |
+
|
| 31 |
+
# Normalization layers
|
| 32 |
+
nn.BatchNorm1d: LayerType.BATCH_NORM_1D.value,
|
| 33 |
+
nn.BatchNorm2d: LayerType.BATCH_NORM_2D.value,
|
| 34 |
+
nn.LayerNorm: LayerType.LAYER_NORM.value,
|
| 35 |
+
nn.GroupNorm: LayerType.GROUP_NORM.value,
|
| 36 |
+
nn.InstanceNorm2d: LayerType.INSTANCE_NORM.value,
|
| 37 |
+
nn.Dropout: LayerType.DROPOUT.value,
|
| 38 |
+
nn.Dropout2d: LayerType.DROPOUT.value,
|
| 39 |
+
|
| 40 |
+
# Activation layers
|
| 41 |
+
nn.ReLU: LayerType.RELU.value,
|
| 42 |
+
nn.LeakyReLU: LayerType.LEAKY_RELU.value,
|
| 43 |
+
nn.GELU: LayerType.GELU.value,
|
| 44 |
+
nn.SiLU: LayerType.SILU.value,
|
| 45 |
+
nn.Sigmoid: LayerType.SIGMOID.value,
|
| 46 |
+
nn.Tanh: LayerType.TANH.value,
|
| 47 |
+
nn.Softmax: LayerType.SOFTMAX.value,
|
| 48 |
+
|
| 49 |
+
# Pooling layers
|
| 50 |
+
nn.MaxPool1d: LayerType.MAX_POOL_1D.value,
|
| 51 |
+
nn.MaxPool2d: LayerType.MAX_POOL_2D.value,
|
| 52 |
+
nn.AvgPool2d: LayerType.AVG_POOL_2D.value,
|
| 53 |
+
nn.AdaptiveAvgPool2d: LayerType.ADAPTIVE_AVG_POOL.value,
|
| 54 |
+
nn.AdaptiveAvgPool1d: LayerType.ADAPTIVE_AVG_POOL.value,
|
| 55 |
+
|
| 56 |
+
# Recurrent layers
|
| 57 |
+
nn.LSTM: LayerType.LSTM.value,
|
| 58 |
+
nn.GRU: LayerType.GRU.value,
|
| 59 |
+
nn.RNN: LayerType.RNN.value,
|
| 60 |
+
|
| 61 |
+
# Attention layers
|
| 62 |
+
nn.MultiheadAttention: LayerType.MULTI_HEAD_ATTENTION.value,
|
| 63 |
+
|
| 64 |
+
# Transform layers
|
| 65 |
+
nn.Flatten: LayerType.FLATTEN.value,
|
| 66 |
+
nn.Upsample: LayerType.UPSAMPLE.value,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class PyTorchExporter:
|
| 71 |
+
"""
|
| 72 |
+
Export PyTorch models to NN3D format.
|
| 73 |
+
|
| 74 |
+
Usage:
|
| 75 |
+
exporter = PyTorchExporter(model, input_shape=(1, 3, 224, 224))
|
| 76 |
+
nn3d_model = exporter.export()
|
| 77 |
+
nn3d_model.save("model.nn3d")
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
model: nn.Module,
|
| 83 |
+
input_shape: Optional[Tuple[int, ...]] = None,
|
| 84 |
+
model_name: Optional[str] = None,
|
| 85 |
+
include_activations: bool = False,
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Initialize the exporter.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
model: PyTorch model to export
|
| 92 |
+
input_shape: Input tensor shape (batch, channels, height, width)
|
| 93 |
+
model_name: Name for the model (defaults to class name)
|
| 94 |
+
include_activations: Whether to trace activations (requires input_shape)
|
| 95 |
+
"""
|
| 96 |
+
self.model = model
|
| 97 |
+
self.input_shape = input_shape
|
| 98 |
+
self.model_name = model_name or model.__class__.__name__
|
| 99 |
+
self.include_activations = include_activations
|
| 100 |
+
|
| 101 |
+
self.nodes: List[NN3DNode] = []
|
| 102 |
+
self.edges: List[NN3DEdge] = []
|
| 103 |
+
self.subgraphs: List[NN3DSubgraph] = []
|
| 104 |
+
|
| 105 |
+
self._node_id_counter = 0
|
| 106 |
+
self._module_to_id: Dict[nn.Module, str] = {}
|
| 107 |
+
self._shapes: Dict[str, Tuple[List[int], List[int]]] = {}
|
| 108 |
+
|
| 109 |
+
def _get_node_id(self) -> str:
|
| 110 |
+
"""Generate unique node ID"""
|
| 111 |
+
node_id = f"node_{self._node_id_counter}"
|
| 112 |
+
self._node_id_counter += 1
|
| 113 |
+
return node_id
|
| 114 |
+
|
| 115 |
+
def _get_layer_type(self, module: nn.Module) -> str:
|
| 116 |
+
"""Map PyTorch module to NN3D layer type"""
|
| 117 |
+
module_type = type(module)
|
| 118 |
+
|
| 119 |
+
if module_type in PYTORCH_TO_NN3D_TYPE:
|
| 120 |
+
return PYTORCH_TO_NN3D_TYPE[module_type]
|
| 121 |
+
|
| 122 |
+
# Check for common container patterns
|
| 123 |
+
class_name = module_type.__name__.lower()
|
| 124 |
+
|
| 125 |
+
if 'attention' in class_name:
|
| 126 |
+
return LayerType.ATTENTION.value
|
| 127 |
+
elif 'transformer' in class_name:
|
| 128 |
+
return LayerType.TRANSFORMER.value
|
| 129 |
+
elif 'encoder' in class_name:
|
| 130 |
+
return LayerType.ENCODER_BLOCK.value
|
| 131 |
+
elif 'decoder' in class_name:
|
| 132 |
+
return LayerType.DECODER_BLOCK.value
|
| 133 |
+
elif 'residual' in class_name or 'resblock' in class_name:
|
| 134 |
+
return LayerType.RESIDUAL_BLOCK.value
|
| 135 |
+
|
| 136 |
+
return LayerType.CUSTOM.value
|
| 137 |
+
|
| 138 |
+
def _extract_params(self, module: nn.Module) -> LayerParams:
|
| 139 |
+
"""Extract layer parameters from PyTorch module"""
|
| 140 |
+
params = LayerParams()
|
| 141 |
+
|
| 142 |
+
# Convolution parameters
|
| 143 |
+
if hasattr(module, 'in_channels'):
|
| 144 |
+
params.in_channels = module.in_channels
|
| 145 |
+
if hasattr(module, 'out_channels'):
|
| 146 |
+
params.out_channels = module.out_channels
|
| 147 |
+
if hasattr(module, 'kernel_size'):
|
| 148 |
+
ks = module.kernel_size
|
| 149 |
+
params.kernel_size = list(ks) if isinstance(ks, tuple) else ks
|
| 150 |
+
if hasattr(module, 'stride'):
|
| 151 |
+
stride = module.stride
|
| 152 |
+
params.stride = list(stride) if isinstance(stride, tuple) else stride
|
| 153 |
+
if hasattr(module, 'padding'):
|
| 154 |
+
pad = module.padding
|
| 155 |
+
params.padding = list(pad) if isinstance(pad, tuple) else pad
|
| 156 |
+
if hasattr(module, 'dilation'):
|
| 157 |
+
dil = module.dilation
|
| 158 |
+
params.dilation = list(dil) if isinstance(dil, tuple) else dil
|
| 159 |
+
if hasattr(module, 'groups'):
|
| 160 |
+
params.groups = module.groups
|
| 161 |
+
|
| 162 |
+
# Linear parameters
|
| 163 |
+
if hasattr(module, 'in_features'):
|
| 164 |
+
params.in_features = module.in_features
|
| 165 |
+
if hasattr(module, 'out_features'):
|
| 166 |
+
params.out_features = module.out_features
|
| 167 |
+
|
| 168 |
+
# Attention parameters
|
| 169 |
+
if hasattr(module, 'num_heads'):
|
| 170 |
+
params.num_heads = module.num_heads
|
| 171 |
+
if hasattr(module, 'embed_dim'):
|
| 172 |
+
params.hidden_size = module.embed_dim
|
| 173 |
+
|
| 174 |
+
# Normalization parameters
|
| 175 |
+
if hasattr(module, 'eps'):
|
| 176 |
+
params.eps = module.eps
|
| 177 |
+
if hasattr(module, 'momentum') and module.momentum is not None:
|
| 178 |
+
params.momentum = module.momentum
|
| 179 |
+
if hasattr(module, 'affine'):
|
| 180 |
+
params.affine = module.affine
|
| 181 |
+
|
| 182 |
+
# Dropout parameters
|
| 183 |
+
if hasattr(module, 'p'):
|
| 184 |
+
params.dropout_rate = module.p
|
| 185 |
+
|
| 186 |
+
# Embedding parameters
|
| 187 |
+
if hasattr(module, 'num_embeddings'):
|
| 188 |
+
params.num_embeddings = module.num_embeddings
|
| 189 |
+
if hasattr(module, 'embedding_dim'):
|
| 190 |
+
params.embedding_dim = module.embedding_dim
|
| 191 |
+
|
| 192 |
+
# Bias
|
| 193 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 194 |
+
params.bias = True
|
| 195 |
+
elif hasattr(module, 'bias'):
|
| 196 |
+
params.bias = False
|
| 197 |
+
|
| 198 |
+
return params
|
| 199 |
+
|
| 200 |
+
def _trace_shapes(self) -> None:
|
| 201 |
+
"""Trace tensor shapes through the model"""
|
| 202 |
+
if self.input_shape is None:
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
hooks = []
|
| 206 |
+
|
| 207 |
+
def hook_fn(name: str):
|
| 208 |
+
def hook(module, input, output):
|
| 209 |
+
input_shape = None
|
| 210 |
+
output_shape = None
|
| 211 |
+
|
| 212 |
+
if isinstance(input, tuple) and len(input) > 0:
|
| 213 |
+
if isinstance(input[0], torch.Tensor):
|
| 214 |
+
input_shape = list(input[0].shape)
|
| 215 |
+
elif isinstance(input, torch.Tensor):
|
| 216 |
+
input_shape = list(input.shape)
|
| 217 |
+
|
| 218 |
+
if isinstance(output, torch.Tensor):
|
| 219 |
+
output_shape = list(output.shape)
|
| 220 |
+
elif isinstance(output, tuple) and len(output) > 0:
|
| 221 |
+
if isinstance(output[0], torch.Tensor):
|
| 222 |
+
output_shape = list(output[0].shape)
|
| 223 |
+
|
| 224 |
+
self._shapes[name] = (input_shape, output_shape)
|
| 225 |
+
return hook
|
| 226 |
+
|
| 227 |
+
# Register hooks
|
| 228 |
+
for name, module in self.model.named_modules():
|
| 229 |
+
if len(list(module.children())) == 0: # Leaf modules only
|
| 230 |
+
hooks.append(module.register_forward_hook(hook_fn(name)))
|
| 231 |
+
|
| 232 |
+
# Run forward pass
|
| 233 |
+
try:
|
| 234 |
+
self.model.eval()
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
dummy_input = torch.zeros(self.input_shape)
|
| 237 |
+
self.model(dummy_input)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"Warning: Could not trace shapes: {e}")
|
| 240 |
+
finally:
|
| 241 |
+
# Remove hooks
|
| 242 |
+
for hook in hooks:
|
| 243 |
+
hook.remove()
|
| 244 |
+
|
| 245 |
+
def _process_module(
|
| 246 |
+
self,
|
| 247 |
+
module: nn.Module,
|
| 248 |
+
name: str,
|
| 249 |
+
parent_id: Optional[str] = None,
|
| 250 |
+
depth: int = 0
|
| 251 |
+
) -> Optional[str]:
|
| 252 |
+
"""Process a single module and its children"""
|
| 253 |
+
|
| 254 |
+
# Skip container modules without parameters
|
| 255 |
+
children = list(module.named_children())
|
| 256 |
+
is_leaf = len(children) == 0
|
| 257 |
+
|
| 258 |
+
# Create node for leaf modules or significant containers
|
| 259 |
+
layer_type = self._get_layer_type(module)
|
| 260 |
+
|
| 261 |
+
if is_leaf or layer_type != LayerType.CUSTOM.value:
|
| 262 |
+
node_id = self._get_node_id()
|
| 263 |
+
self._module_to_id[module] = node_id
|
| 264 |
+
|
| 265 |
+
# Get shapes if traced
|
| 266 |
+
input_shape, output_shape = self._shapes.get(name, (None, None))
|
| 267 |
+
|
| 268 |
+
# Extract parameters
|
| 269 |
+
params = self._extract_params(module)
|
| 270 |
+
|
| 271 |
+
# Count parameters
|
| 272 |
+
num_params = sum(p.numel() for p in module.parameters(recurse=False))
|
| 273 |
+
|
| 274 |
+
node = NN3DNode(
|
| 275 |
+
id=node_id,
|
| 276 |
+
type=layer_type,
|
| 277 |
+
name=name or module.__class__.__name__,
|
| 278 |
+
params=params if any(v is not None for v in [
|
| 279 |
+
params.in_channels, params.out_channels,
|
| 280 |
+
params.in_features, params.out_features,
|
| 281 |
+
params.kernel_size, params.num_heads
|
| 282 |
+
]) else None,
|
| 283 |
+
input_shape=input_shape,
|
| 284 |
+
output_shape=output_shape,
|
| 285 |
+
depth=depth,
|
| 286 |
+
attributes={'num_params': num_params} if num_params > 0 else None
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
self.nodes.append(node)
|
| 290 |
+
|
| 291 |
+
# Create edge from parent
|
| 292 |
+
if parent_id:
|
| 293 |
+
edge = NN3DEdge(
|
| 294 |
+
source=parent_id,
|
| 295 |
+
target=node_id,
|
| 296 |
+
tensor_shape=input_shape,
|
| 297 |
+
)
|
| 298 |
+
self.edges.append(edge)
|
| 299 |
+
|
| 300 |
+
# Process children
|
| 301 |
+
if children:
|
| 302 |
+
subgraph_nodes = [node_id]
|
| 303 |
+
prev_id = node_id
|
| 304 |
+
|
| 305 |
+
for child_name, child in children:
|
| 306 |
+
full_name = f"{name}.{child_name}" if name else child_name
|
| 307 |
+
child_id = self._process_module(child, full_name, prev_id, depth + 1)
|
| 308 |
+
if child_id:
|
| 309 |
+
subgraph_nodes.append(child_id)
|
| 310 |
+
prev_id = child_id
|
| 311 |
+
|
| 312 |
+
# Create subgraph for container
|
| 313 |
+
if len(subgraph_nodes) > 1:
|
| 314 |
+
self.subgraphs.append(NN3DSubgraph(
|
| 315 |
+
id=f"subgraph_{node_id}",
|
| 316 |
+
name=name or module.__class__.__name__,
|
| 317 |
+
nodes=subgraph_nodes,
|
| 318 |
+
type='sequential'
|
| 319 |
+
))
|
| 320 |
+
|
| 321 |
+
return prev_id
|
| 322 |
+
|
| 323 |
+
return node_id
|
| 324 |
+
|
| 325 |
+
# Process children of container
|
| 326 |
+
prev_id = parent_id
|
| 327 |
+
for child_name, child in children:
|
| 328 |
+
full_name = f"{name}.{child_name}" if name else child_name
|
| 329 |
+
child_id = self._process_module(child, full_name, prev_id, depth)
|
| 330 |
+
if child_id:
|
| 331 |
+
prev_id = child_id
|
| 332 |
+
|
| 333 |
+
return prev_id
|
| 334 |
+
|
| 335 |
+
def export(self) -> NN3DModel:
|
| 336 |
+
"""Export the model to NN3D format"""
|
| 337 |
+
|
| 338 |
+
# Trace shapes first
|
| 339 |
+
self._trace_shapes()
|
| 340 |
+
|
| 341 |
+
# Add input node
|
| 342 |
+
input_id = self._get_node_id()
|
| 343 |
+
input_node = NN3DNode(
|
| 344 |
+
id=input_id,
|
| 345 |
+
type=LayerType.INPUT.value,
|
| 346 |
+
name="input",
|
| 347 |
+
output_shape=list(self.input_shape) if self.input_shape else None,
|
| 348 |
+
depth=0
|
| 349 |
+
)
|
| 350 |
+
self.nodes.append(input_node)
|
| 351 |
+
|
| 352 |
+
# Process all modules
|
| 353 |
+
last_id = self._process_module(self.model, "", input_id, 1)
|
| 354 |
+
|
| 355 |
+
# Add output node
|
| 356 |
+
output_id = self._get_node_id()
|
| 357 |
+
output_shape = None
|
| 358 |
+
if last_id and self.nodes:
|
| 359 |
+
# Get output shape from last processed node
|
| 360 |
+
for node in reversed(self.nodes):
|
| 361 |
+
if node.output_shape:
|
| 362 |
+
output_shape = node.output_shape
|
| 363 |
+
break
|
| 364 |
+
|
| 365 |
+
output_node = NN3DNode(
|
| 366 |
+
id=output_id,
|
| 367 |
+
type=LayerType.OUTPUT.value,
|
| 368 |
+
name="output",
|
| 369 |
+
input_shape=output_shape,
|
| 370 |
+
depth=len(self.nodes)
|
| 371 |
+
)
|
| 372 |
+
self.nodes.append(output_node)
|
| 373 |
+
|
| 374 |
+
if last_id:
|
| 375 |
+
self.edges.append(NN3DEdge(
|
| 376 |
+
source=last_id,
|
| 377 |
+
target=output_id,
|
| 378 |
+
tensor_shape=output_shape
|
| 379 |
+
))
|
| 380 |
+
|
| 381 |
+
# Count parameters
|
| 382 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 383 |
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 384 |
+
|
| 385 |
+
# Create metadata
|
| 386 |
+
metadata = NN3DMetadata(
|
| 387 |
+
name=self.model_name,
|
| 388 |
+
framework="pytorch",
|
| 389 |
+
created=datetime.now().isoformat(),
|
| 390 |
+
input_shape=list(self.input_shape) if self.input_shape else None,
|
| 391 |
+
output_shape=output_shape,
|
| 392 |
+
total_params=total_params,
|
| 393 |
+
trainable_params=trainable_params,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Create graph
|
| 397 |
+
graph = NN3DGraph(
|
| 398 |
+
nodes=self.nodes,
|
| 399 |
+
edges=self.edges,
|
| 400 |
+
subgraphs=self.subgraphs if self.subgraphs else None
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# Create visualization config
|
| 404 |
+
viz_config = VisualizationConfig()
|
| 405 |
+
|
| 406 |
+
return NN3DModel(
|
| 407 |
+
metadata=metadata,
|
| 408 |
+
graph=graph,
|
| 409 |
+
visualization=viz_config
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def export_pytorch_model(
|
| 414 |
+
model: nn.Module,
|
| 415 |
+
output_path: str,
|
| 416 |
+
input_shape: Optional[Tuple[int, ...]] = None,
|
| 417 |
+
model_name: Optional[str] = None,
|
| 418 |
+
) -> NN3DModel:
|
| 419 |
+
"""
|
| 420 |
+
Convenience function to export a PyTorch model to .nn3d file.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
model: PyTorch model to export
|
| 424 |
+
output_path: Path to save the .nn3d file
|
| 425 |
+
input_shape: Input tensor shape (batch, channels, height, width)
|
| 426 |
+
model_name: Name for the model
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
The exported NN3DModel
|
| 430 |
+
"""
|
| 431 |
+
exporter = PyTorchExporter(model, input_shape, model_name)
|
| 432 |
+
nn3d_model = exporter.export()
|
| 433 |
+
nn3d_model.save(output_path)
|
| 434 |
+
return nn3d_model
|
exporters/python/nn3d_exporter/schema.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NN3D Schema definitions using Python dataclasses
|
| 3 |
+
|
| 4 |
+
These classes mirror the JSON schema and can be serialized to .nn3d format.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field, asdict
|
| 8 |
+
from typing import List, Optional, Dict, Any, Union
|
| 9 |
+
from enum import Enum
|
| 10 |
+
import json
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LayerType(str, Enum):
|
| 15 |
+
"""Supported layer types"""
|
| 16 |
+
INPUT = "input"
|
| 17 |
+
OUTPUT = "output"
|
| 18 |
+
CONV1D = "conv1d"
|
| 19 |
+
CONV2D = "conv2d"
|
| 20 |
+
CONV3D = "conv3d"
|
| 21 |
+
CONV_TRANSPOSE_2D = "convTranspose2d"
|
| 22 |
+
DEPTHWISE_CONV2D = "depthwiseConv2d"
|
| 23 |
+
SEPARABLE_CONV2D = "separableConv2d"
|
| 24 |
+
LINEAR = "linear"
|
| 25 |
+
DENSE = "dense"
|
| 26 |
+
EMBEDDING = "embedding"
|
| 27 |
+
BATCH_NORM_1D = "batchNorm1d"
|
| 28 |
+
BATCH_NORM_2D = "batchNorm2d"
|
| 29 |
+
LAYER_NORM = "layerNorm"
|
| 30 |
+
GROUP_NORM = "groupNorm"
|
| 31 |
+
INSTANCE_NORM = "instanceNorm"
|
| 32 |
+
DROPOUT = "dropout"
|
| 33 |
+
RELU = "relu"
|
| 34 |
+
LEAKY_RELU = "leakyRelu"
|
| 35 |
+
GELU = "gelu"
|
| 36 |
+
SILU = "silu"
|
| 37 |
+
SIGMOID = "sigmoid"
|
| 38 |
+
TANH = "tanh"
|
| 39 |
+
SOFTMAX = "softmax"
|
| 40 |
+
MAX_POOL_1D = "maxPool1d"
|
| 41 |
+
MAX_POOL_2D = "maxPool2d"
|
| 42 |
+
AVG_POOL_2D = "avgPool2d"
|
| 43 |
+
GLOBAL_AVG_POOL = "globalAvgPool"
|
| 44 |
+
ADAPTIVE_AVG_POOL = "adaptiveAvgPool"
|
| 45 |
+
FLATTEN = "flatten"
|
| 46 |
+
RESHAPE = "reshape"
|
| 47 |
+
CONCAT = "concat"
|
| 48 |
+
ADD = "add"
|
| 49 |
+
MULTIPLY = "multiply"
|
| 50 |
+
SPLIT = "split"
|
| 51 |
+
ATTENTION = "attention"
|
| 52 |
+
MULTI_HEAD_ATTENTION = "multiHeadAttention"
|
| 53 |
+
SELF_ATTENTION = "selfAttention"
|
| 54 |
+
CROSS_ATTENTION = "crossAttention"
|
| 55 |
+
LSTM = "lstm"
|
| 56 |
+
GRU = "gru"
|
| 57 |
+
RNN = "rnn"
|
| 58 |
+
TRANSFORMER = "transformer"
|
| 59 |
+
ENCODER_BLOCK = "encoderBlock"
|
| 60 |
+
DECODER_BLOCK = "decoderBlock"
|
| 61 |
+
RESIDUAL_BLOCK = "residualBlock"
|
| 62 |
+
UPSAMPLE = "upsample"
|
| 63 |
+
INTERPOLATE = "interpolate"
|
| 64 |
+
PAD = "pad"
|
| 65 |
+
CUSTOM = "custom"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class Position3D:
|
| 70 |
+
"""3D position"""
|
| 71 |
+
x: float = 0.0
|
| 72 |
+
y: float = 0.0
|
| 73 |
+
z: float = 0.0
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class LayerParams:
|
| 78 |
+
"""Layer parameters"""
|
| 79 |
+
in_channels: Optional[int] = None
|
| 80 |
+
out_channels: Optional[int] = None
|
| 81 |
+
in_features: Optional[int] = None
|
| 82 |
+
out_features: Optional[int] = None
|
| 83 |
+
kernel_size: Optional[Union[int, List[int]]] = None
|
| 84 |
+
stride: Optional[Union[int, List[int]]] = None
|
| 85 |
+
padding: Optional[Union[int, str, List[int]]] = None
|
| 86 |
+
dilation: Optional[Union[int, List[int]]] = None
|
| 87 |
+
groups: Optional[int] = None
|
| 88 |
+
bias: Optional[bool] = None
|
| 89 |
+
num_heads: Optional[int] = None
|
| 90 |
+
hidden_size: Optional[int] = None
|
| 91 |
+
dropout_rate: Optional[float] = None
|
| 92 |
+
eps: Optional[float] = None
|
| 93 |
+
momentum: Optional[float] = None
|
| 94 |
+
affine: Optional[bool] = None
|
| 95 |
+
num_embeddings: Optional[int] = None
|
| 96 |
+
embedding_dim: Optional[int] = None
|
| 97 |
+
|
| 98 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 99 |
+
"""Convert to dict with camelCase keys, excluding None values"""
|
| 100 |
+
key_mapping = {
|
| 101 |
+
'in_channels': 'inChannels',
|
| 102 |
+
'out_channels': 'outChannels',
|
| 103 |
+
'in_features': 'inFeatures',
|
| 104 |
+
'out_features': 'outFeatures',
|
| 105 |
+
'kernel_size': 'kernelSize',
|
| 106 |
+
'num_heads': 'numHeads',
|
| 107 |
+
'hidden_size': 'hiddenSize',
|
| 108 |
+
'dropout_rate': 'dropoutRate',
|
| 109 |
+
'num_embeddings': 'numEmbeddings',
|
| 110 |
+
'embedding_dim': 'embeddingDim',
|
| 111 |
+
}
|
| 112 |
+
result = {}
|
| 113 |
+
for key, value in asdict(self).items():
|
| 114 |
+
if value is not None:
|
| 115 |
+
camel_key = key_mapping.get(key, key)
|
| 116 |
+
result[camel_key] = value
|
| 117 |
+
return result
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass
|
| 121 |
+
class NN3DNode:
|
| 122 |
+
"""Graph node representing a layer"""
|
| 123 |
+
id: str
|
| 124 |
+
type: str # LayerType value
|
| 125 |
+
name: str
|
| 126 |
+
params: Optional[LayerParams] = None
|
| 127 |
+
input_shape: Optional[List[Union[int, str]]] = None
|
| 128 |
+
output_shape: Optional[List[Union[int, str]]] = None
|
| 129 |
+
position: Optional[Position3D] = None
|
| 130 |
+
attributes: Optional[Dict[str, Any]] = None
|
| 131 |
+
group: Optional[str] = None
|
| 132 |
+
depth: Optional[int] = None
|
| 133 |
+
|
| 134 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 135 |
+
result = {
|
| 136 |
+
'id': self.id,
|
| 137 |
+
'type': self.type,
|
| 138 |
+
'name': self.name,
|
| 139 |
+
}
|
| 140 |
+
if self.params:
|
| 141 |
+
result['params'] = self.params.to_dict()
|
| 142 |
+
if self.input_shape:
|
| 143 |
+
result['inputShape'] = self.input_shape
|
| 144 |
+
if self.output_shape:
|
| 145 |
+
result['outputShape'] = self.output_shape
|
| 146 |
+
if self.position:
|
| 147 |
+
result['position'] = asdict(self.position)
|
| 148 |
+
if self.attributes:
|
| 149 |
+
result['attributes'] = self.attributes
|
| 150 |
+
if self.group:
|
| 151 |
+
result['group'] = self.group
|
| 152 |
+
if self.depth is not None:
|
| 153 |
+
result['depth'] = self.depth
|
| 154 |
+
return result
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@dataclass
|
| 158 |
+
class NN3DEdge:
|
| 159 |
+
"""Graph edge representing a connection"""
|
| 160 |
+
source: str
|
| 161 |
+
target: str
|
| 162 |
+
id: Optional[str] = None
|
| 163 |
+
source_port: Optional[int] = None
|
| 164 |
+
target_port: Optional[int] = None
|
| 165 |
+
tensor_shape: Optional[List[Union[int, str]]] = None
|
| 166 |
+
dtype: Optional[str] = None
|
| 167 |
+
label: Optional[str] = None
|
| 168 |
+
|
| 169 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 170 |
+
result = {
|
| 171 |
+
'source': self.source,
|
| 172 |
+
'target': self.target,
|
| 173 |
+
}
|
| 174 |
+
if self.id:
|
| 175 |
+
result['id'] = self.id
|
| 176 |
+
if self.source_port is not None:
|
| 177 |
+
result['sourcePort'] = self.source_port
|
| 178 |
+
if self.target_port is not None:
|
| 179 |
+
result['targetPort'] = self.target_port
|
| 180 |
+
if self.tensor_shape:
|
| 181 |
+
result['tensorShape'] = self.tensor_shape
|
| 182 |
+
if self.dtype:
|
| 183 |
+
result['dtype'] = self.dtype
|
| 184 |
+
if self.label:
|
| 185 |
+
result['label'] = self.label
|
| 186 |
+
return result
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@dataclass
|
| 190 |
+
class NN3DSubgraph:
|
| 191 |
+
"""Subgraph for grouping layers"""
|
| 192 |
+
id: str
|
| 193 |
+
name: str
|
| 194 |
+
nodes: List[str] = field(default_factory=list)
|
| 195 |
+
type: Optional[str] = None
|
| 196 |
+
color: Optional[str] = None
|
| 197 |
+
collapsed: bool = False
|
| 198 |
+
|
| 199 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 200 |
+
result = {
|
| 201 |
+
'id': self.id,
|
| 202 |
+
'name': self.name,
|
| 203 |
+
'nodes': self.nodes,
|
| 204 |
+
}
|
| 205 |
+
if self.type:
|
| 206 |
+
result['type'] = self.type
|
| 207 |
+
if self.color:
|
| 208 |
+
result['color'] = self.color
|
| 209 |
+
if self.collapsed:
|
| 210 |
+
result['collapsed'] = self.collapsed
|
| 211 |
+
return result
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@dataclass
|
| 215 |
+
class NN3DGraph:
|
| 216 |
+
"""Graph containing nodes and edges"""
|
| 217 |
+
nodes: List[NN3DNode] = field(default_factory=list)
|
| 218 |
+
edges: List[NN3DEdge] = field(default_factory=list)
|
| 219 |
+
subgraphs: Optional[List[NN3DSubgraph]] = None
|
| 220 |
+
|
| 221 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 222 |
+
result = {
|
| 223 |
+
'nodes': [n.to_dict() for n in self.nodes],
|
| 224 |
+
'edges': [e.to_dict() for e in self.edges],
|
| 225 |
+
}
|
| 226 |
+
if self.subgraphs:
|
| 227 |
+
result['subgraphs'] = [s.to_dict() for s in self.subgraphs]
|
| 228 |
+
return result
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@dataclass
|
| 232 |
+
class NN3DMetadata:
|
| 233 |
+
"""Model metadata"""
|
| 234 |
+
name: str
|
| 235 |
+
description: Optional[str] = None
|
| 236 |
+
framework: Optional[str] = None
|
| 237 |
+
author: Optional[str] = None
|
| 238 |
+
created: Optional[str] = None
|
| 239 |
+
tags: Optional[List[str]] = None
|
| 240 |
+
input_shape: Optional[List[Union[int, str]]] = None
|
| 241 |
+
output_shape: Optional[List[Union[int, str]]] = None
|
| 242 |
+
total_params: Optional[int] = None
|
| 243 |
+
trainable_params: Optional[int] = None
|
| 244 |
+
|
| 245 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 246 |
+
result = {'name': self.name}
|
| 247 |
+
if self.description:
|
| 248 |
+
result['description'] = self.description
|
| 249 |
+
if self.framework:
|
| 250 |
+
result['framework'] = self.framework
|
| 251 |
+
if self.author:
|
| 252 |
+
result['author'] = self.author
|
| 253 |
+
if self.created:
|
| 254 |
+
result['created'] = self.created
|
| 255 |
+
if self.tags:
|
| 256 |
+
result['tags'] = self.tags
|
| 257 |
+
if self.input_shape:
|
| 258 |
+
result['inputShape'] = self.input_shape
|
| 259 |
+
if self.output_shape:
|
| 260 |
+
result['outputShape'] = self.output_shape
|
| 261 |
+
if self.total_params is not None:
|
| 262 |
+
result['totalParams'] = self.total_params
|
| 263 |
+
if self.trainable_params is not None:
|
| 264 |
+
result['trainableParams'] = self.trainable_params
|
| 265 |
+
return result
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@dataclass
|
| 269 |
+
class VisualizationConfig:
|
| 270 |
+
"""Visualization configuration"""
|
| 271 |
+
layout: str = "layered"
|
| 272 |
+
theme: str = "dark"
|
| 273 |
+
layer_spacing: float = 3.0
|
| 274 |
+
node_scale: float = 1.0
|
| 275 |
+
show_labels: bool = True
|
| 276 |
+
show_edges: bool = True
|
| 277 |
+
edge_style: str = "tube"
|
| 278 |
+
|
| 279 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 280 |
+
return {
|
| 281 |
+
'layout': self.layout,
|
| 282 |
+
'theme': self.theme,
|
| 283 |
+
'layerSpacing': self.layer_spacing,
|
| 284 |
+
'nodeScale': self.node_scale,
|
| 285 |
+
'showLabels': self.show_labels,
|
| 286 |
+
'showEdges': self.show_edges,
|
| 287 |
+
'edgeStyle': self.edge_style,
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@dataclass
|
| 292 |
+
class NN3DModel:
|
| 293 |
+
"""Complete NN3D model"""
|
| 294 |
+
metadata: NN3DMetadata
|
| 295 |
+
graph: NN3DGraph
|
| 296 |
+
version: str = "1.0.0"
|
| 297 |
+
visualization: Optional[VisualizationConfig] = None
|
| 298 |
+
|
| 299 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 300 |
+
result = {
|
| 301 |
+
'version': self.version,
|
| 302 |
+
'metadata': self.metadata.to_dict(),
|
| 303 |
+
'graph': self.graph.to_dict(),
|
| 304 |
+
}
|
| 305 |
+
if self.visualization:
|
| 306 |
+
result['visualization'] = self.visualization.to_dict()
|
| 307 |
+
return result
|
| 308 |
+
|
| 309 |
+
def to_json(self, indent: int = 2) -> str:
|
| 310 |
+
"""Serialize to JSON string"""
|
| 311 |
+
return json.dumps(self.to_dict(), indent=indent)
|
| 312 |
+
|
| 313 |
+
def save(self, filepath: str) -> None:
|
| 314 |
+
"""Save to .nn3d file"""
|
| 315 |
+
with open(filepath, 'w') as f:
|
| 316 |
+
f.write(self.to_json())
|
exporters/python/pyproject.toml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "nn3d-exporter"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Export neural network models to NN3D format for 3D visualization"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.8"
|
| 11 |
+
license = {text = "MIT"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "NN3D Team"}
|
| 14 |
+
]
|
| 15 |
+
keywords = ["neural-network", "visualization", "deep-learning", "pytorch", "onnx", "3d"]
|
| 16 |
+
classifiers = [
|
| 17 |
+
"Development Status :: 4 - Beta",
|
| 18 |
+
"Intended Audience :: Developers",
|
| 19 |
+
"Intended Audience :: Science/Research",
|
| 20 |
+
"License :: OSI Approved :: MIT License",
|
| 21 |
+
"Programming Language :: Python :: 3",
|
| 22 |
+
"Programming Language :: Python :: 3.8",
|
| 23 |
+
"Programming Language :: Python :: 3.9",
|
| 24 |
+
"Programming Language :: Python :: 3.10",
|
| 25 |
+
"Programming Language :: Python :: 3.11",
|
| 26 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 27 |
+
"Topic :: Scientific/Engineering :: Visualization",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
dependencies = []
|
| 31 |
+
|
| 32 |
+
[project.optional-dependencies]
|
| 33 |
+
pytorch = ["torch>=1.9.0"]
|
| 34 |
+
onnx = ["onnx>=1.10.0"]
|
| 35 |
+
all = ["torch>=1.9.0", "onnx>=1.10.0"]
|
| 36 |
+
dev = ["pytest>=7.0.0", "black", "isort", "mypy"]
|
| 37 |
+
|
| 38 |
+
[project.urls]
|
| 39 |
+
Homepage = "https://github.com/nn3d/visualizer"
|
| 40 |
+
Documentation = "https://nn3d.dev/docs"
|
| 41 |
+
Repository = "https://github.com/nn3d/visualizer"
|
| 42 |
+
|
| 43 |
+
[project.scripts]
|
| 44 |
+
nn3d-export = "nn3d_exporter.cli:main"
|
| 45 |
+
|
| 46 |
+
[tool.setuptools.packages.find]
|
| 47 |
+
where = ["."]
|
| 48 |
+
|
| 49 |
+
[tool.black]
|
| 50 |
+
line-length = 100
|
| 51 |
+
target-version = ["py38", "py39", "py310", "py311"]
|
| 52 |
+
|
| 53 |
+
[tool.isort]
|
| 54 |
+
profile = "black"
|
| 55 |
+
line_length = 100
|
| 56 |
+
|
| 57 |
+
[tool.mypy]
|
| 58 |
+
python_version = "3.8"
|
| 59 |
+
warn_return_any = true
|
| 60 |
+
warn_unused_configs = true
|
| 61 |
+
ignore_missing_imports = true
|
files_to_commit.txt
ADDED
|
Binary file (5.42 kB). View file
|
|
|
index.html
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>NN3D Visualizer - 3D Deep Learning Model Viewer</title>
|
| 8 |
+
<meta
|
| 9 |
+
name="description"
|
| 10 |
+
content="Interactive 3D visualization of neural network architectures"
|
| 11 |
+
/>
|
| 12 |
+
<style>
|
| 13 |
+
* {
|
| 14 |
+
margin: 0;
|
| 15 |
+
padding: 0;
|
| 16 |
+
box-sizing: border-box;
|
| 17 |
+
}
|
| 18 |
+
html,
|
| 19 |
+
body,
|
| 20 |
+
#root {
|
| 21 |
+
width: 100%;
|
| 22 |
+
height: 100%;
|
| 23 |
+
overflow: hidden;
|
| 24 |
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
|
| 25 |
+
Oxygen, Ubuntu, sans-serif;
|
| 26 |
+
}
|
| 27 |
+
</style>
|
| 28 |
+
</head>
|
| 29 |
+
<body>
|
| 30 |
+
<div id="root"></div>
|
| 31 |
+
<script type="module" src="/src/main.tsx"></script>
|
| 32 |
+
</body>
|
| 33 |
+
</html>
|
nginx.conf
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
server {
|
| 2 |
+
listen 80;
|
| 3 |
+
server_name localhost;
|
| 4 |
+
root /usr/share/nginx/html;
|
| 5 |
+
index index.html;
|
| 6 |
+
|
| 7 |
+
# Gzip compression
|
| 8 |
+
gzip on;
|
| 9 |
+
gzip_vary on;
|
| 10 |
+
gzip_min_length 1024;
|
| 11 |
+
gzip_proxied expired no-cache no-store private auth;
|
| 12 |
+
gzip_types text/plain text/css text/xml text/javascript application/x-javascript application/xml application/javascript application/json;
|
| 13 |
+
|
| 14 |
+
# Frontend routes - SPA support
|
| 15 |
+
location / {
|
| 16 |
+
try_files $uri $uri/ /index.html;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
# Proxy API requests to backend
|
| 20 |
+
location /api/ {
|
| 21 |
+
proxy_pass http://backend:8000/;
|
| 22 |
+
proxy_http_version 1.1;
|
| 23 |
+
proxy_set_header Upgrade $http_upgrade;
|
| 24 |
+
proxy_set_header Connection 'upgrade';
|
| 25 |
+
proxy_set_header Host $host;
|
| 26 |
+
proxy_set_header X-Real-IP $remote_addr;
|
| 27 |
+
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
| 28 |
+
proxy_set_header X-Forwarded-Proto $scheme;
|
| 29 |
+
proxy_cache_bypass $http_upgrade;
|
| 30 |
+
proxy_read_timeout 300s;
|
| 31 |
+
proxy_connect_timeout 75s;
|
| 32 |
+
|
| 33 |
+
# Handle large file uploads
|
| 34 |
+
client_max_body_size 500M;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Cache static assets
|
| 38 |
+
location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {
|
| 39 |
+
expires 1y;
|
| 40 |
+
add_header Cache-Control "public, immutable";
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Security headers
|
| 44 |
+
add_header X-Frame-Options "SAMEORIGIN" always;
|
| 45 |
+
add_header X-Content-Type-Options "nosniff" always;
|
| 46 |
+
add_header X-XSS-Protection "1; mode=block" always;
|
| 47 |
+
}
|
package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
package.json
CHANGED
|
@@ -1,39 +1,53 @@
|
|
| 1 |
{
|
| 2 |
-
"name": "
|
| 3 |
-
"version": "
|
| 4 |
-
"
|
| 5 |
-
"
|
| 6 |
-
"@testing-library/dom": "^10.4.0",
|
| 7 |
-
"@testing-library/jest-dom": "^6.6.3",
|
| 8 |
-
"@testing-library/react": "^16.3.0",
|
| 9 |
-
"@testing-library/user-event": "^13.5.0",
|
| 10 |
-
"react": "^19.1.0",
|
| 11 |
-
"react-dom": "^19.1.0",
|
| 12 |
-
"react-scripts": "5.0.1",
|
| 13 |
-
"web-vitals": "^2.1.4"
|
| 14 |
-
},
|
| 15 |
"scripts": {
|
| 16 |
-
"
|
| 17 |
-
"build": "
|
| 18 |
-
"
|
| 19 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
},
|
| 21 |
-
"
|
| 22 |
-
"
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
},
|
| 27 |
-
"
|
| 28 |
-
"
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
"
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
}
|
| 39 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"name": "nn3d-visualizer",
|
| 3 |
+
"version": "1.0.0",
|
| 4 |
+
"description": "3D Deep Learning Model Visualizer - Interactive WebGL visualization of neural network architectures",
|
| 5 |
+
"type": "module",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"scripts": {
|
| 7 |
+
"dev": "vite",
|
| 8 |
+
"build": "tsc && vite build",
|
| 9 |
+
"preview": "vite preview",
|
| 10 |
+
"lint": "eslint src --ext .ts,.tsx",
|
| 11 |
+
"test": "vitest",
|
| 12 |
+
"validate-schema": "ajv validate -s src/schema/nn3d.schema.json -d"
|
| 13 |
+
},
|
| 14 |
+
"dependencies": {
|
| 15 |
+
"@react-three/drei": "^9.92.7",
|
| 16 |
+
"@react-three/fiber": "^8.15.12",
|
| 17 |
+
"ajv": "^8.12.0",
|
| 18 |
+
"d3-force-3d": "^3.0.5",
|
| 19 |
+
"h5wasm": "^0.7.5",
|
| 20 |
+
"leva": "^0.9.35",
|
| 21 |
+
"onnxruntime-web": "^1.16.3",
|
| 22 |
+
"pako": "^2.1.0",
|
| 23 |
+
"react": "^18.2.0",
|
| 24 |
+
"react-dom": "^18.2.0",
|
| 25 |
+
"three": "^0.160.0",
|
| 26 |
+
"zustand": "^4.4.7"
|
| 27 |
},
|
| 28 |
+
"devDependencies": {
|
| 29 |
+
"@types/node": "^25.0.3",
|
| 30 |
+
"@types/pako": "^2.0.4",
|
| 31 |
+
"@types/react": "^18.2.45",
|
| 32 |
+
"@types/react-dom": "^18.2.18",
|
| 33 |
+
"@types/three": "^0.160.0",
|
| 34 |
+
"@typescript-eslint/eslint-plugin": "^6.16.0",
|
| 35 |
+
"@typescript-eslint/parser": "^6.16.0",
|
| 36 |
+
"@vitejs/plugin-react": "^4.2.1",
|
| 37 |
+
"eslint": "^8.56.0",
|
| 38 |
+
"typescript": "^5.3.3",
|
| 39 |
+
"vite": "^5.0.10",
|
| 40 |
+
"vitest": "^1.1.0"
|
| 41 |
},
|
| 42 |
+
"keywords": [
|
| 43 |
+
"neural-network",
|
| 44 |
+
"deep-learning",
|
| 45 |
+
"visualization",
|
| 46 |
+
"3d",
|
| 47 |
+
"threejs",
|
| 48 |
+
"webgl",
|
| 49 |
+
"machine-learning"
|
| 50 |
+
],
|
| 51 |
+
"author": "",
|
| 52 |
+
"license": "MIT"
|
|
|
|
| 53 |
}
|
public/favicon.ico
DELETED
|
Binary file (3.87 kB)
|
|
|
public/favicon.svg
ADDED
|
|
public/index.html
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html lang="en">
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset="utf-8" />
|
| 5 |
-
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
|
| 6 |
-
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
| 7 |
-
<meta name="theme-color" content="#000000" />
|
| 8 |
-
<meta
|
| 9 |
-
name="description"
|
| 10 |
-
content="Web site created using create-react-app"
|
| 11 |
-
/>
|
| 12 |
-
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
|
| 13 |
-
<!--
|
| 14 |
-
manifest.json provides metadata used when your web app is installed on a
|
| 15 |
-
user's mobile device or desktop. See https://developers.google.com/web/fundamentals/web-app-manifest/
|
| 16 |
-
-->
|
| 17 |
-
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
|
| 18 |
-
<!--
|
| 19 |
-
Notice the use of %PUBLIC_URL% in the tags above.
|
| 20 |
-
It will be replaced with the URL of the `public` folder during the build.
|
| 21 |
-
Only files inside the `public` folder can be referenced from the HTML.
|
| 22 |
-
|
| 23 |
-
Unlike "/favicon.ico" or "favicon.ico", "%PUBLIC_URL%/favicon.ico" will
|
| 24 |
-
work correctly both with client-side routing and a non-root public URL.
|
| 25 |
-
Learn how to configure a non-root public URL by running `npm run build`.
|
| 26 |
-
-->
|
| 27 |
-
<title>React App</title>
|
| 28 |
-
</head>
|
| 29 |
-
<body>
|
| 30 |
-
<noscript>You need to enable JavaScript to run this app.</noscript>
|
| 31 |
-
<div id="root"></div>
|
| 32 |
-
<!--
|
| 33 |
-
This HTML file is a template.
|
| 34 |
-
If you open it directly in the browser, you will see an empty page.
|
| 35 |
-
|
| 36 |
-
You can add webfonts, meta tags, or analytics to this file.
|
| 37 |
-
The build step will place the bundled scripts into the <body> tag.
|
| 38 |
-
|
| 39 |
-
To begin the development, run `npm start` or `yarn start`.
|
| 40 |
-
To create a production bundle, use `npm run build` or `yarn build`.
|
| 41 |
-
-->
|
| 42 |
-
</body>
|
| 43 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public/logo192.png
DELETED
|
Binary file (5.35 kB)
|
|
|
public/logo512.png
DELETED
|
Binary file (9.66 kB)
|
|
|
public/manifest.json
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"short_name": "React App",
|
| 3 |
-
"name": "Create React App Sample",
|
| 4 |
-
"icons": [
|
| 5 |
-
{
|
| 6 |
-
"src": "favicon.ico",
|
| 7 |
-
"sizes": "64x64 32x32 24x24 16x16",
|
| 8 |
-
"type": "image/x-icon"
|
| 9 |
-
},
|
| 10 |
-
{
|
| 11 |
-
"src": "logo192.png",
|
| 12 |
-
"type": "image/png",
|
| 13 |
-
"sizes": "192x192"
|
| 14 |
-
},
|
| 15 |
-
{
|
| 16 |
-
"src": "logo512.png",
|
| 17 |
-
"type": "image/png",
|
| 18 |
-
"sizes": "512x512"
|
| 19 |
-
}
|
| 20 |
-
],
|
| 21 |
-
"start_url": ".",
|
| 22 |
-
"display": "standalone",
|
| 23 |
-
"theme_color": "#000000",
|
| 24 |
-
"background_color": "#ffffff"
|
| 25 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public/robots.txt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
# https://www.robotstxt.org/robotstxt.html
|
| 2 |
-
User-agent: *
|
| 3 |
-
Disallow:
|
|
|
|
|
|
|
|
|
|
|
|
samples/cnn_resnet.nn3d
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "1.0.0",
|
| 3 |
+
"metadata": {
|
| 4 |
+
"name": "CNN Image Classifier",
|
| 5 |
+
"description": "Convolutional Neural Network for image classification",
|
| 6 |
+
"framework": "pytorch",
|
| 7 |
+
"created": "2024-12-17T10:00:00Z",
|
| 8 |
+
"tags": ["cnn", "classification", "imagenet"],
|
| 9 |
+
"inputShape": [1, 3, 224, 224],
|
| 10 |
+
"outputShape": [1, 1000],
|
| 11 |
+
"totalParams": 11689512,
|
| 12 |
+
"trainableParams": 11689512
|
| 13 |
+
},
|
| 14 |
+
"graph": {
|
| 15 |
+
"nodes": [
|
| 16 |
+
{
|
| 17 |
+
"id": "input",
|
| 18 |
+
"type": "input",
|
| 19 |
+
"name": "Input Image",
|
| 20 |
+
"outputShape": [1, 3, 224, 224],
|
| 21 |
+
"depth": 0
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"id": "conv1",
|
| 25 |
+
"type": "conv2d",
|
| 26 |
+
"name": "conv1",
|
| 27 |
+
"params": {
|
| 28 |
+
"inChannels": 3,
|
| 29 |
+
"outChannels": 64,
|
| 30 |
+
"kernelSize": [7, 7],
|
| 31 |
+
"stride": [2, 2],
|
| 32 |
+
"padding": [3, 3],
|
| 33 |
+
"bias": false
|
| 34 |
+
},
|
| 35 |
+
"inputShape": [1, 3, 224, 224],
|
| 36 |
+
"outputShape": [1, 64, 112, 112],
|
| 37 |
+
"depth": 1,
|
| 38 |
+
"group": "stem"
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"id": "bn1",
|
| 42 |
+
"type": "batchNorm2d",
|
| 43 |
+
"name": "bn1",
|
| 44 |
+
"params": {
|
| 45 |
+
"eps": 1e-5,
|
| 46 |
+
"momentum": 0.1
|
| 47 |
+
},
|
| 48 |
+
"inputShape": [1, 64, 112, 112],
|
| 49 |
+
"outputShape": [1, 64, 112, 112],
|
| 50 |
+
"depth": 2,
|
| 51 |
+
"group": "stem"
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"id": "relu1",
|
| 55 |
+
"type": "relu",
|
| 56 |
+
"name": "ReLU",
|
| 57 |
+
"inputShape": [1, 64, 112, 112],
|
| 58 |
+
"outputShape": [1, 64, 112, 112],
|
| 59 |
+
"depth": 3,
|
| 60 |
+
"group": "stem"
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"id": "maxpool",
|
| 64 |
+
"type": "maxPool2d",
|
| 65 |
+
"name": "MaxPool",
|
| 66 |
+
"params": {
|
| 67 |
+
"kernelSize": [3, 3],
|
| 68 |
+
"stride": [2, 2],
|
| 69 |
+
"padding": [1, 1]
|
| 70 |
+
},
|
| 71 |
+
"inputShape": [1, 64, 112, 112],
|
| 72 |
+
"outputShape": [1, 64, 56, 56],
|
| 73 |
+
"depth": 4,
|
| 74 |
+
"group": "stem"
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"id": "layer1_conv1",
|
| 78 |
+
"type": "conv2d",
|
| 79 |
+
"name": "layer1.0.conv1",
|
| 80 |
+
"params": {
|
| 81 |
+
"inChannels": 64,
|
| 82 |
+
"outChannels": 64,
|
| 83 |
+
"kernelSize": [3, 3],
|
| 84 |
+
"stride": [1, 1],
|
| 85 |
+
"padding": [1, 1],
|
| 86 |
+
"bias": false
|
| 87 |
+
},
|
| 88 |
+
"inputShape": [1, 64, 56, 56],
|
| 89 |
+
"outputShape": [1, 64, 56, 56],
|
| 90 |
+
"depth": 5,
|
| 91 |
+
"group": "layer1"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"id": "layer1_bn1",
|
| 95 |
+
"type": "batchNorm2d",
|
| 96 |
+
"name": "layer1.0.bn1",
|
| 97 |
+
"inputShape": [1, 64, 56, 56],
|
| 98 |
+
"outputShape": [1, 64, 56, 56],
|
| 99 |
+
"depth": 6,
|
| 100 |
+
"group": "layer1"
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"id": "layer1_relu1",
|
| 104 |
+
"type": "relu",
|
| 105 |
+
"name": "ReLU",
|
| 106 |
+
"inputShape": [1, 64, 56, 56],
|
| 107 |
+
"outputShape": [1, 64, 56, 56],
|
| 108 |
+
"depth": 7,
|
| 109 |
+
"group": "layer1"
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"id": "layer1_conv2",
|
| 113 |
+
"type": "conv2d",
|
| 114 |
+
"name": "layer1.0.conv2",
|
| 115 |
+
"params": {
|
| 116 |
+
"inChannels": 64,
|
| 117 |
+
"outChannels": 64,
|
| 118 |
+
"kernelSize": [3, 3],
|
| 119 |
+
"stride": [1, 1],
|
| 120 |
+
"padding": [1, 1],
|
| 121 |
+
"bias": false
|
| 122 |
+
},
|
| 123 |
+
"inputShape": [1, 64, 56, 56],
|
| 124 |
+
"outputShape": [1, 64, 56, 56],
|
| 125 |
+
"depth": 8,
|
| 126 |
+
"group": "layer1"
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"id": "layer1_bn2",
|
| 130 |
+
"type": "batchNorm2d",
|
| 131 |
+
"name": "layer1.0.bn2",
|
| 132 |
+
"inputShape": [1, 64, 56, 56],
|
| 133 |
+
"outputShape": [1, 64, 56, 56],
|
| 134 |
+
"depth": 9,
|
| 135 |
+
"group": "layer1"
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"id": "layer1_add",
|
| 139 |
+
"type": "add",
|
| 140 |
+
"name": "Residual Add",
|
| 141 |
+
"inputShape": [1, 64, 56, 56],
|
| 142 |
+
"outputShape": [1, 64, 56, 56],
|
| 143 |
+
"depth": 10,
|
| 144 |
+
"group": "layer1"
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"id": "layer2_conv1",
|
| 148 |
+
"type": "conv2d",
|
| 149 |
+
"name": "layer2.0.conv1",
|
| 150 |
+
"params": {
|
| 151 |
+
"inChannels": 64,
|
| 152 |
+
"outChannels": 128,
|
| 153 |
+
"kernelSize": [3, 3],
|
| 154 |
+
"stride": [2, 2],
|
| 155 |
+
"padding": [1, 1],
|
| 156 |
+
"bias": false
|
| 157 |
+
},
|
| 158 |
+
"inputShape": [1, 64, 56, 56],
|
| 159 |
+
"outputShape": [1, 128, 28, 28],
|
| 160 |
+
"depth": 11,
|
| 161 |
+
"group": "layer2"
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"id": "layer2_bn1",
|
| 165 |
+
"type": "batchNorm2d",
|
| 166 |
+
"name": "layer2.0.bn1",
|
| 167 |
+
"inputShape": [1, 128, 28, 28],
|
| 168 |
+
"outputShape": [1, 128, 28, 28],
|
| 169 |
+
"depth": 12,
|
| 170 |
+
"group": "layer2"
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"id": "layer2_relu1",
|
| 174 |
+
"type": "relu",
|
| 175 |
+
"name": "ReLU",
|
| 176 |
+
"inputShape": [1, 128, 28, 28],
|
| 177 |
+
"outputShape": [1, 128, 28, 28],
|
| 178 |
+
"depth": 13,
|
| 179 |
+
"group": "layer2"
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"id": "layer2_conv2",
|
| 183 |
+
"type": "conv2d",
|
| 184 |
+
"name": "layer2.0.conv2",
|
| 185 |
+
"params": {
|
| 186 |
+
"inChannels": 128,
|
| 187 |
+
"outChannels": 128,
|
| 188 |
+
"kernelSize": [3, 3],
|
| 189 |
+
"stride": [1, 1],
|
| 190 |
+
"padding": [1, 1],
|
| 191 |
+
"bias": false
|
| 192 |
+
},
|
| 193 |
+
"inputShape": [1, 128, 28, 28],
|
| 194 |
+
"outputShape": [1, 128, 28, 28],
|
| 195 |
+
"depth": 14,
|
| 196 |
+
"group": "layer2"
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"id": "layer2_bn2",
|
| 200 |
+
"type": "batchNorm2d",
|
| 201 |
+
"name": "layer2.0.bn2",
|
| 202 |
+
"inputShape": [1, 128, 28, 28],
|
| 203 |
+
"outputShape": [1, 128, 28, 28],
|
| 204 |
+
"depth": 15,
|
| 205 |
+
"group": "layer2"
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"id": "layer2_downsample",
|
| 209 |
+
"type": "conv2d",
|
| 210 |
+
"name": "layer2.0.downsample",
|
| 211 |
+
"params": {
|
| 212 |
+
"inChannels": 64,
|
| 213 |
+
"outChannels": 128,
|
| 214 |
+
"kernelSize": [1, 1],
|
| 215 |
+
"stride": [2, 2],
|
| 216 |
+
"bias": false
|
| 217 |
+
},
|
| 218 |
+
"inputShape": [1, 64, 56, 56],
|
| 219 |
+
"outputShape": [1, 128, 28, 28],
|
| 220 |
+
"depth": 16,
|
| 221 |
+
"group": "layer2"
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"id": "layer2_add",
|
| 225 |
+
"type": "add",
|
| 226 |
+
"name": "Residual Add",
|
| 227 |
+
"inputShape": [1, 128, 28, 28],
|
| 228 |
+
"outputShape": [1, 128, 28, 28],
|
| 229 |
+
"depth": 17,
|
| 230 |
+
"group": "layer2"
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"id": "avgpool",
|
| 234 |
+
"type": "globalAvgPool",
|
| 235 |
+
"name": "Global Average Pool",
|
| 236 |
+
"inputShape": [1, 128, 28, 28],
|
| 237 |
+
"outputShape": [1, 128, 1, 1],
|
| 238 |
+
"depth": 18
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"id": "flatten",
|
| 242 |
+
"type": "flatten",
|
| 243 |
+
"name": "Flatten",
|
| 244 |
+
"inputShape": [1, 128, 1, 1],
|
| 245 |
+
"outputShape": [1, 128],
|
| 246 |
+
"depth": 19
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"id": "fc",
|
| 250 |
+
"type": "linear",
|
| 251 |
+
"name": "fc",
|
| 252 |
+
"params": {
|
| 253 |
+
"inFeatures": 128,
|
| 254 |
+
"outFeatures": 1000,
|
| 255 |
+
"bias": true
|
| 256 |
+
},
|
| 257 |
+
"inputShape": [1, 128],
|
| 258 |
+
"outputShape": [1, 1000],
|
| 259 |
+
"depth": 20
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"id": "output",
|
| 263 |
+
"type": "output",
|
| 264 |
+
"name": "Output",
|
| 265 |
+
"inputShape": [1, 1000],
|
| 266 |
+
"depth": 21
|
| 267 |
+
}
|
| 268 |
+
],
|
| 269 |
+
"edges": [
|
| 270 |
+
{ "source": "input", "target": "conv1" },
|
| 271 |
+
{ "source": "conv1", "target": "bn1" },
|
| 272 |
+
{ "source": "bn1", "target": "relu1" },
|
| 273 |
+
{ "source": "relu1", "target": "maxpool" },
|
| 274 |
+
{ "source": "maxpool", "target": "layer1_conv1" },
|
| 275 |
+
{ "source": "layer1_conv1", "target": "layer1_bn1" },
|
| 276 |
+
{ "source": "layer1_bn1", "target": "layer1_relu1" },
|
| 277 |
+
{ "source": "layer1_relu1", "target": "layer1_conv2" },
|
| 278 |
+
{ "source": "layer1_conv2", "target": "layer1_bn2" },
|
| 279 |
+
{ "source": "layer1_bn2", "target": "layer1_add" },
|
| 280 |
+
{ "source": "maxpool", "target": "layer1_add" },
|
| 281 |
+
{ "source": "layer1_add", "target": "layer2_conv1" },
|
| 282 |
+
{ "source": "layer2_conv1", "target": "layer2_bn1" },
|
| 283 |
+
{ "source": "layer2_bn1", "target": "layer2_relu1" },
|
| 284 |
+
{ "source": "layer2_relu1", "target": "layer2_conv2" },
|
| 285 |
+
{ "source": "layer2_conv2", "target": "layer2_bn2" },
|
| 286 |
+
{ "source": "layer1_add", "target": "layer2_downsample" },
|
| 287 |
+
{ "source": "layer2_bn2", "target": "layer2_add" },
|
| 288 |
+
{ "source": "layer2_downsample", "target": "layer2_add" },
|
| 289 |
+
{ "source": "layer2_add", "target": "avgpool" },
|
| 290 |
+
{ "source": "avgpool", "target": "flatten" },
|
| 291 |
+
{ "source": "flatten", "target": "fc" },
|
| 292 |
+
{ "source": "fc", "target": "output" }
|
| 293 |
+
],
|
| 294 |
+
"subgraphs": [
|
| 295 |
+
{
|
| 296 |
+
"id": "stem_group",
|
| 297 |
+
"name": "Stem",
|
| 298 |
+
"type": "sequential",
|
| 299 |
+
"nodes": ["conv1", "bn1", "relu1", "maxpool"],
|
| 300 |
+
"color": "#4CAF50"
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"id": "layer1_group",
|
| 304 |
+
"name": "Layer 1 (ResBlock)",
|
| 305 |
+
"type": "residual",
|
| 306 |
+
"nodes": ["layer1_conv1", "layer1_bn1", "layer1_relu1", "layer1_conv2", "layer1_bn2", "layer1_add"],
|
| 307 |
+
"color": "#2196F3"
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"id": "layer2_group",
|
| 311 |
+
"name": "Layer 2 (ResBlock)",
|
| 312 |
+
"type": "residual",
|
| 313 |
+
"nodes": ["layer2_conv1", "layer2_bn1", "layer2_relu1", "layer2_conv2", "layer2_bn2", "layer2_downsample", "layer2_add"],
|
| 314 |
+
"color": "#9C27B0"
|
| 315 |
+
}
|
| 316 |
+
]
|
| 317 |
+
},
|
| 318 |
+
"visualization": {
|
| 319 |
+
"layout": "layered",
|
| 320 |
+
"theme": "dark",
|
| 321 |
+
"layerSpacing": 2.5,
|
| 322 |
+
"nodeScale": 1.0,
|
| 323 |
+
"showLabels": true,
|
| 324 |
+
"showEdges": true,
|
| 325 |
+
"edgeStyle": "tube"
|
| 326 |
+
}
|
| 327 |
+
}
|
samples/simple_mlp.nn3d
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "1.0.0",
|
| 3 |
+
"metadata": {
|
| 4 |
+
"name": "Simple MLP",
|
| 5 |
+
"description": "A simple Multi-Layer Perceptron for MNIST classification",
|
| 6 |
+
"framework": "pytorch",
|
| 7 |
+
"created": "2024-12-17T10:00:00Z",
|
| 8 |
+
"tags": ["mlp", "classification", "mnist"],
|
| 9 |
+
"inputShape": [1, 784],
|
| 10 |
+
"outputShape": [1, 10],
|
| 11 |
+
"totalParams": 669706,
|
| 12 |
+
"trainableParams": 669706
|
| 13 |
+
},
|
| 14 |
+
"graph": {
|
| 15 |
+
"nodes": [
|
| 16 |
+
{
|
| 17 |
+
"id": "input",
|
| 18 |
+
"type": "input",
|
| 19 |
+
"name": "Input",
|
| 20 |
+
"outputShape": [1, 784],
|
| 21 |
+
"depth": 0
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"id": "fc1",
|
| 25 |
+
"type": "linear",
|
| 26 |
+
"name": "fc1",
|
| 27 |
+
"params": {
|
| 28 |
+
"inFeatures": 784,
|
| 29 |
+
"outFeatures": 512,
|
| 30 |
+
"bias": true
|
| 31 |
+
},
|
| 32 |
+
"inputShape": [1, 784],
|
| 33 |
+
"outputShape": [1, 512],
|
| 34 |
+
"depth": 1
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"id": "relu1",
|
| 38 |
+
"type": "relu",
|
| 39 |
+
"name": "ReLU",
|
| 40 |
+
"inputShape": [1, 512],
|
| 41 |
+
"outputShape": [1, 512],
|
| 42 |
+
"depth": 2
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"id": "dropout1",
|
| 46 |
+
"type": "dropout",
|
| 47 |
+
"name": "Dropout",
|
| 48 |
+
"params": {
|
| 49 |
+
"dropoutRate": 0.2
|
| 50 |
+
},
|
| 51 |
+
"inputShape": [1, 512],
|
| 52 |
+
"outputShape": [1, 512],
|
| 53 |
+
"depth": 3
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"id": "fc2",
|
| 57 |
+
"type": "linear",
|
| 58 |
+
"name": "fc2",
|
| 59 |
+
"params": {
|
| 60 |
+
"inFeatures": 512,
|
| 61 |
+
"outFeatures": 256,
|
| 62 |
+
"bias": true
|
| 63 |
+
},
|
| 64 |
+
"inputShape": [1, 512],
|
| 65 |
+
"outputShape": [1, 256],
|
| 66 |
+
"depth": 4
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"id": "relu2",
|
| 70 |
+
"type": "relu",
|
| 71 |
+
"name": "ReLU",
|
| 72 |
+
"inputShape": [1, 256],
|
| 73 |
+
"outputShape": [1, 256],
|
| 74 |
+
"depth": 5
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"id": "dropout2",
|
| 78 |
+
"type": "dropout",
|
| 79 |
+
"name": "Dropout",
|
| 80 |
+
"params": {
|
| 81 |
+
"dropoutRate": 0.2
|
| 82 |
+
},
|
| 83 |
+
"inputShape": [1, 256],
|
| 84 |
+
"outputShape": [1, 256],
|
| 85 |
+
"depth": 6
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"id": "fc3",
|
| 89 |
+
"type": "linear",
|
| 90 |
+
"name": "fc3",
|
| 91 |
+
"params": {
|
| 92 |
+
"inFeatures": 256,
|
| 93 |
+
"outFeatures": 10,
|
| 94 |
+
"bias": true
|
| 95 |
+
},
|
| 96 |
+
"inputShape": [1, 256],
|
| 97 |
+
"outputShape": [1, 10],
|
| 98 |
+
"depth": 7
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"id": "softmax",
|
| 102 |
+
"type": "softmax",
|
| 103 |
+
"name": "Softmax",
|
| 104 |
+
"inputShape": [1, 10],
|
| 105 |
+
"outputShape": [1, 10],
|
| 106 |
+
"depth": 8
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"id": "output",
|
| 110 |
+
"type": "output",
|
| 111 |
+
"name": "Output",
|
| 112 |
+
"inputShape": [1, 10],
|
| 113 |
+
"depth": 9
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
"edges": [
|
| 117 |
+
{ "source": "input", "target": "fc1", "tensorShape": [1, 784] },
|
| 118 |
+
{ "source": "fc1", "target": "relu1", "tensorShape": [1, 512] },
|
| 119 |
+
{ "source": "relu1", "target": "dropout1", "tensorShape": [1, 512] },
|
| 120 |
+
{ "source": "dropout1", "target": "fc2", "tensorShape": [1, 512] },
|
| 121 |
+
{ "source": "fc2", "target": "relu2", "tensorShape": [1, 256] },
|
| 122 |
+
{ "source": "relu2", "target": "dropout2", "tensorShape": [1, 256] },
|
| 123 |
+
{ "source": "dropout2", "target": "fc3", "tensorShape": [1, 256] },
|
| 124 |
+
{ "source": "fc3", "target": "softmax", "tensorShape": [1, 10] },
|
| 125 |
+
{ "source": "softmax", "target": "output", "tensorShape": [1, 10] }
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
"visualization": {
|
| 129 |
+
"layout": "layered",
|
| 130 |
+
"theme": "dark",
|
| 131 |
+
"layerSpacing": 3.0,
|
| 132 |
+
"nodeScale": 1.0,
|
| 133 |
+
"showLabels": true,
|
| 134 |
+
"showEdges": true,
|
| 135 |
+
"edgeStyle": "tube"
|
| 136 |
+
}
|
| 137 |
+
}
|
samples/transformer_encoder.nn3d
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "1.0.0",
|
| 3 |
+
"metadata": {
|
| 4 |
+
"name": "Transformer Encoder",
|
| 5 |
+
"description": "A Transformer encoder block for sequence modeling",
|
| 6 |
+
"framework": "pytorch",
|
| 7 |
+
"created": "2024-12-17T10:00:00Z",
|
| 8 |
+
"tags": ["transformer", "attention", "nlp", "encoder"],
|
| 9 |
+
"inputShape": [1, 512, 768],
|
| 10 |
+
"outputShape": [1, 512, 768],
|
| 11 |
+
"totalParams": 7087872,
|
| 12 |
+
"trainableParams": 7087872
|
| 13 |
+
},
|
| 14 |
+
"graph": {
|
| 15 |
+
"nodes": [
|
| 16 |
+
{
|
| 17 |
+
"id": "input",
|
| 18 |
+
"type": "input",
|
| 19 |
+
"name": "Input Embeddings",
|
| 20 |
+
"outputShape": [1, 512, 768],
|
| 21 |
+
"depth": 0
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"id": "pos_embed",
|
| 25 |
+
"type": "embedding",
|
| 26 |
+
"name": "Positional Embedding",
|
| 27 |
+
"params": {
|
| 28 |
+
"numEmbeddings": 512,
|
| 29 |
+
"embeddingDim": 768
|
| 30 |
+
},
|
| 31 |
+
"inputShape": [1, 512],
|
| 32 |
+
"outputShape": [1, 512, 768],
|
| 33 |
+
"depth": 1
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"id": "embed_add",
|
| 37 |
+
"type": "add",
|
| 38 |
+
"name": "Add Position",
|
| 39 |
+
"inputShape": [1, 512, 768],
|
| 40 |
+
"outputShape": [1, 512, 768],
|
| 41 |
+
"depth": 2
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"id": "layer1_ln1",
|
| 45 |
+
"type": "layerNorm",
|
| 46 |
+
"name": "LayerNorm 1",
|
| 47 |
+
"params": {
|
| 48 |
+
"eps": 1e-6
|
| 49 |
+
},
|
| 50 |
+
"inputShape": [1, 512, 768],
|
| 51 |
+
"outputShape": [1, 512, 768],
|
| 52 |
+
"depth": 3,
|
| 53 |
+
"group": "encoder_layer_1"
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"id": "layer1_mha",
|
| 57 |
+
"type": "multiHeadAttention",
|
| 58 |
+
"name": "Multi-Head Attention",
|
| 59 |
+
"params": {
|
| 60 |
+
"numHeads": 12,
|
| 61 |
+
"hiddenSize": 768,
|
| 62 |
+
"dropoutRate": 0.1
|
| 63 |
+
},
|
| 64 |
+
"inputShape": [1, 512, 768],
|
| 65 |
+
"outputShape": [1, 512, 768],
|
| 66 |
+
"depth": 4,
|
| 67 |
+
"group": "encoder_layer_1"
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"id": "layer1_dropout1",
|
| 71 |
+
"type": "dropout",
|
| 72 |
+
"name": "Dropout",
|
| 73 |
+
"params": {
|
| 74 |
+
"dropoutRate": 0.1
|
| 75 |
+
},
|
| 76 |
+
"inputShape": [1, 512, 768],
|
| 77 |
+
"outputShape": [1, 512, 768],
|
| 78 |
+
"depth": 5,
|
| 79 |
+
"group": "encoder_layer_1"
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"id": "layer1_add1",
|
| 83 |
+
"type": "add",
|
| 84 |
+
"name": "Residual Add 1",
|
| 85 |
+
"inputShape": [1, 512, 768],
|
| 86 |
+
"outputShape": [1, 512, 768],
|
| 87 |
+
"depth": 6,
|
| 88 |
+
"group": "encoder_layer_1"
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"id": "layer1_ln2",
|
| 92 |
+
"type": "layerNorm",
|
| 93 |
+
"name": "LayerNorm 2",
|
| 94 |
+
"params": {
|
| 95 |
+
"eps": 1e-6
|
| 96 |
+
},
|
| 97 |
+
"inputShape": [1, 512, 768],
|
| 98 |
+
"outputShape": [1, 512, 768],
|
| 99 |
+
"depth": 7,
|
| 100 |
+
"group": "encoder_layer_1"
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"id": "layer1_ff1",
|
| 104 |
+
"type": "linear",
|
| 105 |
+
"name": "Feed Forward 1",
|
| 106 |
+
"params": {
|
| 107 |
+
"inFeatures": 768,
|
| 108 |
+
"outFeatures": 3072,
|
| 109 |
+
"bias": true
|
| 110 |
+
},
|
| 111 |
+
"inputShape": [1, 512, 768],
|
| 112 |
+
"outputShape": [1, 512, 3072],
|
| 113 |
+
"depth": 8,
|
| 114 |
+
"group": "encoder_layer_1"
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"id": "layer1_gelu",
|
| 118 |
+
"type": "gelu",
|
| 119 |
+
"name": "GELU",
|
| 120 |
+
"inputShape": [1, 512, 3072],
|
| 121 |
+
"outputShape": [1, 512, 3072],
|
| 122 |
+
"depth": 9,
|
| 123 |
+
"group": "encoder_layer_1"
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"id": "layer1_ff2",
|
| 127 |
+
"type": "linear",
|
| 128 |
+
"name": "Feed Forward 2",
|
| 129 |
+
"params": {
|
| 130 |
+
"inFeatures": 3072,
|
| 131 |
+
"outFeatures": 768,
|
| 132 |
+
"bias": true
|
| 133 |
+
},
|
| 134 |
+
"inputShape": [1, 512, 3072],
|
| 135 |
+
"outputShape": [1, 512, 768],
|
| 136 |
+
"depth": 10,
|
| 137 |
+
"group": "encoder_layer_1"
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"id": "layer1_dropout2",
|
| 141 |
+
"type": "dropout",
|
| 142 |
+
"name": "Dropout",
|
| 143 |
+
"params": {
|
| 144 |
+
"dropoutRate": 0.1
|
| 145 |
+
},
|
| 146 |
+
"inputShape": [1, 512, 768],
|
| 147 |
+
"outputShape": [1, 512, 768],
|
| 148 |
+
"depth": 11,
|
| 149 |
+
"group": "encoder_layer_1"
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"id": "layer1_add2",
|
| 153 |
+
"type": "add",
|
| 154 |
+
"name": "Residual Add 2",
|
| 155 |
+
"inputShape": [1, 512, 768],
|
| 156 |
+
"outputShape": [1, 512, 768],
|
| 157 |
+
"depth": 12,
|
| 158 |
+
"group": "encoder_layer_1"
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"id": "final_ln",
|
| 162 |
+
"type": "layerNorm",
|
| 163 |
+
"name": "Final LayerNorm",
|
| 164 |
+
"params": {
|
| 165 |
+
"eps": 1e-6
|
| 166 |
+
},
|
| 167 |
+
"inputShape": [1, 512, 768],
|
| 168 |
+
"outputShape": [1, 512, 768],
|
| 169 |
+
"depth": 13
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"id": "output",
|
| 173 |
+
"type": "output",
|
| 174 |
+
"name": "Output",
|
| 175 |
+
"inputShape": [1, 512, 768],
|
| 176 |
+
"depth": 14
|
| 177 |
+
}
|
| 178 |
+
],
|
| 179 |
+
"edges": [
|
| 180 |
+
{ "source": "input", "target": "embed_add" },
|
| 181 |
+
{ "source": "pos_embed", "target": "embed_add" },
|
| 182 |
+
{ "source": "embed_add", "target": "layer1_ln1" },
|
| 183 |
+
{ "source": "layer1_ln1", "target": "layer1_mha" },
|
| 184 |
+
{ "source": "layer1_mha", "target": "layer1_dropout1" },
|
| 185 |
+
{ "source": "layer1_dropout1", "target": "layer1_add1" },
|
| 186 |
+
{ "source": "embed_add", "target": "layer1_add1" },
|
| 187 |
+
{ "source": "layer1_add1", "target": "layer1_ln2" },
|
| 188 |
+
{ "source": "layer1_ln2", "target": "layer1_ff1" },
|
| 189 |
+
{ "source": "layer1_ff1", "target": "layer1_gelu" },
|
| 190 |
+
{ "source": "layer1_gelu", "target": "layer1_ff2" },
|
| 191 |
+
{ "source": "layer1_ff2", "target": "layer1_dropout2" },
|
| 192 |
+
{ "source": "layer1_dropout2", "target": "layer1_add2" },
|
| 193 |
+
{ "source": "layer1_add1", "target": "layer1_add2" },
|
| 194 |
+
{ "source": "layer1_add2", "target": "final_ln" },
|
| 195 |
+
{ "source": "final_ln", "target": "output" }
|
| 196 |
+
],
|
| 197 |
+
"subgraphs": [
|
| 198 |
+
{
|
| 199 |
+
"id": "encoder_layer_1_group",
|
| 200 |
+
"name": "Encoder Layer 1",
|
| 201 |
+
"type": "attention",
|
| 202 |
+
"nodes": [
|
| 203 |
+
"layer1_ln1", "layer1_mha", "layer1_dropout1", "layer1_add1",
|
| 204 |
+
"layer1_ln2", "layer1_ff1", "layer1_gelu", "layer1_ff2",
|
| 205 |
+
"layer1_dropout2", "layer1_add2"
|
| 206 |
+
],
|
| 207 |
+
"color": "#E91E63"
|
| 208 |
+
}
|
| 209 |
+
]
|
| 210 |
+
},
|
| 211 |
+
"visualization": {
|
| 212 |
+
"layout": "layered",
|
| 213 |
+
"theme": "dark",
|
| 214 |
+
"layerSpacing": 2.0,
|
| 215 |
+
"nodeScale": 1.0,
|
| 216 |
+
"showLabels": true,
|
| 217 |
+
"showEdges": true,
|
| 218 |
+
"edgeStyle": "bezier"
|
| 219 |
+
}
|
| 220 |
+
}
|
src/App.css
CHANGED
|
@@ -1,38 +1,195 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
}
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
pointer-events: none;
|
|
|
|
| 8 |
}
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
}
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
flex-direction: column;
|
| 21 |
-
align-items: center;
|
| 22 |
-
justify-content: center;
|
| 23 |
-
font-size: calc(10px + 2vmin);
|
| 24 |
-
color: white;
|
| 25 |
}
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
}
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
from {
|
| 33 |
-
|
|
|
|
| 34 |
}
|
| 35 |
to {
|
| 36 |
-
|
|
|
|
| 37 |
}
|
| 38 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Oldskool Wireframe Design System
|
| 3 |
+
* Inspired by ctxdc.com and awwwards-winning dark interfaces
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;500;600;700&display=swap');
|
| 7 |
+
|
| 8 |
+
:root {
|
| 9 |
+
/* Core Colors */
|
| 10 |
+
--bg-primary: #0a0a0a;
|
| 11 |
+
--bg-secondary: #0f0f0f;
|
| 12 |
+
--bg-tertiary: #141414;
|
| 13 |
+
--bg-card: rgba(15, 15, 15, 0.8);
|
| 14 |
+
|
| 15 |
+
/* Accent Colors */
|
| 16 |
+
--accent-primary: #b4ff39;
|
| 17 |
+
--accent-secondary: #7fff00;
|
| 18 |
+
--accent-dim: rgba(180, 255, 57, 0.15);
|
| 19 |
+
--accent-glow: rgba(180, 255, 57, 0.3);
|
| 20 |
+
|
| 21 |
+
/* Text Colors */
|
| 22 |
+
--text-primary: #ffffff;
|
| 23 |
+
--text-secondary: #a0a0a0;
|
| 24 |
+
--text-tertiary: #606060;
|
| 25 |
+
--text-muted: #404040;
|
| 26 |
+
|
| 27 |
+
/* Border Colors */
|
| 28 |
+
--border-primary: rgba(180, 255, 57, 0.3);
|
| 29 |
+
--border-secondary: rgba(255, 255, 255, 0.1);
|
| 30 |
+
--border-dim: rgba(255, 255, 255, 0.05);
|
| 31 |
+
|
| 32 |
+
/* Status Colors */
|
| 33 |
+
--success: #b4ff39;
|
| 34 |
+
--error: #ff4444;
|
| 35 |
+
--warning: #ffaa00;
|
| 36 |
+
|
| 37 |
+
/* Typography */
|
| 38 |
+
--font-mono: 'JetBrains Mono', 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
|
| 39 |
+
|
| 40 |
+
/* Spacing */
|
| 41 |
+
--space-xs: 4px;
|
| 42 |
+
--space-sm: 8px;
|
| 43 |
+
--space-md: 16px;
|
| 44 |
+
--space-lg: 24px;
|
| 45 |
+
--space-xl: 32px;
|
| 46 |
+
|
| 47 |
+
/* Border Radius */
|
| 48 |
+
--radius-sm: 2px;
|
| 49 |
+
--radius-md: 4px;
|
| 50 |
+
--radius-lg: 8px;
|
| 51 |
}
|
| 52 |
|
| 53 |
+
* {
|
| 54 |
+
margin: 0;
|
| 55 |
+
padding: 0;
|
| 56 |
+
box-sizing: border-box;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
html, body, #root {
|
| 60 |
+
width: 100%;
|
| 61 |
+
height: 100%;
|
| 62 |
+
overflow: hidden;
|
| 63 |
+
font-family: var(--font-mono);
|
| 64 |
+
font-size: 13px;
|
| 65 |
+
-webkit-font-smoothing: antialiased;
|
| 66 |
+
-moz-osx-font-smoothing: grayscale;
|
| 67 |
+
background: var(--bg-primary);
|
| 68 |
+
color: var(--text-primary);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.app {
|
| 72 |
+
width: 100%;
|
| 73 |
+
height: 100%;
|
| 74 |
+
position: relative;
|
| 75 |
+
background: var(--bg-primary);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
/* Grid Background Pattern */
|
| 79 |
+
.app::before {
|
| 80 |
+
content: '';
|
| 81 |
+
position: absolute;
|
| 82 |
+
top: 0;
|
| 83 |
+
left: 0;
|
| 84 |
+
right: 0;
|
| 85 |
+
bottom: 0;
|
| 86 |
+
background-image:
|
| 87 |
+
linear-gradient(rgba(180, 255, 57, 0.03) 1px, transparent 1px),
|
| 88 |
+
linear-gradient(90deg, rgba(180, 255, 57, 0.03) 1px, transparent 1px);
|
| 89 |
+
background-size: 50px 50px;
|
| 90 |
pointer-events: none;
|
| 91 |
+
z-index: 0;
|
| 92 |
}
|
| 93 |
|
| 94 |
+
/* Scrollbar styling */
|
| 95 |
+
::-webkit-scrollbar {
|
| 96 |
+
width: 6px;
|
| 97 |
+
height: 6px;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
::-webkit-scrollbar-track {
|
| 101 |
+
background: var(--bg-secondary);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
::-webkit-scrollbar-thumb {
|
| 105 |
+
background: var(--border-primary);
|
| 106 |
+
border-radius: var(--radius-sm);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
::-webkit-scrollbar-thumb:hover {
|
| 110 |
+
background: var(--accent-primary);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
/* Selection styling */
|
| 114 |
+
::selection {
|
| 115 |
+
background: var(--accent-dim);
|
| 116 |
+
color: var(--accent-primary);
|
| 117 |
}
|
| 118 |
|
| 119 |
+
/* Focus styling */
|
| 120 |
+
:focus-visible {
|
| 121 |
+
outline: 1px solid var(--accent-primary);
|
| 122 |
+
outline-offset: 2px;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
}
|
| 124 |
|
| 125 |
+
/* Custom checkbox */
|
| 126 |
+
input[type="checkbox"] {
|
| 127 |
+
appearance: none;
|
| 128 |
+
width: 14px;
|
| 129 |
+
height: 14px;
|
| 130 |
+
border: 1px solid var(--border-primary);
|
| 131 |
+
border-radius: var(--radius-sm);
|
| 132 |
+
background: transparent;
|
| 133 |
+
cursor: pointer;
|
| 134 |
+
position: relative;
|
| 135 |
+
transition: all 0.15s ease;
|
| 136 |
}
|
| 137 |
|
| 138 |
+
input[type="checkbox"]:checked {
|
| 139 |
+
background: var(--accent-primary);
|
| 140 |
+
border-color: var(--accent-primary);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
input[type="checkbox"]:checked::after {
|
| 144 |
+
content: '';
|
| 145 |
+
position: absolute;
|
| 146 |
+
top: 2px;
|
| 147 |
+
left: 4px;
|
| 148 |
+
width: 4px;
|
| 149 |
+
height: 7px;
|
| 150 |
+
border: solid var(--bg-primary);
|
| 151 |
+
border-width: 0 2px 2px 0;
|
| 152 |
+
transform: rotate(45deg);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
input[type="checkbox"]:hover {
|
| 156 |
+
border-color: var(--accent-primary);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/* Terminal cursor blink */
|
| 160 |
+
@keyframes blink {
|
| 161 |
+
0%, 50% { opacity: 1; }
|
| 162 |
+
51%, 100% { opacity: 0; }
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
/* Scan line effect */
|
| 166 |
+
@keyframes scanline {
|
| 167 |
+
0% { transform: translateY(-100%); }
|
| 168 |
+
100% { transform: translateY(100vh); }
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
/* Pulse glow */
|
| 172 |
+
@keyframes pulse-glow {
|
| 173 |
+
0%, 100% {
|
| 174 |
+
box-shadow: 0 0 5px var(--accent-dim);
|
| 175 |
+
}
|
| 176 |
+
50% {
|
| 177 |
+
box-shadow: 0 0 20px var(--accent-glow), 0 0 30px var(--accent-dim);
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
/* Tooltip animations */
|
| 182 |
+
@keyframes fadeIn {
|
| 183 |
from {
|
| 184 |
+
opacity: 0;
|
| 185 |
+
transform: translateY(5px);
|
| 186 |
}
|
| 187 |
to {
|
| 188 |
+
opacity: 1;
|
| 189 |
+
transform: translateY(0);
|
| 190 |
}
|
| 191 |
}
|
| 192 |
+
|
| 193 |
+
.tooltip-enter {
|
| 194 |
+
animation: fadeIn 0.2s ease-out;
|
| 195 |
+
}
|
src/App.js
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
import logo from './logo.svg';
|
| 2 |
-
import './App.css';
|
| 3 |
-
|
| 4 |
-
function App() {
|
| 5 |
-
return (
|
| 6 |
-
<div className="App">
|
| 7 |
-
<header className="App-header">
|
| 8 |
-
<img src={logo} className="App-logo" alt="logo" />
|
| 9 |
-
<p>
|
| 10 |
-
Edit <code>src/App.js</code> and save to reload.
|
| 11 |
-
</p>
|
| 12 |
-
<a
|
| 13 |
-
className="App-link"
|
| 14 |
-
href="https://reactjs.org"
|
| 15 |
-
target="_blank"
|
| 16 |
-
rel="noopener noreferrer"
|
| 17 |
-
>
|
| 18 |
-
Learn React
|
| 19 |
-
</a>
|
| 20 |
-
</header>
|
| 21 |
-
</div>
|
| 22 |
-
);
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
export default App;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/App.test.js
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
import { render, screen } from '@testing-library/react';
|
| 2 |
-
import App from './App';
|
| 3 |
-
|
| 4 |
-
test('renders learn react link', () => {
|
| 5 |
-
render(<App />);
|
| 6 |
-
const linkElement = screen.getByText(/learn react/i);
|
| 7 |
-
expect(linkElement).toBeInTheDocument();
|
| 8 |
-
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/App.tsx
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useCallback } from 'react';
|
| 2 |
+
import { DropZone, NeuralVisualizer } from './components';
|
| 3 |
+
import { useVisualizerStore } from './core/store';
|
| 4 |
+
import { LAYER_CATEGORIES, type LayerType } from '@/schema/types';
|
| 5 |
+
import type { NN3DModel } from '@/schema/types';
|
| 6 |
+
import './App.css';
|
| 7 |
+
|
| 8 |
+
/**
|
| 9 |
+
* Main Application Component
|
| 10 |
+
*
|
| 11 |
+
* Integrates the new 3D Neural Network Visualization System
|
| 12 |
+
*/
|
| 13 |
+
function App() {
|
| 14 |
+
const model = useVisualizerStore(state => state.model);
|
| 15 |
+
const isLoading = useVisualizerStore(state => state.isLoading);
|
| 16 |
+
const error = useVisualizerStore(state => state.error);
|
| 17 |
+
const selectNode = useVisualizerStore(state => state.selectNode);
|
| 18 |
+
const clearModel = useVisualizerStore(state => state.clearModel);
|
| 19 |
+
const loadModel = useVisualizerStore(state => state.loadModel);
|
| 20 |
+
|
| 21 |
+
// Build architecture data from store model for the visualizer
|
| 22 |
+
const architecture = model ? {
|
| 23 |
+
name: model.metadata?.name || 'Model',
|
| 24 |
+
framework: model.metadata?.framework || 'Unknown',
|
| 25 |
+
totalParameters: model.metadata?.totalParams || 0,
|
| 26 |
+
trainableParameters: model.metadata?.trainableParams,
|
| 27 |
+
inputShape: model.metadata?.inputShape as number[] | null || null,
|
| 28 |
+
outputShape: model.metadata?.outputShape as number[] | null || null,
|
| 29 |
+
layers: model.graph.nodes.map(node => {
|
| 30 |
+
// Get category from attributes (set by backend) or infer from layer type
|
| 31 |
+
const backendCategory = node.attributes?.category as string | undefined;
|
| 32 |
+
const layerType = node.type as LayerType;
|
| 33 |
+
const category = backendCategory || LAYER_CATEGORIES[layerType] || 'other';
|
| 34 |
+
|
| 35 |
+
// Extract num parameters from attributes or params
|
| 36 |
+
const numParameters =
|
| 37 |
+
(node.attributes?.parameters as number) ||
|
| 38 |
+
(node.params?.totalParams ? parseInt(String(node.params.totalParams).replace(/,/g, '')) : 0) ||
|
| 39 |
+
0;
|
| 40 |
+
|
| 41 |
+
return {
|
| 42 |
+
id: node.id,
|
| 43 |
+
name: node.name,
|
| 44 |
+
type: node.type,
|
| 45 |
+
category,
|
| 46 |
+
inputShape: (node.inputShape as number[] | null) || null,
|
| 47 |
+
outputShape: (node.outputShape as number[] | null) || null,
|
| 48 |
+
params: node.params || {},
|
| 49 |
+
numParameters,
|
| 50 |
+
trainable: true,
|
| 51 |
+
};
|
| 52 |
+
}),
|
| 53 |
+
connections: model.graph.edges.map(edge => ({
|
| 54 |
+
source: edge.source,
|
| 55 |
+
target: edge.target,
|
| 56 |
+
tensorShape: (edge.tensorShape as number[] | null) || null,
|
| 57 |
+
})),
|
| 58 |
+
} : null;
|
| 59 |
+
|
| 60 |
+
// Handle layer selection
|
| 61 |
+
const handleLayerSelect = useCallback((layerId: string | null) => {
|
| 62 |
+
selectNode(layerId);
|
| 63 |
+
}, [selectNode]);
|
| 64 |
+
|
| 65 |
+
// Handle uploading a new model (clear current)
|
| 66 |
+
const handleUploadNew = useCallback(() => {
|
| 67 |
+
clearModel();
|
| 68 |
+
}, [clearModel]);
|
| 69 |
+
|
| 70 |
+
// Handle loading a saved model from database
|
| 71 |
+
const handleLoadSavedModel = useCallback((arch: any) => {
|
| 72 |
+
// Convert architecture back to NN3DModel format
|
| 73 |
+
const savedModel: NN3DModel = {
|
| 74 |
+
version: '1.0',
|
| 75 |
+
metadata: {
|
| 76 |
+
name: arch.name,
|
| 77 |
+
framework: arch.framework,
|
| 78 |
+
totalParams: arch.totalParameters,
|
| 79 |
+
trainableParams: arch.trainableParameters,
|
| 80 |
+
inputShape: arch.inputShape,
|
| 81 |
+
outputShape: arch.outputShape,
|
| 82 |
+
},
|
| 83 |
+
graph: {
|
| 84 |
+
nodes: arch.layers.map((layer: any) => ({
|
| 85 |
+
id: layer.id,
|
| 86 |
+
name: layer.name,
|
| 87 |
+
type: layer.type,
|
| 88 |
+
inputShape: layer.inputShape,
|
| 89 |
+
outputShape: layer.outputShape,
|
| 90 |
+
params: layer.params,
|
| 91 |
+
attributes: {
|
| 92 |
+
category: layer.category,
|
| 93 |
+
parameters: layer.numParameters,
|
| 94 |
+
},
|
| 95 |
+
})),
|
| 96 |
+
edges: arch.connections.map((conn: any, idx: number) => ({
|
| 97 |
+
id: `edge-${idx}`,
|
| 98 |
+
source: conn.source,
|
| 99 |
+
target: conn.target,
|
| 100 |
+
tensorShape: conn.tensorShape,
|
| 101 |
+
})),
|
| 102 |
+
},
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
loadModel(savedModel);
|
| 106 |
+
}, [loadModel]);
|
| 107 |
+
|
| 108 |
+
return (
|
| 109 |
+
<div className="app">
|
| 110 |
+
<DropZone>
|
| 111 |
+
<NeuralVisualizer
|
| 112 |
+
architecture={architecture}
|
| 113 |
+
isLoading={isLoading}
|
| 114 |
+
error={error}
|
| 115 |
+
onLayerSelect={handleLayerSelect}
|
| 116 |
+
onUploadNew={handleUploadNew}
|
| 117 |
+
onLoadSavedModel={handleLoadSavedModel}
|
| 118 |
+
/>
|
| 119 |
+
</DropZone>
|
| 120 |
+
</div>
|
| 121 |
+
);
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
export default App;
|
src/components/Scene.tsx
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { Suspense } from 'react';
|
| 2 |
+
import { Canvas } from '@react-three/fiber';
|
| 3 |
+
import { Stars, GizmoHelper, GizmoViewport } from '@react-three/drei';
|
| 4 |
+
import { LayerNodes } from './layers';
|
| 5 |
+
import { EdgeConnections } from './edges';
|
| 6 |
+
import { CameraControls, useKeyboardShortcuts } from './controls';
|
| 7 |
+
import { useVisualizerStore } from '@/core/store';
|
| 8 |
+
|
| 9 |
+
/**
|
| 10 |
+
* Loading fallback component
|
| 11 |
+
*/
|
| 12 |
+
function LoadingFallback() {
|
| 13 |
+
return (
|
| 14 |
+
<mesh>
|
| 15 |
+
<sphereGeometry args={[0.5, 32, 32]} />
|
| 16 |
+
<meshStandardMaterial color="#4fc3f7" wireframe />
|
| 17 |
+
</mesh>
|
| 18 |
+
);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* Scene lighting setup
|
| 23 |
+
*/
|
| 24 |
+
function SceneLighting() {
|
| 25 |
+
const config = useVisualizerStore(state => state.config);
|
| 26 |
+
const isDark = config.theme !== 'light';
|
| 27 |
+
|
| 28 |
+
return (
|
| 29 |
+
<>
|
| 30 |
+
<ambientLight intensity={isDark ? 0.4 : 0.6} />
|
| 31 |
+
<directionalLight
|
| 32 |
+
position={[10, 20, 10]}
|
| 33 |
+
intensity={isDark ? 0.8 : 1}
|
| 34 |
+
castShadow
|
| 35 |
+
shadow-mapSize={[2048, 2048]}
|
| 36 |
+
/>
|
| 37 |
+
<directionalLight position={[-10, 10, -10]} intensity={0.3} />
|
| 38 |
+
<pointLight position={[0, 10, 0]} intensity={0.5} color="#4fc3f7" />
|
| 39 |
+
</>
|
| 40 |
+
);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/**
|
| 44 |
+
* Scene background
|
| 45 |
+
*/
|
| 46 |
+
function SceneBackground() {
|
| 47 |
+
const config = useVisualizerStore(state => state.config);
|
| 48 |
+
|
| 49 |
+
if (config.theme === 'light') {
|
| 50 |
+
return <color attach="background" args={['#f0f0f0']} />;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
if (config.theme === 'blueprint') {
|
| 54 |
+
return (
|
| 55 |
+
<>
|
| 56 |
+
<color attach="background" args={['#0a1929']} />
|
| 57 |
+
<gridHelper args={[100, 100, '#1e3a5f', '#0d2137']} position={[0, -10, 0]} />
|
| 58 |
+
</>
|
| 59 |
+
);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// Dark theme (default)
|
| 63 |
+
return (
|
| 64 |
+
<>
|
| 65 |
+
<color attach="background" args={['#0f0f1a']} />
|
| 66 |
+
<Stars radius={100} depth={50} count={2000} factor={4} fade speed={0.5} />
|
| 67 |
+
</>
|
| 68 |
+
);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
/**
|
| 72 |
+
* Keyboard shortcuts handler component
|
| 73 |
+
*/
|
| 74 |
+
function KeyboardHandler() {
|
| 75 |
+
useKeyboardShortcuts();
|
| 76 |
+
return null;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
/**
|
| 80 |
+
* Grid and helper elements
|
| 81 |
+
*/
|
| 82 |
+
function SceneHelpers() {
|
| 83 |
+
const model = useVisualizerStore(state => state.model);
|
| 84 |
+
|
| 85 |
+
if (!model) return null;
|
| 86 |
+
|
| 87 |
+
return (
|
| 88 |
+
<>
|
| 89 |
+
<gridHelper
|
| 90 |
+
args={[50, 50, '#333', '#222']}
|
| 91 |
+
position={[0, -15, 0]}
|
| 92 |
+
rotation={[0, 0, 0]}
|
| 93 |
+
/>
|
| 94 |
+
<GizmoHelper alignment="bottom-left" margin={[80, 80]}>
|
| 95 |
+
<GizmoViewport axisColors={['#f44336', '#4caf50', '#2196f3']} labelColor="white" />
|
| 96 |
+
</GizmoHelper>
|
| 97 |
+
</>
|
| 98 |
+
);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
/**
|
| 102 |
+
* Main 3D network scene
|
| 103 |
+
*/
|
| 104 |
+
function NetworkScene() {
|
| 105 |
+
return (
|
| 106 |
+
<>
|
| 107 |
+
<SceneLighting />
|
| 108 |
+
<SceneBackground />
|
| 109 |
+
<SceneHelpers />
|
| 110 |
+
|
| 111 |
+
<Suspense fallback={<LoadingFallback />}>
|
| 112 |
+
<EdgeConnections />
|
| 113 |
+
<LayerNodes />
|
| 114 |
+
</Suspense>
|
| 115 |
+
</>
|
| 116 |
+
);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
/**
|
| 120 |
+
* Main 3D Canvas component
|
| 121 |
+
*/
|
| 122 |
+
export function Scene() {
|
| 123 |
+
return (
|
| 124 |
+
<Canvas
|
| 125 |
+
camera={{
|
| 126 |
+
position: [0, 5, 20],
|
| 127 |
+
fov: 60,
|
| 128 |
+
near: 0.1,
|
| 129 |
+
far: 1000,
|
| 130 |
+
}}
|
| 131 |
+
dpr={[1, 2]}
|
| 132 |
+
gl={{
|
| 133 |
+
antialias: true,
|
| 134 |
+
alpha: false,
|
| 135 |
+
powerPreference: 'high-performance',
|
| 136 |
+
}}
|
| 137 |
+
>
|
| 138 |
+
<CameraControls />
|
| 139 |
+
<KeyboardHandler />
|
| 140 |
+
<NetworkScene />
|
| 141 |
+
</Canvas>
|
| 142 |
+
);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
export default Scene;
|
src/components/controls/CameraControls.tsx
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useRef, useEffect } from 'react';
|
| 2 |
+
import { useThree, useFrame } from '@react-three/fiber';
|
| 3 |
+
import { OrbitControls as DreiOrbitControls } from '@react-three/drei';
|
| 4 |
+
import * as THREE from 'three';
|
| 5 |
+
import { useVisualizerStore } from '@/core/store';
|
| 6 |
+
|
| 7 |
+
/**
|
| 8 |
+
* Camera controls component with orbit, zoom, and pan
|
| 9 |
+
*/
|
| 10 |
+
export function CameraControls() {
|
| 11 |
+
const controlsRef = useRef<any>(null);
|
| 12 |
+
const model = useVisualizerStore(state => state.model);
|
| 13 |
+
const setCameraPosition = useVisualizerStore(state => state.setCameraPosition);
|
| 14 |
+
const setCameraTarget = useVisualizerStore(state => state.setCameraTarget);
|
| 15 |
+
|
| 16 |
+
const { camera: threeCamera } = useThree();
|
| 17 |
+
|
| 18 |
+
// Update store when camera moves
|
| 19 |
+
useFrame(() => {
|
| 20 |
+
if (controlsRef.current) {
|
| 21 |
+
const pos = threeCamera.position;
|
| 22 |
+
const target = controlsRef.current.target;
|
| 23 |
+
|
| 24 |
+
// Debounce updates
|
| 25 |
+
setCameraPosition({ x: pos.x, y: pos.y, z: pos.z });
|
| 26 |
+
setCameraTarget({ x: target.x, y: target.y, z: target.z });
|
| 27 |
+
}
|
| 28 |
+
});
|
| 29 |
+
|
| 30 |
+
// Reset camera when model changes
|
| 31 |
+
useEffect(() => {
|
| 32 |
+
if (model && controlsRef.current) {
|
| 33 |
+
// Compute camera position to frame the model
|
| 34 |
+
const nodeCount = model.graph.nodes.length;
|
| 35 |
+
const distance = Math.max(nodeCount * 1.5, 15);
|
| 36 |
+
|
| 37 |
+
threeCamera.position.set(0, distance * 0.3, distance);
|
| 38 |
+
controlsRef.current.target.set(0, -nodeCount * 0.5, 0);
|
| 39 |
+
controlsRef.current.update();
|
| 40 |
+
}
|
| 41 |
+
}, [model, threeCamera]);
|
| 42 |
+
|
| 43 |
+
return (
|
| 44 |
+
<DreiOrbitControls
|
| 45 |
+
ref={controlsRef}
|
| 46 |
+
enablePan={true}
|
| 47 |
+
enableZoom={true}
|
| 48 |
+
enableRotate={true}
|
| 49 |
+
minDistance={2}
|
| 50 |
+
maxDistance={100}
|
| 51 |
+
minPolarAngle={0}
|
| 52 |
+
maxPolarAngle={Math.PI}
|
| 53 |
+
dampingFactor={0.1}
|
| 54 |
+
rotateSpeed={0.5}
|
| 55 |
+
panSpeed={0.5}
|
| 56 |
+
zoomSpeed={0.8}
|
| 57 |
+
/>
|
| 58 |
+
);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
/**
|
| 62 |
+
* Camera animation for transitions
|
| 63 |
+
*/
|
| 64 |
+
export function useCameraAnimation() {
|
| 65 |
+
const { camera } = useThree();
|
| 66 |
+
const targetRef = useRef<THREE.Vector3 | null>(null);
|
| 67 |
+
const lookAtRef = useRef<THREE.Vector3 | null>(null);
|
| 68 |
+
const progressRef = useRef(0);
|
| 69 |
+
|
| 70 |
+
useFrame((_, delta) => {
|
| 71 |
+
if (targetRef.current && progressRef.current < 1) {
|
| 72 |
+
progressRef.current += delta * 2;
|
| 73 |
+
const t = Math.min(progressRef.current, 1);
|
| 74 |
+
const eased = 1 - Math.pow(1 - t, 3); // Ease out cubic
|
| 75 |
+
|
| 76 |
+
camera.position.lerp(targetRef.current, eased);
|
| 77 |
+
|
| 78 |
+
if (lookAtRef.current) {
|
| 79 |
+
camera.lookAt(lookAtRef.current);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if (t >= 1) {
|
| 83 |
+
targetRef.current = null;
|
| 84 |
+
lookAtRef.current = null;
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
});
|
| 88 |
+
|
| 89 |
+
const animateTo = (position: THREE.Vector3, lookAt?: THREE.Vector3) => {
|
| 90 |
+
targetRef.current = position;
|
| 91 |
+
lookAtRef.current = lookAt || null;
|
| 92 |
+
progressRef.current = 0;
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
return { animateTo };
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
/**
|
| 99 |
+
* Focus camera on a specific node
|
| 100 |
+
*/
|
| 101 |
+
export function useFocusNode() {
|
| 102 |
+
const computedNodes = useVisualizerStore(state => state.computedNodes);
|
| 103 |
+
const { animateTo } = useCameraAnimation();
|
| 104 |
+
|
| 105 |
+
const focusNode = (nodeId: string) => {
|
| 106 |
+
const node = computedNodes.get(nodeId);
|
| 107 |
+
if (!node) return;
|
| 108 |
+
|
| 109 |
+
const { x, y, z } = node.computedPosition;
|
| 110 |
+
const targetPos = new THREE.Vector3(x, y + 5, z + 10);
|
| 111 |
+
const lookAtPos = new THREE.Vector3(x, y, z);
|
| 112 |
+
|
| 113 |
+
animateTo(targetPos, lookAtPos);
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
return { focusNode };
|
| 117 |
+
}
|
src/components/controls/Interaction.tsx
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useCallback, useEffect } from 'react';
|
| 2 |
+
import { useThree } from '@react-three/fiber';
|
| 3 |
+
import * as THREE from 'three';
|
| 4 |
+
import { useVisualizerStore } from '@/core/store';
|
| 5 |
+
|
| 6 |
+
/**
|
| 7 |
+
* Hook for raycasting and object picking
|
| 8 |
+
*/
|
| 9 |
+
export function useRaycast() {
|
| 10 |
+
const { camera, scene, gl } = useThree();
|
| 11 |
+
const raycaster = new THREE.Raycaster();
|
| 12 |
+
const pointer = new THREE.Vector2();
|
| 13 |
+
|
| 14 |
+
const getIntersections = useCallback((event: MouseEvent | PointerEvent) => {
|
| 15 |
+
const rect = gl.domElement.getBoundingClientRect();
|
| 16 |
+
pointer.x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
|
| 17 |
+
pointer.y = -((event.clientY - rect.top) / rect.height) * 2 + 1;
|
| 18 |
+
|
| 19 |
+
raycaster.setFromCamera(pointer, camera);
|
| 20 |
+
return raycaster.intersectObjects(scene.children, true);
|
| 21 |
+
}, [camera, scene, gl]);
|
| 22 |
+
|
| 23 |
+
return { getIntersections };
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
/**
|
| 27 |
+
* Keyboard shortcuts handler
|
| 28 |
+
*/
|
| 29 |
+
export function useKeyboardShortcuts() {
|
| 30 |
+
const resetCamera = useVisualizerStore(state => state.resetCamera);
|
| 31 |
+
const selectNode = useVisualizerStore(state => state.selectNode);
|
| 32 |
+
const selection = useVisualizerStore(state => state.selection);
|
| 33 |
+
const computedNodes = useVisualizerStore(state => state.computedNodes);
|
| 34 |
+
const updateConfig = useVisualizerStore(state => state.updateConfig);
|
| 35 |
+
const config = useVisualizerStore(state => state.config);
|
| 36 |
+
|
| 37 |
+
useEffect(() => {
|
| 38 |
+
const handleKeyDown = (event: KeyboardEvent) => {
|
| 39 |
+
switch (event.key) {
|
| 40 |
+
case 'Escape':
|
| 41 |
+
// Deselect current selection
|
| 42 |
+
selectNode(null);
|
| 43 |
+
break;
|
| 44 |
+
|
| 45 |
+
case 'r':
|
| 46 |
+
case 'R':
|
| 47 |
+
// Reset camera
|
| 48 |
+
if (!event.ctrlKey && !event.metaKey) {
|
| 49 |
+
resetCamera();
|
| 50 |
+
}
|
| 51 |
+
break;
|
| 52 |
+
|
| 53 |
+
case 'l':
|
| 54 |
+
case 'L':
|
| 55 |
+
// Toggle labels
|
| 56 |
+
updateConfig({ showLabels: !config.showLabels });
|
| 57 |
+
break;
|
| 58 |
+
|
| 59 |
+
case 'e':
|
| 60 |
+
case 'E':
|
| 61 |
+
// Toggle edges
|
| 62 |
+
updateConfig({ showEdges: !config.showEdges });
|
| 63 |
+
break;
|
| 64 |
+
|
| 65 |
+
case 'ArrowUp':
|
| 66 |
+
case 'ArrowDown': {
|
| 67 |
+
// Navigate between nodes
|
| 68 |
+
if (selection.selectedNodeId) {
|
| 69 |
+
const nodeIds = Array.from(computedNodes.keys());
|
| 70 |
+
const currentIndex = nodeIds.indexOf(selection.selectedNodeId);
|
| 71 |
+
const nextIndex = event.key === 'ArrowDown'
|
| 72 |
+
? Math.min(currentIndex + 1, nodeIds.length - 1)
|
| 73 |
+
: Math.max(currentIndex - 1, 0);
|
| 74 |
+
selectNode(nodeIds[nextIndex]);
|
| 75 |
+
}
|
| 76 |
+
break;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
default:
|
| 80 |
+
break;
|
| 81 |
+
}
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
window.addEventListener('keydown', handleKeyDown);
|
| 85 |
+
return () => window.removeEventListener('keydown', handleKeyDown);
|
| 86 |
+
}, [resetCamera, selectNode, selection, computedNodes, updateConfig, config]);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
/**
|
| 90 |
+
* Touch gesture handler for mobile
|
| 91 |
+
*/
|
| 92 |
+
export function useTouchGestures() {
|
| 93 |
+
// Placeholder for touch gesture handling
|
| 94 |
+
// Can be expanded for pinch-to-zoom, two-finger rotate, etc.
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/**
|
| 98 |
+
* LOD (Level of Detail) manager based on camera distance
|
| 99 |
+
*/
|
| 100 |
+
export function useLODManager() {
|
| 101 |
+
const { camera } = useThree();
|
| 102 |
+
const computedNodes = useVisualizerStore(state => state.computedNodes);
|
| 103 |
+
const updateNodeLOD = useVisualizerStore(state => state.updateNodeLOD);
|
| 104 |
+
|
| 105 |
+
// LOD thresholds
|
| 106 |
+
const LOD_DISTANCES = {
|
| 107 |
+
HIGH: 20, // LOD 0 (full detail) when closer than this
|
| 108 |
+
MEDIUM: 40, // LOD 1 (medium detail)
|
| 109 |
+
LOW: 80, // LOD 2 (low detail)
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
const updateLOD = useCallback(() => {
|
| 113 |
+
const lodMap = new Map<string, number>();
|
| 114 |
+
const cameraPos = camera.position;
|
| 115 |
+
|
| 116 |
+
computedNodes.forEach((node, id) => {
|
| 117 |
+
const nodePos = new THREE.Vector3(
|
| 118 |
+
node.computedPosition.x,
|
| 119 |
+
node.computedPosition.y,
|
| 120 |
+
node.computedPosition.z
|
| 121 |
+
);
|
| 122 |
+
const distance = cameraPos.distanceTo(nodePos);
|
| 123 |
+
|
| 124 |
+
let lod = 0;
|
| 125 |
+
if (distance > LOD_DISTANCES.LOW) {
|
| 126 |
+
lod = 2;
|
| 127 |
+
} else if (distance > LOD_DISTANCES.MEDIUM) {
|
| 128 |
+
lod = 1;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
lodMap.set(id, lod);
|
| 132 |
+
});
|
| 133 |
+
|
| 134 |
+
updateNodeLOD(lodMap);
|
| 135 |
+
}, [camera, computedNodes, updateNodeLOD]);
|
| 136 |
+
|
| 137 |
+
return { updateLOD };
|
| 138 |
+
}
|
src/components/controls/index.ts
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export * from './CameraControls';
|
| 2 |
+
export * from './Interaction';
|
src/components/edges/EdgeConnections.tsx
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useMemo } from 'react';
|
| 2 |
+
import { useVisualizerStore } from '@/core/store';
|
| 3 |
+
import { getEdgeComponent, EdgeStyle } from './EdgeGeometry';
|
| 4 |
+
import { SmartConnection } from './NeuralConnection';
|
| 5 |
+
|
| 6 |
+
/**
|
| 7 |
+
* Renders all edges/connections in the network
|
| 8 |
+
*/
|
| 9 |
+
export function EdgeConnections() {
|
| 10 |
+
const computedEdges = useVisualizerStore(state => state.computedEdges);
|
| 11 |
+
const computedNodes = useVisualizerStore(state => state.computedNodes);
|
| 12 |
+
const config = useVisualizerStore(state => state.config);
|
| 13 |
+
|
| 14 |
+
// Check if we should use enhanced neural visualization
|
| 15 |
+
const useEnhancedEdges = useMemo(() => {
|
| 16 |
+
// Check if any node has enhanced attributes (from backend analysis)
|
| 17 |
+
for (const node of computedNodes.values()) {
|
| 18 |
+
if (node.params && (
|
| 19 |
+
node.params.out_features !== undefined ||
|
| 20 |
+
node.params.outFeatures !== undefined ||
|
| 21 |
+
node.params.out_channels !== undefined ||
|
| 22 |
+
node.params.outChannels !== undefined
|
| 23 |
+
)) {
|
| 24 |
+
return true;
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
return false;
|
| 28 |
+
}, [computedNodes]);
|
| 29 |
+
|
| 30 |
+
const EdgeComponent = useMemo(
|
| 31 |
+
() => getEdgeComponent(config.edgeStyle as EdgeStyle || 'tube'),
|
| 32 |
+
[config.edgeStyle]
|
| 33 |
+
);
|
| 34 |
+
|
| 35 |
+
if (!config.showEdges) return null;
|
| 36 |
+
|
| 37 |
+
// Enhanced neural connections
|
| 38 |
+
if (useEnhancedEdges) {
|
| 39 |
+
return (
|
| 40 |
+
<group name="edge-connections">
|
| 41 |
+
{computedEdges.map((edge, index) => {
|
| 42 |
+
// Get source and target node info for neuron counts
|
| 43 |
+
const sourceNode = computedNodes.get(edge.source);
|
| 44 |
+
const targetNode = computedNodes.get(edge.target);
|
| 45 |
+
|
| 46 |
+
const sourceNeurons = sourceNode?.params?.out_features ||
|
| 47 |
+
sourceNode?.params?.outFeatures ||
|
| 48 |
+
sourceNode?.params?.out_channels ||
|
| 49 |
+
sourceNode?.params?.outChannels || 16;
|
| 50 |
+
|
| 51 |
+
const targetNeurons = targetNode?.params?.in_features ||
|
| 52 |
+
targetNode?.params?.inFeatures ||
|
| 53 |
+
targetNode?.params?.in_channels ||
|
| 54 |
+
targetNode?.params?.inChannels || 16;
|
| 55 |
+
|
| 56 |
+
return (
|
| 57 |
+
<SmartConnection
|
| 58 |
+
key={edge.id || `edge-${index}`}
|
| 59 |
+
sourcePosition={edge.sourcePosition}
|
| 60 |
+
targetPosition={edge.targetPosition}
|
| 61 |
+
sourceNeurons={typeof sourceNeurons === 'number' ? sourceNeurons : 16}
|
| 62 |
+
targetNeurons={typeof targetNeurons === 'number' ? targetNeurons : 16}
|
| 63 |
+
color={edge.color}
|
| 64 |
+
highlighted={edge.highlighted}
|
| 65 |
+
animated={true}
|
| 66 |
+
style="bundle"
|
| 67 |
+
/>
|
| 68 |
+
);
|
| 69 |
+
})}
|
| 70 |
+
</group>
|
| 71 |
+
);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// Original edge rendering
|
| 75 |
+
return (
|
| 76 |
+
<group name="edge-connections">
|
| 77 |
+
{computedEdges.map((edge, index) => (
|
| 78 |
+
<EdgeComponent
|
| 79 |
+
key={edge.id || `edge-${index}`}
|
| 80 |
+
edge={edge}
|
| 81 |
+
style={config.edgeStyle as EdgeStyle}
|
| 82 |
+
animated={false}
|
| 83 |
+
/>
|
| 84 |
+
))}
|
| 85 |
+
</group>
|
| 86 |
+
);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
export { EdgeConnections as Edges };
|
src/components/edges/EdgeGeometry.tsx
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useMemo, useRef } from 'react';
|
| 2 |
+
import * as THREE from 'three';
|
| 3 |
+
import { useFrame } from '@react-three/fiber';
|
| 4 |
+
import { Line, QuadraticBezierLine } from '@react-three/drei';
|
| 5 |
+
import type { ComputedEdge } from '@/core/store';
|
| 6 |
+
import type { Position3D } from '@/schema/types';
|
| 7 |
+
|
| 8 |
+
/**
|
| 9 |
+
* Edge style type
|
| 10 |
+
*/
|
| 11 |
+
export type EdgeStyle = 'line' | 'tube' | 'arrow' | 'bezier';
|
| 12 |
+
|
| 13 |
+
/**
|
| 14 |
+
* Props for edge components
|
| 15 |
+
*/
|
| 16 |
+
export interface EdgeProps {
|
| 17 |
+
edge: ComputedEdge;
|
| 18 |
+
style?: EdgeStyle;
|
| 19 |
+
animated?: boolean;
|
| 20 |
+
flowSpeed?: number;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
/**
|
| 24 |
+
* Calculate control points for bezier curves
|
| 25 |
+
*/
|
| 26 |
+
function getBezierControlPoints(
|
| 27 |
+
start: Position3D,
|
| 28 |
+
end: Position3D
|
| 29 |
+
): { mid1: Position3D; mid2: Position3D } {
|
| 30 |
+
const midY = (start.y + end.y) / 2;
|
| 31 |
+
const offsetZ = Math.abs(end.y - start.y) * 0.3;
|
| 32 |
+
|
| 33 |
+
return {
|
| 34 |
+
mid1: { x: start.x, y: midY, z: start.z + offsetZ },
|
| 35 |
+
mid2: { x: end.x, y: midY, z: end.z + offsetZ },
|
| 36 |
+
};
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
/**
|
| 40 |
+
* Simple line edge
|
| 41 |
+
*/
|
| 42 |
+
export function LineEdge({ edge }: EdgeProps) {
|
| 43 |
+
const points = useMemo(() => [
|
| 44 |
+
new THREE.Vector3(edge.sourcePosition.x, edge.sourcePosition.y, edge.sourcePosition.z),
|
| 45 |
+
new THREE.Vector3(edge.targetPosition.x, edge.targetPosition.y, edge.targetPosition.z),
|
| 46 |
+
], [edge.sourcePosition, edge.targetPosition]);
|
| 47 |
+
|
| 48 |
+
const color = edge.highlighted ? '#ffffff' : edge.color;
|
| 49 |
+
const lineWidth = edge.highlighted ? 3 : 1.5;
|
| 50 |
+
|
| 51 |
+
return (
|
| 52 |
+
<Line
|
| 53 |
+
points={points}
|
| 54 |
+
color={color}
|
| 55 |
+
lineWidth={lineWidth}
|
| 56 |
+
opacity={edge.visible ? 1 : 0.2}
|
| 57 |
+
transparent
|
| 58 |
+
/>
|
| 59 |
+
);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/**
|
| 63 |
+
* Bezier curve edge for smoother connections
|
| 64 |
+
*/
|
| 65 |
+
export function BezierEdge({ edge }: EdgeProps) {
|
| 66 |
+
const start = useMemo(() =>
|
| 67 |
+
new THREE.Vector3(edge.sourcePosition.x, edge.sourcePosition.y, edge.sourcePosition.z),
|
| 68 |
+
[edge.sourcePosition]
|
| 69 |
+
);
|
| 70 |
+
|
| 71 |
+
const end = useMemo(() =>
|
| 72 |
+
new THREE.Vector3(edge.targetPosition.x, edge.targetPosition.y, edge.targetPosition.z),
|
| 73 |
+
[edge.targetPosition]
|
| 74 |
+
);
|
| 75 |
+
|
| 76 |
+
const control = useMemo(() => {
|
| 77 |
+
const { mid1 } = getBezierControlPoints(edge.sourcePosition, edge.targetPosition);
|
| 78 |
+
return new THREE.Vector3(mid1.x, mid1.y, mid1.z);
|
| 79 |
+
}, [edge.sourcePosition, edge.targetPosition]);
|
| 80 |
+
|
| 81 |
+
const color = edge.highlighted ? '#ffffff' : edge.color;
|
| 82 |
+
const lineWidth = edge.highlighted ? 3 : 1.5;
|
| 83 |
+
|
| 84 |
+
return (
|
| 85 |
+
<QuadraticBezierLine
|
| 86 |
+
start={start}
|
| 87 |
+
end={end}
|
| 88 |
+
mid={control}
|
| 89 |
+
color={color}
|
| 90 |
+
lineWidth={lineWidth}
|
| 91 |
+
opacity={edge.visible ? 1 : 0.2}
|
| 92 |
+
transparent
|
| 93 |
+
/>
|
| 94 |
+
);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/**
|
| 98 |
+
* Tube edge for 3D pipe-like connections
|
| 99 |
+
*/
|
| 100 |
+
export function TubeEdge({ edge, animated = false, flowSpeed = 1 }: EdgeProps) {
|
| 101 |
+
const tubeRef = useRef<THREE.Mesh>(null);
|
| 102 |
+
|
| 103 |
+
// Create curve for tube
|
| 104 |
+
const curve = useMemo(() => {
|
| 105 |
+
const start = new THREE.Vector3(
|
| 106 |
+
edge.sourcePosition.x,
|
| 107 |
+
edge.sourcePosition.y,
|
| 108 |
+
edge.sourcePosition.z
|
| 109 |
+
);
|
| 110 |
+
const end = new THREE.Vector3(
|
| 111 |
+
edge.targetPosition.x,
|
| 112 |
+
edge.targetPosition.y,
|
| 113 |
+
edge.targetPosition.z
|
| 114 |
+
);
|
| 115 |
+
|
| 116 |
+
const { mid1, mid2 } = getBezierControlPoints(edge.sourcePosition, edge.targetPosition);
|
| 117 |
+
const control1 = new THREE.Vector3(mid1.x, mid1.y, mid1.z);
|
| 118 |
+
const control2 = new THREE.Vector3(mid2.x, mid2.y, mid2.z);
|
| 119 |
+
|
| 120 |
+
return new THREE.CubicBezierCurve3(start, control1, control2, end);
|
| 121 |
+
}, [edge.sourcePosition, edge.targetPosition]);
|
| 122 |
+
|
| 123 |
+
// Tube geometry
|
| 124 |
+
const geometry = useMemo(() => {
|
| 125 |
+
const radius = edge.highlighted ? 0.06 : 0.03;
|
| 126 |
+
return new THREE.TubeGeometry(curve, 32, radius, 8, false);
|
| 127 |
+
}, [curve, edge.highlighted]);
|
| 128 |
+
|
| 129 |
+
// Animate flow effect
|
| 130 |
+
useFrame(({ clock }) => {
|
| 131 |
+
if (animated && tubeRef.current) {
|
| 132 |
+
const material = tubeRef.current.material as THREE.MeshStandardMaterial;
|
| 133 |
+
if (material.map) {
|
| 134 |
+
material.map.offset.y = clock.getElapsedTime() * flowSpeed;
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
});
|
| 138 |
+
|
| 139 |
+
const color = edge.highlighted ? '#ffffff' : edge.color;
|
| 140 |
+
|
| 141 |
+
return (
|
| 142 |
+
<mesh ref={tubeRef} geometry={geometry}>
|
| 143 |
+
<meshStandardMaterial
|
| 144 |
+
color={color}
|
| 145 |
+
opacity={edge.visible ? 0.8 : 0.2}
|
| 146 |
+
transparent
|
| 147 |
+
metalness={0.3}
|
| 148 |
+
roughness={0.6}
|
| 149 |
+
/>
|
| 150 |
+
</mesh>
|
| 151 |
+
);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/**
|
| 155 |
+
* Arrow edge with direction indicator
|
| 156 |
+
*/
|
| 157 |
+
export function ArrowEdge({ edge }: EdgeProps) {
|
| 158 |
+
const start = new THREE.Vector3(
|
| 159 |
+
edge.sourcePosition.x,
|
| 160 |
+
edge.sourcePosition.y,
|
| 161 |
+
edge.sourcePosition.z
|
| 162 |
+
);
|
| 163 |
+
const end = new THREE.Vector3(
|
| 164 |
+
edge.targetPosition.x,
|
| 165 |
+
edge.targetPosition.y,
|
| 166 |
+
edge.targetPosition.z
|
| 167 |
+
);
|
| 168 |
+
|
| 169 |
+
const direction = end.clone().sub(start).normalize();
|
| 170 |
+
const arrowPosition = end.clone().sub(direction.clone().multiplyScalar(0.3));
|
| 171 |
+
|
| 172 |
+
const color = edge.highlighted ? '#ffffff' : edge.color;
|
| 173 |
+
|
| 174 |
+
// Calculate arrow rotation
|
| 175 |
+
const quaternion = new THREE.Quaternion();
|
| 176 |
+
quaternion.setFromUnitVectors(new THREE.Vector3(0, 1, 0), direction);
|
| 177 |
+
|
| 178 |
+
return (
|
| 179 |
+
<group>
|
| 180 |
+
<LineEdge edge={edge} />
|
| 181 |
+
{/* Arrow head */}
|
| 182 |
+
<mesh position={arrowPosition} quaternion={quaternion}>
|
| 183 |
+
<coneGeometry args={[0.1, 0.2, 8]} />
|
| 184 |
+
<meshStandardMaterial color={color} />
|
| 185 |
+
</mesh>
|
| 186 |
+
</group>
|
| 187 |
+
);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
/**
|
| 191 |
+
* Factory function to get edge component by style
|
| 192 |
+
*/
|
| 193 |
+
export function getEdgeComponent(style: EdgeStyle): React.ComponentType<EdgeProps> {
|
| 194 |
+
switch (style) {
|
| 195 |
+
case 'line':
|
| 196 |
+
return LineEdge;
|
| 197 |
+
case 'bezier':
|
| 198 |
+
return BezierEdge;
|
| 199 |
+
case 'tube':
|
| 200 |
+
return TubeEdge;
|
| 201 |
+
case 'arrow':
|
| 202 |
+
return ArrowEdge;
|
| 203 |
+
default:
|
| 204 |
+
return TubeEdge;
|
| 205 |
+
}
|
| 206 |
+
}
|
src/components/edges/NeuralConnection.tsx
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Neural Network Connection Visualization
|
| 3 |
+
* Shows dense connections between layers like actual neural networks
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
import { useMemo, useRef } from 'react';
|
| 7 |
+
import * as THREE from 'three';
|
| 8 |
+
import { useFrame } from '@react-three/fiber';
|
| 9 |
+
import type { Position3D } from '@/schema/types';
|
| 10 |
+
|
| 11 |
+
export interface NeuralConnectionProps {
|
| 12 |
+
sourcePosition: Position3D;
|
| 13 |
+
targetPosition: Position3D;
|
| 14 |
+
sourceNeurons?: number;
|
| 15 |
+
targetNeurons?: number;
|
| 16 |
+
color?: string;
|
| 17 |
+
highlighted?: boolean;
|
| 18 |
+
animated?: boolean;
|
| 19 |
+
connectionDensity?: number; // 0-1, how many connections to show
|
| 20 |
+
style?: 'single' | 'dense' | 'bundle';
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
/**
|
| 24 |
+
* Single connection line between layers
|
| 25 |
+
*/
|
| 26 |
+
export function SingleConnection({
|
| 27 |
+
sourcePosition,
|
| 28 |
+
targetPosition,
|
| 29 |
+
color = '#ffffff',
|
| 30 |
+
highlighted = false,
|
| 31 |
+
}: NeuralConnectionProps) {
|
| 32 |
+
const geometry = useMemo(() => {
|
| 33 |
+
const points = [
|
| 34 |
+
new THREE.Vector3(sourcePosition.x, sourcePosition.y, sourcePosition.z),
|
| 35 |
+
new THREE.Vector3(targetPosition.x, targetPosition.y, targetPosition.z),
|
| 36 |
+
];
|
| 37 |
+
return new THREE.BufferGeometry().setFromPoints(points);
|
| 38 |
+
}, [sourcePosition, targetPosition]);
|
| 39 |
+
|
| 40 |
+
return (
|
| 41 |
+
<primitive object={new THREE.Line(geometry, new THREE.LineBasicMaterial({
|
| 42 |
+
color: highlighted ? '#ffffff' : color,
|
| 43 |
+
transparent: true,
|
| 44 |
+
opacity: highlighted ? 1 : 0.5,
|
| 45 |
+
}))} />
|
| 46 |
+
);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
/**
|
| 50 |
+
* Dense connection bundle showing multiple lines
|
| 51 |
+
*/
|
| 52 |
+
export function DenseConnection({
|
| 53 |
+
sourcePosition,
|
| 54 |
+
targetPosition,
|
| 55 |
+
sourceNeurons = 16,
|
| 56 |
+
targetNeurons = 16,
|
| 57 |
+
color = '#4488ff',
|
| 58 |
+
highlighted = false,
|
| 59 |
+
animated = true,
|
| 60 |
+
connectionDensity = 0.3,
|
| 61 |
+
}: NeuralConnectionProps) {
|
| 62 |
+
const groupRef = useRef<THREE.Group>(null);
|
| 63 |
+
const materialRef = useRef<THREE.LineBasicMaterial>(null);
|
| 64 |
+
|
| 65 |
+
// Limit connections for performance
|
| 66 |
+
const maxConnections = 100;
|
| 67 |
+
const numConnections = Math.min(
|
| 68 |
+
Math.floor(sourceNeurons * targetNeurons * connectionDensity),
|
| 69 |
+
maxConnections
|
| 70 |
+
);
|
| 71 |
+
|
| 72 |
+
// Generate connection lines
|
| 73 |
+
const lines = useMemo(() => {
|
| 74 |
+
const connections: { start: THREE.Vector3; end: THREE.Vector3 }[] = [];
|
| 75 |
+
|
| 76 |
+
// Calculate source and target layer bounds
|
| 77 |
+
const sourceSpread = Math.min(Math.sqrt(sourceNeurons) * 0.1, 0.4);
|
| 78 |
+
const targetSpread = Math.min(Math.sqrt(targetNeurons) * 0.1, 0.4);
|
| 79 |
+
|
| 80 |
+
for (let i = 0; i < numConnections; i++) {
|
| 81 |
+
// Random positions within layer bounds
|
| 82 |
+
const srcOffset = {
|
| 83 |
+
y: (Math.random() - 0.5) * sourceSpread,
|
| 84 |
+
z: (Math.random() - 0.5) * sourceSpread * 0.5,
|
| 85 |
+
};
|
| 86 |
+
const tgtOffset = {
|
| 87 |
+
y: (Math.random() - 0.5) * targetSpread,
|
| 88 |
+
z: (Math.random() - 0.5) * targetSpread * 0.5,
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
connections.push({
|
| 92 |
+
start: new THREE.Vector3(
|
| 93 |
+
sourcePosition.x + 0.3, // Offset from layer center
|
| 94 |
+
sourcePosition.y + srcOffset.y,
|
| 95 |
+
sourcePosition.z + srcOffset.z
|
| 96 |
+
),
|
| 97 |
+
end: new THREE.Vector3(
|
| 98 |
+
targetPosition.x - 0.3, // Offset from layer center
|
| 99 |
+
targetPosition.y + tgtOffset.y,
|
| 100 |
+
targetPosition.z + tgtOffset.z
|
| 101 |
+
),
|
| 102 |
+
});
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return connections;
|
| 106 |
+
}, [sourcePosition, targetPosition, sourceNeurons, targetNeurons, numConnections]);
|
| 107 |
+
|
| 108 |
+
// Animate opacity
|
| 109 |
+
useFrame((state) => {
|
| 110 |
+
if (animated && materialRef.current) {
|
| 111 |
+
const pulse = Math.sin(state.clock.elapsedTime * 2) * 0.1 + 0.3;
|
| 112 |
+
materialRef.current.opacity = highlighted ? 0.8 : pulse;
|
| 113 |
+
}
|
| 114 |
+
});
|
| 115 |
+
|
| 116 |
+
// Create buffer geometry for all lines
|
| 117 |
+
const geometry = useMemo(() => {
|
| 118 |
+
const positions: number[] = [];
|
| 119 |
+
|
| 120 |
+
lines.forEach(line => {
|
| 121 |
+
positions.push(line.start.x, line.start.y, line.start.z);
|
| 122 |
+
positions.push(line.end.x, line.end.y, line.end.z);
|
| 123 |
+
});
|
| 124 |
+
|
| 125 |
+
const geo = new THREE.BufferGeometry();
|
| 126 |
+
geo.setAttribute('position', new THREE.Float32BufferAttribute(positions, 3));
|
| 127 |
+
return geo;
|
| 128 |
+
}, [lines]);
|
| 129 |
+
|
| 130 |
+
return (
|
| 131 |
+
<group ref={groupRef}>
|
| 132 |
+
<lineSegments geometry={geometry}>
|
| 133 |
+
<lineBasicMaterial
|
| 134 |
+
ref={materialRef}
|
| 135 |
+
color={highlighted ? '#ffffff' : color}
|
| 136 |
+
transparent
|
| 137 |
+
opacity={0.3}
|
| 138 |
+
depthWrite={false}
|
| 139 |
+
/>
|
| 140 |
+
</lineSegments>
|
| 141 |
+
</group>
|
| 142 |
+
);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
/**
|
| 146 |
+
* Bundled connection - shows as a tube/pipe
|
| 147 |
+
*/
|
| 148 |
+
export function BundledConnection({
|
| 149 |
+
sourcePosition,
|
| 150 |
+
targetPosition,
|
| 151 |
+
sourceNeurons = 16,
|
| 152 |
+
targetNeurons = 16,
|
| 153 |
+
color = '#4488ff',
|
| 154 |
+
highlighted = false,
|
| 155 |
+
animated = true,
|
| 156 |
+
}: NeuralConnectionProps) {
|
| 157 |
+
const meshRef = useRef<THREE.Mesh>(null);
|
| 158 |
+
|
| 159 |
+
// Calculate bundle thickness based on connection count
|
| 160 |
+
const connectionStrength = Math.log2(Math.min(sourceNeurons, targetNeurons) + 1) * 0.02;
|
| 161 |
+
const thickness = Math.max(0.02, Math.min(connectionStrength, 0.1));
|
| 162 |
+
|
| 163 |
+
// Create tube path
|
| 164 |
+
const curve = useMemo(() => {
|
| 165 |
+
const start = new THREE.Vector3(sourcePosition.x, sourcePosition.y, sourcePosition.z);
|
| 166 |
+
const end = new THREE.Vector3(targetPosition.x, targetPosition.y, targetPosition.z);
|
| 167 |
+
|
| 168 |
+
// Bezier control points for smooth curve
|
| 169 |
+
const midX = (start.x + end.x) / 2;
|
| 170 |
+
const control1 = new THREE.Vector3(midX, start.y, start.z);
|
| 171 |
+
const control2 = new THREE.Vector3(midX, end.y, end.z);
|
| 172 |
+
|
| 173 |
+
return new THREE.CubicBezierCurve3(start, control1, control2, end);
|
| 174 |
+
}, [sourcePosition, targetPosition]);
|
| 175 |
+
|
| 176 |
+
const geometry = useMemo(() => {
|
| 177 |
+
return new THREE.TubeGeometry(curve, 20, thickness, 8, false);
|
| 178 |
+
}, [curve, thickness]);
|
| 179 |
+
|
| 180 |
+
// Animate flow effect
|
| 181 |
+
useFrame((state) => {
|
| 182 |
+
if (animated && meshRef.current) {
|
| 183 |
+
const material = meshRef.current.material as THREE.MeshStandardMaterial;
|
| 184 |
+
if (material.emissiveIntensity !== undefined) {
|
| 185 |
+
material.emissiveIntensity = Math.sin(state.clock.elapsedTime * 3) * 0.2 + 0.3;
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
});
|
| 189 |
+
|
| 190 |
+
const baseColor = new THREE.Color(color);
|
| 191 |
+
|
| 192 |
+
return (
|
| 193 |
+
<mesh ref={meshRef} geometry={geometry}>
|
| 194 |
+
<meshStandardMaterial
|
| 195 |
+
color={highlighted ? '#ffffff' : baseColor}
|
| 196 |
+
emissive={baseColor}
|
| 197 |
+
emissiveIntensity={0.3}
|
| 198 |
+
transparent
|
| 199 |
+
opacity={highlighted ? 0.9 : 0.6}
|
| 200 |
+
metalness={0.3}
|
| 201 |
+
roughness={0.7}
|
| 202 |
+
/>
|
| 203 |
+
</mesh>
|
| 204 |
+
);
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
/**
|
| 208 |
+
* Flow particles along connection
|
| 209 |
+
*/
|
| 210 |
+
export function FlowParticles({
|
| 211 |
+
sourcePosition,
|
| 212 |
+
targetPosition,
|
| 213 |
+
color = '#ffffff',
|
| 214 |
+
particleCount = 5,
|
| 215 |
+
speed = 1,
|
| 216 |
+
}: {
|
| 217 |
+
sourcePosition: Position3D;
|
| 218 |
+
targetPosition: Position3D;
|
| 219 |
+
color?: string;
|
| 220 |
+
particleCount?: number;
|
| 221 |
+
speed?: number;
|
| 222 |
+
}) {
|
| 223 |
+
const particlesRef = useRef<THREE.Points>(null);
|
| 224 |
+
|
| 225 |
+
// Create particle positions
|
| 226 |
+
const { positions, offsets } = useMemo(() => {
|
| 227 |
+
const pos = new Float32Array(particleCount * 3);
|
| 228 |
+
const off = new Float32Array(particleCount);
|
| 229 |
+
|
| 230 |
+
for (let i = 0; i < particleCount; i++) {
|
| 231 |
+
off[i] = i / particleCount; // Spread particles along path
|
| 232 |
+
|
| 233 |
+
// Initial positions will be updated in useFrame
|
| 234 |
+
pos[i * 3] = sourcePosition.x;
|
| 235 |
+
pos[i * 3 + 1] = sourcePosition.y;
|
| 236 |
+
pos[i * 3 + 2] = sourcePosition.z;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
return { positions: pos, offsets: off };
|
| 240 |
+
}, [sourcePosition, particleCount]);
|
| 241 |
+
|
| 242 |
+
const geometry = useMemo(() => {
|
| 243 |
+
const geo = new THREE.BufferGeometry();
|
| 244 |
+
geo.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
| 245 |
+
return geo;
|
| 246 |
+
}, [positions]);
|
| 247 |
+
|
| 248 |
+
// Animate particles
|
| 249 |
+
useFrame((state) => {
|
| 250 |
+
if (particlesRef.current) {
|
| 251 |
+
const posAttr = particlesRef.current.geometry.getAttribute('position') as THREE.BufferAttribute;
|
| 252 |
+
|
| 253 |
+
for (let i = 0; i < particleCount; i++) {
|
| 254 |
+
// Calculate t along path (0 to 1)
|
| 255 |
+
const t = ((state.clock.elapsedTime * speed + offsets[i]) % 1);
|
| 256 |
+
|
| 257 |
+
// Lerp position
|
| 258 |
+
posAttr.setXYZ(
|
| 259 |
+
i,
|
| 260 |
+
sourcePosition.x + (targetPosition.x - sourcePosition.x) * t,
|
| 261 |
+
sourcePosition.y + (targetPosition.y - sourcePosition.y) * t,
|
| 262 |
+
sourcePosition.z + (targetPosition.z - sourcePosition.z) * t
|
| 263 |
+
);
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
posAttr.needsUpdate = true;
|
| 267 |
+
}
|
| 268 |
+
});
|
| 269 |
+
|
| 270 |
+
return (
|
| 271 |
+
<points ref={particlesRef} geometry={geometry}>
|
| 272 |
+
<pointsMaterial
|
| 273 |
+
color={color}
|
| 274 |
+
size={0.05}
|
| 275 |
+
transparent
|
| 276 |
+
opacity={0.8}
|
| 277 |
+
sizeAttenuation
|
| 278 |
+
/>
|
| 279 |
+
</points>
|
| 280 |
+
);
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
/**
|
| 284 |
+
* Smart connection that chooses visualization based on layer sizes
|
| 285 |
+
*/
|
| 286 |
+
export function SmartConnection({
|
| 287 |
+
sourcePosition,
|
| 288 |
+
targetPosition,
|
| 289 |
+
sourceNeurons = 16,
|
| 290 |
+
targetNeurons = 16,
|
| 291 |
+
color = '#4488ff',
|
| 292 |
+
highlighted = false,
|
| 293 |
+
animated = true,
|
| 294 |
+
style = 'bundle',
|
| 295 |
+
}: NeuralConnectionProps) {
|
| 296 |
+
const totalConnections = sourceNeurons * targetNeurons;
|
| 297 |
+
|
| 298 |
+
// Choose visualization based on connection count
|
| 299 |
+
if (style === 'single' || totalConnections < 50) {
|
| 300 |
+
return (
|
| 301 |
+
<SingleConnection
|
| 302 |
+
sourcePosition={sourcePosition}
|
| 303 |
+
targetPosition={targetPosition}
|
| 304 |
+
color={color}
|
| 305 |
+
highlighted={highlighted}
|
| 306 |
+
/>
|
| 307 |
+
);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
if (style === 'dense' || totalConnections < 500) {
|
| 311 |
+
return (
|
| 312 |
+
<>
|
| 313 |
+
<DenseConnection
|
| 314 |
+
sourcePosition={sourcePosition}
|
| 315 |
+
targetPosition={targetPosition}
|
| 316 |
+
sourceNeurons={sourceNeurons}
|
| 317 |
+
targetNeurons={targetNeurons}
|
| 318 |
+
color={color}
|
| 319 |
+
highlighted={highlighted}
|
| 320 |
+
animated={animated}
|
| 321 |
+
connectionDensity={0.2}
|
| 322 |
+
/>
|
| 323 |
+
{animated && (
|
| 324 |
+
<FlowParticles
|
| 325 |
+
sourcePosition={sourcePosition}
|
| 326 |
+
targetPosition={targetPosition}
|
| 327 |
+
color={color}
|
| 328 |
+
particleCount={3}
|
| 329 |
+
speed={0.5}
|
| 330 |
+
/>
|
| 331 |
+
)}
|
| 332 |
+
</>
|
| 333 |
+
);
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
// For very dense connections, use bundled representation
|
| 337 |
+
return (
|
| 338 |
+
<>
|
| 339 |
+
<BundledConnection
|
| 340 |
+
sourcePosition={sourcePosition}
|
| 341 |
+
targetPosition={targetPosition}
|
| 342 |
+
sourceNeurons={sourceNeurons}
|
| 343 |
+
targetNeurons={targetNeurons}
|
| 344 |
+
color={color}
|
| 345 |
+
highlighted={highlighted}
|
| 346 |
+
animated={animated}
|
| 347 |
+
/>
|
| 348 |
+
{animated && (
|
| 349 |
+
<FlowParticles
|
| 350 |
+
sourcePosition={sourcePosition}
|
| 351 |
+
targetPosition={targetPosition}
|
| 352 |
+
color="#ffffff"
|
| 353 |
+
particleCount={5}
|
| 354 |
+
speed={0.8}
|
| 355 |
+
/>
|
| 356 |
+
)}
|
| 357 |
+
</>
|
| 358 |
+
);
|
| 359 |
+
}
|
src/components/edges/index.ts
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export * from './EdgeGeometry';
|
| 2 |
+
export * from './EdgeConnections';
|