aditya-sah's picture
Update app.py
0ffa251 verified
import os
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" # fix matplotlib warnings
import io
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
import random
# ----------------------------
# Ensure outputs directory
# ----------------------------
OUTPUT_DIR = "/tmp/outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ----------------------------
# Setup
# ----------------------------
app = FastAPI()
device = "cuda" if torch.cuda.is_available() else "cpu"
breeds = [
"Alambadi", "Amritmahal", "Ayrshire", "Banni", "Bargur", "Bhadawari", "Brown_Swiss",
"Dangi", "Deoni", "Gir", "Guernsey", "Hallikar", "Hariana", "Holstein_Friesian",
"Jaffrabadi", "Jersey", "Kangayam", "Kankrej", "Kasargod", "Kenkatha", "Kherigarh",
"Khillari", "Krishna_Valley", "Malnad_gidda", "Mehsana", "Murrah", "Nagori", "Nagpuri",
"Nili_Ravi", "Nimari", "Ongole", "Pulikulam", "Rathi", "Red_Dane", "Red_Sindhi",
"Sahiwal", "Surti", "Tharparkar", "Toda", "Umblachery", "Vechur"
]
num_classes = len(breeds)
# ----------------------------
# Load EfficientNetV2-S Model
# ----------------------------
model = models.efficientnet_v2_s(weights=None)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
state = torch.load("best_efficientnetv2_s_cow.pth", map_location=device)
model.load_state_dict(state)
model.to(device).eval()
# ----------------------------
# Preprocessing (EffNetV2-S uses 384x384)
# ----------------------------
val_transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# ----------------------------
# Routes
# ----------------------------
@app.get("/")
def root():
return {"message": "GoVed AI EfficientNetV2-S Cattle Breed API is running!"}
@app.post("/predict")
async def predict(request: Request, file: UploadFile = File(...)):
contents = await file.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
input_tensor = val_transform(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
top_prob, top_idx = torch.max(probs, dim=0)
raw_conf = float(top_prob.item()) * 100.0
predicted_breed = breeds[int(top_idx.item())]
# Apply adjustment
conf = raw_conf
# Save annotated image
annotated_path = os.path.join(
OUTPUT_DIR,
f"{predicted_breed}_{conf:.2f}pct_{file.filename}.png"
)
fig, ax = plt.subplots()
ax.imshow(img)
ax.set_title(f"{predicted_breed} ({conf:.2f}%)")
ax.axis("off")
plt.savefig(annotated_path, bbox_inches="tight", dpi=150)
plt.close(fig)
# Save CSV
csv_path = os.path.join(
OUTPUT_DIR,
f"{file.filename}_prediction.csv"
)
df = pd.DataFrame([{
"breed": predicted_breed,
"confidence_percent": f"{conf:.2f}%",
"filename": file.filename
}])
df.to_csv(csv_path, index=False)
# Generate URLs
base_url = str(request.base_url).rstrip("/")
annotated_url = f"{base_url}/outputs/{os.path.basename(annotated_path)}"
csv_url = f"{base_url}/outputs/{os.path.basename(csv_path)}"
return JSONResponse(content={
"breed": predicted_breed,
"confidence": round(conf, 2),
"annotated_image_url": annotated_url,
"csv_file_url": csv_url
})
# ----------------------------
# Serve static outputs
# ----------------------------
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")