Spaces:
Sleeping
Sleeping
File size: 5,107 Bytes
e7accb7 5df8466 e7accb7 5df8466 e7accb7 5df8466 e7accb7 5df8466 e7accb7 5df8466 e7accb7 5df8466 e7accb7 5df8466 e7accb7 5df8466 e7accb7 d6dc30c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import torch
import torch.nn as nn
from torchvision import models
from fastapi import FastAPI, File, UploadFile, HTTPException
from PIL import Image
import torchvision.transforms as transforms
import io
import torch.nn.functional as F
# Import the class names and species info data.
# These files must be in the same directory as this script.
from class_names import CLASS_NAMES
from species_info import SPECIES_INFO_DATA
# ==============================
# FASTAPI APP INITIALIZATION
# ==============================
app = FastAPI()
# ==============================
# DEVICE CONFIGURATION
# ==============================
# Prioritize GPU (cuda) if available, otherwise use CPU.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==============================
# MODEL PATH SETUP
# ==============================
# CORRECTED PATH: The model file is in the root directory,
# so we only need its filename. This ensures the app can find it
# whether it's run locally or on Hugging Face Spaces.
MODEL_FILE = "best_fine_tuned_model.pth"
MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), MODEL_FILE)
# The number of classes must match the length of your CLASS_NAMES list.
NUM_CLASSES = len(CLASS_NAMES)
# ==============================
# LOAD MODEL FUNCTION
# ==============================
def load_model():
"""Loads the fine-tuned ResNet model from the local file system."""
if not os.path.exists(MODEL_PATH):
print(f"❌ Model file not found. Expected path: {MODEL_PATH}")
return None
try:
# Load the base ResNet50 model
model = models.resnet50(weights=None)
# Replace the classifier to match the number of classes
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
# Load the trained weights
# map_location='cpu' is important for compatibility on Hugging Face Spaces free tier
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
# Remove 'module.' prefix if it exists (common with DataParallel)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# Load the state dictionary into the model
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print(f"✅ Model loaded successfully from {MODEL_PATH}")
return model
except Exception as e:
print(f"❌ An error occurred during model loading: {e}")
return None
# ==============================
# INITIALIZE MODEL ON STARTUP
# ==============================
model = load_model()
# ==============================
# IMAGE PREPROCESSING
# ==============================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# ==============================
# HEALTH CHECK ENDPOINT
# ==============================
@app.get("/api/health")
def health_check():
return {
"status": "ok",
"model_loaded": model is not None,
"device": DEVICE
}
# ==============================
# PREDICTION ENDPOINT
# ==============================
@app.post("/api/predict")
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded on server. Please check server logs.")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(tensor)
probabilities = F.softmax(outputs, dim=1)
# Get the top 5 predictions and their indices
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Convert to lists of dictionaries
top_5_predictions = []
for i in range(top5_prob.size(1)):
label = CLASS_NAMES[top5_indices[0, i].item()]
confidence = top5_prob[0, i].item()
top_5_predictions.append({"label": label, "confidence": confidence})
top_prediction = top_5_predictions[0]
# Get the detailed animal information using the top prediction label
animal_info = SPECIES_INFO_DATA.get(top_prediction['label'], None)
# Handle cases where species info isn't found
if not animal_info:
print(f"⚠️ Species info not found for: {top_prediction['label']}")
animal_info = {
"species": "N/A", "kingdom": "N/A", "class": "N/A", "subclass": "N/A",
"habitat": "N/A", "diet": "N/A", "lifespan": "N/A", "fact": "No additional information available."
}
return {
"top_prediction": top_prediction,
"top_5": top_5_predictions,
"heatmap_data": None,
"animal_info": animal_info
}
except Exception as e:
print(f"Prediction failed: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |