Spaces:
Running
Running
| """ | |
| FastAPI backend for Model Inspector. | |
| Provides endpoints to inspect model architectures from HuggingFace. | |
| """ | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, HTMLResponse | |
| from pydantic import BaseModel | |
| from architecture_parser import parse_model, parse_config, format_params | |
| # Get paths | |
| BACKEND_DIR = Path(__file__).parent | |
| FRONTEND_DIR = BACKEND_DIR.parent / "frontend" | |
| app = FastAPI( | |
| title="Model Inspector API", | |
| description="Inspect transformer model architectures without downloading weights", | |
| version="1.0.0" | |
| ) | |
| # CORS for local development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class InspectRequest(BaseModel): | |
| model_config = {"protected_namespaces": ()} | |
| model_id: Optional[str] = None | |
| config: Optional[Dict[str, Any]] = None | |
| class ModelMetadata(BaseModel): | |
| model_config = {"protected_namespaces": ()} | |
| model_id: Optional[str] | |
| model_type: str | |
| total_params: int | |
| formatted_params: str | |
| class InspectResponse(BaseModel): | |
| pipeline: Dict[str, Any] | |
| metadata: ModelMetadata | |
| async def inspect_model(request: InspectRequest): | |
| """ | |
| Inspect a model architecture. | |
| Provide either: | |
| - model_id: HuggingFace model ID (e.g., "meta-llama/Llama-2-7b-hf") | |
| - config: Direct config.json object | |
| """ | |
| if request.model_id is None and request.config is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Must provide either model_id or config" | |
| ) | |
| try: | |
| if request.model_id is not None: | |
| # Parse from HuggingFace model ID | |
| pipeline = parse_model(request.model_id) | |
| model_id = request.model_id | |
| else: | |
| # Parse from config dict | |
| pipeline = parse_config(request.config) | |
| model_id = None | |
| metadata = ModelMetadata( | |
| model_id=model_id, | |
| model_type=pipeline.get("model_type", "unknown"), | |
| total_params=pipeline.get("params", 0), | |
| formatted_params=pipeline.get("formatted_params", "0"), | |
| ) | |
| return InspectResponse(pipeline=pipeline, metadata=metadata) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error inspecting model: {str(e)}" | |
| ) | |
| async def upload_config(file: UploadFile = File(...)): | |
| """ | |
| Upload a config.json file for inspection. | |
| """ | |
| if not file.filename.endswith('.json'): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="File must be a JSON file" | |
| ) | |
| try: | |
| content = await file.read() | |
| config_dict = json.loads(content) | |
| except json.JSONDecodeError: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid JSON file" | |
| ) | |
| try: | |
| pipeline = parse_config(config_dict) | |
| metadata = ModelMetadata( | |
| model_id=None, | |
| model_type=pipeline.get("model_type", "unknown"), | |
| total_params=pipeline.get("params", 0), | |
| formatted_params=pipeline.get("formatted_params", "0"), | |
| ) | |
| return InspectResponse(pipeline=pipeline, metadata=metadata) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error parsing config: {str(e)}" | |
| ) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy"} | |
| # Serve frontend static files | |
| async def serve_index(): | |
| """Serve the main index.html.""" | |
| return FileResponse(FRONTEND_DIR / "index.html") | |
| async def serve_css(path: str): | |
| """Serve CSS files.""" | |
| file_path = FRONTEND_DIR / "css" / path | |
| if file_path.exists(): | |
| return FileResponse(file_path, media_type="text/css") | |
| raise HTTPException(status_code=404, detail="File not found") | |
| async def serve_js(path: str): | |
| """Serve JavaScript files.""" | |
| file_path = FRONTEND_DIR / "js" / path | |
| if file_path.exists(): | |
| return FileResponse(file_path, media_type="application/javascript") | |
| raise HTTPException(status_code=404, detail="File not found") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |