Spaces:
Sleeping
Sleeping
File size: 5,066 Bytes
45e0498 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | # index.py
import os
import json
import pickle
import numpy as np
from typing import List
from fastapi import FastAPI, Query, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from tensorflow.keras.models import load_model
# ==========================
# CONFIG
# ==========================
MODELS_BASE_DIR = "models"
# These must match folder names under models/
SPECIES_LIST = [
"mackerel",
"sardinella",
"scomber",
"skipjack",
"tuna",
]
# Cache: species_id -> (model, scaler, meta, last_seq_scaled)
ARTIFACT_CACHE = {}
def load_artifacts(species_id: str):
"""
Load model, scaler, metadata, and last sequence for a given species.
Uses in-memory cache so subsequent calls are fast.
"""
if species_id in ARTIFACT_CACHE:
return ARTIFACT_CACHE[species_id]
if species_id not in SPECIES_LIST:
raise ValueError(f"Unknown species '{species_id}'. Allowed: {SPECIES_LIST}")
base_dir = os.path.join(MODELS_BASE_DIR, species_id)
model_path = os.path.join(base_dir, f"{species_id}_model.h5")
scaler_path = os.path.join(base_dir, f"{species_id}_scaler.pkl")
meta_path = os.path.join(base_dir, f"{species_id}_metadata.json")
if not (os.path.exists(model_path) and os.path.exists(scaler_path) and os.path.exists(meta_path)):
raise FileNotFoundError(f"Artifacts not found for species '{species_id}' in {base_dir}")
# Load model
model = load_model(model_path, compile=False)
# Load scaler
with open(scaler_path, "rb") as f:
scaler = pickle.load(f)
# Load metadata
with open(meta_path, "r") as f:
meta = json.load(f)
seq_len = int(meta["sequence_length"])
last_seq_scaled = np.array(meta["last_sequence"]).reshape(1, seq_len, 2)
ARTIFACT_CACHE[species_id] = (model, scaler, meta, last_seq_scaled)
return ARTIFACT_CACHE[species_id]
# ==========================
# FASTAPI SETUP
# ==========================
app = FastAPI(title="Multi-Species Fish Migration LSTM API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # restrict in production
allow_methods=["*"],
allow_headers=["*"],
)
class PredictionPoint(BaseModel):
year: int
month: int
latitude: float
longitude: float
class PredictionResponse(BaseModel):
species: str
months_requested: int
sequence_length_used: int
points: List[PredictionPoint]
# ==========================
# CORE PREDICTION LOGIC
# ==========================
def predict_future_months(species_id: str, n_months: int):
"""
Predict n_months into the future for a given species.
Uses:
- last_year, last_month from metadata
- last_sequence (scaled) from metadata
- sequence_length from metadata
"""
model, scaler, meta, last_seq_scaled = load_artifacts(species_id)
seq_len = int(meta["sequence_length"])
year = int(meta["last_year"])
month = int(meta["last_month"])
seq = last_seq_scaled.copy()
results = []
for _ in range(n_months):
# 1. predict next step (scaled)
pred_scaled = model.predict(seq, verbose=0) # shape (1, 2)
# 2. convert back to real lat/lon
pred = scaler.inverse_transform(pred_scaled)[0] # shape (2,)
# 3. advance calendar by one month
month += 1
if month > 12:
month = 1
year += 1
results.append(
{
"year": int(year),
"month": int(month),
"latitude": float(pred[0]),
"longitude": float(pred[1]),
}
)
# 4. slide window: drop oldest, add new prediction
new_seq = np.vstack([seq[0][1:], pred_scaled[0]]) # (seq_len, 2)
seq = new_seq.reshape(1, seq_len, 2)
return results, seq_len
# ==========================
# ENDPOINTS
# ==========================
@app.get("/predict-migration", response_model=PredictionResponse)
def predict_migration(
species: str = Query("mackerel", description="Species ID (e.g., mackerel, sardinella)"),
months: int = Query(6, ge=1, le=24, description="Number of future months to predict"),
):
"""
Example:
GET /predict-migration?species=mackerel&months=12
"""
try:
points, seq_len_used = predict_future_months(species, months)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return PredictionResponse(
species=species,
months_requested=months,
sequence_length_used=seq_len_used,
points=[PredictionPoint(**p) for p in points],
)
@app.get("/")
def root():
return {
"message": "Multi-Species Fish Migration LSTM API is running",
"available_species": SPECIES_LIST,
"example": "/predict-migration?species=mackerel&months=12",
}
|