Spaces:
Running
Running
| import torch as t | |
| import asyncio | |
| import os | |
| import shutil | |
| from backend.utils import predict_image | |
| from backend.model import EMNIST_VGG | |
| from pydantic import BaseModel | |
| from fastapi import FastAPI | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from contextlib import asynccontextmanager | |
| from huggingface_hub import hf_hub_download | |
| # ---------------------------- | |
| # LIFESPAN stuff (basically for startup/shutdown since we run with --reload) | |
| async def lifespan(app: FastAPI): | |
| try: | |
| yield | |
| except asyncio.CancelledError: | |
| print("Code likely edited, restarting server...") | |
| return # Suppressing annoying tracebacks on --reload | |
| except Exception: | |
| # real startup/shutdown failure | |
| raise | |
| # LIFESPAN END | |
| # ---------------------------- | |
| app = FastAPI(lifespan=lifespan) | |
| device = t.device("cuda" if t.cuda.is_available() else "cpu") | |
| print(f"Server running on: {device}") | |
| # ---------------------------- | |
| # MODEL DOWNLOAD | |
| # 1. Setup paths | |
| REPO_ID = "compendious/EMNIST-OCR-WEIGHTS" | |
| FILENAME = "EMNIST_CNN.pth" | |
| # NOTE: If I ever make this repo private, I need to add authentication tokens to hf_hub_download calls. | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_PATH = os.path.join(BASE_DIR, FILENAME) | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"Model weights not found. Downloading from {REPO_ID}...") | |
| # This downloads the file and returns the local path | |
| downloaded_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) | |
| shutil.copy(downloaded_path, MODEL_PATH) | |
| print(f"Weights secured at {MODEL_PATH}") | |
| # MODEL DOWNLOAD END | |
| # ---------------------------- | |
| # Instantiate the empty architecture first | |
| model = EMNIST_VGG(num_classes=62).to(device) | |
| # Load the weights safely | |
| # Note: If this fails, it means your file is still the old "full model" format. | |
| # If so, re-run your training script to generate a clean state_dict. | |
| try: | |
| model.load_state_dict(t.load(MODEL_PATH, map_location=device, weights_only=True)) | |
| except Exception as e: | |
| print("State dict load failed, trying legacy full-model load:", e) | |
| model = t.load(MODEL_PATH, map_location=device, weights_only=False) | |
| model.eval() | |
| app.mount("/static", StaticFiles(directory="frontend"), name="static") | |
| async def read_index(): | |
| path = os.path.join("frontend", "index.html") | |
| return FileResponse(path) | |
| class PredictRequest(BaseModel): | |
| image: list[float] # flat 28*28 array | |
| k: int = 10 # number of top predictions to return | |
| def predict(req: PredictRequest): | |
| print(f"Predicting... +{1+1}") | |
| # top_k currently set to 10 to preserve existing behavior | |
| return predict_image(req.image, model, device, top_k=req.k) | |