Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, HTTPException, WebSocket, WebSocketDisconnect | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import uvicorn | |
| import torch | |
| from scripts.model import Net | |
| from scripts.training.train import train, start_comparison_training | |
| from pathlib import Path | |
| from fastapi import BackgroundTasks | |
| import warnings | |
| import asyncio | |
| import json | |
| import numpy as np | |
| warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms") | |
| app = FastAPI() | |
| # Mount static files with a name parameter | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Model configurations | |
| class TrainingConfig(BaseModel): | |
| block1: int | |
| block2: int | |
| block3: int | |
| optimizer: str | |
| batch_size: int | |
| epochs: int = 1 | |
| class ComparisonConfig(BaseModel): | |
| model1: TrainingConfig | |
| model2: TrainingConfig | |
| def get_available_models(): | |
| models_dir = Path("scripts/training/models") | |
| if not models_dir.exists(): | |
| models_dir.mkdir(exist_ok=True, parents=True) | |
| return [f.stem for f in models_dir.glob("*.pth")] | |
| # Add a global variable to store training task | |
| training_task = None | |
| async def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def train_page(request: Request): | |
| return templates.TemplateResponse("train.html", {"request": request}) | |
| async def inference_page(request: Request): | |
| available_models = get_available_models() | |
| return templates.TemplateResponse( | |
| "inference.html", | |
| { | |
| "request": request, | |
| "available_models": available_models | |
| } | |
| ) | |
| async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks): | |
| try: | |
| # Create model instance with the configuration | |
| model = Net( | |
| kernels=[config.block1, config.block2, config.block3] | |
| ) | |
| # Store training configuration | |
| training_config = { | |
| "optimizer": config.optimizer, | |
| "batch_size": config.batch_size | |
| } | |
| return {"status": "success", "message": "Training configuration received"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| try: | |
| print("WebSocket connection accepted for single model training") | |
| config_data = await websocket.receive_json() | |
| print(f"Received config data: {config_data}") | |
| model = Net( | |
| kernels=[ | |
| config_data['block1'], | |
| config_data['block2'], | |
| config_data['block3'] | |
| ] | |
| ) | |
| # Create TrainingConfig object for single model using **kwargs | |
| config = TrainingConfig(**{ | |
| 'block1': config_data['block1'], | |
| 'block2': config_data['block2'], | |
| 'block3': config_data['block3'], | |
| 'optimizer': config_data['optimizer'], | |
| 'batch_size': config_data['batch_size'], | |
| 'epochs': config_data['epochs'] | |
| }) | |
| print(f"Starting training with config: {config_data}") | |
| try: | |
| await train(model, config, websocket, model_type="single") | |
| except Exception as e: | |
| print(f"Training error: {str(e)}") | |
| await websocket.send_json({ | |
| "type": "training_error", | |
| "data": { | |
| "message": f"Training failed: {str(e)}" | |
| } | |
| }) | |
| except WebSocketDisconnect: | |
| print("WebSocket disconnected") | |
| except Exception as e: | |
| print(f"WebSocket error: {str(e)}") | |
| await websocket.send_json({ | |
| "type": "training_error", | |
| "data": { | |
| "message": f"WebSocket error: {str(e)}" | |
| } | |
| }) | |
| finally: | |
| print("WebSocket connection closed") | |
| async def websocket_endpoint(websocket: WebSocket): | |
| print("\n=== New WebSocket Connection ===") | |
| print("New WebSocket connection attempt") | |
| try: | |
| await websocket.accept() | |
| print("WebSocket connection accepted") | |
| print("Waiting for initial message...") | |
| data = await websocket.receive_json() | |
| print(f"Received initial message: {data}") | |
| if 'action' not in data: | |
| print("Error: Missing 'action' in message") | |
| await websocket.send_json({ | |
| 'status': 'error', | |
| 'message': 'Missing action in request' | |
| }) | |
| return | |
| if data['action'] == 'start_training': | |
| if 'parameters' not in data: | |
| print("Error: Missing 'parameters' in message") | |
| await websocket.send_json({ | |
| 'status': 'error', | |
| 'message': 'Missing parameters in request' | |
| }) | |
| return | |
| print("Starting training task") | |
| try: | |
| training_task = asyncio.create_task(start_comparison_training( | |
| websocket, | |
| data['parameters'] | |
| )) | |
| print("Training task created, awaiting completion...") | |
| await training_task | |
| print("Training task completed") | |
| except Exception as e: | |
| print(f"Error during training task: {str(e)}") | |
| await websocket.send_json({ | |
| 'status': 'error', | |
| 'message': f'Training error: {str(e)}' | |
| }) | |
| else: | |
| print(f"Unknown action received: {data['action']}") | |
| except WebSocketDisconnect: | |
| print("WebSocket disconnected") | |
| except json.JSONDecodeError as e: | |
| print(f"JSON decode error: {str(e)}") | |
| except Exception as e: | |
| print(f"Unexpected error in websocket handler: {str(e)}") | |
| finally: | |
| print("=== WebSocket Connection Closed ===\n") | |
| # @app.post("/api/train_single") | |
| # async def train_single_model(config: TrainingConfig): | |
| # try: | |
| # model = Net(kernels=config.kernels) | |
| # # Start training without passing the websocket | |
| # await train(model, config) | |
| # return {"status": "success"} | |
| # except Exception as e: | |
| # # Log the error for debugging | |
| # print(f"Error during training: {str(e)}") | |
| # # Return a JSON response with the error message | |
| # raise HTTPException(status_code=500, detail=f"Error during training: {str(e)}") | |
| async def train_compare_models(config: ComparisonConfig): | |
| try: | |
| # Train both models | |
| model1 = Net(kernels=config.model1.kernels) | |
| model2 = Net(kernels=config.model2.kernels) | |
| results1 = train(model1, config.model1) | |
| results2 = train(model2, config.model2) | |
| return { | |
| "status": "success", | |
| "model1_results": results1, | |
| "model2_results": results2 | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def parse_model_filename(filename): | |
| """Extract configuration from model filename""" | |
| # Example filename: single_arch_32_64_128_opt_adam_batch_64_20240322_123456.pth | |
| try: | |
| parts = filename.split('_') | |
| # Find architecture values | |
| arch_index = parts.index('arch') | |
| block1 = int(parts[arch_index + 1]) | |
| block2 = int(parts[arch_index + 2]) | |
| block3 = int(parts[arch_index + 3]) | |
| # Find optimizer | |
| opt_index = parts.index('opt') | |
| optimizer = parts[opt_index + 1] | |
| # Find batch size | |
| batch_index = parts.index('batch') | |
| batch_size = int(parts[batch_index + 1]) | |
| return { | |
| 'block1': block1, | |
| 'block2': block2, | |
| 'block3': block3, | |
| 'optimizer': optimizer, | |
| 'batch_size': batch_size | |
| } | |
| except Exception as e: | |
| print(f"Error parsing model filename: {e}") | |
| return None | |
| async def perform_inference(data: dict): | |
| try: | |
| model_name = data.get("model_name") | |
| if not model_name: | |
| raise HTTPException(status_code=400, detail="No model selected") | |
| model_path = Path("scripts/training/models") / f"{model_name}.pth" | |
| if not model_path.exists(): | |
| raise HTTPException(status_code=404, detail=f"Model not found: {model_path}") | |
| # Parse model configuration from filename | |
| config = parse_model_filename(model_name) | |
| if not config: | |
| raise HTTPException(status_code=500, detail="Could not parse model configuration") | |
| # Create model with the correct configuration | |
| model = Net( | |
| kernels=[ | |
| config['block1'], | |
| config['block2'], | |
| config['block3'] | |
| ] | |
| ) | |
| # Load model weights | |
| model.load_state_dict(torch.load(str(model_path), map_location=torch.device('cpu'), weights_only=True)) | |
| model.eval() | |
| # Process image data and get prediction | |
| image_data = data.get("image") | |
| if not image_data: | |
| raise HTTPException(status_code=400, detail="No image data provided") | |
| # Convert base64 image to tensor and process | |
| try: | |
| # Remove the data URL prefix | |
| image_data = image_data.split(',')[1] | |
| import base64 | |
| import io | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| # Decode base64 to image | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(io.BytesIO(image_bytes)).convert('L') # Convert to grayscale | |
| # Resize using PIL directly with LANCZOS | |
| image = image.resize((28, 28), Image.LANCZOS) | |
| # Invert the image (subtract from 255 to invert grayscale) | |
| image = Image.fromarray(255 - np.array(image)) | |
| # Preprocess image | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| # Convert to tensor and add batch dimension | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Get prediction | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| prediction = output.argmax(dim=1).item() | |
| # Add configuration info to response | |
| return { | |
| "prediction": prediction, | |
| "model_config": { | |
| "architecture": f"{config['block1']}-{config['block2']}-{config['block3']}", | |
| "optimizer": config['optimizer'], | |
| "batch_size": config['batch_size'] | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def train_single_page(request: Request): | |
| return templates.TemplateResponse("train_single.html", {"request": request}) | |
| async def train_compare_page(request: Request): | |
| return templates.TemplateResponse("train_compare.html", {"request": request}) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |