import os os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" # fix matplotlib warnings import io import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms from fastapi import FastAPI, File, UploadFile, Request from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from PIL import Image import matplotlib.pyplot as plt import pandas as pd import random # ---------------------------- # Ensure outputs directory # ---------------------------- OUTPUT_DIR = "/tmp/outputs" os.makedirs(OUTPUT_DIR, exist_ok=True) # ---------------------------- # Setup # ---------------------------- app = FastAPI() device = "cuda" if torch.cuda.is_available() else "cpu" breeds = [ "Alambadi", "Amritmahal", "Ayrshire", "Banni", "Bargur", "Bhadawari", "Brown_Swiss", "Dangi", "Deoni", "Gir", "Guernsey", "Hallikar", "Hariana", "Holstein_Friesian", "Jaffrabadi", "Jersey", "Kangayam", "Kankrej", "Kasargod", "Kenkatha", "Kherigarh", "Khillari", "Krishna_Valley", "Malnad_gidda", "Mehsana", "Murrah", "Nagori", "Nagpuri", "Nili_Ravi", "Nimari", "Ongole", "Pulikulam", "Rathi", "Red_Dane", "Red_Sindhi", "Sahiwal", "Surti", "Tharparkar", "Toda", "Umblachery", "Vechur" ] num_classes = len(breeds) # ---------------------------- # Load EfficientNetV2-S Model # ---------------------------- model = models.efficientnet_v2_s(weights=None) model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) state = torch.load("best_efficientnetv2_s_cow.pth", map_location=device) model.load_state_dict(state) model.to(device).eval() # ---------------------------- # Preprocessing (EffNetV2-S uses 384x384) # ---------------------------- val_transform = transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ---------------------------- # Routes # ---------------------------- @app.get("/") def root(): return {"message": "GoVed AI EfficientNetV2-S Cattle Breed API is running!"} @app.post("/predict") async def predict(request: Request, file: UploadFile = File(...)): contents = await file.read() img = Image.open(io.BytesIO(contents)).convert("RGB") input_tensor = val_transform(img).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) probs = torch.nn.functional.softmax(outputs, dim=1)[0] top_prob, top_idx = torch.max(probs, dim=0) raw_conf = float(top_prob.item()) * 100.0 predicted_breed = breeds[int(top_idx.item())] # Apply adjustment conf = raw_conf # Save annotated image annotated_path = os.path.join( OUTPUT_DIR, f"{predicted_breed}_{conf:.2f}pct_{file.filename}.png" ) fig, ax = plt.subplots() ax.imshow(img) ax.set_title(f"{predicted_breed} ({conf:.2f}%)") ax.axis("off") plt.savefig(annotated_path, bbox_inches="tight", dpi=150) plt.close(fig) # Save CSV csv_path = os.path.join( OUTPUT_DIR, f"{file.filename}_prediction.csv" ) df = pd.DataFrame([{ "breed": predicted_breed, "confidence_percent": f"{conf:.2f}%", "filename": file.filename }]) df.to_csv(csv_path, index=False) # Generate URLs base_url = str(request.base_url).rstrip("/") annotated_url = f"{base_url}/outputs/{os.path.basename(annotated_path)}" csv_url = f"{base_url}/outputs/{os.path.basename(csv_path)}" return JSONResponse(content={ "breed": predicted_breed, "confidence": round(conf, 2), "annotated_image_url": annotated_url, "csv_file_url": csv_url }) # ---------------------------- # Serve static outputs # ---------------------------- app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")