from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware import json import torch from model_utils import create_model, predict # --- App Initialization --- app = FastAPI(title="Fish Species Classification API") # --- CORS Configuration --- app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:5173", # For your local frontend development "https://aqua-ai-omega.vercel.app" # MODIFIED: Add your Vercel URL here ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Model Loading --- try: with open("class_names.json", "r") as f: CLASS_NAMES = json.load(f) NUM_CLASSES = len(CLASS_NAMES) except FileNotFoundError: raise RuntimeError("class_names.json not found. Please run the training script to generate it.") MODEL_PATH = "models/fish_resnet18_best.pth" model = create_model(NUM_CLASSES) try: # MODIFIED: Added weights_only=False to the torch.load call state_dict = torch.load(MODEL_PATH, map_location=torch.device('cpu'), weights_only=False) model.load_state_dict(state_dict) print("✅ Model loaded successfully.") except FileNotFoundError: raise RuntimeError(f"Model file not found at {MODEL_PATH}. Please ensure it's in the correct directory.") model.eval() # --- API Endpoints --- @app.get("/") def read_root(): return {"message": "Welcome to the Fish Species Classification API"} @app.post("/predict") async def predict_species(file: UploadFile = File(...)): if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File provided is not an image.") image_bytes = await file.read() try: prediction_result = predict(model, image_bytes, CLASS_NAMES) return prediction_result except Exception as e: raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")