EMNIST-OCR / backend /app.py
compendious's picture
Upload folder using huggingface_hub
e6f5f64 verified
import torch as t
import asyncio
import os
import shutil
from backend.utils import predict_image
from backend.model import EMNIST_VGG
from pydantic import BaseModel
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from huggingface_hub import hf_hub_download
# ----------------------------
# LIFESPAN stuff (basically for startup/shutdown since we run with --reload)
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
yield
except asyncio.CancelledError:
print("Code likely edited, restarting server...")
return # Suppressing annoying tracebacks on --reload
except Exception:
# real startup/shutdown failure
raise
# LIFESPAN END
# ----------------------------
app = FastAPI(lifespan=lifespan)
device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(f"Server running on: {device}")
# ----------------------------
# MODEL DOWNLOAD
# 1. Setup paths
REPO_ID = "compendious/EMNIST-OCR-WEIGHTS"
FILENAME = "EMNIST_CNN.pth"
# NOTE: If I ever make this repo private, I need to add authentication tokens to hf_hub_download calls.
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(BASE_DIR, FILENAME)
if not os.path.exists(MODEL_PATH):
print(f"Model weights not found. Downloading from {REPO_ID}...")
# This downloads the file and returns the local path
downloaded_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
shutil.copy(downloaded_path, MODEL_PATH)
print(f"Weights secured at {MODEL_PATH}")
# MODEL DOWNLOAD END
# ----------------------------
# Instantiate the empty architecture first
model = EMNIST_VGG(num_classes=62).to(device)
# Load the weights safely
# Note: If this fails, it means your file is still the old "full model" format.
# If so, re-run your training script to generate a clean state_dict.
try:
model.load_state_dict(t.load(MODEL_PATH, map_location=device, weights_only=True))
except Exception as e:
print("State dict load failed, trying legacy full-model load:", e)
model = t.load(MODEL_PATH, map_location=device, weights_only=False)
model.eval()
app.mount("/static", StaticFiles(directory="frontend"), name="static")
@app.get("/")
async def read_index():
path = os.path.join("frontend", "index.html")
return FileResponse(path)
class PredictRequest(BaseModel):
image: list[float] # flat 28*28 array
k: int = 10 # number of top predictions to return
@app.post("/predict")
def predict(req: PredictRequest):
print(f"Predicting... +{1+1}")
# top_k currently set to 10 to preserve existing behavior
return predict_image(req.image, model, device, top_k=req.k)