okeowo1014 commited on
Commit
c6896d2
·
verified ·
1 Parent(s): bdca202

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -0
main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import torch
5
+ import io
6
+
7
+ app = FastAPI()
8
+
9
+ # Load the pre-trained VGG16 model
10
+ model = models.vgg16()
11
+ num_features_in = model.classifier[6].in_features
12
+ model.classifier[6] = torch.nn.Linear(num_features_in, 1)
13
+ model.load_state_dict(torch.load('cat_dog_classifier.pt'))
14
+ model.eval()
15
+
16
+ def preprocess_image(image):
17
+ img_transform = transforms.Compose([
18
+ transforms.Resize((224, 224)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
21
+ ])
22
+ img = img_transform(image).unsqueeze(0) # Add a batch dimension
23
+ return img
24
+
25
+ @app.post("/predict/")
26
+ async def predict_image(file: UploadFile = File(...)):
27
+ try:
28
+ contents = await file.read()
29
+ image = Image.open(io.BytesIO(contents))
30
+ image_tensor = preprocess_image(image)
31
+ with torch.no_grad():
32
+ output = model(image_tensor)
33
+ prediction = torch.sigmoid(output.squeeze()).item()
34
+ predicted_class = "Dog" if prediction > 0.5 else "Cat"
35
+ return {"class": predicted_class}
36
+ except Exception as e:
37
+ raise HTTPException(status_code=400, detail=str(e))