File size: 1,298 Bytes
c6896d2
 
 
 
 
 
 
 
8445c89
c6896d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from fastapi import FastAPI, File, UploadFile, HTTPException
from torchvision import models, transforms
from PIL import Image
import torch
import io

app = FastAPI()

# Load the pre-trained VGG16 model for the prediction of model
model = models.vgg16()
num_features_in = model.classifier[6].in_features
model.classifier[6] = torch.nn.Linear(num_features_in, 1)
model.load_state_dict(torch.load('cat_dog_classifier.pt'))
model.eval()

def preprocess_image(image):
    img_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = img_transform(image).unsqueeze(0)  # Add a batch dimension
    return img

@app.post("/predict/")
async def predict_image(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))
        image_tensor = preprocess_image(image)
        with torch.no_grad():
            output = model(image_tensor)
            prediction = torch.sigmoid(output.squeeze()).item()
        predicted_class = "Dog" if prediction > 0.5 else "Cat"
        return {"class": predicted_class}
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))