File size: 2,781 Bytes
a5a80dd
 
 
e6f5f64
a5a80dd
 
 
 
 
 
 
 
 
 
 
72ebe85
 
 
 
a5a80dd
 
 
 
 
 
 
 
 
 
 
 
72ebe85
 
a5a80dd
 
 
 
 
 
72ebe85
 
 
 
 
6042735
72ebe85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5a80dd
 
 
 
 
 
 
72ebe85
a5a80dd
72ebe85
 
a5a80dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch as t
import asyncio
import os
import shutil

from backend.utils import predict_image
from backend.model import EMNIST_VGG

from pydantic import BaseModel

from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager

from huggingface_hub import hf_hub_download

# ----------------------------
# LIFESPAN stuff (basically for startup/shutdown since we run with --reload)

@asynccontextmanager
async def lifespan(app: FastAPI):
    try:
        yield
    except asyncio.CancelledError:
        print("Code likely edited, restarting server...")
        return # Suppressing annoying tracebacks on --reload
    except Exception:
        # real startup/shutdown failure
        raise

# LIFESPAN END
# ----------------------------

app = FastAPI(lifespan=lifespan)

device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(f"Server running on: {device}")


# ----------------------------
# MODEL DOWNLOAD

# 1. Setup paths
REPO_ID = "compendious/EMNIST-OCR-WEIGHTS"
FILENAME = "EMNIST_CNN.pth"

# NOTE: If I ever make this repo private, I need to add authentication tokens to hf_hub_download calls.

BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 
MODEL_PATH = os.path.join(BASE_DIR, FILENAME)

if not os.path.exists(MODEL_PATH):
    print(f"Model weights not found. Downloading from {REPO_ID}...")
    # This downloads the file and returns the local path
    downloaded_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
    
    shutil.copy(downloaded_path, MODEL_PATH)
    print(f"Weights secured at {MODEL_PATH}")

# MODEL DOWNLOAD END
# ----------------------------


# Instantiate the empty architecture first
model = EMNIST_VGG(num_classes=62).to(device)

# Load the weights safely
# Note: If this fails, it means your file is still the old "full model" format.
# If so, re-run your training script to generate a clean state_dict.
try:
    model.load_state_dict(t.load(MODEL_PATH, map_location=device, weights_only=True))
except Exception as e:
    print("State dict load failed, trying legacy full-model load:", e)
    model = t.load(MODEL_PATH, map_location=device, weights_only=False)

model.eval()

app.mount("/static", StaticFiles(directory="frontend"), name="static")

@app.get("/")
async def read_index():
    path = os.path.join("frontend", "index.html")
    return FileResponse(path)

class PredictRequest(BaseModel):
    image: list[float]  # flat 28*28 array
    k: int = 10  # number of top predictions to return

@app.post("/predict")
def predict(req: PredictRequest):
    print(f"Predicting... +{1+1}")
    # top_k currently set to 10 to preserve existing behavior
    return predict_image(req.image, model, device, top_k=req.k)