Spaces:
Sleeping
Sleeping
| # index.py | |
| import os | |
| import json | |
| import pickle | |
| import numpy as np | |
| from typing import List | |
| from fastapi import FastAPI, Query, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from tensorflow.keras.models import load_model | |
| # ========================== | |
| # CONFIG | |
| # ========================== | |
| MODELS_BASE_DIR = "models" | |
| # These must match folder names under models/ | |
| SPECIES_LIST = [ | |
| "mackerel", | |
| "sardinella", | |
| "scomber", | |
| "skipjack", | |
| "tuna", | |
| ] | |
| # Cache: species_id -> (model, scaler, meta, last_seq_scaled) | |
| ARTIFACT_CACHE = {} | |
| def load_artifacts(species_id: str): | |
| """ | |
| Load model, scaler, metadata, and last sequence for a given species. | |
| Uses in-memory cache so subsequent calls are fast. | |
| """ | |
| if species_id in ARTIFACT_CACHE: | |
| return ARTIFACT_CACHE[species_id] | |
| if species_id not in SPECIES_LIST: | |
| raise ValueError(f"Unknown species '{species_id}'. Allowed: {SPECIES_LIST}") | |
| base_dir = os.path.join(MODELS_BASE_DIR, species_id) | |
| model_path = os.path.join(base_dir, f"{species_id}_model.h5") | |
| scaler_path = os.path.join(base_dir, f"{species_id}_scaler.pkl") | |
| meta_path = os.path.join(base_dir, f"{species_id}_metadata.json") | |
| if not (os.path.exists(model_path) and os.path.exists(scaler_path) and os.path.exists(meta_path)): | |
| raise FileNotFoundError(f"Artifacts not found for species '{species_id}' in {base_dir}") | |
| # Load model | |
| model = load_model(model_path, compile=False) | |
| # Load scaler | |
| with open(scaler_path, "rb") as f: | |
| scaler = pickle.load(f) | |
| # Load metadata | |
| with open(meta_path, "r") as f: | |
| meta = json.load(f) | |
| seq_len = int(meta["sequence_length"]) | |
| last_seq_scaled = np.array(meta["last_sequence"]).reshape(1, seq_len, 2) | |
| ARTIFACT_CACHE[species_id] = (model, scaler, meta, last_seq_scaled) | |
| return ARTIFACT_CACHE[species_id] | |
| # ========================== | |
| # FASTAPI SETUP | |
| # ========================== | |
| app = FastAPI(title="Multi-Species Fish Migration LSTM API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # restrict in production | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class PredictionPoint(BaseModel): | |
| year: int | |
| month: int | |
| latitude: float | |
| longitude: float | |
| class PredictionResponse(BaseModel): | |
| species: str | |
| months_requested: int | |
| sequence_length_used: int | |
| points: List[PredictionPoint] | |
| # ========================== | |
| # CORE PREDICTION LOGIC | |
| # ========================== | |
| def predict_future_months(species_id: str, n_months: int): | |
| """ | |
| Predict n_months into the future for a given species. | |
| Uses: | |
| - last_year, last_month from metadata | |
| - last_sequence (scaled) from metadata | |
| - sequence_length from metadata | |
| """ | |
| model, scaler, meta, last_seq_scaled = load_artifacts(species_id) | |
| seq_len = int(meta["sequence_length"]) | |
| year = int(meta["last_year"]) | |
| month = int(meta["last_month"]) | |
| seq = last_seq_scaled.copy() | |
| results = [] | |
| for _ in range(n_months): | |
| # 1. predict next step (scaled) | |
| pred_scaled = model.predict(seq, verbose=0) # shape (1, 2) | |
| # 2. convert back to real lat/lon | |
| pred = scaler.inverse_transform(pred_scaled)[0] # shape (2,) | |
| # 3. advance calendar by one month | |
| month += 1 | |
| if month > 12: | |
| month = 1 | |
| year += 1 | |
| results.append( | |
| { | |
| "year": int(year), | |
| "month": int(month), | |
| "latitude": float(pred[0]), | |
| "longitude": float(pred[1]), | |
| } | |
| ) | |
| # 4. slide window: drop oldest, add new prediction | |
| new_seq = np.vstack([seq[0][1:], pred_scaled[0]]) # (seq_len, 2) | |
| seq = new_seq.reshape(1, seq_len, 2) | |
| return results, seq_len | |
| # ========================== | |
| # ENDPOINTS | |
| # ========================== | |
| def predict_migration( | |
| species: str = Query("mackerel", description="Species ID (e.g., mackerel, sardinella)"), | |
| months: int = Query(6, ge=1, le=24, description="Number of future months to predict"), | |
| ): | |
| """ | |
| Example: | |
| GET /predict-migration?species=mackerel&months=12 | |
| """ | |
| try: | |
| points, seq_len_used = predict_future_months(species, months) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| return PredictionResponse( | |
| species=species, | |
| months_requested=months, | |
| sequence_length_used=seq_len_used, | |
| points=[PredictionPoint(**p) for p in points], | |
| ) | |
| def root(): | |
| return { | |
| "message": "Multi-Species Fish Migration LSTM API is running", | |
| "available_species": SPECIES_LIST, | |
| "example": "/predict-migration?species=mackerel&months=12", | |
| } | |