Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" # fix matplotlib warnings | |
| import io | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from fastapi import FastAPI, File, UploadFile, Request | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import random | |
| # ---------------------------- | |
| # Ensure outputs directory | |
| # ---------------------------- | |
| OUTPUT_DIR = "/tmp/outputs" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # ---------------------------- | |
| # Setup | |
| # ---------------------------- | |
| app = FastAPI() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| breeds = [ | |
| "Alambadi", "Amritmahal", "Ayrshire", "Banni", "Bargur", "Bhadawari", "Brown_Swiss", | |
| "Dangi", "Deoni", "Gir", "Guernsey", "Hallikar", "Hariana", "Holstein_Friesian", | |
| "Jaffrabadi", "Jersey", "Kangayam", "Kankrej", "Kasargod", "Kenkatha", "Kherigarh", | |
| "Khillari", "Krishna_Valley", "Malnad_gidda", "Mehsana", "Murrah", "Nagori", "Nagpuri", | |
| "Nili_Ravi", "Nimari", "Ongole", "Pulikulam", "Rathi", "Red_Dane", "Red_Sindhi", | |
| "Sahiwal", "Surti", "Tharparkar", "Toda", "Umblachery", "Vechur" | |
| ] | |
| num_classes = len(breeds) | |
| # ---------------------------- | |
| # Load EfficientNetV2-S Model | |
| # ---------------------------- | |
| model = models.efficientnet_v2_s(weights=None) | |
| model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) | |
| state = torch.load("best_efficientnetv2_s_cow.pth", map_location=device) | |
| model.load_state_dict(state) | |
| model.to(device).eval() | |
| # ---------------------------- | |
| # Preprocessing (EffNetV2-S uses 384x384) | |
| # ---------------------------- | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((384, 384)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| # ---------------------------- | |
| # Routes | |
| # ---------------------------- | |
| def root(): | |
| return {"message": "GoVed AI EfficientNetV2-S Cattle Breed API is running!"} | |
| async def predict(request: Request, file: UploadFile = File(...)): | |
| contents = await file.read() | |
| img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| input_tensor = val_transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probs = torch.nn.functional.softmax(outputs, dim=1)[0] | |
| top_prob, top_idx = torch.max(probs, dim=0) | |
| raw_conf = float(top_prob.item()) * 100.0 | |
| predicted_breed = breeds[int(top_idx.item())] | |
| # Apply adjustment | |
| conf = raw_conf | |
| # Save annotated image | |
| annotated_path = os.path.join( | |
| OUTPUT_DIR, | |
| f"{predicted_breed}_{conf:.2f}pct_{file.filename}.png" | |
| ) | |
| fig, ax = plt.subplots() | |
| ax.imshow(img) | |
| ax.set_title(f"{predicted_breed} ({conf:.2f}%)") | |
| ax.axis("off") | |
| plt.savefig(annotated_path, bbox_inches="tight", dpi=150) | |
| plt.close(fig) | |
| # Save CSV | |
| csv_path = os.path.join( | |
| OUTPUT_DIR, | |
| f"{file.filename}_prediction.csv" | |
| ) | |
| df = pd.DataFrame([{ | |
| "breed": predicted_breed, | |
| "confidence_percent": f"{conf:.2f}%", | |
| "filename": file.filename | |
| }]) | |
| df.to_csv(csv_path, index=False) | |
| # Generate URLs | |
| base_url = str(request.base_url).rstrip("/") | |
| annotated_url = f"{base_url}/outputs/{os.path.basename(annotated_path)}" | |
| csv_url = f"{base_url}/outputs/{os.path.basename(csv_path)}" | |
| return JSONResponse(content={ | |
| "breed": predicted_breed, | |
| "confidence": round(conf, 2), | |
| "annotated_image_url": annotated_url, | |
| "csv_file_url": csv_url | |
| }) | |
| # ---------------------------- | |
| # Serve static outputs | |
| # ---------------------------- | |
| app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") | |