chinmay0805's picture
Upload 5 files
45e0498 verified
# index.py
import os
import json
import pickle
import numpy as np
from typing import List
from fastapi import FastAPI, Query, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from tensorflow.keras.models import load_model
# ==========================
# CONFIG
# ==========================
MODELS_BASE_DIR = "models"
# These must match folder names under models/
SPECIES_LIST = [
"mackerel",
"sardinella",
"scomber",
"skipjack",
"tuna",
]
# Cache: species_id -> (model, scaler, meta, last_seq_scaled)
ARTIFACT_CACHE = {}
def load_artifacts(species_id: str):
"""
Load model, scaler, metadata, and last sequence for a given species.
Uses in-memory cache so subsequent calls are fast.
"""
if species_id in ARTIFACT_CACHE:
return ARTIFACT_CACHE[species_id]
if species_id not in SPECIES_LIST:
raise ValueError(f"Unknown species '{species_id}'. Allowed: {SPECIES_LIST}")
base_dir = os.path.join(MODELS_BASE_DIR, species_id)
model_path = os.path.join(base_dir, f"{species_id}_model.h5")
scaler_path = os.path.join(base_dir, f"{species_id}_scaler.pkl")
meta_path = os.path.join(base_dir, f"{species_id}_metadata.json")
if not (os.path.exists(model_path) and os.path.exists(scaler_path) and os.path.exists(meta_path)):
raise FileNotFoundError(f"Artifacts not found for species '{species_id}' in {base_dir}")
# Load model
model = load_model(model_path, compile=False)
# Load scaler
with open(scaler_path, "rb") as f:
scaler = pickle.load(f)
# Load metadata
with open(meta_path, "r") as f:
meta = json.load(f)
seq_len = int(meta["sequence_length"])
last_seq_scaled = np.array(meta["last_sequence"]).reshape(1, seq_len, 2)
ARTIFACT_CACHE[species_id] = (model, scaler, meta, last_seq_scaled)
return ARTIFACT_CACHE[species_id]
# ==========================
# FASTAPI SETUP
# ==========================
app = FastAPI(title="Multi-Species Fish Migration LSTM API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # restrict in production
allow_methods=["*"],
allow_headers=["*"],
)
class PredictionPoint(BaseModel):
year: int
month: int
latitude: float
longitude: float
class PredictionResponse(BaseModel):
species: str
months_requested: int
sequence_length_used: int
points: List[PredictionPoint]
# ==========================
# CORE PREDICTION LOGIC
# ==========================
def predict_future_months(species_id: str, n_months: int):
"""
Predict n_months into the future for a given species.
Uses:
- last_year, last_month from metadata
- last_sequence (scaled) from metadata
- sequence_length from metadata
"""
model, scaler, meta, last_seq_scaled = load_artifacts(species_id)
seq_len = int(meta["sequence_length"])
year = int(meta["last_year"])
month = int(meta["last_month"])
seq = last_seq_scaled.copy()
results = []
for _ in range(n_months):
# 1. predict next step (scaled)
pred_scaled = model.predict(seq, verbose=0) # shape (1, 2)
# 2. convert back to real lat/lon
pred = scaler.inverse_transform(pred_scaled)[0] # shape (2,)
# 3. advance calendar by one month
month += 1
if month > 12:
month = 1
year += 1
results.append(
{
"year": int(year),
"month": int(month),
"latitude": float(pred[0]),
"longitude": float(pred[1]),
}
)
# 4. slide window: drop oldest, add new prediction
new_seq = np.vstack([seq[0][1:], pred_scaled[0]]) # (seq_len, 2)
seq = new_seq.reshape(1, seq_len, 2)
return results, seq_len
# ==========================
# ENDPOINTS
# ==========================
@app.get("/predict-migration", response_model=PredictionResponse)
def predict_migration(
species: str = Query("mackerel", description="Species ID (e.g., mackerel, sardinella)"),
months: int = Query(6, ge=1, le=24, description="Number of future months to predict"),
):
"""
Example:
GET /predict-migration?species=mackerel&months=12
"""
try:
points, seq_len_used = predict_future_months(species, months)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return PredictionResponse(
species=species,
months_requested=months,
sequence_length_used=seq_len_used,
points=[PredictionPoint(**p) for p in points],
)
@app.get("/")
def root():
return {
"message": "Multi-Species Fish Migration LSTM API is running",
"available_species": SPECIES_LIST,
"example": "/predict-migration?species=mackerel&months=12",
}