AI Agent commited on
Commit
a0098d0
·
0 Parent(s):

Deploy to Spaces

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +15 -0
  2. Dockerfile +69 -0
  3. README.md +59 -0
  4. app.py +16 -0
  5. backend/__init__.py +1 -0
  6. backend/__pycache__/__init__.cpython-312.pyc +0 -0
  7. backend/api/__init__.py +6 -0
  8. backend/api/__pycache__/__init__.cpython-312.pyc +0 -0
  9. backend/api/__pycache__/main.cpython-312.pyc +0 -0
  10. backend/api/main.py +61 -0
  11. backend/api/routes/__init__.py +1 -0
  12. backend/api/routes/__pycache__/__init__.cpython-312.pyc +0 -0
  13. backend/api/routes/__pycache__/analysis.cpython-312.pyc +0 -0
  14. backend/api/routes/__pycache__/models.cpython-312.pyc +0 -0
  15. backend/api/routes/__pycache__/quantization.cpython-312.pyc +0 -0
  16. backend/api/routes/__pycache__/system.cpython-312.pyc +0 -0
  17. backend/api/routes/analysis.py +249 -0
  18. backend/api/routes/models.py +411 -0
  19. backend/api/routes/quantization.py +366 -0
  20. backend/api/routes/system.py +64 -0
  21. backend/core/__init__.py +6 -0
  22. backend/core/__pycache__/__init__.cpython-312.pyc +0 -0
  23. backend/core/__pycache__/model_loader.cpython-312.pyc +0 -0
  24. backend/core/__pycache__/model_manager.cpython-312.pyc +0 -0
  25. backend/core/__pycache__/quantizer.cpython-312.pyc +0 -0
  26. backend/core/__pycache__/system_checker.cpython-312.pyc +0 -0
  27. backend/core/__pycache__/visualization.cpython-312.pyc +0 -0
  28. backend/core/model_loader.py +411 -0
  29. backend/core/model_manager.py +247 -0
  30. backend/core/quantizer.py +605 -0
  31. backend/core/system_checker.py +299 -0
  32. backend/core/visualization.py +277 -0
  33. backend/requirements.txt +11 -0
  34. docker-compose.yml +78 -0
  35. frontend/.gitignore +24 -0
  36. frontend/README.md +16 -0
  37. frontend/eslint.config.js +29 -0
  38. frontend/index.html +13 -0
  39. frontend/package-lock.json +0 -0
  40. frontend/package.json +35 -0
  41. frontend/public/vite.svg +1 -0
  42. frontend/src/App.css +42 -0
  43. frontend/src/App.jsx +82 -0
  44. frontend/src/assets/react.svg +1 -0
  45. frontend/src/components/Layout.jsx +297 -0
  46. frontend/src/index.css +751 -0
  47. frontend/src/main.jsx +10 -0
  48. frontend/src/pages/Analysis.jsx +483 -0
  49. frontend/src/pages/Dashboard.jsx +412 -0
  50. frontend/src/pages/ModelLoader.jsx +775 -0
.dockerignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ node_modules
2
+ dist
3
+ build
4
+ .git
5
+ .gitignore
6
+ venv
7
+ env
8
+ __pycache__
9
+ *.pyc
10
+ *.pyo
11
+ *.pyd
12
+ .DS_Store
13
+ .env
14
+ site-packages
15
+ .gemini
Dockerfile ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-stage Dockerfile for Neural Network Quantizer
2
+ # Builds frontend and serves with FastAPI backend
3
+
4
+ # ============================================
5
+ # Stage 1: Build Frontend
6
+ # ============================================
7
+ FROM node:20-alpine AS frontend-build
8
+
9
+ WORKDIR /app/frontend
10
+
11
+ # Copy package files
12
+ COPY frontend/package*.json ./
13
+
14
+ # Install dependencies
15
+ RUN npm ci
16
+
17
+ # Copy frontend source
18
+ COPY frontend/ ./
19
+
20
+ # Build production bundle
21
+ RUN npm run build
22
+
23
+ # ============================================
24
+ # Stage 2: Python Backend + Frontend
25
+ # ============================================
26
+ FROM python:3.11-slim
27
+
28
+ # Set environment variables
29
+ ENV PYTHONDONTWRITEBYTECODE=1
30
+ ENV PYTHONUNBUFFERED=1
31
+ ENV GRADIO_SERVER_NAME=0.0.0.0
32
+ ENV GRADIO_SERVER_PORT=7860
33
+
34
+ WORKDIR /app
35
+
36
+ # Install system dependencies
37
+ RUN apt-get update && apt-get install -y --no-install-recommends \
38
+ build-essential \
39
+ curl \
40
+ && rm -rf /var/lib/apt/lists/*
41
+
42
+ # Copy backend requirements
43
+ COPY backend/requirements.txt ./requirements.txt
44
+
45
+ # Install Python dependencies
46
+ RUN pip install --no-cache-dir -r requirements.txt
47
+
48
+ # Copy backend code
49
+ COPY backend/ ./backend/
50
+
51
+ # Copy frontend build
52
+ COPY --from=frontend-build /app/frontend/dist ./frontend/dist
53
+
54
+ # Copy HuggingFace Spaces entry point
55
+ COPY app.py ./
56
+
57
+ # Create non-root user
58
+ RUN useradd -m -u 1000 user
59
+ USER user
60
+
61
+ # Expose port
62
+ EXPOSE 7860
63
+
64
+ # Health check
65
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
66
+ CMD curl -f http://localhost:7860/api/health || exit 1
67
+
68
+ # Start the application
69
+ CMD ["python", "-m", "uvicorn", "backend.api.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Neural Network Quantizer
3
+ emoji: ⚡
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ app_port: 7860
10
+ ---
11
+
12
+ # Neural Network Weight Quantizer
13
+
14
+ Quantize neural network weights to lower precision formats (INT8, INT4, NF4) with interactive visualizations.
15
+
16
+ ## Features
17
+
18
+ - 🔢 Multi-bit quantization (4-bit, 8-bit)
19
+ - 📊 Interactive weight visualizations
20
+ - 🤗 HuggingFace model support (optional)
21
+ - ⚡ GPU acceleration (when available)
22
+ - 📈 Quantization error analysis
23
+ - 🔄 Method comparison (INT8 vs INT4 vs NF4)
24
+
25
+ ## Quick Start
26
+
27
+ 1. Use the **Quantizer** tab to test on random weights
28
+ 2. Compare different methods in the **Analysis** tab
29
+ 3. Optionally load a HuggingFace model in the **Models** tab
30
+
31
+ ## API
32
+
33
+ The backend exposes a REST API at `/api`:
34
+
35
+ - `GET /api/system/info` - System capabilities
36
+ - `POST /api/quantize/weights` - Quantize custom weights
37
+ - `POST /api/models/load` - Load HuggingFace model
38
+ - `POST /api/analysis/compare` - Compare methods
39
+
40
+ ## 🚀 Deployment
41
+
42
+ ### Hugging Face Spaces
43
+ This project is configured for **Hugging Face Spaces** using the Docker SDK.
44
+
45
+ 1. Create a new Space on [Hugging Face](https://huggingface.co/new-space).
46
+ 2. Select **Docker** as the SDK.
47
+ 3. Push this repository to your Space:
48
+ ```bash
49
+ git remote add space https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
50
+ git push space main
51
+ ```
52
+
53
+ ### Docker
54
+ Run locally with Docker:
55
+ ```bash
56
+ docker build -t quantizer .
57
+ docker run -p 7860:7860 quantizer
58
+ ```
59
+ Open `http://localhost:7860`.
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Spaces Entry Point
3
+ This file serves as the entry point for HuggingFace Spaces deployment.
4
+ It starts the FastAPI application which serves both the API and the React frontend.
5
+ """
6
+
7
+ import uvicorn
8
+ from backend.api.main import app
9
+
10
+ if __name__ == "__main__":
11
+ uvicorn.run(
12
+ app,
13
+ host="0.0.0.0",
14
+ port=7860,
15
+ log_level="info"
16
+ )
backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Backend package init"""
backend/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (204 Bytes). View file
 
backend/api/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ API Package Init
3
+ """
4
+ from .main import app
5
+
6
+ __all__ = ["app"]
backend/api/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (275 Bytes). View file
 
backend/api/__pycache__/main.cpython-312.pyc ADDED
Binary file (2.82 kB). View file
 
backend/api/main.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Main Application
3
+ Neural Network Weight Quantizer API
4
+ """
5
+
6
+ from fastapi import FastAPI
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.staticfiles import StaticFiles
9
+ from fastapi.responses import FileResponse
10
+ from pathlib import Path
11
+ import os
12
+
13
+ from .routes import quantization, models, analysis, system
14
+
15
+ # Create FastAPI app
16
+ app = FastAPI(
17
+ title="Neural Network Quantizer API",
18
+ description="API for quantizing neural network weights to lower precision formats",
19
+ version="1.0.0",
20
+ docs_url="/api/docs",
21
+ openapi_url="/api/openapi.json"
22
+ )
23
+
24
+ # CORS configuration
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"], # Configure appropriately in production
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # Include routers
34
+ app.include_router(system.router, prefix="/api/system", tags=["System"])
35
+ app.include_router(models.router, prefix="/api/models", tags=["Models"])
36
+ app.include_router(quantization.router, prefix="/api/quantize", tags=["Quantization"])
37
+ app.include_router(analysis.router, prefix="/api/analysis", tags=["Analysis"])
38
+
39
+ # Health check
40
+ @app.get("/api/health")
41
+ async def health_check():
42
+ return {"status": "healthy", "service": "quantizer-api"}
43
+
44
+ # Serve frontend in production
45
+ FRONTEND_DIR = Path(__file__).parent.parent.parent / "frontend" / "dist"
46
+
47
+ if FRONTEND_DIR.exists():
48
+ app.mount("/assets", StaticFiles(directory=FRONTEND_DIR / "assets"), name="assets")
49
+
50
+ @app.get("/{full_path:path}")
51
+ async def serve_frontend(full_path: str):
52
+ # Serve index.html for SPA routing
53
+ file_path = FRONTEND_DIR / full_path
54
+ if file_path.exists() and file_path.is_file():
55
+ return FileResponse(file_path)
56
+ return FileResponse(FRONTEND_DIR / "index.html")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ import uvicorn
61
+ uvicorn.run(app, host="0.0.0.0", port=8000)
backend/api/routes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Routes package"""
backend/api/routes/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (209 Bytes). View file
 
backend/api/routes/__pycache__/analysis.cpython-312.pyc ADDED
Binary file (11.5 kB). View file
 
backend/api/routes/__pycache__/models.cpython-312.pyc ADDED
Binary file (17.9 kB). View file
 
backend/api/routes/__pycache__/quantization.cpython-312.pyc ADDED
Binary file (15.9 kB). View file
 
backend/api/routes/__pycache__/system.cpython-312.pyc ADDED
Binary file (2.77 kB). View file
 
backend/api/routes/analysis.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analysis Routes
3
+ Weight analysis and visualization endpoints
4
+ """
5
+
6
+ from fastapi import APIRouter, HTTPException
7
+ from pydantic import BaseModel
8
+ from typing import Optional, Dict, Any, List
9
+ import torch
10
+
11
+ from backend.core.model_loader import model_loader
12
+ from backend.core.visualization import visualizer
13
+ from backend.core.quantizer import (
14
+ QuantizationConfig, QuantizationMethod, QuantizationMode,
15
+ get_quantizer
16
+ )
17
+
18
+ router = APIRouter()
19
+
20
+
21
+ class AnalyzeLayerRequest(BaseModel):
22
+ """Request to analyze a specific layer"""
23
+ layer_name: str
24
+
25
+
26
+ class CompareQuantizationRequest(BaseModel):
27
+ """Compare different quantization methods on same weights"""
28
+ layer_name: Optional[str] = None
29
+ in_features: int = 64
30
+ out_features: int = 128
31
+ methods: List[str] = ["int8", "int4", "nf4"]
32
+
33
+
34
+ @router.get("/weights/{layer_name}")
35
+ async def get_weight_analysis(layer_name: str) -> Dict[str, Any]:
36
+ """
37
+ Get detailed weight analysis for a specific layer.
38
+ """
39
+ if model_loader is None or model_loader.get_model() is None:
40
+ raise HTTPException(status_code=404, detail="No model loaded")
41
+
42
+ weights = model_loader.get_layer_weights(layer_name)
43
+ if weights is None:
44
+ raise HTTPException(status_code=404, detail=f"Layer not found: {layer_name}")
45
+
46
+ # Flatten for analysis
47
+ flat = weights.flatten()
48
+
49
+ # Statistics
50
+ stats = {
51
+ "shape": list(weights.shape),
52
+ "dtype": str(weights.dtype),
53
+ "num_params": int(weights.numel()),
54
+ "memory_mb": weights.numel() * weights.element_size() / (1024 * 1024),
55
+ "min": float(weights.min()),
56
+ "max": float(weights.max()),
57
+ "mean": float(weights.mean()),
58
+ "std": float(weights.std()),
59
+ "median": float(torch.median(flat)),
60
+ "sparsity": float((weights == 0).sum() / weights.numel()),
61
+ "abs_mean": float(weights.abs().mean()),
62
+ "percentiles": {
63
+ "1%": float(torch.quantile(flat.float(), 0.01)),
64
+ "5%": float(torch.quantile(flat.float(), 0.05)),
65
+ "25%": float(torch.quantile(flat.float(), 0.25)),
66
+ "50%": float(torch.quantile(flat.float(), 0.50)),
67
+ "75%": float(torch.quantile(flat.float(), 0.75)),
68
+ "95%": float(torch.quantile(flat.float(), 0.95)),
69
+ "99%": float(torch.quantile(flat.float(), 0.99))
70
+ }
71
+ }
72
+
73
+ # Visualizations
74
+ heatmap = visualizer.to_dict(
75
+ visualizer.weight_heatmap(weights, f"Weights: {layer_name}")
76
+ )
77
+ histogram = visualizer.to_dict(
78
+ visualizer.weight_histogram(weights, "Weight Distribution")
79
+ )
80
+
81
+ return {
82
+ "layer_name": layer_name,
83
+ "stats": stats,
84
+ "visualizations": {
85
+ "heatmap": heatmap,
86
+ "histogram": histogram
87
+ }
88
+ }
89
+
90
+
91
+ @router.post("/compare")
92
+ async def compare_quantization_methods(request: CompareQuantizationRequest) -> Dict[str, Any]:
93
+ """
94
+ Compare multiple quantization methods on the same weights.
95
+ """
96
+ # Get or generate weights
97
+ if request.layer_name and model_loader and model_loader.get_model():
98
+ weights = model_loader.get_layer_weights(request.layer_name)
99
+ if weights is None:
100
+ raise HTTPException(status_code=404, detail=f"Layer not found: {request.layer_name}")
101
+ source = f"layer:{request.layer_name}"
102
+ else:
103
+ weights = torch.randn(request.out_features, request.in_features)
104
+ source = "random"
105
+
106
+ # Ensure 2D
107
+ if len(weights.shape) == 1:
108
+ weights = weights.unsqueeze(0)
109
+ elif len(weights.shape) > 2:
110
+ weights = weights.reshape(weights.shape[0], -1)
111
+
112
+ # Compare methods
113
+ method_map = {
114
+ "int8": QuantizationMethod.INT8,
115
+ "int4": QuantizationMethod.INT4,
116
+ "nf4": QuantizationMethod.NF4
117
+ }
118
+
119
+ comparison = []
120
+
121
+ for method_name in request.methods:
122
+ if method_name not in method_map:
123
+ continue
124
+
125
+ config = QuantizationConfig(
126
+ bits=8 if method_name == "int8" else 4,
127
+ method=method_map[method_name],
128
+ group_size=128 if method_name in ["int4", "nf4"] else None
129
+ )
130
+
131
+ try:
132
+ quantizer = get_quantizer(config)
133
+ result = quantizer.quantize(weights)
134
+
135
+ comparison.append({
136
+ "method": method_name,
137
+ "bits": config.bits,
138
+ "max_error": result.max_error,
139
+ "mean_error": result.mean_error,
140
+ "memory_savings_percent": result.memory_savings_percent,
141
+ "histogram": visualizer.to_dict(
142
+ visualizer.weight_histogram(
143
+ result.quantized_weights.float(),
144
+ f"{method_name.upper()} Distribution"
145
+ )
146
+ )
147
+ })
148
+ except Exception as e:
149
+ comparison.append({
150
+ "method": method_name,
151
+ "error": str(e)
152
+ })
153
+
154
+ return {
155
+ "source": source,
156
+ "original_shape": list(weights.shape),
157
+ "original_stats": {
158
+ "min": float(weights.min()),
159
+ "max": float(weights.max()),
160
+ "mean": float(weights.mean()),
161
+ "std": float(weights.std())
162
+ },
163
+ "comparison": comparison
164
+ }
165
+
166
+
167
+ @router.get("/model-summary")
168
+ async def get_model_summary() -> Dict[str, Any]:
169
+ """
170
+ Get summary statistics for all layers in loaded model.
171
+ """
172
+ if model_loader is None or model_loader.get_model() is None:
173
+ raise HTTPException(status_code=404, detail="No model loaded")
174
+
175
+ model_info = model_loader.get_model_info()
176
+ if model_info is None:
177
+ raise HTTPException(status_code=500, detail="Failed to get model info")
178
+
179
+ # Analyze each layer
180
+ layer_stats = []
181
+ total_params = 0
182
+ quantizable_params = 0
183
+
184
+ for layer in model_info.layers:
185
+ total_params += layer.num_params
186
+ if layer.is_quantizable:
187
+ quantizable_params += layer.num_params
188
+
189
+ layer_stats.append({
190
+ "name": layer.name,
191
+ "type": layer.module_type,
192
+ "params": layer.num_params,
193
+ "params_mb": layer.num_params * 4 / (1024 * 1024), # Assuming FP32
194
+ "quantizable": layer.is_quantizable
195
+ })
196
+
197
+ # Sort by parameter count
198
+ layer_stats.sort(key=lambda x: x["params"], reverse=True)
199
+
200
+ return {
201
+ "model_name": model_info.name,
202
+ "architecture": model_info.architecture,
203
+ "total_params": total_params,
204
+ "total_params_billions": total_params / 1e9,
205
+ "quantizable_params": quantizable_params,
206
+ "quantizable_percent": quantizable_params / total_params * 100 if total_params > 0 else 0,
207
+ "memory_fp32_gb": total_params * 4 / (1024**3),
208
+ "memory_int8_estimate_gb": quantizable_params * 1 / (1024**3) + (total_params - quantizable_params) * 4 / (1024**3),
209
+ "memory_int4_estimate_gb": quantizable_params * 0.5 / (1024**3) + (total_params - quantizable_params) * 4 / (1024**3),
210
+ "top_layers": layer_stats[:20] # Top 20 largest layers
211
+ }
212
+
213
+
214
+ @router.get("/outliers/{layer_name}")
215
+ async def detect_outliers(layer_name: str, threshold: float = 3.0) -> Dict[str, Any]:
216
+ """
217
+ Detect outlier weights that may cause quantization issues.
218
+ """
219
+ if model_loader is None or model_loader.get_model() is None:
220
+ raise HTTPException(status_code=404, detail="No model loaded")
221
+
222
+ weights = model_loader.get_layer_weights(layer_name)
223
+ if weights is None:
224
+ raise HTTPException(status_code=404, detail=f"Layer not found: {layer_name}")
225
+
226
+ flat = weights.flatten()
227
+ mean = flat.mean()
228
+ std = flat.std()
229
+
230
+ # Find outliers (values beyond threshold * std from mean)
231
+ outlier_mask = (flat - mean).abs() > threshold * std
232
+ num_outliers = outlier_mask.sum().item()
233
+ outlier_values = flat[outlier_mask].tolist()[:100] # Limit to 100
234
+
235
+ return {
236
+ "layer_name": layer_name,
237
+ "threshold": threshold,
238
+ "total_weights": int(flat.numel()),
239
+ "num_outliers": num_outliers,
240
+ "outlier_percent": num_outliers / flat.numel() * 100,
241
+ "mean": float(mean),
242
+ "std": float(std),
243
+ "outlier_range": {
244
+ "below": float(mean - threshold * std),
245
+ "above": float(mean + threshold * std)
246
+ },
247
+ "sample_outliers": outlier_values,
248
+ "recommendation": "Consider clipping or mixed-precision for this layer" if num_outliers > flat.numel() * 0.01 else "Layer is suitable for quantization"
249
+ }
backend/api/routes/models.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Routes with Download Progress Streaming
3
+ Supports HuggingFace Spaces with proper cache management
4
+ """
5
+
6
+ from fastapi import APIRouter, HTTPException, BackgroundTasks
7
+ from fastapi.responses import StreamingResponse
8
+ from pydantic import BaseModel
9
+ from typing import Optional, Dict, Any, List
10
+ import torch
11
+ import asyncio
12
+ import json
13
+ import traceback
14
+ import time
15
+ from backend.core.model_loader import model_loader
16
+
17
+ from backend.core.model_manager import (
18
+ get_download_progress, set_download_progress, clear_download_progress,
19
+ get_cached_models, cleanup_old_models, delete_model_cache,
20
+ get_cache_stats, ensure_sample_models, start_cleanup_scheduler,
21
+ SAMPLE_MODELS
22
+ )
23
+
24
+ router = APIRouter()
25
+
26
+
27
+ class LoadModelRequest(BaseModel):
28
+ """Request to load a model"""
29
+ model_name: str
30
+ dtype: str = "auto"
31
+ device: str = "auto"
32
+ trust_remote_code: bool = True
33
+
34
+
35
+ class DeleteModelRequest(BaseModel):
36
+ """Request to delete a cached model"""
37
+ model_name: str
38
+
39
+
40
+ # In-memory state
41
+ _loaded_model = None
42
+ _loaded_tokenizer = None
43
+ _model_name = None
44
+
45
+ # Start cleanup scheduler on module load
46
+ start_cleanup_scheduler()
47
+
48
+
49
+ def _get_device():
50
+ """Get best available device"""
51
+ if torch.cuda.is_available():
52
+ return "cuda"
53
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
54
+ return "mps"
55
+ return "cpu"
56
+
57
+
58
+ def _get_torch_dtype(dtype_str: str, device: str):
59
+ """Convert dtype string to torch dtype"""
60
+ if dtype_str == "auto":
61
+ if device == "cuda":
62
+ return torch.float16
63
+ return torch.float32
64
+
65
+ dtype_map = {
66
+ "fp32": torch.float32,
67
+ "float32": torch.float32,
68
+ "fp16": torch.float16,
69
+ "float16": torch.float16,
70
+ "bf16": torch.bfloat16,
71
+ "bfloat16": torch.bfloat16,
72
+ }
73
+ return dtype_map.get(dtype_str, torch.float32)
74
+
75
+
76
+ async def _load_model_with_progress(model_name: str, dtype: str, device: str, trust_remote_code: bool):
77
+ """Load model and yield progress updates"""
78
+ global _loaded_model, _loaded_tokenizer, _model_name
79
+
80
+ try:
81
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
82
+ except ImportError:
83
+ yield {"type": "error", "error": "transformers library not installed"}
84
+ return
85
+
86
+ try:
87
+ # Phase 1: Fetching config
88
+ yield {"type": "progress", "phase": "config", "percent": 5, "message": "Fetching model configuration..."}
89
+
90
+ try:
91
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
92
+ except Exception as e:
93
+ yield {"type": "error", "error": f"Model not found: {str(e)}", "suggestion": "Check the model ID is correct"}
94
+ return
95
+
96
+ # Phase 2: Determine device and dtype
97
+ actual_device = device if device != "auto" else _get_device()
98
+ torch_dtype = _get_torch_dtype(dtype, actual_device)
99
+
100
+ yield {"type": "progress", "phase": "download", "percent": 10, "message": f"Downloading model to {actual_device}..."}
101
+
102
+ # Set download progress for polling
103
+ set_download_progress(model_name, {
104
+ "status": "downloading",
105
+ "percent": 10,
106
+ "message": "Downloading model files..."
107
+ })
108
+
109
+ # Phase 3: Download and load model
110
+ try:
111
+ model = AutoModel.from_pretrained(
112
+ model_name,
113
+ torch_dtype=torch_dtype,
114
+ trust_remote_code=trust_remote_code,
115
+ low_cpu_mem_usage=True
116
+ )
117
+ yield {"type": "progress", "phase": "download", "percent": 70, "message": "Model downloaded successfully"}
118
+ except Exception as e:
119
+ # Try without low_cpu_mem_usage
120
+ try:
121
+ model = AutoModel.from_pretrained(
122
+ model_name,
123
+ torch_dtype=torch_dtype,
124
+ trust_remote_code=trust_remote_code
125
+ )
126
+ yield {"type": "progress", "phase": "download", "percent": 70, "message": "Model downloaded (fallback mode)"}
127
+ except Exception as e2:
128
+ yield {"type": "error", "error": f"Failed to load model: {str(e2)}"}
129
+ clear_download_progress(model_name)
130
+ return
131
+
132
+ # Phase 4: Move to device
133
+ yield {"type": "progress", "phase": "device", "percent": 80, "message": f"Moving model to {actual_device}..."}
134
+
135
+ if actual_device != "cpu" and not hasattr(model, 'hf_device_map'):
136
+ try:
137
+ model = model.to(actual_device)
138
+ except Exception:
139
+ actual_device = "cpu"
140
+ model = model.to("cpu")
141
+
142
+ model.eval()
143
+
144
+ # Phase 5: Load tokenizer
145
+ yield {"type": "progress", "phase": "tokenizer", "percent": 90, "message": "Loading tokenizer..."}
146
+
147
+ try:
148
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
149
+ except Exception:
150
+ tokenizer = None
151
+
152
+ # Store in memory
153
+ _loaded_model = model
154
+ _loaded_tokenizer = tokenizer
155
+ _model_name = model_name
156
+
157
+ # Sync with global model loader
158
+ if model_loader:
159
+ model_loader.register_model(model, model_name, tokenizer)
160
+
161
+ # Compute model info
162
+ num_params = sum(p.numel() for p in model.parameters())
163
+ memory_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
164
+
165
+ quantizable_layers = []
166
+ for name, module in model.named_modules():
167
+ if any(t in module.__class__.__name__ for t in ["Linear", "Conv1d", "Conv2d"]):
168
+ quantizable_layers.append(name)
169
+
170
+ # Phase 6: Complete
171
+ clear_download_progress(model_name)
172
+
173
+ yield {
174
+ "type": "complete",
175
+ "percent": 100,
176
+ "model_info": {
177
+ "name": model_name,
178
+ "architecture": model.config.architectures[0] if hasattr(model.config, 'architectures') and model.config.architectures else "Unknown",
179
+ "num_params": num_params,
180
+ "num_params_millions": round(num_params / 1e6, 2),
181
+ "memory_mb": round(memory_mb, 2),
182
+ "device": str(next(model.parameters()).device),
183
+ "dtype": str(next(model.parameters()).dtype),
184
+ "num_quantizable_layers": len(quantizable_layers),
185
+ "has_tokenizer": tokenizer is not None,
186
+ "is_sample": model_name in SAMPLE_MODELS
187
+ }
188
+ }
189
+
190
+ except Exception as e:
191
+ clear_download_progress(model_name)
192
+ yield {"type": "error", "error": str(e), "traceback": traceback.format_exc()}
193
+
194
+
195
+ @router.post("/load")
196
+ async def load_model(request: LoadModelRequest) -> Dict[str, Any]:
197
+ """Load a model (non-streaming version for simple requests)"""
198
+ result = None
199
+ async for update in _load_model_with_progress(
200
+ request.model_name, request.dtype, request.device, request.trust_remote_code
201
+ ):
202
+ result = update
203
+
204
+ if result and result.get("type") == "complete":
205
+ return {"success": True, "model_info": result["model_info"]}
206
+ elif result and result.get("type") == "error":
207
+ return {"success": False, "error": result.get("error"), "suggestion": result.get("suggestion")}
208
+ else:
209
+ return {"success": False, "error": "Unknown error"}
210
+
211
+
212
+ @router.post("/load/stream")
213
+ async def load_model_stream(request: LoadModelRequest):
214
+ """Load a model with Server-Sent Events for progress updates"""
215
+
216
+ async def event_generator():
217
+ async for update in _load_model_with_progress(
218
+ request.model_name, request.dtype, request.device, request.trust_remote_code
219
+ ):
220
+ yield f"data: {json.dumps(update)}\n\n"
221
+ await asyncio.sleep(0.1) # Small delay between events
222
+
223
+ return StreamingResponse(
224
+ event_generator(),
225
+ media_type="text/event-stream",
226
+ headers={
227
+ "Cache-Control": "no-cache",
228
+ "Connection": "keep-alive",
229
+ }
230
+ )
231
+
232
+
233
+ @router.get("/progress/{model_name}")
234
+ async def get_model_progress(model_name: str) -> Dict[str, Any]:
235
+ """Get download progress for a model (polling endpoint)"""
236
+ progress = get_download_progress(model_name)
237
+ if progress:
238
+ return {"downloading": True, **progress}
239
+ return {"downloading": False}
240
+
241
+
242
+ @router.get("/status")
243
+ async def get_loading_status() -> Dict[str, Any]:
244
+ """Get current model loading status"""
245
+ return {
246
+ "model_loaded": _loaded_model is not None,
247
+ "model_name": _model_name,
248
+ "has_tokenizer": _loaded_tokenizer is not None
249
+ }
250
+
251
+
252
+ @router.get("/info")
253
+ async def get_model_info() -> Dict[str, Any]:
254
+ """Get information about the currently loaded model"""
255
+ if _loaded_model is None:
256
+ return {"loaded": False, "message": "No model loaded"}
257
+
258
+ num_params = sum(p.numel() for p in _loaded_model.parameters())
259
+ memory_mb = sum(p.numel() * p.element_size() for p in _loaded_model.parameters()) / (1024 * 1024)
260
+
261
+ return {
262
+ "loaded": True,
263
+ "name": _model_name,
264
+ "num_params": num_params,
265
+ "num_params_millions": round(num_params / 1e6, 2),
266
+ "memory_mb": round(memory_mb, 2),
267
+ "device": str(next(_loaded_model.parameters()).device),
268
+ "dtype": str(next(_loaded_model.parameters()).dtype)
269
+ }
270
+
271
+
272
+ @router.get("/layers")
273
+ async def get_layers() -> Dict[str, Any]:
274
+ """Get list of layers in the loaded model"""
275
+ if _loaded_model is None:
276
+ return {"error": "No model loaded", "layers": []}
277
+
278
+ layers = []
279
+ quantizable_names = []
280
+
281
+ for name, module in _loaded_model.named_modules():
282
+ if not name:
283
+ continue
284
+
285
+ module_type = module.__class__.__name__
286
+ is_quantizable = any(t in module_type for t in ["Linear", "Conv1d", "Conv2d", "Embedding"])
287
+
288
+ shape = None
289
+ num_params = 0
290
+ if hasattr(module, 'weight') and module.weight is not None:
291
+ shape = list(module.weight.shape)
292
+ num_params = module.weight.numel()
293
+
294
+ if num_params > 0:
295
+ layers.append({
296
+ "name": name,
297
+ "type": module_type,
298
+ "shape": shape,
299
+ "params": num_params,
300
+ "quantizable": is_quantizable
301
+ })
302
+
303
+ if is_quantizable:
304
+ quantizable_names.append(name)
305
+
306
+ return {
307
+ "total_layers": len(layers),
308
+ "quantizable_count": len(quantizable_names),
309
+ "quantizable_layers": quantizable_names,
310
+ "layers": layers
311
+ }
312
+
313
+
314
+ @router.post("/unload")
315
+ async def unload_model() -> Dict[str, Any]:
316
+ """Unload the current model and free memory"""
317
+ global _loaded_model, _loaded_tokenizer, _model_name
318
+
319
+ if _loaded_model is not None:
320
+ del _loaded_model
321
+ _loaded_model = None
322
+
323
+ if _loaded_tokenizer is not None:
324
+ del _loaded_tokenizer
325
+ _loaded_tokenizer = None
326
+
327
+ _model_name = None
328
+
329
+ # Sync with global module loader
330
+ if model_loader:
331
+ model_loader.unload()
332
+
333
+ import gc
334
+ gc.collect()
335
+ if torch.cuda.is_available():
336
+ torch.cuda.empty_cache()
337
+
338
+ return {"success": True, "message": "Model unloaded"}
339
+
340
+
341
+ # ============================================
342
+ # Cache Management Endpoints
343
+ # ============================================
344
+
345
+ @router.get("/cache")
346
+ async def get_cache_info() -> Dict[str, Any]:
347
+ """Get information about cached models"""
348
+ return get_cache_stats()
349
+
350
+
351
+ @router.post("/cache/cleanup")
352
+ async def trigger_cleanup(hours: float = 4.0) -> Dict[str, Any]:
353
+ """Manually trigger cache cleanup"""
354
+ result = cleanup_old_models(hours)
355
+ return {
356
+ "success": True,
357
+ "deleted_count": len(result["deleted"]),
358
+ "kept_count": len(result["kept"]),
359
+ **result
360
+ }
361
+
362
+
363
+ @router.delete("/cache/{model_name:path}")
364
+ async def delete_cached_model(model_name: str) -> Dict[str, Any]:
365
+ """Delete a specific model from cache"""
366
+ if model_name in SAMPLE_MODELS:
367
+ return {"success": False, "error": "Cannot delete sample models"}
368
+
369
+ success = delete_model_cache(model_name)
370
+ return {"success": success, "model_name": model_name}
371
+
372
+
373
+ # ============================================
374
+ # Example Models
375
+ # ============================================
376
+
377
+ @router.get("/examples")
378
+ async def get_example_models() -> Dict[str, Any]:
379
+ """Get list of example models for testing"""
380
+ return {
381
+ "sample_models": [
382
+ {"id": model, "is_default": True, "description": "Pre-cached for quick testing"}
383
+ for model in SAMPLE_MODELS
384
+ ],
385
+ "small_models": [
386
+ {"id": "gpt2", "size": "124M", "description": "GPT-2 base model"},
387
+ {"id": "distilbert-base-uncased", "size": "66M", "description": "DistilBERT for NLP"},
388
+ {"id": "prajjwal1/bert-tiny", "size": "4.4M", "description": "Tiny BERT for testing"},
389
+ {"id": "microsoft/DialoGPT-small", "size": "124M", "description": "Small conversational model"},
390
+ ],
391
+ "medium_models": [
392
+ {"id": "gpt2-medium", "size": "355M", "description": "GPT-2 medium"},
393
+ {"id": "bert-base-uncased", "size": "110M", "description": "BERT base model"},
394
+ ],
395
+ "cleanup_policy": f"Non-sample models are deleted after {4} hours of inactivity",
396
+ "note": "Sample models are always available for quick testing"
397
+ }
398
+
399
+
400
+ # Helper functions for other routes
401
+ def get_loaded_model():
402
+ return _loaded_model
403
+
404
+
405
+ def get_layer_weights_tensor(layer_name: str):
406
+ if _loaded_model is None:
407
+ return None
408
+ for name, module in _loaded_model.named_modules():
409
+ if name == layer_name and hasattr(module, 'weight'):
410
+ return module.weight.data.clone()
411
+ return None
backend/api/routes/quantization.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quantization Routes
3
+ Core quantization API endpoints
4
+ """
5
+
6
+ from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
7
+ from pydantic import BaseModel
8
+ from typing import Optional, Dict, Any, List
9
+ import torch
10
+ import asyncio
11
+ import json
12
+
13
+ from backend.core.quantizer import (
14
+ QuantizationConfig, QuantizationMethod, QuantizationMode,
15
+ INT8Quantizer, INT4Quantizer, NF4Quantizer, get_quantizer
16
+ )
17
+ from backend.core.model_loader import model_loader
18
+ from backend.core.visualization import visualizer
19
+
20
+ router = APIRouter()
21
+
22
+
23
+ class QuantizeWeightsRequest(BaseModel):
24
+ """Request to quantize custom weights"""
25
+ in_features: int = 64
26
+ out_features: int = 128
27
+ bits: int = 8 # 4 or 8
28
+ method: str = "int8" # int8, int4, nf4
29
+ mode: str = "symmetric" # symmetric, asymmetric
30
+ group_size: Optional[int] = None
31
+ weight_pattern: str = "random" # random, eye, ones, alternating, gradient
32
+ dtype: str = "float32"
33
+
34
+
35
+ class QuantizeLayerRequest(BaseModel):
36
+ """Request to quantize a specific layer from loaded model"""
37
+ layer_name: str
38
+ bits: int = 8
39
+ method: str = "int8"
40
+ mode: str = "symmetric"
41
+ group_size: Optional[int] = None
42
+
43
+
44
+ class QuantizeModelRequest(BaseModel):
45
+ """Request to quantize entire model"""
46
+ bits: int = 8
47
+ method: str = "int8"
48
+ mode: str = "symmetric"
49
+ group_size: Optional[int] = None
50
+ layers_to_skip: List[str] = []
51
+ layers_to_include: Optional[List[str]] = None # None = all quantizable
52
+
53
+
54
+ def _generate_weights(pattern: str, out_features: int, in_features: int,
55
+ dtype: torch.dtype) -> torch.Tensor:
56
+ """Generate weights based on pattern"""
57
+ if pattern == "random":
58
+ return torch.randn((out_features, in_features), dtype=dtype)
59
+ elif pattern == "eye":
60
+ weights = torch.zeros((out_features, in_features), dtype=dtype)
61
+ min_dim = min(out_features, in_features)
62
+ weights[:min_dim, :min_dim] = torch.eye(min_dim, dtype=dtype)
63
+ return weights
64
+ elif pattern == "ones":
65
+ return torch.ones((out_features, in_features), dtype=dtype)
66
+ elif pattern == "alternating":
67
+ weights = torch.ones((out_features, in_features), dtype=dtype)
68
+ for i in range(out_features):
69
+ for j in range(in_features):
70
+ if (i + j) % 2 == 1:
71
+ weights[i, j] = -1.0
72
+ return weights
73
+ elif pattern == "gradient":
74
+ x = torch.linspace(-1, 1, in_features)
75
+ y = torch.linspace(-1, 1, out_features)
76
+ xx, yy = torch.meshgrid(x, y, indexing='ij')
77
+ return (xx + yy).t().to(dtype)
78
+ else:
79
+ return torch.randn((out_features, in_features), dtype=dtype)
80
+
81
+
82
+ def _get_quantizer_from_config(request) -> tuple:
83
+ """Get quantizer and config from request parameters"""
84
+ method_map = {
85
+ "int8": QuantizationMethod.INT8,
86
+ "int4": QuantizationMethod.INT4,
87
+ "nf4": QuantizationMethod.NF4
88
+ }
89
+ mode_map = {
90
+ "symmetric": QuantizationMode.SYMMETRIC,
91
+ "asymmetric": QuantizationMode.ASYMMETRIC
92
+ }
93
+
94
+ config = QuantizationConfig(
95
+ bits=request.bits,
96
+ method=method_map.get(request.method, QuantizationMethod.INT8),
97
+ mode=mode_map.get(request.mode, QuantizationMode.SYMMETRIC),
98
+ group_size=request.group_size
99
+ )
100
+
101
+ quantizer = get_quantizer(config)
102
+ return quantizer, config
103
+
104
+
105
+ @router.post("/weights")
106
+ async def quantize_custom_weights(request: QuantizeWeightsRequest) -> Dict[str, Any]:
107
+ """
108
+ Quantize custom generated weights.
109
+ This endpoint works without loading a real model.
110
+ """
111
+ # Map dtype
112
+ dtype_map = {
113
+ "float32": torch.float32,
114
+ "float16": torch.float16,
115
+ "bfloat16": torch.bfloat16
116
+ }
117
+ dtype = dtype_map.get(request.dtype, torch.float32)
118
+
119
+ # Generate weights
120
+ weights = _generate_weights(
121
+ request.weight_pattern,
122
+ request.out_features,
123
+ request.in_features,
124
+ dtype
125
+ )
126
+
127
+ # Get quantizer
128
+ quantizer, config = _get_quantizer_from_config(request)
129
+
130
+ # Quantize
131
+ result = quantizer.quantize(weights)
132
+
133
+ # Dequantize for visualization
134
+ dequantized = quantizer.dequantize(result)
135
+
136
+ # Generate visualizations
137
+ original_heatmap = visualizer.to_dict(
138
+ visualizer.weight_heatmap(weights, "Original Weights")
139
+ )
140
+ quantized_heatmap = visualizer.to_dict(
141
+ visualizer.weight_heatmap(result.quantized_weights.float(), f"Quantized Weights ({request.bits}-bit)")
142
+ )
143
+ dequantized_heatmap = visualizer.to_dict(
144
+ visualizer.weight_heatmap(dequantized, "Dequantized Weights")
145
+ )
146
+ error_heatmap = visualizer.to_dict(
147
+ visualizer.weight_heatmap((weights - dequantized).abs(), "Quantization Error")
148
+ )
149
+ original_hist = visualizer.to_dict(
150
+ visualizer.weight_histogram(weights, "Original Distribution")
151
+ )
152
+ quantized_hist = visualizer.to_dict(
153
+ visualizer.weight_histogram(result.quantized_weights.float(), "Quantized Distribution")
154
+ )
155
+ scales_hist = visualizer.to_dict(
156
+ visualizer.scales_histogram(result.scales)
157
+ )
158
+
159
+ return {
160
+ "success": True,
161
+ "config": config.to_dict(),
162
+ "stats": {
163
+ "original_shape": list(weights.shape),
164
+ "quantized_shape": list(result.quantized_weights.shape),
165
+ "scales_shape": list(result.scales.shape),
166
+ "max_error": result.max_error,
167
+ "mean_error": result.mean_error,
168
+ "memory_savings_percent": result.memory_savings_percent,
169
+ "original_dtype": str(weights.dtype),
170
+ "quantized_dtype": str(result.quantized_weights.dtype)
171
+ },
172
+ "visualizations": {
173
+ "original_heatmap": original_heatmap,
174
+ "quantized_heatmap": quantized_heatmap,
175
+ "dequantized_heatmap": dequantized_heatmap,
176
+ "error_heatmap": error_heatmap,
177
+ "original_histogram": original_hist,
178
+ "quantized_histogram": quantized_hist,
179
+ "scales_histogram": scales_hist
180
+ }
181
+ }
182
+
183
+
184
+ @router.post("/layer")
185
+ async def quantize_layer(request: QuantizeLayerRequest) -> Dict[str, Any]:
186
+ """
187
+ Quantize a specific layer from the loaded model.
188
+ Requires a model to be loaded first.
189
+ """
190
+ if model_loader is None or model_loader.get_model() is None:
191
+ raise HTTPException(
192
+ status_code=400,
193
+ detail="No model loaded. Load a model first or use /quantize/weights for custom weights."
194
+ )
195
+
196
+ # Get layer weights
197
+ weights = model_loader.get_layer_weights(request.layer_name)
198
+ if weights is None:
199
+ raise HTTPException(status_code=404, detail=f"Layer not found: {request.layer_name}")
200
+
201
+ # Ensure 2D
202
+ original_shape = weights.shape
203
+ if len(weights.shape) == 1:
204
+ weights = weights.unsqueeze(0)
205
+ elif len(weights.shape) > 2:
206
+ weights = weights.reshape(weights.shape[0], -1)
207
+
208
+ # Get quantizer
209
+ quantizer, config = _get_quantizer_from_config(request)
210
+
211
+ # Quantize
212
+ result = quantizer.quantize(weights)
213
+ dequantized = quantizer.dequantize(result)
214
+
215
+ # Generate Visualizations
216
+ original_hist = visualizer.to_dict(visualizer.weight_histogram(weights, "Original Distribution"))
217
+ quantized_hist = visualizer.to_dict(visualizer.weight_histogram(result.quantized_weights.float(), "Quantized Distribution"))
218
+ scales_hist = visualizer.to_dict(visualizer.scales_histogram(result.scales))
219
+
220
+ return {
221
+ "success": True,
222
+ "layer_name": request.layer_name,
223
+ "config": config.to_dict(),
224
+ "stats": {
225
+ "original_shape": list(original_shape),
226
+ "quantized_shape": list(result.quantized_weights.shape),
227
+ "scales_shape": list(result.scales.shape),
228
+ "max_error": result.max_error,
229
+ "mean_error": result.mean_error,
230
+ "memory_savings_percent": result.memory_savings_percent,
231
+ "original_dtype": str(weights.dtype),
232
+ "quantized_dtype": str(result.quantized_weights.dtype)
233
+ },
234
+ "visualizations": {
235
+ "original_heatmap": visualizer.to_dict(
236
+ visualizer.weight_heatmap(weights, f"Original: {request.layer_name}")
237
+ ),
238
+ "quantized_heatmap": visualizer.to_dict(
239
+ visualizer.weight_heatmap(result.quantized_weights.float(), f"Quantized ({request.bits}-bit)")
240
+ ),
241
+ "dequantized_heatmap": visualizer.to_dict(
242
+ visualizer.weight_heatmap(dequantized, "Dequantized Weights")
243
+ ),
244
+ "error_heatmap": visualizer.to_dict(
245
+ visualizer.weight_heatmap((weights - dequantized).abs(), "Error")
246
+ ),
247
+ "original_histogram": original_hist,
248
+ "quantized_histogram": quantized_hist,
249
+ "scales_histogram": scales_hist
250
+ }
251
+ }
252
+
253
+
254
+ @router.post("/model")
255
+ async def quantize_model(request: QuantizeModelRequest) -> Dict[str, Any]:
256
+ """
257
+ Quantize all quantizable layers in the loaded model.
258
+ Returns summary statistics for all layers.
259
+ """
260
+ if model_loader is None or model_loader.get_model() is None:
261
+ raise HTTPException(
262
+ status_code=400,
263
+ detail="No model loaded. This feature requires a loaded model."
264
+ )
265
+
266
+ model_info = model_loader.get_model_info()
267
+ if model_info is None:
268
+ raise HTTPException(status_code=500, detail="Failed to get model info")
269
+
270
+ # Determine layers to quantize
271
+ if request.layers_to_include:
272
+ layers_to_quantize = request.layers_to_include
273
+ else:
274
+ layers_to_quantize = model_info.quantizable_layers
275
+
276
+ # Remove skipped layers
277
+ layers_to_quantize = [l for l in layers_to_quantize if l not in request.layers_to_skip]
278
+
279
+ # Get quantizer
280
+ quantizer, config = _get_quantizer_from_config(request)
281
+
282
+ # Quantize each layer
283
+ results = []
284
+ total_memory_saved = 0
285
+ total_original_size = 0
286
+
287
+ for layer_name in layers_to_quantize:
288
+ weights = model_loader.get_layer_weights(layer_name)
289
+ if weights is None:
290
+ continue
291
+
292
+ # Handle non-2D weights
293
+ original_shape = weights.shape
294
+ if len(weights.shape) == 1:
295
+ weights = weights.unsqueeze(0)
296
+ elif len(weights.shape) > 2:
297
+ weights = weights.reshape(weights.shape[0], -1)
298
+
299
+ try:
300
+ result = quantizer.quantize(weights)
301
+
302
+ original_bytes = weights.numel() * weights.element_size()
303
+ total_original_size += original_bytes
304
+ total_memory_saved += original_bytes * (result.memory_savings_percent / 100)
305
+
306
+ results.append({
307
+ "layer": layer_name,
308
+ "shape": list(original_shape),
309
+ "max_error": result.max_error,
310
+ "mean_error": result.mean_error,
311
+ "memory_savings_percent": result.memory_savings_percent
312
+ })
313
+ except Exception as e:
314
+ results.append({
315
+ "layer": layer_name,
316
+ "error": str(e)
317
+ })
318
+
319
+ return {
320
+ "success": True,
321
+ "config": config.to_dict(),
322
+ "summary": {
323
+ "layers_quantized": len([r for r in results if "error" not in r]),
324
+ "layers_failed": len([r for r in results if "error" in r]),
325
+ "total_memory_saved_mb": total_memory_saved / (1024 * 1024),
326
+ "average_memory_savings_percent": (total_memory_saved / total_original_size * 100) if total_original_size > 0 else 0
327
+ },
328
+ "layers": results
329
+ }
330
+
331
+
332
+ # WebSocket for real-time progress
333
+ @router.websocket("/stream")
334
+ async def quantization_stream(websocket: WebSocket):
335
+ """WebSocket endpoint for streaming quantization progress"""
336
+ await websocket.accept()
337
+
338
+ try:
339
+ while True:
340
+ # Receive quantization request
341
+ data = await websocket.receive_text()
342
+ request_data = json.loads(data)
343
+
344
+ # Process and send updates
345
+ await websocket.send_json({
346
+ "type": "progress",
347
+ "progress": 0,
348
+ "message": "Starting quantization..."
349
+ })
350
+
351
+ # Simulate progress (in real implementation, this would be actual quantization)
352
+ for i in range(0, 101, 10):
353
+ await asyncio.sleep(0.1)
354
+ await websocket.send_json({
355
+ "type": "progress",
356
+ "progress": i,
357
+ "message": f"Processing... {i}%"
358
+ })
359
+
360
+ await websocket.send_json({
361
+ "type": "complete",
362
+ "message": "Quantization complete"
363
+ })
364
+
365
+ except WebSocketDisconnect:
366
+ pass
backend/api/routes/system.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System Routes
3
+ Hardware detection and system information
4
+ """
5
+
6
+ from fastapi import APIRouter
7
+ from typing import Dict, Any
8
+
9
+ from backend.core.system_checker import system_checker, check_model_requirements
10
+
11
+ router = APIRouter()
12
+
13
+
14
+ @router.get("/info")
15
+ async def get_system_info() -> Dict[str, Any]:
16
+ """
17
+ Get complete system information including GPU, RAM, and capabilities.
18
+ """
19
+ return system_checker.to_dict()
20
+
21
+
22
+ @router.get("/capabilities")
23
+ async def get_capabilities() -> Dict[str, Any]:
24
+ """
25
+ Get system capabilities for quantization tasks.
26
+ """
27
+ info = system_checker.check()
28
+ return {
29
+ "capability": info.capability.value,
30
+ "recommended_batch_size": info.recommended_batch_size,
31
+ "max_model_size": info.max_model_size,
32
+ "cuda_available": info.cuda_available,
33
+ "mps_available": info.mps_available,
34
+ "gpus": [
35
+ {
36
+ "name": gpu.name,
37
+ "memory_gb": gpu.total_memory_gb
38
+ }
39
+ for gpu in info.gpus
40
+ ]
41
+ }
42
+
43
+
44
+ @router.post("/check-model")
45
+ async def check_model_requirements_endpoint(
46
+ model_params_billions: float,
47
+ dtype: str = "fp16"
48
+ ) -> Dict[str, Any]:
49
+ """
50
+ Check if system can handle a model of specified size.
51
+
52
+ Args:
53
+ model_params_billions: Model size in billions of parameters
54
+ dtype: Data type (fp32, fp16, int8, int4)
55
+ """
56
+ return check_model_requirements(model_params_billions, dtype)
57
+
58
+
59
+ @router.get("/refresh")
60
+ async def refresh_system_info() -> Dict[str, Any]:
61
+ """
62
+ Force refresh system information.
63
+ """
64
+ return system_checker.check(force_refresh=True).__dict__
backend/core/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Neural Network Quantizer - Backend Core Package
3
+ Multi-bit quantization engine supporting 4-bit, 8-bit, NF4, and GPTQ methods.
4
+ """
5
+
6
+ __version__ = "1.0.0"
backend/core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (357 Bytes). View file
 
backend/core/__pycache__/model_loader.cpython-312.pyc ADDED
Binary file (16.1 kB). View file
 
backend/core/__pycache__/model_manager.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
backend/core/__pycache__/quantizer.cpython-312.pyc ADDED
Binary file (30.2 kB). View file
 
backend/core/__pycache__/system_checker.cpython-312.pyc ADDED
Binary file (11.2 kB). View file
 
backend/core/__pycache__/visualization.cpython-312.pyc ADDED
Binary file (13.1 kB). View file
 
backend/core/model_loader.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Model Loader
3
+ Loads models from HuggingFace Hub or local files with memory-efficient options.
4
+ """
5
+
6
+ import torch
7
+ import gc
8
+ from pathlib import Path
9
+ from typing import Optional, Dict, Any, List, Tuple, Union, TYPE_CHECKING
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+
13
+ try:
14
+ from transformers import (
15
+ AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification,
16
+ AutoTokenizer, AutoConfig
17
+ )
18
+ HAS_TRANSFORMERS = True
19
+ except ImportError:
20
+ HAS_TRANSFORMERS = False
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers import PreTrainedModel
24
+
25
+ from .system_checker import system_checker, check_model_requirements
26
+
27
+
28
+ class ModelType(Enum):
29
+ """Supported model types"""
30
+ CAUSAL_LM = "causal_lm"
31
+ SEQUENCE_CLASSIFICATION = "sequence_classification"
32
+ GENERIC = "generic"
33
+
34
+
35
+ @dataclass
36
+ class LayerInfo:
37
+ """Information about a model layer"""
38
+ name: str
39
+ module_type: str
40
+ shape: Optional[Tuple[int, ...]]
41
+ num_params: int
42
+ dtype: str
43
+ is_quantizable: bool
44
+
45
+
46
+ @dataclass
47
+ class ModelInfo:
48
+ """Complete model information"""
49
+ name: str
50
+ model_type: ModelType
51
+ architecture: str
52
+ num_params: int
53
+ num_params_billions: float
54
+ hidden_size: int
55
+ num_layers: int
56
+ vocab_size: Optional[int]
57
+ dtype: str
58
+ memory_footprint_gb: float
59
+ layers: List[LayerInfo]
60
+ quantizable_layers: List[str]
61
+
62
+
63
+ class ModelLoader:
64
+ """
65
+ Load and inspect HuggingFace models with memory-efficient options.
66
+ Provides layer-by-layer analysis for selective quantization.
67
+ """
68
+
69
+ # Layer types that can be quantized
70
+ QUANTIZABLE_TYPES = (
71
+ "Linear",
72
+ "Conv1d",
73
+ "Conv2d",
74
+ "Embedding"
75
+ )
76
+
77
+ def __init__(self):
78
+ if not HAS_TRANSFORMERS:
79
+ raise ImportError(
80
+ "transformers library not installed. "
81
+ "Install with: pip install transformers"
82
+ )
83
+ self._loaded_model = None # Optional[PreTrainedModel]
84
+ self._model_info: Optional[ModelInfo] = None
85
+ self._tokenizer = None
86
+
87
+ def check_requirements(self, model_name: str, dtype: str = "fp16") -> Dict[str, Any]:
88
+ """Check if system can load the model before attempting"""
89
+ try:
90
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
91
+
92
+ # Estimate parameters
93
+ if hasattr(config, 'num_parameters'):
94
+ num_params = config.num_parameters
95
+ else:
96
+ # Estimate from config
97
+ hidden = getattr(config, 'hidden_size', 768)
98
+ layers = getattr(config, 'num_hidden_layers', 12)
99
+ vocab = getattr(config, 'vocab_size', 30000)
100
+ num_params = self._estimate_params(hidden, layers, vocab)
101
+
102
+ params_billions = num_params / 1e9
103
+ return check_model_requirements(params_billions, dtype)
104
+
105
+ except Exception as e:
106
+ return {
107
+ "can_load": False,
108
+ "error": str(e),
109
+ "warnings": [f"Failed to fetch model config: {str(e)}"]
110
+ }
111
+
112
+ def _estimate_params(self, hidden: int, layers: int, vocab: int) -> int:
113
+ """Estimate parameter count from config"""
114
+ # Rough estimate: embeddings + transformer layers
115
+ embedding_params = vocab * hidden
116
+ # Each layer: attention (4 * hidden^2) + FFN (8 * hidden^2)
117
+ layer_params = layers * (12 * hidden * hidden)
118
+ return embedding_params + layer_params
119
+
120
+ def load(self, model_name: str,
121
+ model_type: ModelType = ModelType.GENERIC,
122
+ dtype: str = "auto",
123
+ device: str = "auto",
124
+ trust_remote_code: bool = True,
125
+ low_memory: bool = False) -> Tuple[Any, Optional[Any]]:
126
+ """
127
+ Load a model from HuggingFace Hub or local path.
128
+
129
+ Args:
130
+ model_name: HuggingFace model ID or local path
131
+ model_type: Type of model to load
132
+ dtype: Data type ("auto", "fp32", "fp16", "bf16")
133
+ device: Device to load to ("auto", "cuda", "cpu", "mps")
134
+ trust_remote_code: Allow custom code from model repos
135
+ low_memory: Use memory-efficient loading
136
+
137
+ Returns:
138
+ Tuple of (model, tokenizer)
139
+ """
140
+ # Clear previous model
141
+ self.unload()
142
+
143
+ # Determine device
144
+ if device == "auto":
145
+ sys_info = system_checker.check()
146
+ if sys_info.cuda_available:
147
+ device = "cuda"
148
+ elif sys_info.mps_available:
149
+ device = "mps"
150
+ else:
151
+ device = "cpu"
152
+
153
+ # Determine dtype
154
+ if dtype == "auto":
155
+ if device == "cuda":
156
+ dtype = "fp16"
157
+ elif device == "mps":
158
+ dtype = "fp32" # MPS has limited bf16 support
159
+ else:
160
+ dtype = "fp32"
161
+
162
+ torch_dtype = {
163
+ "fp32": torch.float32,
164
+ "fp16": torch.float16,
165
+ "bf16": torch.bfloat16
166
+ }.get(dtype, torch.float32)
167
+
168
+ # Load config first
169
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
170
+
171
+ # Select model class
172
+ if model_type == ModelType.CAUSAL_LM:
173
+ model_class = AutoModelForCausalLM
174
+ elif model_type == ModelType.SEQUENCE_CLASSIFICATION:
175
+ model_class = AutoModelForSequenceClassification
176
+ else:
177
+ model_class = AutoModel
178
+
179
+ # Load model
180
+ load_kwargs = {
181
+ "pretrained_model_name_or_path": model_name,
182
+ "torch_dtype": torch_dtype,
183
+ "trust_remote_code": trust_remote_code,
184
+ }
185
+
186
+ if low_memory:
187
+ load_kwargs["low_cpu_mem_usage"] = True
188
+ if device == "cuda":
189
+ load_kwargs["device_map"] = "auto"
190
+
191
+ model = model_class.from_pretrained(**load_kwargs)
192
+
193
+ if not low_memory and device != "cpu":
194
+ model = model.to(device)
195
+
196
+ model.eval()
197
+
198
+ # Load tokenizer
199
+ try:
200
+ tokenizer = AutoTokenizer.from_pretrained(
201
+ model_name, trust_remote_code=trust_remote_code
202
+ )
203
+ except Exception:
204
+ tokenizer = None
205
+
206
+ self._loaded_model = model
207
+ self._tokenizer = tokenizer
208
+ self._model_info = self._analyze_model(model, model_name, model_type)
209
+
210
+ return model, tokenizer
211
+
212
+ def load_weights_only(self, model_name: str) -> Dict[str, torch.Tensor]:
213
+ """
214
+ Load only the state dict without instantiating the model.
215
+ More memory efficient for inspection.
216
+ """
217
+ from safetensors import safe_open
218
+ from huggingface_hub import hf_hub_download
219
+
220
+ try:
221
+ # Try safetensors first
222
+ path = hf_hub_download(model_name, "model.safetensors")
223
+ weights = {}
224
+ with safe_open(path, framework="pt") as f:
225
+ for key in f.keys():
226
+ weights[key] = f.get_tensor(key)
227
+ return weights
228
+ except Exception:
229
+ # Fallback to torch
230
+ try:
231
+ path = hf_hub_download(model_name, "pytorch_model.bin")
232
+ return torch.load(path, map_location="cpu")
233
+ except Exception as e:
234
+ raise RuntimeError(f"Failed to load weights: {str(e)}")
235
+
236
+ def _analyze_model(self, model: Any, name: str,
237
+ model_type: ModelType) -> ModelInfo:
238
+ """Analyze model structure and extract layer information"""
239
+ layers = []
240
+ quantizable_layers = []
241
+ total_params = 0
242
+
243
+ for layer_name, module in model.named_modules():
244
+ if not layer_name:
245
+ continue
246
+
247
+ # Get module info
248
+ module_type = module.__class__.__name__
249
+
250
+ # Check if quantizable
251
+ is_quantizable = any(
252
+ qt in module_type for qt in self.QUANTIZABLE_TYPES
253
+ )
254
+
255
+ # Get shape and params for leaf modules
256
+ shape = None
257
+ num_params = 0
258
+ dtype = "N/A"
259
+
260
+ if hasattr(module, 'weight') and module.weight is not None:
261
+ shape = tuple(module.weight.shape)
262
+ num_params = module.weight.numel()
263
+ dtype = str(module.weight.dtype)
264
+ if hasattr(module, 'bias') and module.bias is not None:
265
+ num_params += module.bias.numel()
266
+
267
+ if num_params > 0:
268
+ total_params += num_params
269
+ layers.append(LayerInfo(
270
+ name=layer_name,
271
+ module_type=module_type,
272
+ shape=shape,
273
+ num_params=num_params,
274
+ dtype=dtype,
275
+ is_quantizable=is_quantizable
276
+ ))
277
+
278
+ if is_quantizable:
279
+ quantizable_layers.append(layer_name)
280
+
281
+ # Get config info
282
+ config = model.config
283
+ hidden_size = getattr(config, 'hidden_size', 768)
284
+ num_layers = getattr(config, 'num_hidden_layers', 12)
285
+ vocab_size = getattr(config, 'vocab_size', None)
286
+
287
+ # Calculate memory
288
+ memory_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3)
289
+
290
+ return ModelInfo(
291
+ name=name,
292
+ model_type=model_type,
293
+ architecture=config.architectures[0] if hasattr(config, 'architectures') and config.architectures else "Unknown",
294
+ num_params=total_params,
295
+ num_params_billions=total_params / 1e9,
296
+ hidden_size=hidden_size,
297
+ num_layers=num_layers,
298
+ vocab_size=vocab_size,
299
+ dtype=str(next(model.parameters()).dtype),
300
+ memory_footprint_gb=round(memory_gb, 2),
301
+ layers=layers,
302
+ quantizable_layers=quantizable_layers
303
+ )
304
+
305
+ def register_model(self, model: Any, name: str, tokenizer: Any = None):
306
+ """Register an externally loaded model"""
307
+ self._loaded_model = model
308
+ self._tokenizer = tokenizer
309
+ self._model_info = self._analyze_model(model, name, ModelType.GENERIC)
310
+
311
+ def get_layer_weights(self, layer_name: str) -> Optional[torch.Tensor]:
312
+ """Get weights from a specific layer"""
313
+ if self._loaded_model is None:
314
+ raise RuntimeError("No model loaded")
315
+
316
+ for name, module in self._loaded_model.named_modules():
317
+ if name == layer_name:
318
+ if hasattr(module, 'weight'):
319
+ return module.weight.data.clone()
320
+ return None
321
+
322
+ def set_layer_weights(self, layer_name: str, weights: torch.Tensor):
323
+ """Set weights for a specific layer"""
324
+ if self._loaded_model is None:
325
+ raise RuntimeError("No model loaded")
326
+
327
+ for name, module in self._loaded_model.named_modules():
328
+ if name == layer_name:
329
+ if hasattr(module, 'weight'):
330
+ module.weight.data = weights.to(module.weight.device)
331
+ return
332
+ raise ValueError(f"Layer not found: {layer_name}")
333
+
334
+ def get_model_info(self) -> Optional[ModelInfo]:
335
+ """Get current model information"""
336
+ return self._model_info
337
+
338
+ def get_model(self) -> Optional[Any]:
339
+ """Get loaded model"""
340
+ return self._loaded_model
341
+
342
+ def get_tokenizer(self):
343
+ """Get loaded tokenizer"""
344
+ return self._tokenizer
345
+
346
+ def unload(self):
347
+ """Unload model and free memory"""
348
+ if self._loaded_model is not None:
349
+ del self._loaded_model
350
+ self._loaded_model = None
351
+
352
+ if self._tokenizer is not None:
353
+ del self._tokenizer
354
+ self._tokenizer = None
355
+
356
+ self._model_info = None
357
+
358
+ # Force garbage collection
359
+ gc.collect()
360
+ if torch.cuda.is_available():
361
+ torch.cuda.empty_cache()
362
+
363
+ def to_dict(self) -> Optional[Dict[str, Any]]:
364
+ """Convert model info to dictionary"""
365
+ if self._model_info is None:
366
+ return None
367
+
368
+ info = self._model_info
369
+ return {
370
+ "name": info.name,
371
+ "model_type": info.model_type.value,
372
+ "architecture": info.architecture,
373
+ "num_params": info.num_params,
374
+ "num_params_billions": round(info.num_params_billions, 3),
375
+ "hidden_size": info.hidden_size,
376
+ "num_layers": info.num_layers,
377
+ "vocab_size": info.vocab_size,
378
+ "dtype": info.dtype,
379
+ "memory_footprint_gb": info.memory_footprint_gb,
380
+ "num_quantizable_layers": len(info.quantizable_layers),
381
+ "quantizable_layers": info.quantizable_layers,
382
+ "layers": [
383
+ {
384
+ "name": layer.name,
385
+ "module_type": layer.module_type,
386
+ "shape": layer.shape,
387
+ "num_params": layer.num_params,
388
+ "dtype": layer.dtype,
389
+ "is_quantizable": layer.is_quantizable
390
+ }
391
+ for layer in info.layers
392
+ ]
393
+ }
394
+
395
+
396
+ # Global instance
397
+ model_loader = ModelLoader() if HAS_TRANSFORMERS else None
398
+
399
+
400
+ def load_model(model_name: str, **kwargs) -> Tuple[Any, Any]:
401
+ """Convenience function to load a model"""
402
+ if model_loader is None:
403
+ raise ImportError("transformers not available")
404
+ return model_loader.load(model_name, **kwargs)
405
+
406
+
407
+ def get_model_info() -> Optional[Dict[str, Any]]:
408
+ """Get current model information"""
409
+ if model_loader is None:
410
+ return None
411
+ return model_loader.to_dict()
backend/core/model_manager.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Manager with Download Progress, Caching, and Auto-Cleanup
3
+ Designed to work with HuggingFace Spaces disk storage
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import shutil
9
+ import asyncio
10
+ import threading
11
+ from pathlib import Path
12
+ from typing import Optional, Dict, Any, Callable
13
+ from dataclasses import dataclass
14
+ from datetime import datetime, timedelta
15
+
16
+ # HuggingFace cache directory - works on Spaces
17
+ HF_CACHE_DIR = os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface")
18
+ MODEL_CACHE_DIR = Path(HF_CACHE_DIR) / "hub"
19
+
20
+ # Sample models that should always be available (tiny models for quick testing)
21
+ SAMPLE_MODELS = [
22
+ "prajjwal1/bert-tiny", # 4.4MB - Perfect for testing
23
+ ]
24
+
25
+ # Auto-cleanup interval (4 hours)
26
+ CLEANUP_INTERVAL_HOURS = 4
27
+
28
+ # Track download progress
29
+ _download_progress: Dict[str, Dict[str, Any]] = {}
30
+ _cleanup_thread: Optional[threading.Thread] = None
31
+
32
+
33
+ @dataclass
34
+ class DownloadProgress:
35
+ """Track download progress for a model"""
36
+ model_name: str
37
+ status: str # "pending", "downloading", "extracting", "complete", "error"
38
+ current_file: str
39
+ files_completed: int
40
+ total_files: int
41
+ bytes_downloaded: int
42
+ total_bytes: int
43
+ speed_mbps: float
44
+ eta_seconds: int
45
+ error: Optional[str] = None
46
+
47
+
48
+ def get_download_progress(model_name: str) -> Optional[Dict[str, Any]]:
49
+ """Get current download progress for a model"""
50
+ return _download_progress.get(model_name)
51
+
52
+
53
+ def set_download_progress(model_name: str, progress: Dict[str, Any]):
54
+ """Update download progress"""
55
+ _download_progress[model_name] = {
56
+ **progress,
57
+ "timestamp": time.time()
58
+ }
59
+
60
+
61
+ def clear_download_progress(model_name: str):
62
+ """Clear download progress after completion"""
63
+ if model_name in _download_progress:
64
+ del _download_progress[model_name]
65
+
66
+
67
+ def get_cached_models() -> list:
68
+ """Get list of models currently in cache"""
69
+ cached = []
70
+
71
+ if not MODEL_CACHE_DIR.exists():
72
+ return cached
73
+
74
+ for item in MODEL_CACHE_DIR.iterdir():
75
+ if item.is_dir() and item.name.startswith("models--"):
76
+ # Parse model name from directory name
77
+ parts = item.name.replace("models--", "").split("--")
78
+ if len(parts) >= 2:
79
+ model_name = f"{parts[0]}/{parts[1]}"
80
+ else:
81
+ model_name = parts[0]
82
+
83
+ # Get size
84
+ size_mb = sum(f.stat().st_size for f in item.rglob("*") if f.is_file()) / (1024 * 1024)
85
+
86
+ # Get last access time
87
+ try:
88
+ last_access = item.stat().st_atime
89
+ except:
90
+ last_access = time.time()
91
+
92
+ cached.append({
93
+ "name": model_name,
94
+ "path": str(item),
95
+ "size_mb": round(size_mb, 2),
96
+ "last_access": datetime.fromtimestamp(last_access).isoformat(),
97
+ "is_sample": model_name in SAMPLE_MODELS
98
+ })
99
+
100
+ return cached
101
+
102
+
103
+ def cleanup_old_models(max_age_hours: float = CLEANUP_INTERVAL_HOURS):
104
+ """
105
+ Remove models that haven't been accessed in max_age_hours.
106
+ Sample models are never deleted.
107
+ """
108
+ if not MODEL_CACHE_DIR.exists():
109
+ return {"deleted": [], "kept": []}
110
+
111
+ deleted = []
112
+ kept = []
113
+ cutoff_time = time.time() - (max_age_hours * 3600)
114
+
115
+ for item in MODEL_CACHE_DIR.iterdir():
116
+ if item.is_dir() and item.name.startswith("models--"):
117
+ # Parse model name
118
+ parts = item.name.replace("models--", "").split("--")
119
+ if len(parts) >= 2:
120
+ model_name = f"{parts[0]}/{parts[1]}"
121
+ else:
122
+ model_name = parts[0]
123
+
124
+ # Never delete sample models
125
+ if model_name in SAMPLE_MODELS:
126
+ kept.append(model_name)
127
+ continue
128
+
129
+ # Check last access time
130
+ try:
131
+ last_access = item.stat().st_atime
132
+ if last_access < cutoff_time:
133
+ shutil.rmtree(item)
134
+ deleted.append(model_name)
135
+ else:
136
+ kept.append(model_name)
137
+ except Exception as e:
138
+ kept.append(f"{model_name} (error: {str(e)})")
139
+
140
+ return {"deleted": deleted, "kept": kept}
141
+
142
+
143
+ def delete_model_cache(model_name: str) -> bool:
144
+ """Delete a specific model from cache"""
145
+ if model_name in SAMPLE_MODELS:
146
+ return False # Don't delete sample models
147
+
148
+ # Convert model name to cache directory name
149
+ cache_name = f"models--{model_name.replace('/', '--')}"
150
+ cache_path = MODEL_CACHE_DIR / cache_name
151
+
152
+ if cache_path.exists():
153
+ try:
154
+ shutil.rmtree(cache_path)
155
+ return True
156
+ except:
157
+ return False
158
+ return False
159
+
160
+
161
+ def ensure_sample_models():
162
+ """
163
+ Ensure sample models are downloaded.
164
+ Called on startup to pre-download tiny test models.
165
+ """
166
+ try:
167
+ from transformers import AutoModel, AutoConfig
168
+
169
+ for model_name in SAMPLE_MODELS:
170
+ try:
171
+ # Just load config first (fast)
172
+ AutoConfig.from_pretrained(model_name)
173
+ print(f"[ModelManager] Sample model '{model_name}' is available")
174
+ except Exception as e:
175
+ print(f"[ModelManager] Sample model '{model_name}' not cached: {e}")
176
+ except ImportError:
177
+ print("[ModelManager] transformers not installed, skipping sample model check")
178
+
179
+
180
+ def start_cleanup_scheduler():
181
+ """Start background thread for periodic cleanup"""
182
+ global _cleanup_thread
183
+
184
+ if _cleanup_thread is not None and _cleanup_thread.is_alive():
185
+ return
186
+
187
+ def cleanup_loop():
188
+ while True:
189
+ time.sleep(CLEANUP_INTERVAL_HOURS * 3600) # Wait 4 hours
190
+ try:
191
+ result = cleanup_old_models()
192
+ if result["deleted"]:
193
+ print(f"[ModelManager] Cleaned up models: {result['deleted']}")
194
+ except Exception as e:
195
+ print(f"[ModelManager] Cleanup error: {e}")
196
+
197
+ _cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
198
+ _cleanup_thread.start()
199
+ print(f"[ModelManager] Cleanup scheduler started (every {CLEANUP_INTERVAL_HOURS} hours)")
200
+
201
+
202
+ def get_cache_stats() -> Dict[str, Any]:
203
+ """Get cache statistics"""
204
+ models = get_cached_models()
205
+ total_size = sum(m["size_mb"] for m in models)
206
+ sample_count = sum(1 for m in models if m["is_sample"])
207
+
208
+ return {
209
+ "cache_dir": str(MODEL_CACHE_DIR),
210
+ "total_models": len(models),
211
+ "sample_models": sample_count,
212
+ "total_size_mb": round(total_size, 2),
213
+ "cleanup_interval_hours": CLEANUP_INTERVAL_HOURS,
214
+ "models": models
215
+ }
216
+
217
+
218
+ # Progress callback for HuggingFace downloads
219
+ class DownloadProgressCallback:
220
+ """Callback to track HuggingFace download progress"""
221
+
222
+ def __init__(self, model_name: str):
223
+ self.model_name = model_name
224
+ self.start_time = time.time()
225
+ self.last_update = 0
226
+
227
+ def __call__(self, current: int, total: int, filename: str = ""):
228
+ now = time.time()
229
+
230
+ # Throttle updates to every 0.5 seconds
231
+ if now - self.last_update < 0.5:
232
+ return
233
+
234
+ self.last_update = now
235
+ elapsed = now - self.start_time
236
+ speed = current / elapsed if elapsed > 0 else 0
237
+ eta = int((total - current) / speed) if speed > 0 else 0
238
+
239
+ set_download_progress(self.model_name, {
240
+ "status": "downloading",
241
+ "current_file": filename,
242
+ "bytes_downloaded": current,
243
+ "total_bytes": total,
244
+ "percent": round(100 * current / total, 1) if total > 0 else 0,
245
+ "speed_mbps": round(speed / (1024 * 1024), 2),
246
+ "eta_seconds": eta
247
+ })
backend/core/quantizer.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-bit Weight Quantization Engine
3
+ Supports INT8, INT4, NF4, and GPTQ-style quantization methods.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from typing import Optional, Tuple, Dict, Any, Literal
11
+ from dataclasses import dataclass
12
+ from enum import Enum
13
+
14
+
15
+ class QuantizationMethod(Enum):
16
+ """Supported quantization methods"""
17
+ INT8 = "int8" # 8-bit integer quantization
18
+ INT4 = "int4" # 4-bit integer quantization
19
+ NF4 = "nf4" # Normal Float 4-bit (QLoRA style)
20
+ GPTQ = "gptq" # GPTQ reconstruction-based
21
+
22
+
23
+ class QuantizationMode(Enum):
24
+ """Quantization modes"""
25
+ SYMMETRIC = "symmetric" # Range: [-max, max]
26
+ ASYMMETRIC = "asymmetric" # Range: [min, max]
27
+
28
+
29
+ @dataclass
30
+ class QuantizationConfig:
31
+ """Configuration for quantization process"""
32
+ bits: int = 8
33
+ method: QuantizationMethod = QuantizationMethod.INT8
34
+ mode: QuantizationMode = QuantizationMode.SYMMETRIC
35
+ group_size: Optional[int] = None # None = per-channel, else group quantization
36
+ use_double_quant: bool = False # Double quantization for scales
37
+ compute_dtype: torch.dtype = torch.float32
38
+
39
+ def to_dict(self) -> Dict[str, Any]:
40
+ return {
41
+ "bits": self.bits,
42
+ "method": self.method.value,
43
+ "mode": self.mode.value,
44
+ "group_size": self.group_size,
45
+ "use_double_quant": self.use_double_quant,
46
+ "compute_dtype": str(self.compute_dtype)
47
+ }
48
+
49
+
50
+ @dataclass
51
+ class QuantizationResult:
52
+ """Result of quantization operation"""
53
+ quantized_weights: torch.Tensor
54
+ scales: torch.Tensor
55
+ zero_points: Optional[torch.Tensor]
56
+ original_shape: Tuple[int, ...]
57
+ config: QuantizationConfig
58
+ max_error: float
59
+ mean_error: float
60
+ memory_savings_percent: float
61
+
62
+
63
+ class BaseQuantizer:
64
+ """Base class for all quantizers"""
65
+
66
+ def __init__(self, config: QuantizationConfig):
67
+ self.config = config
68
+
69
+ def quantize(self, weights: torch.Tensor) -> QuantizationResult:
70
+ raise NotImplementedError
71
+
72
+ def dequantize(self, result: QuantizationResult) -> torch.Tensor:
73
+ raise NotImplementedError
74
+
75
+ def _calculate_error(self, original: torch.Tensor, dequantized: torch.Tensor) -> Tuple[float, float]:
76
+ """Calculate quantization error metrics"""
77
+ abs_error = (original - dequantized).abs()
78
+ return abs_error.max().item(), abs_error.mean().item()
79
+
80
+ def _calculate_memory_savings(self, original: torch.Tensor, quantized: torch.Tensor,
81
+ scales: torch.Tensor) -> float:
82
+ """Calculate memory savings percentage"""
83
+ original_bytes = original.numel() * original.element_size()
84
+ quantized_bytes = quantized.numel() * quantized.element_size()
85
+ scales_bytes = scales.numel() * scales.element_size()
86
+
87
+ new_bytes = quantized_bytes + scales_bytes
88
+ savings = 100 * (1 - new_bytes / original_bytes)
89
+ return savings
90
+
91
+
92
+ class INT8Quantizer(BaseQuantizer):
93
+ """8-bit integer quantization (W8A16)"""
94
+
95
+ def __init__(self, config: Optional[QuantizationConfig] = None):
96
+ if config is None:
97
+ config = QuantizationConfig(bits=8, method=QuantizationMethod.INT8)
98
+ super().__init__(config)
99
+
100
+ def quantize(self, weights: torch.Tensor) -> QuantizationResult:
101
+ """
102
+ Quantize weights to INT8 precision.
103
+
104
+ Args:
105
+ weights: Tensor of shape (out_features, in_features)
106
+
107
+ Returns:
108
+ QuantizationResult with int8 weights and scales
109
+ """
110
+ original_shape = weights.shape
111
+ w_fp32 = weights.clone().to(torch.float32)
112
+
113
+ if self.config.group_size is not None:
114
+ # Group quantization
115
+ return self._quantize_grouped(w_fp32, original_shape)
116
+
117
+ # Per-channel quantization
118
+ if self.config.mode == QuantizationMode.SYMMETRIC:
119
+ # Symmetric: scale = max(|w|) / 127
120
+ scales = w_fp32.abs().max(dim=-1).values / 127
121
+ scales = scales.clamp(min=1e-8) # Avoid division by zero
122
+ zero_points = None
123
+
124
+ # Quantize
125
+ int8_weights = torch.round(w_fp32 / scales.unsqueeze(1)).clamp(-128, 127).to(torch.int8)
126
+
127
+ else:
128
+ # Asymmetric: use full [-128, 127] range
129
+ w_min = w_fp32.min(dim=-1).values
130
+ w_max = w_fp32.max(dim=-1).values
131
+
132
+ scales = (w_max - w_min) / 255
133
+ scales = scales.clamp(min=1e-8)
134
+ zero_points = torch.round(-w_min / scales).clamp(0, 255).to(torch.int32)
135
+
136
+ # Quantize
137
+ int8_weights = torch.round(w_fp32 / scales.unsqueeze(1) + zero_points.unsqueeze(1))
138
+ int8_weights = int8_weights.clamp(-128, 127).to(torch.int8)
139
+
140
+ # Dequantize for error calculation
141
+ dequantized = self.dequantize_weights(int8_weights, scales, zero_points)
142
+ max_error, mean_error = self._calculate_error(weights, dequantized)
143
+ memory_savings = self._calculate_memory_savings(weights, int8_weights, scales)
144
+
145
+ return QuantizationResult(
146
+ quantized_weights=int8_weights,
147
+ scales=scales,
148
+ zero_points=zero_points,
149
+ original_shape=original_shape,
150
+ config=self.config,
151
+ max_error=max_error,
152
+ mean_error=mean_error,
153
+ memory_savings_percent=memory_savings
154
+ )
155
+
156
+ def _quantize_grouped(self, weights: torch.Tensor, original_shape: Tuple[int, ...]) -> QuantizationResult:
157
+ """Quantize with group-wise scaling"""
158
+ out_features, in_features = weights.shape
159
+ group_size = self.config.group_size
160
+
161
+ # Pad if necessary
162
+ if in_features % group_size != 0:
163
+ pad_size = group_size - (in_features % group_size)
164
+ weights = F.pad(weights, (0, pad_size))
165
+ in_features = weights.shape[1]
166
+
167
+ # Reshape for group quantization
168
+ num_groups = in_features // group_size
169
+ weights_grouped = weights.reshape(out_features, num_groups, group_size)
170
+
171
+ # Calculate scales per group
172
+ scales = weights_grouped.abs().max(dim=-1).values / 127
173
+ scales = scales.clamp(min=1e-8)
174
+
175
+ # Quantize
176
+ int8_weights = torch.round(weights_grouped / scales.unsqueeze(-1))
177
+ int8_weights = int8_weights.clamp(-128, 127).to(torch.int8)
178
+ int8_weights = int8_weights.reshape(out_features, in_features)
179
+
180
+ # Trim padding
181
+ int8_weights = int8_weights[:, :original_shape[1]]
182
+ scales = scales.reshape(out_features, num_groups)
183
+
184
+ # Dequantize for error calculation
185
+ dequantized = self.dequantize_weights(int8_weights, scales, None, group_size)
186
+ max_error, mean_error = self._calculate_error(
187
+ weights[:, :original_shape[1]], dequantized
188
+ )
189
+ memory_savings = self._calculate_memory_savings(
190
+ weights[:, :original_shape[1]], int8_weights, scales
191
+ )
192
+
193
+ return QuantizationResult(
194
+ quantized_weights=int8_weights,
195
+ scales=scales,
196
+ zero_points=None,
197
+ original_shape=original_shape,
198
+ config=self.config,
199
+ max_error=max_error,
200
+ mean_error=mean_error,
201
+ memory_savings_percent=memory_savings
202
+ )
203
+
204
+ def dequantize_weights(self, int8_weights: torch.Tensor, scales: torch.Tensor,
205
+ zero_points: Optional[torch.Tensor] = None,
206
+ group_size: Optional[int] = None) -> torch.Tensor:
207
+ """Dequantize INT8 weights back to floating point"""
208
+ if group_size is not None:
209
+ # Group dequantization
210
+ out_features, in_features = int8_weights.shape
211
+ num_groups = scales.shape[1]
212
+
213
+ # Expand scales to match weight shape
214
+ scales_expanded = scales.unsqueeze(-1).expand(-1, -1, group_size)
215
+ scales_expanded = scales_expanded.reshape(out_features, -1)[:, :in_features]
216
+
217
+ return int8_weights.float() * scales_expanded
218
+
219
+ if zero_points is not None:
220
+ # Asymmetric dequantization
221
+ return (int8_weights.float() - zero_points.unsqueeze(1).float()) * scales.unsqueeze(1)
222
+
223
+ # Symmetric dequantization
224
+ return int8_weights.float() * scales.unsqueeze(1)
225
+
226
+ def dequantize(self, result: QuantizationResult) -> torch.Tensor:
227
+ """Dequantize from QuantizationResult"""
228
+ return self.dequantize_weights(
229
+ result.quantized_weights,
230
+ result.scales,
231
+ result.zero_points,
232
+ result.config.group_size
233
+ )
234
+
235
+
236
+ class INT4Quantizer(BaseQuantizer):
237
+ """4-bit integer quantization (W4A16)"""
238
+
239
+ def __init__(self, config: Optional[QuantizationConfig] = None):
240
+ if config is None:
241
+ config = QuantizationConfig(bits=4, method=QuantizationMethod.INT4, group_size=128)
242
+ super().__init__(config)
243
+
244
+ def quantize(self, weights: torch.Tensor) -> QuantizationResult:
245
+ """
246
+ Quantize weights to INT4 precision.
247
+ Uses group quantization for better accuracy.
248
+
249
+ Args:
250
+ weights: Tensor of shape (out_features, in_features)
251
+
252
+ Returns:
253
+ QuantizationResult with packed int4 weights and scales
254
+ """
255
+ original_shape = weights.shape
256
+ w_fp32 = weights.clone().to(torch.float32)
257
+ out_features, in_features = w_fp32.shape
258
+
259
+ group_size = self.config.group_size or 128
260
+
261
+ # Pad if necessary
262
+ if in_features % group_size != 0:
263
+ pad_size = group_size - (in_features % group_size)
264
+ w_fp32 = F.pad(w_fp32, (0, pad_size))
265
+ in_features = w_fp32.shape[1]
266
+
267
+ # Reshape for group quantization
268
+ num_groups = in_features // group_size
269
+ weights_grouped = w_fp32.reshape(out_features, num_groups, group_size)
270
+
271
+ if self.config.mode == QuantizationMode.SYMMETRIC:
272
+ # Symmetric: range [-8, 7] for signed int4
273
+ scales = weights_grouped.abs().max(dim=-1).values / 7
274
+ scales = scales.clamp(min=1e-8)
275
+ zero_points = None
276
+
277
+ # Quantize to int4 range
278
+ int4_weights = torch.round(weights_grouped / scales.unsqueeze(-1))
279
+ int4_weights = int4_weights.clamp(-8, 7).to(torch.int8) # Store as int8
280
+ else:
281
+ # Asymmetric: range [0, 15] for unsigned int4
282
+ w_min = weights_grouped.min(dim=-1).values
283
+ w_max = weights_grouped.max(dim=-1).values
284
+
285
+ scales = (w_max - w_min) / 15
286
+ scales = scales.clamp(min=1e-8)
287
+ zero_points = torch.round(-w_min / scales).clamp(0, 15).to(torch.int8)
288
+
289
+ int4_weights = torch.round(weights_grouped / scales.unsqueeze(-1) + zero_points.unsqueeze(-1))
290
+ int4_weights = int4_weights.clamp(0, 15).to(torch.int8)
291
+
292
+ # Reshape back
293
+ int4_weights = int4_weights.reshape(out_features, in_features)
294
+ int4_weights = int4_weights[:, :original_shape[1]]
295
+
296
+ # Pack two int4 values into one int8 (for memory efficiency)
297
+ packed_weights = self._pack_int4(int4_weights)
298
+
299
+ # Dequantize for error calculation
300
+ dequantized = self.dequantize_weights(int4_weights, scales, zero_points, group_size)
301
+ dequantized = dequantized[:, :original_shape[1]]
302
+
303
+ max_error, mean_error = self._calculate_error(weights, dequantized)
304
+
305
+ # Memory savings: int4 is half of int8
306
+ original_bytes = weights.numel() * weights.element_size()
307
+ packed_bytes = packed_weights.numel() * packed_weights.element_size()
308
+ scales_bytes = scales.numel() * scales.element_size()
309
+ memory_savings = 100 * (1 - (packed_bytes + scales_bytes) / original_bytes)
310
+
311
+ return QuantizationResult(
312
+ quantized_weights=packed_weights,
313
+ scales=scales.reshape(out_features, num_groups),
314
+ zero_points=zero_points.reshape(out_features, num_groups) if zero_points is not None else None,
315
+ original_shape=original_shape,
316
+ config=self.config,
317
+ max_error=max_error,
318
+ mean_error=mean_error,
319
+ memory_savings_percent=memory_savings
320
+ )
321
+
322
+ def _pack_int4(self, int4_weights: torch.Tensor) -> torch.Tensor:
323
+ """Pack two int4 values into one int8"""
324
+ out_features, in_features = int4_weights.shape
325
+
326
+ # Ensure even number of features
327
+ if in_features % 2 != 0:
328
+ int4_weights = F.pad(int4_weights, (0, 1))
329
+ in_features += 1
330
+
331
+ # Reshape and pack
332
+ reshaped = int4_weights.reshape(out_features, in_features // 2, 2)
333
+ # Pack: low 4 bits + high 4 bits
334
+ packed = (reshaped[:, :, 0] & 0x0F) | ((reshaped[:, :, 1] & 0x0F) << 4)
335
+ return packed.to(torch.int8)
336
+
337
+ def _unpack_int4(self, packed_weights: torch.Tensor, original_in_features: int) -> torch.Tensor:
338
+ """Unpack int8 to two int4 values"""
339
+ out_features = packed_weights.shape[0]
340
+
341
+ # Unpack
342
+ low = packed_weights & 0x0F
343
+ high = (packed_weights >> 4) & 0x0F
344
+
345
+ # Handle signed values
346
+ low = torch.where(low > 7, low - 16, low)
347
+ high = torch.where(high > 7, high - 16, high)
348
+
349
+ # Interleave
350
+ unpacked = torch.stack([low, high], dim=-1).reshape(out_features, -1)
351
+ return unpacked[:, :original_in_features]
352
+
353
+ def dequantize_weights(self, int4_weights: torch.Tensor, scales: torch.Tensor,
354
+ zero_points: Optional[torch.Tensor] = None,
355
+ group_size: Optional[int] = None) -> torch.Tensor:
356
+ """Dequantize INT4 weights back to floating point"""
357
+ out_features, in_features = int4_weights.shape
358
+ group_size = group_size or self.config.group_size or 128
359
+ num_groups = scales.shape[1] if scales.dim() > 1 else 1
360
+
361
+ # Expand scales
362
+ scales_flat = scales.reshape(out_features, num_groups)
363
+ scales_expanded = scales_flat.unsqueeze(-1).expand(-1, -1, group_size)
364
+ scales_expanded = scales_expanded.reshape(out_features, -1)[:, :in_features]
365
+
366
+ if zero_points is not None:
367
+ zp_flat = zero_points.reshape(out_features, num_groups)
368
+ zp_expanded = zp_flat.unsqueeze(-1).expand(-1, -1, group_size)
369
+ zp_expanded = zp_expanded.reshape(out_features, -1)[:, :in_features]
370
+ return (int4_weights.float() - zp_expanded.float()) * scales_expanded
371
+
372
+ return int4_weights.float() * scales_expanded
373
+
374
+ def dequantize(self, result: QuantizationResult) -> torch.Tensor:
375
+ """Dequantize from QuantizationResult (handles packed weights)"""
376
+ unpacked = self._unpack_int4(result.quantized_weights, result.original_shape[1])
377
+ return self.dequantize_weights(
378
+ unpacked,
379
+ result.scales,
380
+ result.zero_points,
381
+ result.config.group_size
382
+ )
383
+
384
+
385
+ class NF4Quantizer(BaseQuantizer):
386
+ """
387
+ Normal Float 4-bit quantization (NF4).
388
+ Uses a fixed codebook optimized for normally distributed weights.
389
+ """
390
+
391
+ # NF4 codebook: values optimized for normal distribution
392
+ NF4_CODEBOOK = torch.tensor([
393
+ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
394
+ -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
395
+ 0.07958029955625534, 0.16093020141124725, 0.24611008348274231,
396
+ 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
397
+ 0.7229568362236023, 1.0
398
+ ])
399
+
400
+ def __init__(self, config: Optional[QuantizationConfig] = None):
401
+ if config is None:
402
+ config = QuantizationConfig(bits=4, method=QuantizationMethod.NF4, group_size=64)
403
+ super().__init__(config)
404
+
405
+ # Fix the codebook
406
+ self.codebook = torch.tensor([
407
+ -1.0, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0,
408
+ 0.0796, 0.1609, 0.2461, 0.3379, 0.4407, 0.5626, 0.7230, 1.0
409
+ ])
410
+
411
+ def quantize(self, weights: torch.Tensor) -> QuantizationResult:
412
+ """Quantize weights using NF4 codebook"""
413
+ original_shape = weights.shape
414
+ w_fp32 = weights.clone().to(torch.float32)
415
+ out_features, in_features = w_fp32.shape
416
+
417
+ group_size = self.config.group_size or 64
418
+
419
+ # Pad if needed
420
+ if in_features % group_size != 0:
421
+ pad_size = group_size - (in_features % group_size)
422
+ w_fp32 = F.pad(w_fp32, (0, pad_size))
423
+ in_features = w_fp32.shape[1]
424
+
425
+ # Reshape for group quantization
426
+ num_groups = in_features // group_size
427
+ weights_grouped = w_fp32.reshape(out_features, num_groups, group_size)
428
+
429
+ # Calculate absmax scales per group
430
+ scales = weights_grouped.abs().max(dim=-1).values
431
+ scales = scales.clamp(min=1e-8)
432
+
433
+ # Normalize weights
434
+ normalized = weights_grouped / scales.unsqueeze(-1)
435
+
436
+ # Find nearest codebook entry for each weight
437
+ codebook = self.codebook.to(weights.device)
438
+ distances = torch.abs(normalized.unsqueeze(-1) - codebook)
439
+ indices = distances.argmin(dim=-1).to(torch.int8)
440
+
441
+ # Reshape back
442
+ indices = indices.reshape(out_features, in_features)[:, :original_shape[1]]
443
+
444
+ # Pack indices
445
+ packed = self._pack_int4(indices)
446
+
447
+ # Dequantize for error calculation
448
+ dequantized = self.dequantize_weights(indices, scales.reshape(out_features, num_groups), group_size)
449
+ dequantized = dequantized[:, :original_shape[1]]
450
+
451
+ max_error, mean_error = self._calculate_error(weights, dequantized)
452
+
453
+ # Memory savings
454
+ original_bytes = weights.numel() * weights.element_size()
455
+ packed_bytes = packed.numel() * packed.element_size()
456
+ scales_bytes = scales.numel() * scales.element_size()
457
+ memory_savings = 100 * (1 - (packed_bytes + scales_bytes) / original_bytes)
458
+
459
+ return QuantizationResult(
460
+ quantized_weights=packed,
461
+ scales=scales.reshape(out_features, num_groups),
462
+ zero_points=None,
463
+ original_shape=original_shape,
464
+ config=self.config,
465
+ max_error=max_error,
466
+ mean_error=mean_error,
467
+ memory_savings_percent=memory_savings
468
+ )
469
+
470
+ def _pack_int4(self, indices: torch.Tensor) -> torch.Tensor:
471
+ """Pack two indices into one int8"""
472
+ out_features, in_features = indices.shape
473
+ if in_features % 2 != 0:
474
+ indices = F.pad(indices, (0, 1))
475
+ in_features += 1
476
+
477
+ reshaped = indices.reshape(out_features, in_features // 2, 2)
478
+ packed = (reshaped[:, :, 0] & 0x0F) | ((reshaped[:, :, 1] & 0x0F) << 4)
479
+ return packed.to(torch.int8)
480
+
481
+ def _unpack_int4(self, packed: torch.Tensor, original_in_features: int) -> torch.Tensor:
482
+ """Unpack int8 to two indices"""
483
+ out_features = packed.shape[0]
484
+ low = packed & 0x0F
485
+ high = (packed >> 4) & 0x0F
486
+ unpacked = torch.stack([low, high], dim=-1).reshape(out_features, -1)
487
+ return unpacked[:, :original_in_features]
488
+
489
+ def dequantize_weights(self, indices: torch.Tensor, scales: torch.Tensor,
490
+ group_size: Optional[int] = None) -> torch.Tensor:
491
+ """Dequantize NF4 indices back to floating point"""
492
+ codebook = self.codebook.to(indices.device)
493
+
494
+ # Look up codebook values
495
+ dequantized = codebook[indices.long()]
496
+
497
+ # Apply scales
498
+ out_features, in_features = indices.shape
499
+ group_size = group_size or self.config.group_size or 64
500
+ num_groups = scales.shape[1]
501
+
502
+ scales_expanded = scales.unsqueeze(-1).expand(-1, -1, group_size)
503
+ scales_expanded = scales_expanded.reshape(out_features, -1)[:, :in_features]
504
+
505
+ return dequantized * scales_expanded
506
+
507
+ def dequantize(self, result: QuantizationResult) -> torch.Tensor:
508
+ """Dequantize from QuantizationResult"""
509
+ unpacked = self._unpack_int4(result.quantized_weights, result.original_shape[1])
510
+ return self.dequantize_weights(
511
+ unpacked,
512
+ result.scales,
513
+ result.config.group_size
514
+ )
515
+
516
+
517
+ class QuantizedLinear(nn.Module):
518
+ """
519
+ Quantized Linear layer supporting multiple quantization methods.
520
+ Compatible with W8A16, W4A16, NF4, and GPTQ quantization.
521
+ """
522
+
523
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
524
+ config: Optional[QuantizationConfig] = None):
525
+ super().__init__()
526
+
527
+ self.in_features = in_features
528
+ self.out_features = out_features
529
+ self.config = config or QuantizationConfig()
530
+
531
+ # Initialize quantizer based on config
532
+ self.quantizer = self._get_quantizer()
533
+
534
+ # Buffers for quantized weights
535
+ self.register_buffer("quantized_weights", None)
536
+ self.register_buffer("scales", None)
537
+ self.register_buffer("zero_points", None)
538
+
539
+ # Bias (kept in full precision)
540
+ if bias:
541
+ self.register_buffer("bias", torch.zeros(out_features))
542
+ else:
543
+ self.bias = None
544
+
545
+ self._quantized = False
546
+
547
+ def _get_quantizer(self) -> BaseQuantizer:
548
+ """Get appropriate quantizer based on config"""
549
+ if self.config.method == QuantizationMethod.INT8:
550
+ return INT8Quantizer(self.config)
551
+ elif self.config.method == QuantizationMethod.INT4:
552
+ return INT4Quantizer(self.config)
553
+ elif self.config.method == QuantizationMethod.NF4:
554
+ return NF4Quantizer(self.config)
555
+ else:
556
+ raise ValueError(f"Unsupported quantization method: {self.config.method}")
557
+
558
+ def quantize_weights(self, weights: torch.Tensor, bias: Optional[torch.Tensor] = None) -> QuantizationResult:
559
+ """Quantize weights and store in layer"""
560
+ result = self.quantizer.quantize(weights)
561
+
562
+ self.quantized_weights = result.quantized_weights
563
+ self.scales = result.scales
564
+ self.zero_points = result.zero_points
565
+
566
+ if bias is not None:
567
+ self.bias = bias.clone()
568
+
569
+ self._quantized = True
570
+ return result
571
+
572
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
573
+ """Forward pass with dequantization on-the-fly"""
574
+ if not self._quantized:
575
+ raise RuntimeError("Layer has not been quantized. Call quantize_weights first.")
576
+
577
+ # Dequantize weights
578
+ weights = self.quantizer.dequantize(QuantizationResult(
579
+ quantized_weights=self.quantized_weights,
580
+ scales=self.scales,
581
+ zero_points=self.zero_points,
582
+ original_shape=(self.out_features, self.in_features),
583
+ config=self.config,
584
+ max_error=0, mean_error=0, memory_savings_percent=0
585
+ ))
586
+
587
+ # Linear operation
588
+ output = F.linear(x, weights.to(x.dtype))
589
+
590
+ if self.bias is not None:
591
+ output = output + self.bias.to(x.dtype)
592
+
593
+ return output
594
+
595
+
596
+ def get_quantizer(config: QuantizationConfig) -> BaseQuantizer:
597
+ """Factory function to get appropriate quantizer"""
598
+ if config.method == QuantizationMethod.INT8:
599
+ return INT8Quantizer(config)
600
+ elif config.method == QuantizationMethod.INT4:
601
+ return INT4Quantizer(config)
602
+ elif config.method == QuantizationMethod.NF4:
603
+ return NF4Quantizer(config)
604
+ else:
605
+ raise ValueError(f"Unsupported method: {config.method}")
backend/core/system_checker.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System Requirements Checker
3
+ Detects GPU availability, memory, and provides hardware recommendations.
4
+ """
5
+
6
+ import torch
7
+ import psutil
8
+ import platform
9
+ from dataclasses import dataclass
10
+ from typing import Optional, List, Dict, Any
11
+ from enum import Enum
12
+
13
+
14
+ class HardwareCapability(Enum):
15
+ """Hardware capability levels"""
16
+ FULL_GPU = "full_gpu" # CUDA GPU with sufficient VRAM
17
+ LIMITED_GPU = "limited_gpu" # CUDA GPU with limited VRAM
18
+ CPU_ONLY = "cpu_only" # No GPU available
19
+ APPLE_SILICON = "apple_silicon" # M1/M2/M3 with MPS
20
+
21
+
22
+ @dataclass
23
+ class GPUInfo:
24
+ """Information about a GPU device"""
25
+ index: int
26
+ name: str
27
+ total_memory_gb: float
28
+ free_memory_gb: float
29
+ compute_capability: Optional[str] = None
30
+
31
+
32
+ @dataclass
33
+ class SystemInfo:
34
+ """Complete system information"""
35
+ platform: str
36
+ python_version: str
37
+ torch_version: str
38
+ cuda_available: bool
39
+ cuda_version: Optional[str]
40
+ mps_available: bool
41
+ cpu_cores: int
42
+ ram_total_gb: float
43
+ ram_available_gb: float
44
+ gpus: List[GPUInfo]
45
+ capability: HardwareCapability
46
+ recommended_batch_size: int
47
+ max_model_size: str
48
+ warnings: List[str]
49
+
50
+
51
+ class SystemChecker:
52
+ """Check system capabilities for quantization tasks"""
53
+
54
+ # Model size thresholds (in billions of parameters)
55
+ MODEL_SIZES = {
56
+ "tiny": 0.1, # ~100M params
57
+ "small": 0.5, # ~500M params
58
+ "medium": 1.0, # ~1B params
59
+ "large": 7.0, # ~7B params
60
+ "xlarge": 13.0, # ~13B params
61
+ "xxlarge": 70.0 # ~70B params
62
+ }
63
+
64
+ # Memory requirements per billion parameters (GB)
65
+ MEMORY_PER_BILLION_PARAMS = {
66
+ "fp32": 4.0,
67
+ "fp16": 2.0,
68
+ "int8": 1.0,
69
+ "int4": 0.5
70
+ }
71
+
72
+ def __init__(self):
73
+ self._system_info: Optional[SystemInfo] = None
74
+
75
+ def check(self, force_refresh: bool = False) -> SystemInfo:
76
+ """Perform full system check"""
77
+ if self._system_info is not None and not force_refresh:
78
+ return self._system_info
79
+
80
+ warnings = []
81
+ gpus = []
82
+
83
+ # Basic info
84
+ cuda_available = torch.cuda.is_available()
85
+ mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
86
+
87
+ # CUDA version
88
+ cuda_version = None
89
+ if cuda_available:
90
+ cuda_version = torch.version.cuda
91
+
92
+ # GPU detection
93
+ if cuda_available:
94
+ try:
95
+ for i in range(torch.cuda.device_count()):
96
+ props = torch.cuda.get_device_properties(i)
97
+ total_mem = props.total_memory / (1024**3)
98
+ free_mem = (props.total_memory - torch.cuda.memory_reserved(i)) / (1024**3)
99
+
100
+ gpus.append(GPUInfo(
101
+ index=i,
102
+ name=props.name,
103
+ total_memory_gb=round(total_mem, 2),
104
+ free_memory_gb=round(free_mem, 2),
105
+ compute_capability=f"{props.major}.{props.minor}"
106
+ ))
107
+ except Exception as e:
108
+ warnings.append(f"Error detecting GPU: {str(e)}")
109
+
110
+ # RAM info
111
+ ram = psutil.virtual_memory()
112
+ ram_total_gb = ram.total / (1024**3)
113
+ ram_available_gb = ram.available / (1024**3)
114
+
115
+ # Determine capability
116
+ capability = self._determine_capability(gpus, mps_available, ram_total_gb)
117
+
118
+ # Recommendations
119
+ recommended_batch_size = self._get_recommended_batch_size(capability, gpus)
120
+ max_model_size = self._get_max_model_size(capability, gpus, ram_total_gb)
121
+
122
+ # Add warnings
123
+ if not cuda_available and not mps_available:
124
+ warnings.append("No GPU detected. Quantization will run on CPU (slower).")
125
+
126
+ if ram_available_gb < 8:
127
+ warnings.append(f"Low RAM available ({ram_available_gb:.1f}GB). Large models may fail.")
128
+
129
+ if gpus and gpus[0].free_memory_gb < 4:
130
+ warnings.append(f"Low GPU memory ({gpus[0].free_memory_gb:.1f}GB free). Consider smaller models.")
131
+
132
+ self._system_info = SystemInfo(
133
+ platform=platform.system(),
134
+ python_version=platform.python_version(),
135
+ torch_version=torch.__version__,
136
+ cuda_available=cuda_available,
137
+ cuda_version=cuda_version,
138
+ mps_available=mps_available,
139
+ cpu_cores=psutil.cpu_count(logical=False) or 1,
140
+ ram_total_gb=round(ram_total_gb, 2),
141
+ ram_available_gb=round(ram_available_gb, 2),
142
+ gpus=gpus,
143
+ capability=capability,
144
+ recommended_batch_size=recommended_batch_size,
145
+ max_model_size=max_model_size,
146
+ warnings=warnings
147
+ )
148
+
149
+ return self._system_info
150
+
151
+ def _determine_capability(self, gpus: List[GPUInfo], mps_available: bool,
152
+ ram_total_gb: float) -> HardwareCapability:
153
+ """Determine hardware capability level"""
154
+ if mps_available:
155
+ return HardwareCapability.APPLE_SILICON
156
+
157
+ if not gpus:
158
+ return HardwareCapability.CPU_ONLY
159
+
160
+ # Check if any GPU has >= 8GB VRAM
161
+ max_vram = max(gpu.total_memory_gb for gpu in gpus)
162
+
163
+ if max_vram >= 8:
164
+ return HardwareCapability.FULL_GPU
165
+ else:
166
+ return HardwareCapability.LIMITED_GPU
167
+
168
+ def _get_recommended_batch_size(self, capability: HardwareCapability,
169
+ gpus: List[GPUInfo]) -> int:
170
+ """Get recommended batch size based on hardware"""
171
+ if capability == HardwareCapability.CPU_ONLY:
172
+ return 1
173
+ elif capability == HardwareCapability.LIMITED_GPU:
174
+ return 4
175
+ elif capability == HardwareCapability.APPLE_SILICON:
176
+ return 8
177
+ else:
178
+ # Full GPU - scale with VRAM
179
+ if gpus:
180
+ vram = gpus[0].total_memory_gb
181
+ if vram >= 24:
182
+ return 32
183
+ elif vram >= 16:
184
+ return 16
185
+ elif vram >= 8:
186
+ return 8
187
+ return 8
188
+
189
+ def _get_max_model_size(self, capability: HardwareCapability,
190
+ gpus: List[GPUInfo], ram_gb: float) -> str:
191
+ """Get maximum recommended model size"""
192
+ if capability == HardwareCapability.CPU_ONLY:
193
+ # CPU-only: limited by RAM, very slow for large models
194
+ if ram_gb >= 32:
195
+ return "medium (1B)"
196
+ elif ram_gb >= 16:
197
+ return "small (500M)"
198
+ else:
199
+ return "tiny (100M)"
200
+
201
+ elif capability == HardwareCapability.LIMITED_GPU:
202
+ return "small (500M)"
203
+
204
+ elif capability == HardwareCapability.APPLE_SILICON:
205
+ # Apple Silicon: depends on unified memory
206
+ if ram_gb >= 32:
207
+ return "large (7B)"
208
+ elif ram_gb >= 16:
209
+ return "medium (1B)"
210
+ else:
211
+ return "small (500M)"
212
+
213
+ else: # FULL_GPU
214
+ if gpus:
215
+ vram = gpus[0].total_memory_gb
216
+ if vram >= 48:
217
+ return "xxlarge (70B)"
218
+ elif vram >= 24:
219
+ return "xlarge (13B)"
220
+ elif vram >= 16:
221
+ return "large (7B)"
222
+ elif vram >= 8:
223
+ return "medium (1B)"
224
+ return "medium (1B)"
225
+
226
+ def can_load_model(self, model_params_billions: float,
227
+ dtype: str = "fp16") -> Dict[str, Any]:
228
+ """Check if a specific model can be loaded"""
229
+ info = self.check()
230
+
231
+ memory_required = model_params_billions * self.MEMORY_PER_BILLION_PARAMS.get(dtype, 2.0)
232
+ memory_required *= 1.3 # 30% overhead for activations, optimizer, etc.
233
+
234
+ # Check GPU memory
235
+ gpu_ok = False
236
+ gpu_memory = 0
237
+ if info.gpus:
238
+ gpu_memory = info.gpus[0].free_memory_gb
239
+ gpu_ok = gpu_memory >= memory_required
240
+
241
+ # Check RAM
242
+ ram_ok = info.ram_available_gb >= memory_required
243
+
244
+ can_load = gpu_ok or (info.capability == HardwareCapability.CPU_ONLY and ram_ok)
245
+
246
+ return {
247
+ "can_load": can_load,
248
+ "memory_required_gb": round(memory_required, 2),
249
+ "gpu_available_gb": round(gpu_memory, 2) if info.gpus else 0,
250
+ "ram_available_gb": round(info.ram_available_gb, 2),
251
+ "recommended_device": "cuda" if gpu_ok else ("mps" if info.mps_available else "cpu"),
252
+ "warnings": [] if can_load else [
253
+ f"Model requires ~{memory_required:.1f}GB memory. " +
254
+ f"Available: GPU={gpu_memory:.1f}GB, RAM={info.ram_available_gb:.1f}GB"
255
+ ]
256
+ }
257
+
258
+ def to_dict(self) -> Dict[str, Any]:
259
+ """Convert system info to dictionary"""
260
+ info = self.check()
261
+ return {
262
+ "platform": info.platform,
263
+ "python_version": info.python_version,
264
+ "torch_version": info.torch_version,
265
+ "cuda_available": info.cuda_available,
266
+ "cuda_version": info.cuda_version,
267
+ "mps_available": info.mps_available,
268
+ "cpu_cores": info.cpu_cores,
269
+ "ram_total_gb": info.ram_total_gb,
270
+ "ram_available_gb": info.ram_available_gb,
271
+ "gpus": [
272
+ {
273
+ "index": gpu.index,
274
+ "name": gpu.name,
275
+ "total_memory_gb": gpu.total_memory_gb,
276
+ "free_memory_gb": gpu.free_memory_gb,
277
+ "compute_capability": gpu.compute_capability
278
+ }
279
+ for gpu in info.gpus
280
+ ],
281
+ "capability": info.capability.value,
282
+ "recommended_batch_size": info.recommended_batch_size,
283
+ "max_model_size": info.max_model_size,
284
+ "warnings": info.warnings
285
+ }
286
+
287
+
288
+ # Global instance
289
+ system_checker = SystemChecker()
290
+
291
+
292
+ def get_system_info() -> Dict[str, Any]:
293
+ """Get system information as dictionary"""
294
+ return system_checker.to_dict()
295
+
296
+
297
+ def check_model_requirements(model_params_billions: float, dtype: str = "fp16") -> Dict[str, Any]:
298
+ """Check if system can handle a specific model"""
299
+ return system_checker.can_load_model(model_params_billions, dtype)
backend/core/visualization.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for weight matrices and quantization analysis.
3
+ Generates chart data for frontend consumption.
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from typing import Dict, Any, List, Tuple, Optional
9
+ from dataclasses import dataclass
10
+ import base64
11
+ import io
12
+
13
+ # Import matplotlib with non-interactive backend
14
+ import matplotlib
15
+ matplotlib.use('Agg')
16
+ import matplotlib.pyplot as plt
17
+ from matplotlib.colors import TwoSlopeNorm
18
+
19
+
20
+ @dataclass
21
+ class ChartData:
22
+ """Data structure for chart rendering"""
23
+ chart_type: str
24
+ data: Dict[str, Any]
25
+ layout: Dict[str, Any]
26
+
27
+
28
+ class Visualizer:
29
+ """Generate visualization data for weight matrices and quantization analysis"""
30
+
31
+ def __init__(self, max_display_size: int = 128):
32
+ """
33
+ Args:
34
+ max_display_size: Maximum dimension for heatmap display (downsampled if larger)
35
+ """
36
+ self.max_display_size = max_display_size
37
+
38
+ def weight_heatmap(self, weights: torch.Tensor, title: str = "Weight Matrix",
39
+ downsample: bool = True) -> ChartData:
40
+ """
41
+ Generate heatmap data for weight matrix visualization.
42
+ Returns Plotly-compatible data structure.
43
+ """
44
+ w = weights.detach().cpu().float().numpy()
45
+
46
+ # Downsample if too large
47
+ if downsample and (w.shape[0] > self.max_display_size or w.shape[1] > self.max_display_size):
48
+ w = self._downsample_2d(w, self.max_display_size)
49
+
50
+ # Calculate symmetric colorscale bounds - convert to Python float for JSON
51
+ vmax = float(max(abs(w.min()), abs(w.max())))
52
+
53
+ return ChartData(
54
+ chart_type="heatmap",
55
+ data={
56
+ "z": w.tolist(),
57
+ "colorscale": "RdBu_r",
58
+ "zmin": -vmax,
59
+ "zmax": vmax,
60
+ "zmid": 0
61
+ },
62
+ layout={
63
+ "title": title,
64
+ "xaxis": {"title": "Input Features"},
65
+ "yaxis": {"title": "Output Features"}
66
+ }
67
+ )
68
+
69
+ def weight_histogram(self, weights: torch.Tensor, title: str = "Weight Distribution",
70
+ bins: int = 50) -> ChartData:
71
+ """Generate histogram data for weight distribution"""
72
+ w = weights.detach().cpu().float().numpy().flatten()
73
+
74
+ hist, bin_edges = np.histogram(w, bins=bins)
75
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
76
+
77
+ return ChartData(
78
+ chart_type="bar",
79
+ data={
80
+ "x": bin_centers.tolist(),
81
+ "y": hist.tolist(),
82
+ "type": "bar"
83
+ },
84
+ layout={
85
+ "title": title,
86
+ "xaxis": {"title": "Weight Value"},
87
+ "yaxis": {"title": "Frequency"},
88
+ "bargap": 0.05
89
+ }
90
+ )
91
+
92
+ def error_heatmap(self, original: torch.Tensor, quantized: torch.Tensor,
93
+ scales: torch.Tensor, title: str = "Quantization Error") -> ChartData:
94
+ """Generate error heatmap between original and dequantized weights"""
95
+ orig = original.detach().cpu().float()
96
+ quant = quantized.detach().cpu().float()
97
+ sc = scales.detach().cpu().float()
98
+
99
+ # Dequantize
100
+ if sc.dim() == 1:
101
+ dequant = quant * sc.unsqueeze(1)
102
+ else:
103
+ # Group quantization - expand scales
104
+ dequant = quant * sc.unsqueeze(-1)
105
+ dequant = dequant.reshape(orig.shape)
106
+
107
+ error = (orig - dequant).abs().numpy()
108
+
109
+ # Downsample if needed
110
+ if error.shape[0] > self.max_display_size or error.shape[1] > self.max_display_size:
111
+ error = self._downsample_2d(error, self.max_display_size)
112
+
113
+ return ChartData(
114
+ chart_type="heatmap",
115
+ data={
116
+ "z": error.tolist(),
117
+ "colorscale": "Reds",
118
+ "zmin": 0
119
+ },
120
+ layout={
121
+ "title": title,
122
+ "xaxis": {"title": "Input Features"},
123
+ "yaxis": {"title": "Output Features"}
124
+ }
125
+ )
126
+
127
+ def comparison_overlay(self, original: torch.Tensor, dequantized: torch.Tensor,
128
+ sample_size: int = 1000) -> ChartData:
129
+ """Generate scatter plot comparing original vs dequantized values"""
130
+ orig = original.detach().cpu().float().numpy().flatten()
131
+ deq = dequantized.detach().cpu().float().numpy().flatten()
132
+
133
+ # Sample if too large
134
+ if len(orig) > sample_size:
135
+ indices = np.random.choice(len(orig), sample_size, replace=False)
136
+ orig = orig[indices]
137
+ deq = deq[indices]
138
+
139
+ return ChartData(
140
+ chart_type="scatter",
141
+ data={
142
+ "x": orig.tolist(),
143
+ "y": deq.tolist(),
144
+ "mode": "markers",
145
+ "marker": {"size": 3, "opacity": 0.5}
146
+ },
147
+ layout={
148
+ "title": "Original vs Dequantized Weights",
149
+ "xaxis": {"title": "Original Value"},
150
+ "yaxis": {"title": "Dequantized Value"},
151
+ "shapes": [{
152
+ "type": "line",
153
+ "x0": float(orig.min()),
154
+ "x1": float(orig.max()),
155
+ "y0": float(orig.min()),
156
+ "y1": float(orig.max()),
157
+ "line": {"color": "red", "dash": "dash"}
158
+ }]
159
+ }
160
+ )
161
+
162
+ def scales_histogram(self, scales: torch.Tensor,
163
+ title: str = "Quantization Scales Distribution") -> ChartData:
164
+ """Generate histogram of quantization scales"""
165
+ s = scales.detach().cpu().float().numpy().flatten()
166
+
167
+ hist, bin_edges = np.histogram(s, bins=30)
168
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
169
+
170
+ return ChartData(
171
+ chart_type="bar",
172
+ data={
173
+ "x": bin_centers.tolist(),
174
+ "y": hist.tolist(),
175
+ "marker": {"color": "green"}
176
+ },
177
+ layout={
178
+ "title": title,
179
+ "xaxis": {"title": "Scale Value"},
180
+ "yaxis": {"title": "Frequency"}
181
+ }
182
+ )
183
+
184
+ def layer_error_bar(self, layer_errors: Dict[str, float],
185
+ title: str = "Quantization Error by Layer") -> ChartData:
186
+ """Generate bar chart of errors per layer"""
187
+ layers = list(layer_errors.keys())
188
+ errors = list(layer_errors.values())
189
+
190
+ return ChartData(
191
+ chart_type="bar",
192
+ data={
193
+ "x": layers,
194
+ "y": errors,
195
+ "marker": {"color": "coral"}
196
+ },
197
+ layout={
198
+ "title": title,
199
+ "xaxis": {"title": "Layer", "tickangle": 45},
200
+ "yaxis": {"title": "Mean Absolute Error"}
201
+ }
202
+ )
203
+
204
+ def memory_comparison(self, original_mb: float, quantized_mb: float,
205
+ overhead_mb: float = 0) -> ChartData:
206
+ """Generate memory comparison chart"""
207
+ return ChartData(
208
+ chart_type="bar",
209
+ data={
210
+ "x": ["Original (FP32)", "Quantized + Scales", "Savings"],
211
+ "y": [original_mb, quantized_mb + overhead_mb, original_mb - quantized_mb - overhead_mb],
212
+ "marker": {"color": ["#3498db", "#2ecc71", "#e74c3c"]}
213
+ },
214
+ layout={
215
+ "title": "Memory Usage Comparison",
216
+ "yaxis": {"title": "Memory (MB)"}
217
+ }
218
+ )
219
+
220
+ def _downsample_2d(self, arr: np.ndarray, max_size: int) -> np.ndarray:
221
+ """Downsample 2D array to max_size x max_size"""
222
+ h, w = arr.shape
223
+
224
+ if h > max_size:
225
+ step_h = h // max_size
226
+ arr = arr[::step_h, :][:max_size, :]
227
+
228
+ if w > max_size:
229
+ step_w = w // max_size
230
+ arr = arr[:, ::step_w][:, :max_size]
231
+
232
+ return arr
233
+
234
+ def generate_png(self, weights: torch.Tensor, title: str = "Weights") -> bytes:
235
+ """Generate PNG image bytes (for backward compatibility)"""
236
+ w = weights.detach().cpu().float().numpy()
237
+
238
+ if w.shape[0] > self.max_display_size or w.shape[1] > self.max_display_size:
239
+ w = self._downsample_2d(w, self.max_display_size)
240
+
241
+ fig, ax = plt.subplots(figsize=(10, 8))
242
+
243
+ vmax = max(abs(w.min()), abs(w.max()))
244
+ norm = TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
245
+
246
+ im = ax.imshow(w, cmap='RdBu_r', norm=norm)
247
+ plt.colorbar(im, label='Weight Value')
248
+ ax.set_title(title)
249
+
250
+ buf = io.BytesIO()
251
+ plt.savefig(buf, format='png', bbox_inches='tight')
252
+ plt.close(fig)
253
+ buf.seek(0)
254
+
255
+ return buf.getvalue()
256
+
257
+ def to_dict(self, chart: ChartData) -> Dict[str, Any]:
258
+ """Convert ChartData to dictionary"""
259
+ return {
260
+ "type": chart.chart_type,
261
+ "data": chart.data,
262
+ "layout": chart.layout
263
+ }
264
+
265
+
266
+ # Global instance
267
+ visualizer = Visualizer()
268
+
269
+
270
+ def get_weight_heatmap(weights: torch.Tensor, title: str = "Weights") -> Dict[str, Any]:
271
+ """Generate weight heatmap data"""
272
+ return visualizer.to_dict(visualizer.weight_heatmap(weights, title))
273
+
274
+
275
+ def get_weight_histogram(weights: torch.Tensor, title: str = "Distribution") -> Dict[str, Any]:
276
+ """Generate weight histogram data"""
277
+ return visualizer.to_dict(visualizer.weight_histogram(weights, title))
backend/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.100.0
2
+ uvicorn>=0.23.0
3
+ python-multipart>=0.0.6
4
+ torch>=2.0.0
5
+ transformers>=4.31.0
6
+ accelerate>=0.21.0
7
+ bitsandbytes>=0.40.0
8
+ scipy>=1.11.0
9
+ numpy>=1.24.0
10
+ pydantic>=2.0.0
11
+ jinja2>=3.1.2
docker-compose.yml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Compose for local development with GPU support
2
+
3
+ services:
4
+ # Full application (frontend + backend)
5
+ app:
6
+ build:
7
+ context: .
8
+ dockerfile: Dockerfile
9
+ ports:
10
+ - "7860:7860"
11
+ environment:
12
+ - CUDA_VISIBLE_DEVICES=0
13
+ volumes:
14
+ - ./models:/app/models
15
+ - ./cache:/app/cache
16
+ deploy:
17
+ resources:
18
+ reservations:
19
+ devices:
20
+ - driver: nvidia
21
+ count: 1
22
+ capabilities: [ gpu ]
23
+ restart: unless-stopped
24
+
25
+ # Development mode: separate frontend and backend
26
+ backend-dev:
27
+ build:
28
+ context: .
29
+ dockerfile: Dockerfile
30
+ target: python-base
31
+ command: python -m uvicorn backend.api.main:app --host 0.0.0.0 --port 8000 --reload
32
+ ports:
33
+ - "8000:8000"
34
+ volumes:
35
+ - ./backend:/app/backend
36
+ - ./models:/app/models
37
+ environment:
38
+ - CUDA_VISIBLE_DEVICES=0
39
+ profiles:
40
+ - dev
41
+ deploy:
42
+ resources:
43
+ reservations:
44
+ devices:
45
+ - driver: nvidia
46
+ count: 1
47
+ capabilities: [ gpu ]
48
+
49
+ frontend-dev:
50
+ image: node:20-alpine
51
+ working_dir: /app
52
+ command: sh -c "npm install && npm run dev -- --host"
53
+ ports:
54
+ - "5173:5173"
55
+ volumes:
56
+ - ./frontend:/app
57
+ - /app/node_modules
58
+ environment:
59
+ - VITE_API_URL=http://localhost:8000/api
60
+ profiles:
61
+ - dev
62
+
63
+ # CPU-only version (no GPU)
64
+ app-cpu:
65
+ build:
66
+ context: .
67
+ dockerfile: Dockerfile
68
+ ports:
69
+ - "7860:7860"
70
+ volumes:
71
+ - ./models:/app/models
72
+ profiles:
73
+ - cpu
74
+ restart: unless-stopped
75
+
76
+ networks:
77
+ default:
78
+ name: quantizer-network
frontend/.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
frontend/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # React + Vite
2
+
3
+ This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4
+
5
+ Currently, two official plugins are available:
6
+
7
+ - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) (or [oxc](https://oxc.rs) when used in [rolldown-vite](https://vite.dev/guide/rolldown)) for Fast Refresh
8
+ - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
9
+
10
+ ## React Compiler
11
+
12
+ The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation).
13
+
14
+ ## Expanding the ESLint configuration
15
+
16
+ If you are developing a production application, we recommend using TypeScript with type-aware lint rules enabled. Check out the [TS template](https://github.com/vitejs/vite/tree/main/packages/create-vite/template-react-ts) for information on how to integrate TypeScript and [`typescript-eslint`](https://typescript-eslint.io) in your project.
frontend/eslint.config.js ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import { defineConfig, globalIgnores } from 'eslint/config'
6
+
7
+ export default defineConfig([
8
+ globalIgnores(['dist']),
9
+ {
10
+ files: ['**/*.{js,jsx}'],
11
+ extends: [
12
+ js.configs.recommended,
13
+ reactHooks.configs.flat.recommended,
14
+ reactRefresh.configs.vite,
15
+ ],
16
+ languageOptions: {
17
+ ecmaVersion: 2020,
18
+ globals: globals.browser,
19
+ parserOptions: {
20
+ ecmaVersion: 'latest',
21
+ ecmaFeatures: { jsx: true },
22
+ sourceType: 'module',
23
+ },
24
+ },
25
+ rules: {
26
+ 'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
27
+ },
28
+ },
29
+ ])
frontend/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>frontend</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.jsx"></script>
12
+ </body>
13
+ </html>
frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
frontend/package.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "frontend",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "axios": "^1.13.2",
14
+ "framer-motion": "^12.26.1",
15
+ "lucide-react": "^0.562.0",
16
+ "react": "^19.2.0",
17
+ "react-dom": "^19.2.0",
18
+ "react-hot-toast": "^2.6.0",
19
+ "react-router-dom": "^7.12.0",
20
+ "recharts": "^3.6.0",
21
+ "zustand": "^5.0.10"
22
+ },
23
+ "devDependencies": {
24
+ "@eslint/js": "^9.39.1",
25
+ "@types/react": "^19.2.5",
26
+ "@types/react-dom": "^19.2.3",
27
+ "@vitejs/plugin-react": "^5.1.1",
28
+ "buffer": "^6.0.3",
29
+ "eslint": "^9.39.1",
30
+ "eslint-plugin-react-hooks": "^7.0.1",
31
+ "eslint-plugin-react-refresh": "^0.4.24",
32
+ "globals": "^16.5.0",
33
+ "vite": "^7.2.4"
34
+ }
35
+ }
frontend/public/vite.svg ADDED
frontend/src/App.css ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #root {
2
+ max-width: 1280px;
3
+ margin: 0 auto;
4
+ padding: 2rem;
5
+ text-align: center;
6
+ }
7
+
8
+ .logo {
9
+ height: 6em;
10
+ padding: 1.5em;
11
+ will-change: filter;
12
+ transition: filter 300ms;
13
+ }
14
+ .logo:hover {
15
+ filter: drop-shadow(0 0 2em #646cffaa);
16
+ }
17
+ .logo.react:hover {
18
+ filter: drop-shadow(0 0 2em #61dafbaa);
19
+ }
20
+
21
+ @keyframes logo-spin {
22
+ from {
23
+ transform: rotate(0deg);
24
+ }
25
+ to {
26
+ transform: rotate(360deg);
27
+ }
28
+ }
29
+
30
+ @media (prefers-reduced-motion: no-preference) {
31
+ a:nth-of-type(2) .logo {
32
+ animation: logo-spin infinite 20s linear;
33
+ }
34
+ }
35
+
36
+ .card {
37
+ padding: 2em;
38
+ }
39
+
40
+ .read-the-docs {
41
+ color: #888;
42
+ }
frontend/src/App.jsx ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom';
2
+ import { Toaster, toast } from 'react-hot-toast';
3
+ import { useEffect } from 'react';
4
+ import Layout from './components/Layout';
5
+ import Dashboard from './pages/Dashboard';
6
+ import Quantizer from './pages/Quantizer';
7
+ import Analysis from './pages/Analysis';
8
+ import ModelLoader from './pages/ModelLoader';
9
+ import { useSystemStore } from './store';
10
+ import './index.css';
11
+
12
+ function App() {
13
+ const fetchSystemInfo = useSystemStore((state) => state.fetchSystemInfo);
14
+
15
+ useEffect(() => {
16
+ // Fetch system info on app load
17
+ fetchSystemInfo();
18
+
19
+ const handleOffline = () => toast.error("Internet connection lost");
20
+ const handleOnline = () => toast.success("Internet connection restored");
21
+
22
+ window.addEventListener('offline', handleOffline);
23
+ window.addEventListener('online', handleOnline);
24
+
25
+ return () => {
26
+ window.removeEventListener('offline', handleOffline);
27
+ window.removeEventListener('online', handleOnline);
28
+ };
29
+ }, [fetchSystemInfo]);
30
+
31
+ return (
32
+ <BrowserRouter>
33
+ <Routes>
34
+ <Route path="/" element={<Layout />}>
35
+ <Route index element={<Navigate to="/dashboard" replace />} />
36
+ <Route path="dashboard" element={<Dashboard />} />
37
+ <Route path="quantize" element={<Quantizer />} />
38
+ <Route path="analysis" element={<Analysis />} />
39
+ <Route path="models" element={<ModelLoader />} />
40
+ </Route>
41
+ </Routes>
42
+ <Toaster
43
+ position="top-right"
44
+ toastOptions={{
45
+ duration: 4000,
46
+ style: {
47
+ background: 'rgba(15, 23, 42, 0.8)',
48
+ color: '#e2e8f0',
49
+ backdropFilter: 'blur(12px)',
50
+ border: '1px solid rgba(255, 255, 255, 0.1)',
51
+ padding: '12px 24px',
52
+ borderRadius: '12px',
53
+ boxShadow: '0 8px 32px rgba(0, 0, 0, 0.2)',
54
+ fontSize: '0.95rem'
55
+ },
56
+ success: {
57
+ iconTheme: {
58
+ primary: '#6366f1',
59
+ secondary: '#fff',
60
+ },
61
+ style: {
62
+ border: '1px solid rgba(99, 102, 241, 0.2)',
63
+ background: 'rgba(99, 102, 241, 0.1)',
64
+ }
65
+ },
66
+ error: {
67
+ iconTheme: {
68
+ primary: '#ef4444',
69
+ secondary: '#fff',
70
+ },
71
+ style: {
72
+ border: '1px solid rgba(239, 68, 68, 0.2)',
73
+ background: 'rgba(239, 68, 68, 0.1)',
74
+ }
75
+ }
76
+ }}
77
+ />
78
+ </BrowserRouter>
79
+ );
80
+ }
81
+
82
+ export default App;
frontend/src/assets/react.svg ADDED
frontend/src/components/Layout.jsx ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Outlet, NavLink, useLocation } from 'react-router-dom';
2
+ import { useEffect } from 'react';
3
+ import {
4
+ LayoutDashboard,
5
+ Layers,
6
+ BarChart3,
7
+ Settings,
8
+ Cpu,
9
+ HardDrive,
10
+ Zap,
11
+ Github,
12
+ Menu,
13
+ X
14
+ } from 'lucide-react';
15
+ import { useSystemStore, useUIStore, useModelStore } from '../store';
16
+ import { motion, AnimatePresence } from 'framer-motion';
17
+
18
+ /**
19
+ * Main application layout with sidebar navigation
20
+ */
21
+ export default function Layout() {
22
+ const { sidebarOpen, toggleSidebar } = useUIStore();
23
+ const systemInfo = useSystemStore((state) => state.systemInfo);
24
+ const checkLoadedModel = useModelStore((state) => state.checkLoadedModel);
25
+ const location = useLocation();
26
+
27
+ // Sync model state on mount
28
+ useEffect(() => {
29
+ checkLoadedModel();
30
+ }, []);
31
+
32
+ const navItems = [
33
+ { path: '/dashboard', label: 'Dashboard', icon: LayoutDashboard },
34
+ { path: '/quantize', label: 'Quantizer', icon: Layers },
35
+ { path: '/analysis', label: 'Analysis', icon: BarChart3 },
36
+ { path: '/models', label: 'Models', icon: HardDrive },
37
+ ];
38
+
39
+ return (
40
+ <div className="app-layout">
41
+ {/* Sidebar */}
42
+ <aside className={`sidebar ${sidebarOpen ? '' : 'closed'}`}>
43
+ {/* Logo */}
44
+ <div className="sidebar-header">
45
+ <div className="logo">
46
+ <div className="logo-icon">
47
+ <Zap size={24} />
48
+ </div>
49
+ <div className="logo-text">
50
+ <span className="logo-title">Quantizer</span>
51
+ <span className="logo-subtitle">Neural Network</span>
52
+ </div>
53
+ </div>
54
+ <button className="btn btn-ghost btn-icon mobile-menu" onClick={toggleSidebar}>
55
+ <X size={20} />
56
+ </button>
57
+ </div>
58
+
59
+ {/* Navigation */}
60
+ <nav className="sidebar-nav">
61
+ {navItems.map((item) => (
62
+ <NavLink
63
+ key={item.path}
64
+ to={item.path}
65
+ className={({ isActive }) => `nav-item ${isActive ? 'active' : ''}`}
66
+ >
67
+ <item.icon size={20} />
68
+ <span>{item.label}</span>
69
+ </NavLink>
70
+ ))}
71
+ </nav>
72
+
73
+ {/* System Status */}
74
+ <div className="sidebar-footer">
75
+ <div className="system-status glass-card no-hover">
76
+ <div className="status-header">
77
+ <Cpu size={16} />
78
+ <span>System Status</span>
79
+ </div>
80
+ {systemInfo ? (
81
+ <div className="status-details">
82
+ <div className="status-item">
83
+ <span className="status-label">GPU</span>
84
+ <span className={`badge ${systemInfo.cuda_available ? 'badge-success' : 'badge-warning'}`}>
85
+ {systemInfo.cuda_available ? 'CUDA' : systemInfo.mps_available ? 'MPS' : 'CPU'}
86
+ </span>
87
+ </div>
88
+ {systemInfo.gpus?.length > 0 && (
89
+ <div className="status-item">
90
+ <span className="status-label">{systemInfo.gpus[0].name}</span>
91
+ <span className="text-xs text-muted">{systemInfo.gpus[0].total_memory_gb}GB</span>
92
+ </div>
93
+ )}
94
+ <div className="status-item">
95
+ <span className="status-label">RAM</span>
96
+ <span className="text-xs text-muted">
97
+ {systemInfo.ram_available_gb?.toFixed(1)}GB / {systemInfo.ram_total_gb?.toFixed(1)}GB
98
+ </span>
99
+ </div>
100
+ </div>
101
+ ) : (
102
+ <div className="status-loading">
103
+ <div className="spinner"></div>
104
+ <span className="text-xs text-muted">Detecting...</span>
105
+ </div>
106
+ )}
107
+ </div>
108
+
109
+ <a
110
+ href="https://github.com"
111
+ target="_blank"
112
+ rel="noopener noreferrer"
113
+ className="nav-item github-link"
114
+ >
115
+ <Github size={20} />
116
+ <span>GitHub</span>
117
+ </a>
118
+ </div>
119
+ </aside>
120
+
121
+ {/* Mobile menu button */}
122
+ <button className="mobile-menu-btn btn btn-secondary btn-icon" onClick={toggleSidebar}>
123
+ <Menu size={20} />
124
+ </button>
125
+
126
+ {/* Main Content */}
127
+ <main className="main-content">
128
+ <AnimatePresence mode="wait">
129
+ <motion.div
130
+ key={location.pathname}
131
+ initial={{ opacity: 0, y: 10 }}
132
+ animate={{ opacity: 1, y: 0 }}
133
+ exit={{ opacity: 0, y: -10 }}
134
+ transition={{ duration: 0.2 }}
135
+ >
136
+ <Outlet />
137
+ </motion.div>
138
+ </AnimatePresence>
139
+ </main>
140
+
141
+ <style>{`
142
+ .sidebar {
143
+ display: flex;
144
+ flex-direction: column;
145
+ }
146
+
147
+ .sidebar-header {
148
+ display: flex;
149
+ align-items: center;
150
+ justify-content: space-between;
151
+ margin-bottom: var(--space-xl);
152
+ }
153
+
154
+ .logo {
155
+ display: flex;
156
+ align-items: center;
157
+ gap: var(--space-md);
158
+ }
159
+
160
+ .logo-icon {
161
+ width: 40px;
162
+ height: 40px;
163
+ display: flex;
164
+ align-items: center;
165
+ justify-content: center;
166
+ background: var(--gradient-primary);
167
+ border-radius: var(--radius-lg);
168
+ color: white;
169
+ }
170
+
171
+ .logo-text {
172
+ display: flex;
173
+ flex-direction: column;
174
+ }
175
+
176
+ .logo-title {
177
+ font-size: var(--text-lg);
178
+ font-weight: 700;
179
+ color: var(--text-primary);
180
+ line-height: 1.2;
181
+ }
182
+
183
+ .logo-subtitle {
184
+ font-size: var(--text-xs);
185
+ color: var(--text-tertiary);
186
+ }
187
+
188
+ .mobile-menu {
189
+ display: none;
190
+ }
191
+
192
+ .mobile-menu-btn {
193
+ display: none;
194
+ position: fixed;
195
+ top: var(--space-md);
196
+ left: var(--space-md);
197
+ z-index: 99;
198
+ }
199
+
200
+ .sidebar-nav {
201
+ flex: 1;
202
+ display: flex;
203
+ flex-direction: column;
204
+ gap: var(--space-xs);
205
+ }
206
+
207
+ .nav-item {
208
+ display: flex;
209
+ align-items: center;
210
+ gap: var(--space-md);
211
+ padding: var(--space-sm) var(--space-md);
212
+ border-radius: var(--radius-lg);
213
+ color: var(--text-secondary);
214
+ text-decoration: none;
215
+ transition: all var(--transition-fast);
216
+ }
217
+
218
+ .nav-item:hover {
219
+ background: var(--glass-bg);
220
+ color: var(--text-primary);
221
+ }
222
+
223
+ .nav-item.active {
224
+ background: var(--gradient-primary);
225
+ color: white;
226
+ box-shadow: var(--shadow-md);
227
+ }
228
+
229
+ .sidebar-footer {
230
+ margin-top: auto;
231
+ display: flex;
232
+ flex-direction: column;
233
+ gap: var(--space-md);
234
+ }
235
+
236
+ .system-status {
237
+ padding: var(--space-md);
238
+ }
239
+
240
+ .status-header {
241
+ display: flex;
242
+ align-items: center;
243
+ gap: var(--space-sm);
244
+ font-size: var(--text-sm);
245
+ font-weight: 500;
246
+ color: var(--text-primary);
247
+ margin-bottom: var(--space-sm);
248
+ }
249
+
250
+ .status-details {
251
+ display: flex;
252
+ flex-direction: column;
253
+ gap: var(--space-xs);
254
+ }
255
+
256
+ .status-item {
257
+ display: flex;
258
+ align-items: center;
259
+ justify-content: space-between;
260
+ font-size: var(--text-xs);
261
+ }
262
+
263
+ .status-label {
264
+ color: var(--text-secondary);
265
+ }
266
+
267
+ .status-loading {
268
+ display: flex;
269
+ align-items: center;
270
+ gap: var(--space-sm);
271
+ }
272
+
273
+ .github-link {
274
+ opacity: 0.7;
275
+ }
276
+
277
+ .github-link:hover {
278
+ opacity: 1;
279
+ }
280
+
281
+ @media (max-width: 768px) {
282
+ .mobile-menu {
283
+ display: flex;
284
+ }
285
+
286
+ .mobile-menu-btn {
287
+ display: flex;
288
+ }
289
+
290
+ .sidebar.closed {
291
+ transform: translateX(-100%);
292
+ }
293
+ }
294
+ `}</style>
295
+ </div>
296
+ );
297
+ }
frontend/src/index.css ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Neural Network Quantizer - Design System
3
+ * Premium glassmorphism dark theme with smooth animations
4
+ */
5
+
6
+ /* ============================================
7
+ CSS Variables - Design Tokens
8
+ ============================================ */
9
+ :root {
10
+ /* Colors - Dark Theme */
11
+ --color-bg-primary: #0a0a0f;
12
+ --color-bg-secondary: #12121a;
13
+ --color-bg-tertiary: #1a1a25;
14
+ --color-bg-elevated: #22222f;
15
+
16
+ /* Glass effect backgrounds */
17
+ --glass-bg: rgba(255, 255, 255, 0.03);
18
+ --glass-bg-hover: rgba(255, 255, 255, 0.06);
19
+ --glass-border: rgba(255, 255, 255, 0.08);
20
+ --glass-border-hover: rgba(255, 255, 255, 0.15);
21
+
22
+ /* Accent colors */
23
+ --color-accent-primary: #6366f1;
24
+ --color-accent-secondary: #8b5cf6;
25
+ --color-accent-tertiary: #a855f7;
26
+ --color-accent-glow: rgba(99, 102, 241, 0.3);
27
+
28
+ /* Status colors */
29
+ --color-success: #10b981;
30
+ --color-success-bg: rgba(16, 185, 129, 0.1);
31
+ --color-warning: #f59e0b;
32
+ --color-warning-bg: rgba(245, 158, 11, 0.1);
33
+ --color-error: #ef4444;
34
+ --color-error-bg: rgba(239, 68, 68, 0.1);
35
+ --color-info: #06b6d4;
36
+ --color-info-bg: rgba(6, 182, 212, 0.1);
37
+
38
+ /* Text colors */
39
+ --text-primary: #f8fafc;
40
+ --text-secondary: #94a3b8;
41
+ --text-tertiary: #64748b;
42
+ --text-muted: #475569;
43
+
44
+ /* Gradients */
45
+ --gradient-primary: linear-gradient(135deg, var(--color-accent-primary) 0%, var(--color-accent-secondary) 100%);
46
+ --gradient-secondary: linear-gradient(135deg, var(--color-accent-secondary) 0%, var(--color-accent-tertiary) 100%);
47
+ --gradient-glow: radial-gradient(ellipse at center, var(--color-accent-glow) 0%, transparent 70%);
48
+ --gradient-mesh: radial-gradient(at 40% 20%, hsla(228,100%,74%,0.15) 0px, transparent 50%),
49
+ radial-gradient(at 80% 0%, hsla(189,100%,56%,0.1) 0px, transparent 50%),
50
+ radial-gradient(at 0% 50%, hsla(355,100%,93%,0.05) 0px, transparent 50%),
51
+ radial-gradient(at 80% 50%, hsla(340,100%,76%,0.1) 0px, transparent 50%);
52
+
53
+ /* Spacing */
54
+ --space-xs: 0.25rem;
55
+ --space-sm: 0.5rem;
56
+ --space-md: 1rem;
57
+ --space-lg: 1.5rem;
58
+ --space-xl: 2rem;
59
+ --space-2xl: 3rem;
60
+ --space-3xl: 4rem;
61
+
62
+ /* Border radius */
63
+ --radius-sm: 0.375rem;
64
+ --radius-md: 0.5rem;
65
+ --radius-lg: 0.75rem;
66
+ --radius-xl: 1rem;
67
+ --radius-2xl: 1.5rem;
68
+ --radius-full: 9999px;
69
+
70
+ /* Shadows */
71
+ --shadow-sm: 0 1px 2px rgba(0, 0, 0, 0.3);
72
+ --shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.3), 0 2px 4px -1px rgba(0, 0, 0, 0.2);
73
+ --shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.4), 0 4px 6px -2px rgba(0, 0, 0, 0.3);
74
+ --shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.5), 0 10px 10px -5px rgba(0, 0, 0, 0.3);
75
+ --shadow-glow: 0 0 40px var(--color-accent-glow);
76
+
77
+ /* Typography */
78
+ --font-sans: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
79
+ --font-mono: 'JetBrains Mono', 'Fira Code', Consolas, monospace;
80
+
81
+ --text-xs: 0.75rem;
82
+ --text-sm: 0.875rem;
83
+ --text-base: 1rem;
84
+ --text-lg: 1.125rem;
85
+ --text-xl: 1.25rem;
86
+ --text-2xl: 1.5rem;
87
+ --text-3xl: 1.875rem;
88
+ --text-4xl: 2.25rem;
89
+
90
+ /* Transitions */
91
+ --transition-fast: 150ms cubic-bezier(0.4, 0, 0.2, 1);
92
+ --transition-base: 200ms cubic-bezier(0.4, 0, 0.2, 1);
93
+ --transition-slow: 300ms cubic-bezier(0.4, 0, 0.2, 1);
94
+ --transition-spring: 500ms cubic-bezier(0.34, 1.56, 0.64, 1);
95
+
96
+ /* Layout */
97
+ --sidebar-width: 280px;
98
+ --header-height: 64px;
99
+ --max-content-width: 1400px;
100
+ }
101
+
102
+ /* ============================================
103
+ Base Styles
104
+ ============================================ */
105
+ *, *::before, *::after {
106
+ box-sizing: border-box;
107
+ margin: 0;
108
+ padding: 0;
109
+ }
110
+
111
+ html {
112
+ font-size: 16px;
113
+ -webkit-font-smoothing: antialiased;
114
+ -moz-osx-font-smoothing: grayscale;
115
+ }
116
+
117
+ body {
118
+ font-family: var(--font-sans);
119
+ background: var(--color-bg-primary);
120
+ color: var(--text-primary);
121
+ line-height: 1.6;
122
+ min-height: 100vh;
123
+ overflow-x: hidden;
124
+ }
125
+
126
+ /* Background mesh gradient */
127
+ body::before {
128
+ content: '';
129
+ position: fixed;
130
+ top: 0;
131
+ left: 0;
132
+ right: 0;
133
+ bottom: 0;
134
+ background: var(--gradient-mesh);
135
+ pointer-events: none;
136
+ z-index: -1;
137
+ }
138
+
139
+ #root {
140
+ min-height: 100vh;
141
+ display: flex;
142
+ flex-direction: column;
143
+ }
144
+
145
+ /* ============================================
146
+ Typography
147
+ ============================================ */
148
+ h1, h2, h3, h4, h5, h6 {
149
+ font-weight: 600;
150
+ line-height: 1.3;
151
+ color: var(--text-primary);
152
+ }
153
+
154
+ h1 { font-size: var(--text-4xl); }
155
+ h2 { font-size: var(--text-3xl); }
156
+ h3 { font-size: var(--text-2xl); }
157
+ h4 { font-size: var(--text-xl); }
158
+ h5 { font-size: var(--text-lg); }
159
+ h6 { font-size: var(--text-base); }
160
+
161
+ p {
162
+ color: var(--text-secondary);
163
+ margin-bottom: var(--space-md);
164
+ }
165
+
166
+ a {
167
+ color: var(--color-accent-primary);
168
+ text-decoration: none;
169
+ transition: color var(--transition-fast);
170
+ }
171
+
172
+ a:hover {
173
+ color: var(--color-accent-secondary);
174
+ }
175
+
176
+ code {
177
+ font-family: var(--font-mono);
178
+ background: var(--glass-bg);
179
+ padding: 0.2em 0.4em;
180
+ border-radius: var(--radius-sm);
181
+ font-size: 0.9em;
182
+ }
183
+
184
+ /* ============================================
185
+ Glass Card Component
186
+ ============================================ */
187
+ .glass-card {
188
+ background: var(--glass-bg);
189
+ border: 1px solid var(--glass-border);
190
+ border-radius: var(--radius-xl);
191
+ padding: var(--space-lg);
192
+ backdrop-filter: blur(20px);
193
+ -webkit-backdrop-filter: blur(20px);
194
+ transition: all var(--transition-base);
195
+ }
196
+
197
+ .glass-card:hover {
198
+ background: var(--glass-bg-hover);
199
+ border-color: var(--glass-border-hover);
200
+ transform: translateY(-2px);
201
+ box-shadow: var(--shadow-lg);
202
+ }
203
+
204
+ .glass-card.no-hover:hover {
205
+ transform: none;
206
+ box-shadow: none;
207
+ }
208
+
209
+ /* ============================================
210
+ Button Styles
211
+ ============================================ */
212
+ .btn {
213
+ display: inline-flex;
214
+ align-items: center;
215
+ justify-content: center;
216
+ gap: var(--space-sm);
217
+ padding: var(--space-sm) var(--space-lg);
218
+ border: none;
219
+ border-radius: var(--radius-lg);
220
+ font-family: var(--font-sans);
221
+ font-size: var(--text-sm);
222
+ font-weight: 500;
223
+ cursor: pointer;
224
+ transition: all var(--transition-base);
225
+ white-space: nowrap;
226
+ }
227
+
228
+ .btn:disabled {
229
+ opacity: 0.5;
230
+ cursor: not-allowed;
231
+ }
232
+
233
+ .btn-primary {
234
+ background: var(--gradient-primary);
235
+ color: white;
236
+ box-shadow: var(--shadow-md), 0 0 20px var(--color-accent-glow);
237
+ }
238
+
239
+ .btn-primary:hover:not(:disabled) {
240
+ transform: translateY(-2px);
241
+ box-shadow: var(--shadow-lg), 0 0 30px var(--color-accent-glow);
242
+ }
243
+
244
+ .btn-secondary {
245
+ background: var(--glass-bg);
246
+ color: var(--text-primary);
247
+ border: 1px solid var(--glass-border);
248
+ backdrop-filter: blur(10px);
249
+ }
250
+
251
+ .btn-secondary:hover:not(:disabled) {
252
+ background: var(--glass-bg-hover);
253
+ border-color: var(--glass-border-hover);
254
+ }
255
+
256
+ .btn-ghost {
257
+ background: transparent;
258
+ color: var(--text-secondary);
259
+ }
260
+
261
+ .btn-ghost:hover:not(:disabled) {
262
+ background: var(--glass-bg);
263
+ color: var(--text-primary);
264
+ }
265
+
266
+ .btn-success {
267
+ background: var(--color-success);
268
+ color: white;
269
+ }
270
+
271
+ .btn-danger {
272
+ background: var(--color-error);
273
+ color: white;
274
+ }
275
+
276
+ .btn-lg {
277
+ padding: var(--space-md) var(--space-xl);
278
+ font-size: var(--text-base);
279
+ }
280
+
281
+ .btn-sm {
282
+ padding: var(--space-xs) var(--space-md);
283
+ font-size: var(--text-xs);
284
+ }
285
+
286
+ .btn-icon {
287
+ padding: var(--space-sm);
288
+ aspect-ratio: 1;
289
+ }
290
+
291
+ /* ============================================
292
+ Input Styles
293
+ ============================================ */
294
+ .input-group {
295
+ display: flex;
296
+ flex-direction: column;
297
+ gap: var(--space-xs);
298
+ }
299
+
300
+ .input-label {
301
+ font-size: var(--text-sm);
302
+ font-weight: 500;
303
+ color: var(--text-secondary);
304
+ }
305
+
306
+ .input {
307
+ width: 100%;
308
+ padding: var(--space-sm) var(--space-md);
309
+ background: var(--glass-bg);
310
+ border: 1px solid var(--glass-border);
311
+ border-radius: var(--radius-md);
312
+ color: var(--text-primary);
313
+ font-family: var(--font-sans);
314
+ font-size: var(--text-sm);
315
+ transition: all var(--transition-fast);
316
+ }
317
+
318
+ .input:focus {
319
+ outline: none;
320
+ border-color: var(--color-accent-primary);
321
+ box-shadow: 0 0 0 3px var(--color-accent-glow);
322
+ }
323
+
324
+ .input::placeholder {
325
+ color: var(--text-muted);
326
+ }
327
+
328
+ .input-error {
329
+ border-color: var(--color-error);
330
+ }
331
+
332
+ /* Select dropdown */
333
+ .select {
334
+ appearance: none;
335
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' fill='none' stroke='%2394a3b8' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpolyline points='6 9 12 15 18 9'%3E%3C/polyline%3E%3C/svg%3E");
336
+ background-repeat: no-repeat;
337
+ background-position: right var(--space-sm) center;
338
+ background-size: 16px;
339
+ padding-right: var(--space-xl);
340
+ }
341
+
342
+ /* Slider */
343
+ .slider {
344
+ width: 100%;
345
+ height: 6px;
346
+ background: var(--glass-bg);
347
+ border-radius: var(--radius-full);
348
+ appearance: none;
349
+ cursor: pointer;
350
+ }
351
+
352
+ .slider::-webkit-slider-thumb {
353
+ appearance: none;
354
+ width: 18px;
355
+ height: 18px;
356
+ background: var(--gradient-primary);
357
+ border-radius: 50%;
358
+ cursor: pointer;
359
+ box-shadow: var(--shadow-md);
360
+ transition: transform var(--transition-fast);
361
+ }
362
+
363
+ .slider::-webkit-slider-thumb:hover {
364
+ transform: scale(1.2);
365
+ }
366
+
367
+ /* ============================================
368
+ Status Badges
369
+ ============================================ */
370
+ .badge {
371
+ display: inline-flex;
372
+ align-items: center;
373
+ gap: var(--space-xs);
374
+ padding: var(--space-xs) var(--space-sm);
375
+ border-radius: var(--radius-full);
376
+ font-size: var(--text-xs);
377
+ font-weight: 500;
378
+ }
379
+
380
+ .badge-success {
381
+ background: var(--color-success-bg);
382
+ color: var(--color-success);
383
+ }
384
+
385
+ .badge-warning {
386
+ background: var(--color-warning-bg);
387
+ color: var(--color-warning);
388
+ }
389
+
390
+ .badge-error {
391
+ background: var(--color-error-bg);
392
+ color: var(--color-error);
393
+ }
394
+
395
+ .badge-info {
396
+ background: var(--color-info-bg);
397
+ color: var(--color-info);
398
+ }
399
+
400
+ /* ============================================
401
+ Layout Components
402
+ ============================================ */
403
+ .app-layout {
404
+ display: flex;
405
+ min-height: 100vh;
406
+ }
407
+
408
+ .sidebar {
409
+ width: var(--sidebar-width);
410
+ background: var(--color-bg-secondary);
411
+ border-right: 1px solid var(--glass-border);
412
+ padding: var(--space-lg);
413
+ display: flex;
414
+ flex-direction: column;
415
+ position: fixed;
416
+ top: 0;
417
+ left: 0;
418
+ bottom: 0;
419
+ z-index: 100;
420
+ }
421
+
422
+ .main-content {
423
+ flex: 1;
424
+ margin-left: var(--sidebar-width);
425
+ padding: var(--space-xl);
426
+ max-width: calc(100vw - var(--sidebar-width));
427
+ }
428
+
429
+ .page-header {
430
+ margin-bottom: var(--space-xl);
431
+ }
432
+
433
+ .page-title {
434
+ font-size: var(--text-3xl);
435
+ font-weight: 700;
436
+ margin-bottom: var(--space-xs);
437
+ }
438
+
439
+ .page-subtitle {
440
+ color: var(--text-secondary);
441
+ font-size: var(--text-base);
442
+ }
443
+
444
+ /* Grid layout */
445
+ .grid {
446
+ display: grid;
447
+ gap: var(--space-lg);
448
+ }
449
+
450
+ .grid-2 { grid-template-columns: repeat(2, 1fr); }
451
+ .grid-3 { grid-template-columns: repeat(3, 1fr); }
452
+ .grid-4 { grid-template-columns: repeat(4, 1fr); }
453
+
454
+ @media (max-width: 1024px) {
455
+ .grid-3, .grid-4 { grid-template-columns: repeat(2, 1fr); }
456
+ }
457
+
458
+ @media (max-width: 768px) {
459
+ .grid-2, .grid-3, .grid-4 { grid-template-columns: 1fr; }
460
+
461
+ .sidebar {
462
+ transform: translateX(-100%);
463
+ transition: transform var(--transition-base);
464
+ }
465
+
466
+ .sidebar.open {
467
+ transform: translateX(0);
468
+ }
469
+
470
+ .main-content {
471
+ margin-left: 0;
472
+ max-width: 100vw;
473
+ }
474
+ }
475
+
476
+ /* ============================================
477
+ Stats Card
478
+ ============================================ */
479
+ .stat-card {
480
+ display: flex;
481
+ flex-direction: column;
482
+ gap: var(--space-sm);
483
+ }
484
+
485
+ .stat-value {
486
+ font-size: var(--text-3xl);
487
+ font-weight: 700;
488
+ background: var(--gradient-primary);
489
+ -webkit-background-clip: text;
490
+ -webkit-text-fill-color: transparent;
491
+ background-clip: text;
492
+ }
493
+
494
+ .stat-label {
495
+ font-size: var(--text-sm);
496
+ color: var(--text-secondary);
497
+ }
498
+
499
+ .stat-change {
500
+ font-size: var(--text-xs);
501
+ display: flex;
502
+ align-items: center;
503
+ gap: var(--space-xs);
504
+ }
505
+
506
+ .stat-change.positive { color: var(--color-success); }
507
+ .stat-change.negative { color: var(--color-error); }
508
+
509
+ /* ============================================
510
+ Progress Bar
511
+ ============================================ */
512
+ .progress-bar {
513
+ width: 100%;
514
+ height: 8px;
515
+ background: var(--glass-bg);
516
+ border-radius: var(--radius-full);
517
+ overflow: hidden;
518
+ }
519
+
520
+ .progress-fill {
521
+ height: 100%;
522
+ background: var(--gradient-primary);
523
+ border-radius: var(--radius-full);
524
+ transition: width var(--transition-slow);
525
+ }
526
+
527
+ /* ============================================
528
+ Tabs
529
+ ============================================ */
530
+ .tabs {
531
+ display: flex;
532
+ gap: var(--space-xs);
533
+ padding: var(--space-xs);
534
+ background: var(--glass-bg);
535
+ border-radius: var(--radius-lg);
536
+ margin-bottom: var(--space-lg);
537
+ }
538
+
539
+ .tab {
540
+ flex: 1;
541
+ padding: var(--space-sm) var(--space-md);
542
+ background: transparent;
543
+ border: none;
544
+ border-radius: var(--radius-md);
545
+ color: var(--text-secondary);
546
+ font-size: var(--text-sm);
547
+ font-weight: 500;
548
+ cursor: pointer;
549
+ transition: all var(--transition-fast);
550
+ }
551
+
552
+ .tab:hover {
553
+ color: var(--text-primary);
554
+ background: var(--glass-bg-hover);
555
+ }
556
+
557
+ .tab.active {
558
+ background: var(--gradient-primary);
559
+ color: white;
560
+ }
561
+
562
+ /* ============================================
563
+ Chart Container
564
+ ============================================ */
565
+ .chart-container {
566
+ background: var(--glass-bg);
567
+ border: 1px solid var(--glass-border);
568
+ border-radius: var(--radius-xl);
569
+ padding: var(--space-md);
570
+ min-height: 300px;
571
+ }
572
+
573
+ .chart-title {
574
+ font-size: var(--text-sm);
575
+ font-weight: 600;
576
+ color: var(--text-primary);
577
+ margin-bottom: var(--space-md);
578
+ }
579
+
580
+ /* ============================================
581
+ Loading States
582
+ ============================================ */
583
+ .skeleton {
584
+ background: linear-gradient(
585
+ 90deg,
586
+ var(--glass-bg) 25%,
587
+ var(--glass-bg-hover) 50%,
588
+ var(--glass-bg) 75%
589
+ );
590
+ background-size: 200% 100%;
591
+ animation: shimmer 1.5s infinite;
592
+ border-radius: var(--radius-md);
593
+ }
594
+
595
+ @keyframes shimmer {
596
+ 0% { background-position: 200% 0; }
597
+ 100% { background-position: -200% 0; }
598
+ }
599
+
600
+ .spinner {
601
+ width: 24px;
602
+ height: 24px;
603
+ border: 2px solid var(--glass-border);
604
+ border-top-color: var(--color-accent-primary);
605
+ border-radius: 50%;
606
+ animation: spin 0.8s linear infinite;
607
+ }
608
+
609
+ @keyframes spin {
610
+ to { transform: rotate(360deg); }
611
+ }
612
+
613
+ /* ============================================
614
+ Tooltips
615
+ ============================================ */
616
+ .tooltip {
617
+ position: relative;
618
+ }
619
+
620
+ .tooltip::after {
621
+ content: attr(data-tooltip);
622
+ position: absolute;
623
+ bottom: 100%;
624
+ left: 50%;
625
+ transform: translateX(-50%);
626
+ padding: var(--space-xs) var(--space-sm);
627
+ background: var(--color-bg-elevated);
628
+ border: 1px solid var(--glass-border);
629
+ border-radius: var(--radius-md);
630
+ font-size: var(--text-xs);
631
+ white-space: nowrap;
632
+ opacity: 0;
633
+ visibility: hidden;
634
+ transition: all var(--transition-fast);
635
+ }
636
+
637
+ .tooltip:hover::after {
638
+ opacity: 1;
639
+ visibility: visible;
640
+ }
641
+
642
+ /* ============================================
643
+ Animations
644
+ ============================================ */
645
+ @keyframes fadeIn {
646
+ from { opacity: 0; }
647
+ to { opacity: 1; }
648
+ }
649
+
650
+ @keyframes slideUp {
651
+ from {
652
+ opacity: 0;
653
+ transform: translateY(20px);
654
+ }
655
+ to {
656
+ opacity: 1;
657
+ transform: translateY(0);
658
+ }
659
+ }
660
+
661
+ @keyframes scaleIn {
662
+ from {
663
+ opacity: 0;
664
+ transform: scale(0.95);
665
+ }
666
+ to {
667
+ opacity: 1;
668
+ transform: scale(1);
669
+ }
670
+ }
671
+
672
+ .animate-fade-in { animation: fadeIn var(--transition-slow) ease-out; }
673
+ .animate-slide-up { animation: slideUp var(--transition-slow) ease-out; }
674
+ .animate-scale-in { animation: scaleIn var(--transition-spring) ease-out; }
675
+
676
+ /* Staggered animations */
677
+ .stagger > * {
678
+ animation: slideUp var(--transition-slow) ease-out forwards;
679
+ opacity: 0;
680
+ }
681
+
682
+ .stagger > *:nth-child(1) { animation-delay: 0ms; }
683
+ .stagger > *:nth-child(2) { animation-delay: 50ms; }
684
+ .stagger > *:nth-child(3) { animation-delay: 100ms; }
685
+ .stagger > *:nth-child(4) { animation-delay: 150ms; }
686
+ .stagger > *:nth-child(5) { animation-delay: 200ms; }
687
+ .stagger > *:nth-child(6) { animation-delay: 250ms; }
688
+
689
+ /* ============================================
690
+ Scrollbar Styles
691
+ ============================================ */
692
+ ::-webkit-scrollbar {
693
+ width: 8px;
694
+ height: 8px;
695
+ }
696
+
697
+ ::-webkit-scrollbar-track {
698
+ background: var(--color-bg-secondary);
699
+ }
700
+
701
+ ::-webkit-scrollbar-thumb {
702
+ background: var(--glass-border);
703
+ border-radius: var(--radius-full);
704
+ }
705
+
706
+ ::-webkit-scrollbar-thumb:hover {
707
+ background: var(--glass-border-hover);
708
+ }
709
+
710
+ /* ============================================
711
+ Utility Classes
712
+ ============================================ */
713
+ .text-center { text-align: center; }
714
+ .text-right { text-align: right; }
715
+ .text-sm { font-size: var(--text-sm); }
716
+ .text-xs { font-size: var(--text-xs); }
717
+ .text-muted { color: var(--text-secondary); }
718
+ .text-accent { color: var(--color-accent-primary); }
719
+
720
+ .flex { display: flex; }
721
+ .flex-col { flex-direction: column; }
722
+ .items-center { align-items: center; }
723
+ .justify-between { justify-content: space-between; }
724
+ .justify-center { justify-content: center; }
725
+ .gap-sm { gap: var(--space-sm); }
726
+ .gap-md { gap: var(--space-md); }
727
+ .gap-lg { gap: var(--space-lg); }
728
+
729
+ .mt-sm { margin-top: var(--space-sm); }
730
+ .mt-md { margin-top: var(--space-md); }
731
+ .mt-lg { margin-top: var(--space-lg); }
732
+ .mb-sm { margin-bottom: var(--space-sm); }
733
+ .mb-md { margin-bottom: var(--space-md); }
734
+ .mb-lg { margin-bottom: var(--space-lg); }
735
+
736
+ .w-full { width: 100%; }
737
+ .h-full { height: 100%; }
738
+
739
+ .overflow-hidden { overflow: hidden; }
740
+ .overflow-auto { overflow: auto; }
741
+
742
+ .relative { position: relative; }
743
+ .absolute { position: absolute; }
744
+
745
+ .rounded { border-radius: var(--radius-md); }
746
+ .rounded-lg { border-radius: var(--radius-lg); }
747
+ .rounded-xl { border-radius: var(--radius-xl); }
748
+
749
+ .shadow { box-shadow: var(--shadow-md); }
750
+ .shadow-lg { box-shadow: var(--shadow-lg); }
751
+ .shadow-glow { box-shadow: var(--shadow-glow); }
frontend/src/main.jsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import { StrictMode } from 'react'
2
+ import { createRoot } from 'react-dom/client'
3
+ import './index.css'
4
+ import App from './App.jsx'
5
+
6
+ createRoot(document.getElementById('root')).render(
7
+ <StrictMode>
8
+ <App />
9
+ </StrictMode>,
10
+ )
frontend/src/pages/Analysis.jsx ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useEffect } from 'react';
2
+ import {
3
+ BarChart3,
4
+ Layers,
5
+ TrendingUp,
6
+ RefreshCw,
7
+ AlertTriangle
8
+ } from 'lucide-react';
9
+ import { useQuantizationStore, useModelStore } from '../store';
10
+ import { motion } from 'framer-motion';
11
+ import {
12
+ BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip,
13
+ ResponsiveContainer, Cell, Legend
14
+ } from 'recharts';
15
+
16
+ /**
17
+ * Analysis page - compare quantization methods and analyze weights
18
+ */
19
+ export default function Analysis() {
20
+ const { compareMethod } = useQuantizationStore();
21
+ const { modelInfo, layers, fetchLayers } = useModelStore();
22
+
23
+ const [comparison, setComparison] = useState(null);
24
+ const [isLoading, setIsLoading] = useState(false);
25
+ const [selectedMethods, setSelectedMethods] = useState(['int8', 'int4', 'nf4']);
26
+ const [source, setSource] = useState('random'); // 'random' | 'layer'
27
+ const [selectedLayer, setSelectedLayer] = useState('');
28
+
29
+ // Switch to layer mode if model is loaded
30
+ // Switch to layer mode if model is loaded
31
+ useEffect(() => {
32
+ if (modelInfo) {
33
+ setSource('layer');
34
+ if (layers.length === 0) fetchLayers();
35
+ }
36
+ }, [modelInfo]);
37
+
38
+ const runComparison = async () => {
39
+ setIsLoading(true);
40
+ const layerToCompare = source === 'layer' ? selectedLayer : null;
41
+ const result = await compareMethod(selectedMethods, layerToCompare);
42
+ setComparison(result);
43
+ setIsLoading(false);
44
+ };
45
+
46
+ const toggleMethod = (method) => {
47
+ setSelectedMethods((prev) =>
48
+ prev.includes(method)
49
+ ? prev.filter(m => m !== method)
50
+ : [...prev, method]
51
+ );
52
+ };
53
+
54
+ // Prepare chart data
55
+ const getComparisonData = () => {
56
+ if (!comparison?.comparison) return [];
57
+ return comparison.comparison
58
+ .filter(c => !c.error)
59
+ .map(c => ({
60
+ method: c.method.toUpperCase(),
61
+ meanError: c.mean_error,
62
+ maxError: c.max_error,
63
+ memorySavings: c.memory_savings_percent
64
+ }));
65
+ };
66
+
67
+ const COLORS = ['#6366f1', '#8b5cf6', '#a855f7'];
68
+
69
+ return (
70
+ <div className="analysis">
71
+ {/* Header */}
72
+ <div className="page-header">
73
+ <h1 className="page-title">Analysis</h1>
74
+ <p className="page-subtitle">
75
+ Compare quantization methods and analyze weight distributions
76
+ </p>
77
+ {modelInfo && (
78
+ <div className="model-badge" style={{ marginTop: '0.5rem', display: 'inline-flex', alignItems: 'center', gap: '0.5rem', padding: '4px 12px', background: 'var(--glass-bg)', border: '1px solid var(--glass-border)', color: 'var(--color-accent-primary)', borderRadius: 'var(--radius-full)', fontSize: '0.875rem' }}>
79
+ <span style={{ opacity: 0.7 }}>Active Model:</span>
80
+ <strong>{modelInfo.name}</strong>
81
+ </div>
82
+ )}
83
+ </div>
84
+
85
+ {/* Method Comparison */}
86
+ <section className="section">
87
+ <div className="section-header">
88
+ <h2 className="section-title">
89
+ <BarChart3 size={20} />
90
+ Method Comparison
91
+ {comparison && (
92
+ <span className="source-badge">
93
+ Source: {comparison.source.startsWith('layer:') ? comparison.source.replace('layer:', '') : 'Random Weights'}
94
+ </span>
95
+ )}
96
+ </h2>
97
+ <button
98
+ className="btn btn-primary"
99
+ onClick={runComparison}
100
+ disabled={isLoading || selectedMethods.length === 0}
101
+ >
102
+ {isLoading ? (
103
+ <>
104
+ <RefreshCw size={16} className="spinning" />
105
+ Comparing...
106
+ </>
107
+ ) : (
108
+ <>
109
+ <TrendingUp size={16} />
110
+ Run Comparison
111
+ </>
112
+ )}
113
+ </button>
114
+ </div>
115
+
116
+ {/* Data Source Selection */}
117
+ <div className="glass-card mb-lg">
118
+ <p className="text-sm text-muted mb-md">Select data source:</p>
119
+
120
+ <div className="source-selection mb-md">
121
+ <div className="btn-group">
122
+ {modelInfo && (
123
+ <button
124
+ className={`btn ${source === 'layer' ? 'btn-primary' : 'btn-secondary'}`}
125
+ onClick={() => setSource('layer')}
126
+ >
127
+ Loaded Model Layer
128
+ </button>
129
+ )}
130
+ <button
131
+ className={`btn ${source === 'random' ? 'btn-primary' : 'btn-secondary'}`}
132
+ onClick={() => setSource('random')}
133
+ >
134
+ Random Weights
135
+ </button>
136
+ </div>
137
+ </div>
138
+
139
+ {source === 'layer' && (
140
+ <div className="layer-selection">
141
+ <select
142
+ className="input select"
143
+ value={selectedLayer}
144
+ onChange={(e) => setSelectedLayer(e.target.value)}
145
+ >
146
+ <option value="">Select a layer...</option>
147
+ {layers.map((layer) => (
148
+ <option key={layer} value={layer}>
149
+ {layer}
150
+ </option>
151
+ ))}
152
+ </select>
153
+ </div>
154
+ )}
155
+ </div>
156
+
157
+ {/* Method Selection */}
158
+ <div className="glass-card">
159
+ <p className="text-sm text-muted mb-md">Select methods to compare:</p>
160
+ <div className="method-selection">
161
+ {['int8', 'int4', 'nf4'].map((method) => (
162
+ <button
163
+ key={method}
164
+ className={`method-btn ${selectedMethods.includes(method) ? 'active' : ''}`}
165
+ onClick={() => toggleMethod(method)}
166
+ >
167
+ <div className="method-check">
168
+ {selectedMethods.includes(method) && '✓'}
169
+ </div>
170
+ <div className="method-info">
171
+ <span className="method-name">{method.toUpperCase()}</span>
172
+ <span className="method-desc">
173
+ {method === 'int8' && '8-bit integer quantization'}
174
+ {method === 'int4' && '4-bit integer with grouping'}
175
+ {method === 'nf4' && 'Normal Float 4-bit (QLoRA)'}
176
+ </span>
177
+ </div>
178
+ </button>
179
+ ))}
180
+ </div>
181
+ </div>
182
+
183
+ {/* Comparison Results */}
184
+ {comparison && (
185
+ <motion.div
186
+ className="comparison-results mt-lg"
187
+ initial={{ opacity: 0, y: 20 }}
188
+ animate={{ opacity: 1, y: 0 }}
189
+ >
190
+ <div className="grid grid-2">
191
+ {/* Error Chart */}
192
+ <div className="glass-card chart-card">
193
+ <h4 className="chart-title">Quantization Error by Method</h4>
194
+ <ResponsiveContainer width="100%" height={300}>
195
+ <BarChart data={getComparisonData()}>
196
+ <CartesianGrid strokeDasharray="3 3" stroke="rgba(255,255,255,0.1)" />
197
+ <XAxis dataKey="method" tick={{ fill: '#94a3b8' }} />
198
+ <YAxis tick={{ fill: '#94a3b8' }} />
199
+ <Tooltip
200
+ contentStyle={{
201
+ backgroundColor: '#1a1a25',
202
+ border: '1px solid rgba(255,255,255,0.1)',
203
+ borderRadius: '8px'
204
+ }}
205
+ />
206
+ <Bar dataKey="meanError" name="Mean Error" radius={[4, 4, 0, 0]}>
207
+ {getComparisonData().map((entry, index) => (
208
+ <Cell key={`cell-${index}`} fill={COLORS[index % COLORS.length]} />
209
+ ))}
210
+ </Bar>
211
+ </BarChart>
212
+ </ResponsiveContainer>
213
+ </div>
214
+
215
+ {/* Memory Savings Chart */}
216
+ <div className="glass-card chart-card">
217
+ <h4 className="chart-title">Memory Savings by Method</h4>
218
+ <ResponsiveContainer width="100%" height={300}>
219
+ <BarChart data={getComparisonData()}>
220
+ <CartesianGrid strokeDasharray="3 3" stroke="rgba(255,255,255,0.1)" />
221
+ <XAxis dataKey="method" tick={{ fill: '#94a3b8' }} />
222
+ <YAxis tick={{ fill: '#94a3b8' }} unit="%" />
223
+ <Tooltip
224
+ contentStyle={{
225
+ backgroundColor: '#1a1a25',
226
+ border: '1px solid rgba(255,255,255,0.1)',
227
+ borderRadius: '8px'
228
+ }}
229
+ formatter={(value) => [`${value.toFixed(1)}%`, 'Savings']}
230
+ />
231
+ <Bar dataKey="memorySavings" name="Memory Savings" radius={[4, 4, 0, 0]}>
232
+ {getComparisonData().map((entry, index) => (
233
+ <Cell key={`cell-${index}`} fill={COLORS[index % COLORS.length]} />
234
+ ))}
235
+ </Bar>
236
+ </BarChart>
237
+ </ResponsiveContainer>
238
+ </div>
239
+ </div>
240
+
241
+ {/* Results Table */}
242
+ <div className="glass-card mt-lg">
243
+ <table className="results-table">
244
+ <thead>
245
+ <tr>
246
+ <th>Method</th>
247
+ <th>Bits</th>
248
+ <th>Max Error</th>
249
+ <th>Mean Error</th>
250
+ <th>Memory Savings</th>
251
+ </tr>
252
+ </thead>
253
+ <tbody>
254
+ {comparison.comparison?.filter(c => !c.error).map((result) => (
255
+ <tr key={result.method}>
256
+ <td><strong>{result.method.toUpperCase()}</strong></td>
257
+ <td>{result.bits}</td>
258
+ <td>{result.max_error?.toFixed(6)}</td>
259
+ <td>{result.mean_error?.toFixed(6)}</td>
260
+ <td>
261
+ <span className="badge badge-success">
262
+ {result.memory_savings_percent?.toFixed(1)}%
263
+ </span>
264
+ </td>
265
+ </tr>
266
+ ))}
267
+ </tbody>
268
+ </table>
269
+ </div>
270
+ </motion.div>
271
+ )}
272
+ </section>
273
+
274
+ {/* Model Analysis (if model loaded) */}
275
+ {modelInfo && (
276
+ <section className="section">
277
+ <h2 className="section-title">
278
+ <Layers size={20} />
279
+ Model Analysis
280
+ </h2>
281
+
282
+ <div className="glass-card">
283
+ <p>
284
+ Model <strong>{modelInfo.name}</strong> is loaded with{' '}
285
+ <strong>{modelInfo.num_quantizable_layers}</strong> quantizable layers.
286
+ </p>
287
+ <p className="text-sm text-muted mt-md">
288
+ Use the Models page to analyze individual layer weights and detect outliers.
289
+ </p>
290
+ </div>
291
+ </section>
292
+ )}
293
+
294
+ {/* Info Section */}
295
+ <section className="section">
296
+ <div className="glass-card info-card">
297
+ <AlertTriangle size={24} className="text-warning" />
298
+ <div>
299
+ <h3>Understanding Quantization Trade-offs</h3>
300
+ <p>
301
+ Lower bit precision (4-bit) provides better memory savings but introduces more error.
302
+ 8-bit quantization offers a good balance between compression and accuracy for most models.
303
+ NF4 uses a codebook optimized for normally distributed weights, ideal for LLMs.
304
+ </p>
305
+ </div>
306
+ </div>
307
+ </section>
308
+
309
+ <style>{`
310
+ .section {
311
+ margin-top: var(--space-2xl);
312
+ }
313
+
314
+ .section-header {
315
+ display: flex;
316
+ align-items: center;
317
+ justify-content: space-between;
318
+ margin-bottom: var(--space-lg);
319
+ }
320
+
321
+ .section-title {
322
+ display: flex;
323
+ align-items: center;
324
+ gap: var(--space-sm);
325
+ font-size: var(--text-xl);
326
+ font-weight: 600;
327
+ margin: 0;
328
+ }
329
+
330
+ .method-selection {
331
+ display: grid;
332
+ grid-template-columns: repeat(3, 1fr);
333
+ gap: var(--space-md);
334
+ }
335
+
336
+ .method-btn {
337
+ display: flex;
338
+ align-items: flex-start;
339
+ gap: var(--space-md);
340
+ padding: var(--space-md);
341
+ background: var(--glass-bg);
342
+ border: 2px solid var(--glass-border);
343
+ border-radius: var(--radius-lg);
344
+ cursor: pointer;
345
+ transition: all var(--transition-fast);
346
+ text-align: left;
347
+ }
348
+
349
+ .method-btn:hover {
350
+ border-color: var(--glass-border-hover);
351
+ }
352
+
353
+ .method-btn.active {
354
+ border-color: var(--color-accent-primary);
355
+ background: rgba(99, 102, 241, 0.1);
356
+ }
357
+
358
+ .method-check {
359
+ width: 24px;
360
+ height: 24px;
361
+ display: flex;
362
+ align-items: center;
363
+ justify-content: center;
364
+ border: 2px solid var(--glass-border);
365
+ border-radius: var(--radius-md);
366
+ font-size: var(--text-sm);
367
+ color: var(--color-accent-primary);
368
+ flex-shrink: 0;
369
+ }
370
+
371
+ .method-btn.active .method-check {
372
+ background: var(--color-accent-primary);
373
+ border-color: var(--color-accent-primary);
374
+ color: white;
375
+ }
376
+
377
+ .method-info {
378
+ display: flex;
379
+ flex-direction: column;
380
+ }
381
+
382
+ .method-name {
383
+ font-weight: 600;
384
+ color: var(--text-primary);
385
+ }
386
+
387
+ .method-desc {
388
+ font-size: var(--text-xs);
389
+ color: var(--text-secondary);
390
+ }
391
+
392
+ .chart-card {
393
+ padding: var(--space-lg);
394
+ }
395
+
396
+ .chart-title {
397
+ font-size: var(--text-sm);
398
+ font-weight: 600;
399
+ color: var(--text-primary);
400
+ margin-bottom: var(--space-md);
401
+ }
402
+
403
+ .results-table {
404
+ width: 100%;
405
+ border-collapse: collapse;
406
+ }
407
+
408
+ .results-table th,
409
+ .results-table td {
410
+ padding: var(--space-sm) var(--space-md);
411
+ text-align: left;
412
+ border-bottom: 1px solid var(--glass-border);
413
+ }
414
+
415
+ .results-table th {
416
+ font-size: var(--text-xs);
417
+ font-weight: 600;
418
+ color: var(--text-secondary);
419
+ text-transform: uppercase;
420
+ }
421
+
422
+ .results-table td {
423
+ font-size: var(--text-sm);
424
+ color: var(--text-primary);
425
+ }
426
+
427
+ .info-card {
428
+ display: flex;
429
+ gap: var(--space-lg);
430
+ padding: var(--space-lg);
431
+ }
432
+
433
+ .info-card h3 {
434
+ font-size: var(--text-base);
435
+ margin-bottom: var(--space-sm);
436
+ }
437
+
438
+ .info-card p {
439
+ margin: 0;
440
+ font-size: var(--text-sm);
441
+ }
442
+
443
+ .text-warning {
444
+ color: var(--color-warning);
445
+ flex-shrink: 0;
446
+ }
447
+
448
+ .spinning {
449
+ animation: spin 1s linear infinite;
450
+ }
451
+
452
+ .method-selection {
453
+ grid-template-columns: 1fr;
454
+ }
455
+ }
456
+
457
+ .btn-group {
458
+ display: flex;
459
+ gap: var(--space-xs);
460
+ max-width: 400px;
461
+ }
462
+
463
+ .source-badge {
464
+ font-size: var(--text-xs);
465
+ font-weight: 500;
466
+ padding: 4px 8px;
467
+ background: var(--glass-bg);
468
+ border: 1px solid var(--glass-border);
469
+ border-radius: var(--radius-full);
470
+ color: var(--text-secondary);
471
+ margin-left: var(--space-md);
472
+ }
473
+
474
+ .btn-group .btn {
475
+ flex: 1;
476
+ }
477
+
478
+ .mb-lg { margin-bottom: var(--space-lg); }
479
+ .mb-md { margin-bottom: var(--space-md); }
480
+ `}</style>
481
+ </div>
482
+ );
483
+ }
frontend/src/pages/Dashboard.jsx ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState } from 'react';
2
+ import { Link } from 'react-router-dom';
3
+ import {
4
+ Zap,
5
+ Cpu,
6
+ HardDrive,
7
+ TrendingUp,
8
+ ArrowRight,
9
+ Layers,
10
+ Activity,
11
+ MemoryStick
12
+ } from 'lucide-react';
13
+ import { useSystemStore, useQuantizationStore, useModelStore } from '../store';
14
+ import { motion } from 'framer-motion';
15
+
16
+ /**
17
+ * Dashboard page - overview of system and recent activity
18
+ */
19
+ export default function Dashboard() {
20
+ const systemInfo = useSystemStore((state) => state.systemInfo);
21
+ const fetchSystemInfo = useSystemStore((state) => state.fetchSystemInfo);
22
+ const quantizationHistory = useQuantizationStore((state) => state.history);
23
+ const modelInfo = useModelStore((state) => state.modelInfo);
24
+
25
+ useEffect(() => {
26
+ if (!systemInfo) {
27
+ fetchSystemInfo();
28
+ }
29
+ }, [systemInfo, fetchSystemInfo]);
30
+
31
+ const stats = [
32
+ {
33
+ label: 'GPU Status',
34
+ value: systemInfo?.cuda_available ? 'CUDA Ready' : systemInfo?.mps_available ? 'MPS Ready' : 'CPU Only',
35
+ icon: Cpu,
36
+ color: systemInfo?.cuda_available ? 'success' : 'warning',
37
+ detail: systemInfo?.gpus?.[0]?.name || 'No GPU detected'
38
+ },
39
+ {
40
+ label: 'Available RAM',
41
+ value: `${systemInfo?.ram_available_gb?.toFixed(1) || '?'}GB`,
42
+ icon: MemoryStick,
43
+ color: 'info',
44
+ detail: `of ${systemInfo?.ram_total_gb?.toFixed(1) || '?'}GB total`
45
+ },
46
+ {
47
+ label: 'Max Model Size',
48
+ value: systemInfo?.max_model_size || 'Unknown',
49
+ icon: Layers,
50
+ color: 'accent',
51
+ detail: 'Recommended limit'
52
+ },
53
+ {
54
+ label: 'Quantizations',
55
+ value: quantizationHistory.length,
56
+ icon: Activity,
57
+ color: 'success',
58
+ detail: 'This session'
59
+ }
60
+ ];
61
+
62
+ const quickActions = [
63
+ {
64
+ title: 'Quick Quantize',
65
+ description: 'Test quantization on random weights',
66
+ path: '/quantize',
67
+ icon: Zap,
68
+ gradient: 'var(--gradient-primary)'
69
+ },
70
+ {
71
+ title: 'Load Model',
72
+ description: 'Load a HuggingFace model',
73
+ path: '/models',
74
+ icon: HardDrive,
75
+ gradient: 'var(--gradient-secondary)'
76
+ },
77
+ {
78
+ title: 'Analyze Weights',
79
+ description: 'Deep dive into weight distributions',
80
+ path: '/analysis',
81
+ icon: TrendingUp,
82
+ gradient: 'linear-gradient(135deg, #10b981 0%, #06b6d4 100%)'
83
+ }
84
+ ];
85
+
86
+ return (
87
+ <div className="dashboard">
88
+ {/* Header */}
89
+ <div className="page-header">
90
+ <h1 className="page-title">Dashboard</h1>
91
+ <p className="page-subtitle">
92
+ Neural Network Weight Quantization Tool
93
+ </p>
94
+ </div>
95
+
96
+ {/* Stats Grid */}
97
+ <div className="grid grid-4 stagger">
98
+ {stats.map((stat, index) => (
99
+ <motion.div
100
+ key={stat.label}
101
+ className="glass-card stat-card"
102
+ initial={{ opacity: 0, y: 20 }}
103
+ animate={{ opacity: 1, y: 0 }}
104
+ transition={{ delay: index * 0.1 }}
105
+ >
106
+ <div className={`stat-icon ${stat.color}`}>
107
+ <stat.icon size={20} />
108
+ </div>
109
+ <div className="stat-content">
110
+ <div className="stat-value">{stat.value}</div>
111
+ <div className="stat-label">{stat.label}</div>
112
+ <div className="stat-detail">{stat.detail}</div>
113
+ </div>
114
+ </motion.div>
115
+ ))}
116
+ </div>
117
+
118
+ {/* Quick Actions */}
119
+ <section className="section">
120
+ <h2 className="section-title">Quick Actions</h2>
121
+ <div className="grid grid-3">
122
+ {quickActions.map((action, index) => (
123
+ <motion.div
124
+ key={action.path}
125
+ initial={{ opacity: 0, y: 20 }}
126
+ animate={{ opacity: 1, y: 0 }}
127
+ transition={{ delay: 0.4 + index * 0.1 }}
128
+ >
129
+ <Link to={action.path} className="action-card glass-card">
130
+ <div className="action-icon" style={{ background: action.gradient }}>
131
+ <action.icon size={24} />
132
+ </div>
133
+ <div className="action-content">
134
+ <h3 className="action-title">{action.title}</h3>
135
+ <p className="action-description">{action.description}</p>
136
+ </div>
137
+ <ArrowRight size={20} className="action-arrow" />
138
+ </Link>
139
+ </motion.div>
140
+ ))}
141
+ </div>
142
+ </section>
143
+
144
+ {/* Current Model */}
145
+ {modelInfo && (
146
+ <section className="section">
147
+ <h2 className="section-title">Loaded Model</h2>
148
+ <div className="glass-card model-info">
149
+ <div className="model-header">
150
+ <HardDrive size={24} />
151
+ <div>
152
+ <h3 className="model-name">{modelInfo.name}</h3>
153
+ <p className="model-arch">{modelInfo.architecture}</p>
154
+ </div>
155
+ </div>
156
+ <div className="model-stats">
157
+ <div className="model-stat">
158
+ <span className="stat-value">{modelInfo.num_params_billions?.toFixed(2)}B</span>
159
+ <span className="stat-label">Parameters</span>
160
+ </div>
161
+ <div className="model-stat">
162
+ <span className="stat-value">{modelInfo.num_quantizable_layers}</span>
163
+ <span className="stat-label">Quantizable Layers</span>
164
+ </div>
165
+ <div className="model-stat">
166
+ <span className="stat-value">{modelInfo.memory_footprint_gb}GB</span>
167
+ <span className="stat-label">Memory</span>
168
+ </div>
169
+ </div>
170
+ </div>
171
+ </section>
172
+ )}
173
+
174
+ {/* Getting Started */}
175
+ {!modelInfo && quantizationHistory.length === 0 && (
176
+ <section className="section">
177
+ <div className="glass-card getting-started">
178
+ <div className="getting-started-content">
179
+ <Zap size={48} className="getting-started-icon" />
180
+ <h2>Get Started</h2>
181
+ <p>
182
+ Welcome to the Neural Network Quantizer! You can either test quantization
183
+ on random weights or load a real HuggingFace model for production use.
184
+ </p>
185
+ <div className="getting-started-actions">
186
+ <Link to="/quantize" className="btn btn-primary btn-lg">
187
+ <Layers size={20} />
188
+ Try Quantization
189
+ </Link>
190
+ <Link to="/models" className="btn btn-secondary btn-lg">
191
+ <HardDrive size={20} />
192
+ Load Model
193
+ </Link>
194
+ </div>
195
+ </div>
196
+ </div>
197
+ </section>
198
+ )}
199
+
200
+ {/* System Warnings */}
201
+ {systemInfo?.warnings?.length > 0 && (
202
+ <section className="section">
203
+ <h2 className="section-title">System Warnings</h2>
204
+ <div className="warnings-list">
205
+ {systemInfo.warnings.map((warning, index) => (
206
+ <div key={index} className="warning-item glass-card">
207
+ <span className="badge badge-warning">Warning</span>
208
+ <span>{warning}</span>
209
+ </div>
210
+ ))}
211
+ </div>
212
+ </section>
213
+ )}
214
+
215
+ <style>{`
216
+ .dashboard {
217
+ max-width: 1400px;
218
+ }
219
+
220
+ .section {
221
+ margin-top: var(--space-2xl);
222
+ }
223
+
224
+ .section-title {
225
+ font-size: var(--text-xl);
226
+ font-weight: 600;
227
+ margin-bottom: var(--space-lg);
228
+ color: var(--text-primary);
229
+ }
230
+
231
+ .stat-card {
232
+ display: flex;
233
+ align-items: flex-start;
234
+ gap: var(--space-md);
235
+ }
236
+
237
+ .stat-icon {
238
+ width: 44px;
239
+ height: 44px;
240
+ display: flex;
241
+ align-items: center;
242
+ justify-content: center;
243
+ border-radius: var(--radius-lg);
244
+ flex-shrink: 0;
245
+ }
246
+
247
+ .stat-icon.success {
248
+ background: var(--color-success-bg);
249
+ color: var(--color-success);
250
+ }
251
+
252
+ .stat-icon.warning {
253
+ background: var(--color-warning-bg);
254
+ color: var(--color-warning);
255
+ }
256
+
257
+ .stat-icon.info {
258
+ background: var(--color-info-bg);
259
+ color: var(--color-info);
260
+ }
261
+
262
+ .stat-icon.accent {
263
+ background: rgba(99, 102, 241, 0.1);
264
+ color: var(--color-accent-primary);
265
+ }
266
+
267
+ .stat-content {
268
+ flex: 1;
269
+ }
270
+
271
+ .stat-card .stat-value {
272
+ font-size: var(--text-xl);
273
+ font-weight: 700;
274
+ color: var(--text-primary);
275
+ line-height: 1.2;
276
+ }
277
+
278
+ .stat-card .stat-label {
279
+ font-size: var(--text-sm);
280
+ color: var(--text-secondary);
281
+ }
282
+
283
+ .stat-detail {
284
+ font-size: var(--text-xs);
285
+ color: var(--text-tertiary);
286
+ margin-top: var(--space-xs);
287
+ }
288
+
289
+ .action-card {
290
+ display: flex;
291
+ align-items: center;
292
+ gap: var(--space-md);
293
+ text-decoration: none;
294
+ transition: all var(--transition-base);
295
+ }
296
+
297
+ .action-card:hover {
298
+ transform: translateY(-4px);
299
+ }
300
+
301
+ .action-card:hover .action-arrow {
302
+ transform: translateX(4px);
303
+ }
304
+
305
+ .action-icon {
306
+ width: 48px;
307
+ height: 48px;
308
+ display: flex;
309
+ align-items: center;
310
+ justify-content: center;
311
+ border-radius: var(--radius-lg);
312
+ color: white;
313
+ flex-shrink: 0;
314
+ }
315
+
316
+ .action-content {
317
+ flex: 1;
318
+ }
319
+
320
+ .action-title {
321
+ font-size: var(--text-base);
322
+ font-weight: 600;
323
+ color: var(--text-primary);
324
+ margin-bottom: var(--space-xs);
325
+ }
326
+
327
+ .action-description {
328
+ font-size: var(--text-sm);
329
+ color: var(--text-secondary);
330
+ margin: 0;
331
+ }
332
+
333
+ .action-arrow {
334
+ color: var(--text-tertiary);
335
+ transition: transform var(--transition-fast);
336
+ }
337
+
338
+ .model-info {
339
+ padding: var(--space-xl);
340
+ }
341
+
342
+ .model-header {
343
+ display: flex;
344
+ align-items: center;
345
+ gap: var(--space-md);
346
+ margin-bottom: var(--space-lg);
347
+ color: var(--color-accent-primary);
348
+ }
349
+
350
+ .model-name {
351
+ font-size: var(--text-lg);
352
+ font-weight: 600;
353
+ color: var(--text-primary);
354
+ }
355
+
356
+ .model-arch {
357
+ font-size: var(--text-sm);
358
+ color: var(--text-secondary);
359
+ margin: 0;
360
+ }
361
+
362
+ .model-stats {
363
+ display: flex;
364
+ gap: var(--space-2xl);
365
+ }
366
+
367
+ .model-stat {
368
+ display: flex;
369
+ flex-direction: column;
370
+ }
371
+
372
+ .getting-started {
373
+ text-align: center;
374
+ padding: var(--space-3xl);
375
+ }
376
+
377
+ .getting-started-icon {
378
+ color: var(--color-accent-primary);
379
+ margin-bottom: var(--space-lg);
380
+ }
381
+
382
+ .getting-started h2 {
383
+ margin-bottom: var(--space-md);
384
+ }
385
+
386
+ .getting-started p {
387
+ max-width: 500px;
388
+ margin: 0 auto var(--space-xl);
389
+ }
390
+
391
+ .getting-started-actions {
392
+ display: flex;
393
+ gap: var(--space-md);
394
+ justify-content: center;
395
+ }
396
+
397
+ .warnings-list {
398
+ display: flex;
399
+ flex-direction: column;
400
+ gap: var(--space-sm);
401
+ }
402
+
403
+ .warning-item {
404
+ display: flex;
405
+ align-items: center;
406
+ gap: var(--space-md);
407
+ padding: var(--space-md);
408
+ }
409
+ `}</style>
410
+ </div>
411
+ );
412
+ }
frontend/src/pages/ModelLoader.jsx ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useEffect, useRef } from 'react';
2
+ import {
3
+ Upload,
4
+ Cpu,
5
+ HardDrive,
6
+ Database,
7
+ CheckCircle,
8
+ AlertCircle,
9
+ Loader2,
10
+ Package,
11
+ Trash2,
12
+ Sparkles,
13
+ Clock,
14
+ Download
15
+ } from 'lucide-react';
16
+ import { useSystemStore } from '../store';
17
+ import { motion, AnimatePresence } from 'framer-motion';
18
+
19
+ /**
20
+ * ModelLoader page - load HuggingFace models with progress tracking
21
+ */
22
+ export default function ModelLoader() {
23
+ const systemInfo = useSystemStore((state) => state.systemInfo);
24
+
25
+ const [modelName, setModelName] = useState('');
26
+ const [exampleModels, setExampleModels] = useState(null);
27
+ const [loadResult, setLoadResult] = useState(null);
28
+ const [isLoading, setIsLoading] = useState(false);
29
+ const [progress, setProgress] = useState(null);
30
+ const [cachedModels, setCachedModels] = useState([]);
31
+ const [modelInfo, setModelInfo] = useState(null);
32
+
33
+ const progressPollRef = useRef(null);
34
+
35
+ // Fetch example models and cache info on mount
36
+ useEffect(() => {
37
+ // Optimistic load from cache
38
+ const cachedExamples = localStorage.getItem('example_models');
39
+ if (cachedExamples) {
40
+ try {
41
+ setExampleModels(JSON.parse(cachedExamples));
42
+ } catch (e) { }
43
+ }
44
+
45
+ fetch('/api/models/examples')
46
+ .then(res => res.json())
47
+ .then(data => {
48
+ setExampleModels(data);
49
+ localStorage.setItem('example_models', JSON.stringify(data));
50
+ })
51
+ .catch(() => { });
52
+
53
+ fetchCacheInfo();
54
+ fetchModelInfo();
55
+ }, []);
56
+
57
+ const fetchCacheInfo = async () => {
58
+ try {
59
+ const res = await fetch('/api/models/cache');
60
+ const data = await res.json();
61
+ setCachedModels(data.models || []);
62
+ } catch (e) { }
63
+ };
64
+
65
+ const fetchModelInfo = async () => {
66
+ try {
67
+ const res = await fetch('/api/models/info');
68
+ const data = await res.json();
69
+ if (data.loaded) {
70
+ setModelInfo(data);
71
+ }
72
+ } catch (e) { }
73
+ };
74
+
75
+ const pollProgress = (name) => {
76
+ if (progressPollRef.current) {
77
+ clearInterval(progressPollRef.current);
78
+ }
79
+
80
+ progressPollRef.current = setInterval(async () => {
81
+ try {
82
+ const res = await fetch(`/api/models/progress/${encodeURIComponent(name)}`);
83
+ const data = await res.json();
84
+ if (data.downloading) {
85
+ setProgress(data);
86
+ }
87
+ } catch (e) { }
88
+ }, 500);
89
+ };
90
+
91
+ const stopPolling = () => {
92
+ if (progressPollRef.current) {
93
+ clearInterval(progressPollRef.current);
94
+ progressPollRef.current = null;
95
+ }
96
+ };
97
+
98
+ const handleLoadModel = async () => {
99
+ if (!modelName.trim() || isLoading) return;
100
+
101
+ setIsLoading(true);
102
+ setLoadResult(null);
103
+ setProgress({ status: 'starting', percent: 0, message: 'Starting download...' });
104
+
105
+ // Start polling for progress
106
+ pollProgress(modelName.trim());
107
+
108
+ try {
109
+ const response = await fetch('/api/models/load', {
110
+ method: 'POST',
111
+ headers: { 'Content-Type': 'application/json' },
112
+ body: JSON.stringify({
113
+ model_name: modelName.trim(),
114
+ dtype: 'auto',
115
+ device: 'auto',
116
+ trust_remote_code: true
117
+ })
118
+ });
119
+
120
+ const data = await response.json();
121
+ setLoadResult(data);
122
+
123
+ if (data.success) {
124
+ setModelInfo(data.model_info);
125
+ setProgress({ status: 'complete', percent: 100, message: 'Model loaded!' });
126
+ fetchCacheInfo();
127
+ } else {
128
+ setProgress(null);
129
+ }
130
+ } catch (err) {
131
+ setLoadResult({ success: false, error: err.message });
132
+ setProgress(null);
133
+ } finally {
134
+ setIsLoading(false);
135
+ stopPolling();
136
+ }
137
+ };
138
+
139
+ const handleQuickLoad = (modelId) => {
140
+ setModelName(modelId);
141
+ };
142
+
143
+ const handleUnload = async () => {
144
+ try {
145
+ await fetch('/api/models/unload', { method: 'POST' });
146
+ setModelInfo(null);
147
+ setLoadResult(null);
148
+ setProgress(null);
149
+ } catch (e) { }
150
+ };
151
+
152
+ const handleDeleteFromCache = async (name) => {
153
+ try {
154
+ await fetch(`/api/models/cache/${encodeURIComponent(name)}`, { method: 'DELETE' });
155
+ fetchCacheInfo();
156
+ } catch (e) { }
157
+ };
158
+
159
+ const handleCleanup = async () => {
160
+ try {
161
+ const res = await fetch('/api/models/cache/cleanup', { method: 'POST' });
162
+ const data = await res.json();
163
+ fetchCacheInfo();
164
+ alert(`Cleaned up ${data.deleted_count} models`);
165
+ } catch (e) { }
166
+ };
167
+
168
+ return (
169
+ <div className="model-loader">
170
+ {/* Header */}
171
+ <div className="page-header">
172
+ <h1 className="page-title">Load HuggingFace Model</h1>
173
+ <p className="page-subtitle">
174
+ Download and analyze models directly from HuggingFace Hub
175
+ </p>
176
+ </div>
177
+
178
+ {/* Main Content */}
179
+ <div className="loader-grid">
180
+ {/* Load Model Card */}
181
+ <motion.div
182
+ className="glass-card load-card"
183
+ initial={{ opacity: 0, y: 20 }}
184
+ animate={{ opacity: 1, y: 0 }}
185
+ >
186
+ <div className="card-header">
187
+ <Package size={24} />
188
+ <h2>Load Model</h2>
189
+ </div>
190
+
191
+ <div className="input-section">
192
+ <label className="input-label">Model ID</label>
193
+ <input
194
+ type="text"
195
+ className="input"
196
+ placeholder="e.g. gpt2, bert-base-uncased, prajjwal1/bert-tiny"
197
+ value={modelName}
198
+ onChange={(e) => setModelName(e.target.value)}
199
+ onKeyDown={(e) => e.key === 'Enter' && handleLoadModel()}
200
+ disabled={isLoading}
201
+ />
202
+ <p className="input-hint">
203
+ Enter the HuggingFace model identifier (organization/model-name)
204
+ </p>
205
+ </div>
206
+
207
+ <button
208
+ className="btn btn-primary btn-lg w-full"
209
+ onClick={handleLoadModel}
210
+ disabled={isLoading || !modelName.trim()}
211
+ >
212
+ {isLoading ? (
213
+ <>
214
+ <Loader2 size={20} className="spinning" />
215
+ Loading...
216
+ </>
217
+ ) : (
218
+ <>
219
+ <Download size={20} />
220
+ Download & Load Model
221
+ </>
222
+ )}
223
+ </button>
224
+
225
+ {/* Progress Bar */}
226
+ <AnimatePresence>
227
+ {progress && (
228
+ <motion.div
229
+ className="progress-container"
230
+ initial={{ opacity: 0, height: 0 }}
231
+ animate={{ opacity: 1, height: 'auto' }}
232
+ exit={{ opacity: 0, height: 0 }}
233
+ >
234
+ <div className="progress-header">
235
+ <span className="progress-status">{progress.message || progress.status}</span>
236
+ <span className="progress-percent">{progress.percent || 0}%</span>
237
+ </div>
238
+ <div className="progress-bar">
239
+ <motion.div
240
+ className="progress-fill"
241
+ initial={{ width: 0 }}
242
+ animate={{ width: `${progress.percent || 0}%` }}
243
+ transition={{ duration: 0.3 }}
244
+ />
245
+ </div>
246
+ {progress.speed_mbps && (
247
+ <div className="progress-details">
248
+ <span>{progress.speed_mbps} MB/s</span>
249
+ {progress.eta_seconds && <span>ETA: {progress.eta_seconds}s</span>}
250
+ </div>
251
+ )}
252
+ </motion.div>
253
+ )}
254
+ </AnimatePresence>
255
+
256
+ {/* Result Message */}
257
+ <AnimatePresence>
258
+ {loadResult && !isLoading && (
259
+ <motion.div
260
+ className={`result-message ${loadResult.success ? 'success' : 'error'}`}
261
+ initial={{ opacity: 0, height: 0 }}
262
+ animate={{ opacity: 1, height: 'auto' }}
263
+ exit={{ opacity: 0, height: 0 }}
264
+ >
265
+ {loadResult.success ? (
266
+ <>
267
+ <CheckCircle size={20} />
268
+ <div>
269
+ <strong>Model loaded successfully!</strong>
270
+ <p>{loadResult.model_info?.architecture} - {loadResult.model_info?.num_params_millions}M params</p>
271
+ </div>
272
+ </>
273
+ ) : (
274
+ <>
275
+ <AlertCircle size={20} />
276
+ <div>
277
+ <strong>Failed to load model</strong>
278
+ <p>{loadResult.error}</p>
279
+ {loadResult.suggestion && <p className="suggestion">{loadResult.suggestion}</p>}
280
+ </div>
281
+ </>
282
+ )}
283
+ </motion.div>
284
+ )}
285
+ </AnimatePresence>
286
+ </motion.div>
287
+
288
+ {/* Currently Loaded Model */}
289
+ {modelInfo && (
290
+ <motion.div
291
+ className="glass-card loaded-model-card"
292
+ initial={{ opacity: 0, scale: 0.95 }}
293
+ animate={{ opacity: 1, scale: 1 }}
294
+ >
295
+ <div className="card-header">
296
+ <CheckCircle size={24} className="text-success" />
297
+ <h2>Loaded Model</h2>
298
+ <button className="btn btn-ghost btn-sm ml-auto" onClick={handleUnload}>
299
+ <Trash2 size={16} />
300
+ Unload
301
+ </button>
302
+ </div>
303
+
304
+ <div className="model-details">
305
+ <div className="detail-item">
306
+ <span className="label">Name</span>
307
+ <span className="value">{modelInfo.name}</span>
308
+ </div>
309
+ <div className="detail-item">
310
+ <span className="label">Parameters</span>
311
+ <span className="value">{modelInfo.num_params_millions}M</span>
312
+ </div>
313
+ <div className="detail-item">
314
+ <span className="label">Memory</span>
315
+ <span className="value">{modelInfo.memory_mb?.toFixed(1)} MB</span>
316
+ </div>
317
+ <div className="detail-item">
318
+ <span className="label">Device</span>
319
+ <span className="value">{modelInfo.device}</span>
320
+ </div>
321
+ <div className="detail-item">
322
+ <span className="label">Quantizable Layers</span>
323
+ <span className="value highlight">{modelInfo.num_quantizable_layers}</span>
324
+ </div>
325
+ </div>
326
+ </motion.div>
327
+ )}
328
+
329
+ {/* Quick Start */}
330
+ <motion.div
331
+ className="glass-card"
332
+ initial={{ opacity: 0, y: 20 }}
333
+ animate={{ opacity: 1, y: 0 }}
334
+ transition={{ delay: 0.1 }}
335
+ >
336
+ <div className="card-header">
337
+ <Sparkles size={24} />
338
+ <h2>Quick Start</h2>
339
+ </div>
340
+
341
+ <p className="text-sm text-muted mb-md">Click to select a model:</p>
342
+
343
+ {exampleModels ? (
344
+ <>
345
+ {exampleModels.sample_models?.length > 0 && (
346
+ <div className="model-group">
347
+ <h4 className="group-title">⭐ Sample Models (Pre-cached)</h4>
348
+ <div className="model-list">
349
+ {exampleModels.sample_models.map((model) => (
350
+ <button
351
+ key={model.id}
352
+ className={`model-chip sample ${modelName === model.id ? 'selected' : ''}`}
353
+ onClick={() => handleQuickLoad(model.id)}
354
+ >
355
+ <span className="model-id">{model.id}</span>
356
+ <span className="model-desc">Instant load</span>
357
+ </button>
358
+ ))}
359
+ </div>
360
+ </div>
361
+ )}
362
+
363
+ <div className="model-group">
364
+ <h4 className="group-title">Small Models</h4>
365
+ <div className="model-list">
366
+ {exampleModels.small_models?.map((model) => (
367
+ <button
368
+ key={model.id}
369
+ className={`model-chip ${modelName === model.id ? 'selected' : ''}`}
370
+ onClick={() => handleQuickLoad(model.id)}
371
+ >
372
+ <span className="model-id">{model.id}</span>
373
+ <span className="model-size">{model.size}</span>
374
+ </button>
375
+ ))}
376
+ </div>
377
+ </div>
378
+ </>
379
+ ) : (
380
+ <div className="loading-placeholder">
381
+ <Loader2 size={20} className="spinning" />
382
+ <span>Loading examples...</span>
383
+ </div>
384
+ )}
385
+ </motion.div>
386
+
387
+ {/* System Status */}
388
+ <motion.div
389
+ className="glass-card"
390
+ initial={{ opacity: 0, y: 20 }}
391
+ animate={{ opacity: 1, y: 0 }}
392
+ transition={{ delay: 0.2 }}
393
+ >
394
+ <div className="card-header">
395
+ <Cpu size={24} />
396
+ <h2>System</h2>
397
+ </div>
398
+
399
+ {systemInfo ? (
400
+ <div className="status-list">
401
+ <div className="status-item">
402
+ <span className="status-label">Device</span>
403
+ <span className="status-value">
404
+ {systemInfo.cuda_available ? '🟢 CUDA GPU' :
405
+ systemInfo.mps_available ? '🟢 Apple MPS' : '🟡 CPU'}
406
+ </span>
407
+ </div>
408
+
409
+ {systemInfo.gpus?.length > 0 && (
410
+ <div className="status-item">
411
+ <span className="status-label">GPU</span>
412
+ <span className="status-value">{systemInfo.gpus[0].name}</span>
413
+ </div>
414
+ )}
415
+
416
+ <div className="status-item">
417
+ <span className="status-label">RAM</span>
418
+ <span className="status-value">{systemInfo.ram_available_gb?.toFixed(1)} GB</span>
419
+ </div>
420
+ </div>
421
+ ) : (
422
+ <p className="text-muted">Loading...</p>
423
+ )}
424
+ </motion.div>
425
+
426
+ {/* Cached Models */}
427
+ <motion.div
428
+ className="glass-card cache-card"
429
+ initial={{ opacity: 0, y: 20 }}
430
+ animate={{ opacity: 1, y: 0 }}
431
+ transition={{ delay: 0.3 }}
432
+ >
433
+ <div className="card-header">
434
+ <Database size={24} />
435
+ <h2>Model Cache</h2>
436
+ <button className="btn btn-ghost btn-sm ml-auto" onClick={handleCleanup}>
437
+ <Clock size={16} />
438
+ Cleanup
439
+ </button>
440
+ </div>
441
+
442
+ <p className="text-xs text-muted mb-sm">
443
+ Models auto-delete after 4 hours (except samples)
444
+ </p>
445
+
446
+ {cachedModels.length > 0 ? (
447
+ <div className="cache-list">
448
+ {cachedModels.map((model) => (
449
+ <div key={model.name} className={`cache-item ${model.is_sample ? 'sample' : ''}`}>
450
+ <div className="cache-info">
451
+ <span className="cache-name">
452
+ {model.is_sample && '⭐ '}
453
+ {model.name}
454
+ </span>
455
+ <span className="cache-size">{model.size_mb} MB</span>
456
+ </div>
457
+ {!model.is_sample && (
458
+ <button
459
+ className="btn btn-ghost btn-xs"
460
+ onClick={() => handleDeleteFromCache(model.name)}
461
+ >
462
+ <Trash2 size={14} />
463
+ </button>
464
+ )}
465
+ </div>
466
+ ))}
467
+ </div>
468
+ ) : (
469
+ <p className="text-muted text-sm">No models cached</p>
470
+ )}
471
+ </motion.div>
472
+ </div>
473
+
474
+ <style>{`
475
+ .loader-grid {
476
+ display: grid;
477
+ grid-template-columns: 1fr 1fr;
478
+ gap: var(--space-lg);
479
+ }
480
+
481
+ @media (max-width: 1024px) {
482
+ .loader-grid {
483
+ grid-template-columns: 1fr;
484
+ }
485
+ }
486
+
487
+ .load-card {
488
+ grid-column: span 2;
489
+ }
490
+
491
+ @media (max-width: 1024px) {
492
+ .load-card {
493
+ grid-column: span 1;
494
+ }
495
+ }
496
+
497
+ .loaded-model-card {
498
+ grid-column: span 2;
499
+ background: rgba(16, 185, 129, 0.05);
500
+ border-color: rgba(16, 185, 129, 0.3);
501
+ }
502
+
503
+ .cache-card {
504
+ grid-column: span 2;
505
+ }
506
+
507
+ .card-header {
508
+ display: flex;
509
+ align-items: center;
510
+ gap: var(--space-sm);
511
+ margin-bottom: var(--space-lg);
512
+ color: var(--text-primary);
513
+ }
514
+
515
+ .card-header h2 {
516
+ font-size: var(--text-lg);
517
+ font-weight: 600;
518
+ margin: 0;
519
+ }
520
+
521
+ .input-section {
522
+ margin-bottom: var(--space-lg);
523
+ }
524
+
525
+ .input-hint {
526
+ font-size: var(--text-xs);
527
+ color: var(--text-tertiary);
528
+ margin-top: var(--space-xs);
529
+ }
530
+
531
+ /* Progress Bar */
532
+ .progress-container {
533
+ margin-top: var(--space-lg);
534
+ padding: var(--space-md);
535
+ background: var(--glass-bg);
536
+ border-radius: var(--radius-md);
537
+ }
538
+
539
+ .progress-header {
540
+ display: flex;
541
+ justify-content: space-between;
542
+ margin-bottom: var(--space-sm);
543
+ font-size: var(--text-sm);
544
+ }
545
+
546
+ .progress-status {
547
+ color: var(--text-secondary);
548
+ }
549
+
550
+ .progress-percent {
551
+ color: var(--color-accent-primary);
552
+ font-weight: 600;
553
+ }
554
+
555
+ .progress-bar {
556
+ height: 8px;
557
+ background: rgba(255, 255, 255, 0.1);
558
+ border-radius: 4px;
559
+ overflow: hidden;
560
+ }
561
+
562
+ .progress-fill {
563
+ height: 100%;
564
+ background: linear-gradient(90deg, var(--color-accent-primary), var(--color-accent-secondary));
565
+ border-radius: 4px;
566
+ }
567
+
568
+ .progress-details {
569
+ display: flex;
570
+ justify-content: space-between;
571
+ margin-top: var(--space-xs);
572
+ font-size: var(--text-xs);
573
+ color: var(--text-tertiary);
574
+ }
575
+
576
+ .result-message {
577
+ display: flex;
578
+ align-items: flex-start;
579
+ gap: var(--space-md);
580
+ padding: var(--space-md);
581
+ border-radius: var(--radius-md);
582
+ margin-top: var(--space-md);
583
+ }
584
+
585
+ .result-message.success {
586
+ background: rgba(16, 185, 129, 0.1);
587
+ border: 1px solid rgba(16, 185, 129, 0.3);
588
+ color: var(--color-success);
589
+ }
590
+
591
+ .result-message.error {
592
+ background: rgba(239, 68, 68, 0.1);
593
+ border: 1px solid rgba(239, 68, 68, 0.3);
594
+ color: var(--color-error);
595
+ }
596
+
597
+ .result-message strong {
598
+ display: block;
599
+ }
600
+
601
+ .result-message p {
602
+ margin: var(--space-xs) 0 0 0;
603
+ font-size: var(--text-sm);
604
+ opacity: 0.9;
605
+ }
606
+
607
+ .model-details {
608
+ display: grid;
609
+ grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
610
+ gap: var(--space-sm);
611
+ }
612
+
613
+ .detail-item {
614
+ display: flex;
615
+ flex-direction: column;
616
+ padding: var(--space-sm);
617
+ background: var(--glass-bg);
618
+ border-radius: var(--radius-md);
619
+ }
620
+
621
+ .detail-item .label {
622
+ font-size: var(--text-xs);
623
+ color: var(--text-tertiary);
624
+ }
625
+
626
+ .detail-item .value {
627
+ font-size: var(--text-base);
628
+ font-weight: 500;
629
+ color: var(--text-primary);
630
+ }
631
+
632
+ .detail-item .value.highlight {
633
+ color: var(--color-accent-primary);
634
+ }
635
+
636
+ .model-group {
637
+ margin-bottom: var(--space-lg);
638
+ }
639
+
640
+ .group-title {
641
+ font-size: var(--text-xs);
642
+ font-weight: 600;
643
+ color: var(--text-secondary);
644
+ text-transform: uppercase;
645
+ margin-bottom: var(--space-sm);
646
+ }
647
+
648
+ .model-list {
649
+ display: flex;
650
+ flex-wrap: wrap;
651
+ gap: var(--space-sm);
652
+ }
653
+
654
+ .model-chip {
655
+ display: flex;
656
+ flex-direction: column;
657
+ padding: var(--space-sm) var(--space-md);
658
+ background: var(--glass-bg);
659
+ border: 1px solid var(--glass-border);
660
+ border-radius: var(--radius-md);
661
+ cursor: pointer;
662
+ transition: all var(--transition-fast);
663
+ text-align: left;
664
+ }
665
+
666
+ .model-chip:hover {
667
+ border-color: var(--glass-border-hover);
668
+ transform: translateY(-1px);
669
+ }
670
+
671
+ .model-chip.selected {
672
+ border-color: var(--color-accent-primary);
673
+ background: rgba(99, 102, 241, 0.1);
674
+ }
675
+
676
+ .model-chip.sample {
677
+ border-color: rgba(16, 185, 129, 0.4);
678
+ background: rgba(16, 185, 129, 0.1);
679
+ }
680
+
681
+ .model-id {
682
+ font-size: var(--text-sm);
683
+ font-weight: 500;
684
+ color: var(--text-primary);
685
+ }
686
+
687
+ .model-size, .model-desc {
688
+ font-size: var(--text-xs);
689
+ color: var(--text-tertiary);
690
+ }
691
+
692
+ .status-list {
693
+ display: flex;
694
+ flex-direction: column;
695
+ gap: var(--space-xs);
696
+ }
697
+
698
+ .status-item {
699
+ display: flex;
700
+ justify-content: space-between;
701
+ padding: var(--space-xs) 0;
702
+ border-bottom: 1px solid var(--glass-border);
703
+ }
704
+
705
+ .status-item:last-child {
706
+ border-bottom: none;
707
+ }
708
+
709
+ .status-label {
710
+ font-size: var(--text-sm);
711
+ color: var(--text-secondary);
712
+ }
713
+
714
+ .status-value {
715
+ font-size: var(--text-sm);
716
+ font-weight: 500;
717
+ color: var(--text-primary);
718
+ }
719
+
720
+ .cache-list {
721
+ display: flex;
722
+ flex-direction: column;
723
+ gap: var(--space-xs);
724
+ }
725
+
726
+ .cache-item {
727
+ display: flex;
728
+ align-items: center;
729
+ justify-content: space-between;
730
+ padding: var(--space-sm);
731
+ background: var(--glass-bg);
732
+ border-radius: var(--radius-md);
733
+ }
734
+
735
+ .cache-item.sample {
736
+ background: rgba(16, 185, 129, 0.05);
737
+ }
738
+
739
+ .cache-info {
740
+ display: flex;
741
+ flex-direction: column;
742
+ }
743
+
744
+ .cache-name {
745
+ font-size: var(--text-sm);
746
+ color: var(--text-primary);
747
+ }
748
+
749
+ .cache-size {
750
+ font-size: var(--text-xs);
751
+ color: var(--text-tertiary);
752
+ }
753
+
754
+ .ml-auto {
755
+ margin-left: auto;
756
+ }
757
+
758
+ .text-success {
759
+ color: var(--color-success);
760
+ }
761
+
762
+ .spinning {
763
+ animation: spin 1s linear infinite;
764
+ }
765
+
766
+ .loading-placeholder {
767
+ display: flex;
768
+ align-items: center;
769
+ gap: var(--space-sm);
770
+ color: var(--text-tertiary);
771
+ }
772
+ `}</style>
773
+ </div>
774
+ );
775
+ }