catanddogapi / main.py
okeowo1014's picture
Update main.py
8445c89 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from torchvision import models, transforms
from PIL import Image
import torch
import io
app = FastAPI()
# Load the pre-trained VGG16 model for the prediction of model
model = models.vgg16()
num_features_in = model.classifier[6].in_features
model.classifier[6] = torch.nn.Linear(num_features_in, 1)
model.load_state_dict(torch.load('cat_dog_classifier.pt'))
model.eval()
def preprocess_image(image):
img_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = img_transform(image).unsqueeze(0) # Add a batch dimension
return img
@app.post("/predict/")
async def predict_image(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image_tensor = preprocess_image(image)
with torch.no_grad():
output = model(image_tensor)
prediction = torch.sigmoid(output.squeeze()).item()
predicted_class = "Dog" if prediction > 0.5 else "Cat"
return {"class": predicted_class}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))