# index.py import os import json import pickle import numpy as np from typing import List from fastapi import FastAPI, Query, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from tensorflow.keras.models import load_model # ========================== # CONFIG # ========================== MODELS_BASE_DIR = "models" # These must match folder names under models/ SPECIES_LIST = [ "mackerel", "sardinella", "scomber", "skipjack", "tuna", ] # Cache: species_id -> (model, scaler, meta, last_seq_scaled) ARTIFACT_CACHE = {} def load_artifacts(species_id: str): """ Load model, scaler, metadata, and last sequence for a given species. Uses in-memory cache so subsequent calls are fast. """ if species_id in ARTIFACT_CACHE: return ARTIFACT_CACHE[species_id] if species_id not in SPECIES_LIST: raise ValueError(f"Unknown species '{species_id}'. Allowed: {SPECIES_LIST}") base_dir = os.path.join(MODELS_BASE_DIR, species_id) model_path = os.path.join(base_dir, f"{species_id}_model.h5") scaler_path = os.path.join(base_dir, f"{species_id}_scaler.pkl") meta_path = os.path.join(base_dir, f"{species_id}_metadata.json") if not (os.path.exists(model_path) and os.path.exists(scaler_path) and os.path.exists(meta_path)): raise FileNotFoundError(f"Artifacts not found for species '{species_id}' in {base_dir}") # Load model model = load_model(model_path, compile=False) # Load scaler with open(scaler_path, "rb") as f: scaler = pickle.load(f) # Load metadata with open(meta_path, "r") as f: meta = json.load(f) seq_len = int(meta["sequence_length"]) last_seq_scaled = np.array(meta["last_sequence"]).reshape(1, seq_len, 2) ARTIFACT_CACHE[species_id] = (model, scaler, meta, last_seq_scaled) return ARTIFACT_CACHE[species_id] # ========================== # FASTAPI SETUP # ========================== app = FastAPI(title="Multi-Species Fish Migration LSTM API") app.add_middleware( CORSMiddleware, allow_origins=["*"], # restrict in production allow_methods=["*"], allow_headers=["*"], ) class PredictionPoint(BaseModel): year: int month: int latitude: float longitude: float class PredictionResponse(BaseModel): species: str months_requested: int sequence_length_used: int points: List[PredictionPoint] # ========================== # CORE PREDICTION LOGIC # ========================== def predict_future_months(species_id: str, n_months: int): """ Predict n_months into the future for a given species. Uses: - last_year, last_month from metadata - last_sequence (scaled) from metadata - sequence_length from metadata """ model, scaler, meta, last_seq_scaled = load_artifacts(species_id) seq_len = int(meta["sequence_length"]) year = int(meta["last_year"]) month = int(meta["last_month"]) seq = last_seq_scaled.copy() results = [] for _ in range(n_months): # 1. predict next step (scaled) pred_scaled = model.predict(seq, verbose=0) # shape (1, 2) # 2. convert back to real lat/lon pred = scaler.inverse_transform(pred_scaled)[0] # shape (2,) # 3. advance calendar by one month month += 1 if month > 12: month = 1 year += 1 results.append( { "year": int(year), "month": int(month), "latitude": float(pred[0]), "longitude": float(pred[1]), } ) # 4. slide window: drop oldest, add new prediction new_seq = np.vstack([seq[0][1:], pred_scaled[0]]) # (seq_len, 2) seq = new_seq.reshape(1, seq_len, 2) return results, seq_len # ========================== # ENDPOINTS # ========================== @app.get("/predict-migration", response_model=PredictionResponse) def predict_migration( species: str = Query("mackerel", description="Species ID (e.g., mackerel, sardinella)"), months: int = Query(6, ge=1, le=24, description="Number of future months to predict"), ): """ Example: GET /predict-migration?species=mackerel&months=12 """ try: points, seq_len_used = predict_future_months(species, months) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) return PredictionResponse( species=species, months_requested=months, sequence_length_used=seq_len_used, points=[PredictionPoint(**p) for p in points], ) @app.get("/") def root(): return { "message": "Multi-Species Fish Migration LSTM API is running", "available_species": SPECIES_LIST, "example": "/predict-migration?species=mackerel&months=12", }