File size: 5,107 Bytes
e7accb7
 
 
 
 
 
 
 
 
 
5df8466
 
 
e7accb7
 
 
 
 
 
 
 
 
 
5df8466
e7accb7
 
 
 
 
5df8466
 
 
e7accb7
5df8466
e7accb7
5df8466
e7accb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5df8466
e7accb7
5df8466
 
 
 
 
e7accb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5df8466
e7accb7
 
 
 
d6dc30c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import torch
import torch.nn as nn
from torchvision import models
from fastapi import FastAPI, File, UploadFile, HTTPException
from PIL import Image
import torchvision.transforms as transforms
import io
import torch.nn.functional as F

# Import the class names and species info data.
# These files must be in the same directory as this script.
from class_names import CLASS_NAMES
from species_info import SPECIES_INFO_DATA

# ==============================
# FASTAPI APP INITIALIZATION
# ==============================
app = FastAPI()

# ==============================
# DEVICE CONFIGURATION
# ==============================
# Prioritize GPU (cuda) if available, otherwise use CPU.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==============================
# MODEL PATH SETUP
# ==============================
# CORRECTED PATH: The model file is in the root directory,
# so we only need its filename. This ensures the app can find it
# whether it's run locally or on Hugging Face Spaces.
MODEL_FILE = "best_fine_tuned_model.pth"
MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), MODEL_FILE)

# The number of classes must match the length of your CLASS_NAMES list.
NUM_CLASSES = len(CLASS_NAMES)

# ==============================
# LOAD MODEL FUNCTION
# ==============================
def load_model():
    """Loads the fine-tuned ResNet model from the local file system."""
    if not os.path.exists(MODEL_PATH):
        print(f"❌ Model file not found. Expected path: {MODEL_PATH}")
        return None

    try:
        # Load the base ResNet50 model
        model = models.resnet50(weights=None)
        
        # Replace the classifier to match the number of classes
        model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
        
        # Load the trained weights
        # map_location='cpu' is important for compatibility on Hugging Face Spaces free tier
        state_dict = torch.load(MODEL_PATH, map_location=DEVICE)

        # Remove 'module.' prefix if it exists (common with DataParallel)
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        
        # Load the state dictionary into the model
        model.load_state_dict(state_dict)

        model.to(DEVICE)
        model.eval()
        print(f"✅ Model loaded successfully from {MODEL_PATH}")
        return model
    except Exception as e:
        print(f"❌ An error occurred during model loading: {e}")
        return None

# ==============================
# INITIALIZE MODEL ON STARTUP
# ==============================
model = load_model()

# ==============================
# IMAGE PREPROCESSING
# ==============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# ==============================
# HEALTH CHECK ENDPOINT
# ==============================
@app.get("/api/health")
def health_check():
    return {
        "status": "ok",
        "model_loaded": model is not None,
        "device": DEVICE
    }

# ==============================
# PREDICTION ENDPOINT
# ==============================
@app.post("/api/predict")
async def predict(file: UploadFile = File(...)):
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded on server. Please check server logs.")

    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
        tensor = transform(image).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            outputs = model(tensor)
            probabilities = F.softmax(outputs, dim=1)
            
            # Get the top 5 predictions and their indices
            top5_prob, top5_indices = torch.topk(probabilities, 5)

        # Convert to lists of dictionaries
        top_5_predictions = []
        for i in range(top5_prob.size(1)):
            label = CLASS_NAMES[top5_indices[0, i].item()]
            confidence = top5_prob[0, i].item()
            top_5_predictions.append({"label": label, "confidence": confidence})

        top_prediction = top_5_predictions[0]
        
        # Get the detailed animal information using the top prediction label
        animal_info = SPECIES_INFO_DATA.get(top_prediction['label'], None)
        
        # Handle cases where species info isn't found
        if not animal_info:
            print(f"⚠️ Species info not found for: {top_prediction['label']}")
            animal_info = {
                "species": "N/A", "kingdom": "N/A", "class": "N/A", "subclass": "N/A", 
                "habitat": "N/A", "diet": "N/A", "lifespan": "N/A", "fact": "No additional information available."
            }

        return {
            "top_prediction": top_prediction,
            "top_5": top_5_predictions,
            "heatmap_data": None,
            "animal_info": animal_info
        }

    except Exception as e:
        print(f"Prediction failed: {e}")
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")