Abdul Rafay commited on
Commit
b19084f
·
1 Parent(s): f83b93d

uploaded model to huggingface

Browse files
Files changed (1) hide show
  1. main.py +27 -18
main.py CHANGED
@@ -2,14 +2,10 @@ import io
2
  import torch
3
  import torchvision.transforms as transforms
4
  from fastapi import FastAPI, UploadFile, File
5
- from fastapi.responses import JSONResponse, FileResponse
6
- from fastapi.staticfiles import StaticFiles
7
- from PIL import Image, ImageOps
8
  from model import Model
9
- import os
10
-
11
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
12
- MODEL_PATH = os.path.join(BASE_DIR, "model.pt")
13
 
14
  transform = transforms.Compose([
15
  transforms.Resize((28, 28)),
@@ -30,20 +26,33 @@ app.add_middleware(
30
  allow_headers=["*"],
31
  )
32
 
33
- @app.post("/predict")
34
- async def predict(file: UploadFile = File(...)):
 
 
 
 
 
 
 
35
 
36
- # -----------------------
37
- # Load Model
38
- # -----------------------
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- try:
41
- model = Model()
42
- model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
43
- except RuntimeError:
44
- model = torch.load(MODEL_PATH, map_location=device)
 
45
  model.eval()
46
 
 
 
 
 
 
47
  image_bytes = await file.read()
48
  image = Image.open(io.BytesIO(image_bytes)).convert("L")
49
  image = transform(image).unsqueeze(0)
 
2
  import torch
3
  import torchvision.transforms as transforms
4
  from fastapi import FastAPI, UploadFile, File
5
+ from fastapi.responses import JSONResponse
6
+ from PIL import Image
 
7
  from model import Model
8
+ from huggingface_hub import hf_hub_download
 
 
 
9
 
10
  transform = transforms.Compose([
11
  transforms.Resize((28, 28)),
 
26
  allow_headers=["*"],
27
  )
28
 
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ model = None
32
+
33
+ @app.on_event("startup")
34
+ def load_model():
35
+ global model
36
+
37
+ print("Downloading model from Hugging Face...")
38
 
39
+ model_path = hf_hub_download(
40
+ repo_id="abdurafay19/Digit-Classifier",
41
+ filename="model.pt"
42
+ )
43
+
44
+ print("Loading model...")
45
+
46
+ model = Model()
47
+ model.load_state_dict(torch.load(model_path, map_location=device))
48
+ model.to(device)
49
  model.eval()
50
 
51
+ print("Model loaded successfully!")
52
+
53
+ @app.post("/predict")
54
+ async def predict(file: UploadFile = File(...)):
55
+
56
  image_bytes = await file.read()
57
  image = Image.open(io.BytesIO(image_bytes)).convert("L")
58
  image = transform(image).unsqueeze(0)