import os import torch import torch.nn as nn from torchvision import models from fastapi import FastAPI, File, UploadFile, HTTPException from PIL import Image import torchvision.transforms as transforms import io import torch.nn.functional as F # Import the class names and species info data. # These files must be in the same directory as this script. from class_names import CLASS_NAMES from species_info import SPECIES_INFO_DATA # ============================== # FASTAPI APP INITIALIZATION # ============================== app = FastAPI() # ============================== # DEVICE CONFIGURATION # ============================== # Prioritize GPU (cuda) if available, otherwise use CPU. DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ============================== # MODEL PATH SETUP # ============================== # CORRECTED PATH: The model file is in the root directory, # so we only need its filename. This ensures the app can find it # whether it's run locally or on Hugging Face Spaces. MODEL_FILE = "best_fine_tuned_model.pth" MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), MODEL_FILE) # The number of classes must match the length of your CLASS_NAMES list. NUM_CLASSES = len(CLASS_NAMES) # ============================== # LOAD MODEL FUNCTION # ============================== def load_model(): """Loads the fine-tuned ResNet model from the local file system.""" if not os.path.exists(MODEL_PATH): print(f"❌ Model file not found. Expected path: {MODEL_PATH}") return None try: # Load the base ResNet50 model model = models.resnet50(weights=None) # Replace the classifier to match the number of classes model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES) # Load the trained weights # map_location='cpu' is important for compatibility on Hugging Face Spaces free tier state_dict = torch.load(MODEL_PATH, map_location=DEVICE) # Remove 'module.' prefix if it exists (common with DataParallel) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # Load the state dictionary into the model model.load_state_dict(state_dict) model.to(DEVICE) model.eval() print(f"✅ Model loaded successfully from {MODEL_PATH}") return model except Exception as e: print(f"❌ An error occurred during model loading: {e}") return None # ============================== # INITIALIZE MODEL ON STARTUP # ============================== model = load_model() # ============================== # IMAGE PREPROCESSING # ============================== transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # ============================== # HEALTH CHECK ENDPOINT # ============================== @app.get("/api/health") def health_check(): return { "status": "ok", "model_loaded": model is not None, "device": DEVICE } # ============================== # PREDICTION ENDPOINT # ============================== @app.post("/api/predict") async def predict(file: UploadFile = File(...)): if model is None: raise HTTPException(status_code=503, detail="Model not loaded on server. Please check server logs.") try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") tensor = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = model(tensor) probabilities = F.softmax(outputs, dim=1) # Get the top 5 predictions and their indices top5_prob, top5_indices = torch.topk(probabilities, 5) # Convert to lists of dictionaries top_5_predictions = [] for i in range(top5_prob.size(1)): label = CLASS_NAMES[top5_indices[0, i].item()] confidence = top5_prob[0, i].item() top_5_predictions.append({"label": label, "confidence": confidence}) top_prediction = top_5_predictions[0] # Get the detailed animal information using the top prediction label animal_info = SPECIES_INFO_DATA.get(top_prediction['label'], None) # Handle cases where species info isn't found if not animal_info: print(f"⚠️ Species info not found for: {top_prediction['label']}") animal_info = { "species": "N/A", "kingdom": "N/A", "class": "N/A", "subclass": "N/A", "habitat": "N/A", "diet": "N/A", "lifespan": "N/A", "fact": "No additional information available." } return { "top_prediction": top_prediction, "top_5": top_5_predictions, "heatmap_data": None, "animal_info": animal_info } except Exception as e: print(f"Prediction failed: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")