chinmay0805 commited on
Commit
45e0498
ยท
verified ยท
1 Parent(s): 2b4d993

Upload 5 files

Browse files
Files changed (5) hide show
  1. DockerFile +25 -0
  2. main.py +182 -0
  3. metadata.json +20 -0
  4. requirements.txt +7 -0
  5. train.py +103 -0
DockerFile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Avoid buffering logs
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+ # Workdir inside container
7
+ WORKDIR /app
8
+
9
+ # Install system deps (optional but safe)
10
+ RUN apt-get update && apt-get install -y \
11
+ build-essential \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy requirements and install
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy all project files
19
+ COPY . .
20
+
21
+ # Expose port used by Hugging Face (must be 7860)
22
+ EXPOSE 7860
23
+
24
+ # Run FastAPI with uvicorn
25
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # index.py
2
+
3
+ import os
4
+ import json
5
+ import pickle
6
+ import numpy as np
7
+ from typing import List
8
+
9
+ from fastapi import FastAPI, Query, HTTPException
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel
12
+ from tensorflow.keras.models import load_model
13
+
14
+
15
+ # ==========================
16
+ # CONFIG
17
+ # ==========================
18
+
19
+ MODELS_BASE_DIR = "models"
20
+
21
+ # These must match folder names under models/
22
+ SPECIES_LIST = [
23
+ "mackerel",
24
+ "sardinella",
25
+ "scomber",
26
+ "skipjack",
27
+ "tuna",
28
+ ]
29
+
30
+ # Cache: species_id -> (model, scaler, meta, last_seq_scaled)
31
+ ARTIFACT_CACHE = {}
32
+
33
+
34
+ def load_artifacts(species_id: str):
35
+ """
36
+ Load model, scaler, metadata, and last sequence for a given species.
37
+ Uses in-memory cache so subsequent calls are fast.
38
+ """
39
+ if species_id in ARTIFACT_CACHE:
40
+ return ARTIFACT_CACHE[species_id]
41
+
42
+ if species_id not in SPECIES_LIST:
43
+ raise ValueError(f"Unknown species '{species_id}'. Allowed: {SPECIES_LIST}")
44
+
45
+ base_dir = os.path.join(MODELS_BASE_DIR, species_id)
46
+
47
+ model_path = os.path.join(base_dir, f"{species_id}_model.h5")
48
+ scaler_path = os.path.join(base_dir, f"{species_id}_scaler.pkl")
49
+ meta_path = os.path.join(base_dir, f"{species_id}_metadata.json")
50
+
51
+ if not (os.path.exists(model_path) and os.path.exists(scaler_path) and os.path.exists(meta_path)):
52
+ raise FileNotFoundError(f"Artifacts not found for species '{species_id}' in {base_dir}")
53
+
54
+ # Load model
55
+ model = load_model(model_path, compile=False)
56
+
57
+ # Load scaler
58
+ with open(scaler_path, "rb") as f:
59
+ scaler = pickle.load(f)
60
+
61
+ # Load metadata
62
+ with open(meta_path, "r") as f:
63
+ meta = json.load(f)
64
+
65
+ seq_len = int(meta["sequence_length"])
66
+ last_seq_scaled = np.array(meta["last_sequence"]).reshape(1, seq_len, 2)
67
+
68
+ ARTIFACT_CACHE[species_id] = (model, scaler, meta, last_seq_scaled)
69
+ return ARTIFACT_CACHE[species_id]
70
+
71
+
72
+ # ==========================
73
+ # FASTAPI SETUP
74
+ # ==========================
75
+
76
+ app = FastAPI(title="Multi-Species Fish Migration LSTM API")
77
+
78
+ app.add_middleware(
79
+ CORSMiddleware,
80
+ allow_origins=["*"], # restrict in production
81
+ allow_methods=["*"],
82
+ allow_headers=["*"],
83
+ )
84
+
85
+
86
+ class PredictionPoint(BaseModel):
87
+ year: int
88
+ month: int
89
+ latitude: float
90
+ longitude: float
91
+
92
+
93
+ class PredictionResponse(BaseModel):
94
+ species: str
95
+ months_requested: int
96
+ sequence_length_used: int
97
+ points: List[PredictionPoint]
98
+
99
+
100
+ # ==========================
101
+ # CORE PREDICTION LOGIC
102
+ # ==========================
103
+
104
+ def predict_future_months(species_id: str, n_months: int):
105
+ """
106
+ Predict n_months into the future for a given species.
107
+ Uses:
108
+ - last_year, last_month from metadata
109
+ - last_sequence (scaled) from metadata
110
+ - sequence_length from metadata
111
+ """
112
+ model, scaler, meta, last_seq_scaled = load_artifacts(species_id)
113
+
114
+ seq_len = int(meta["sequence_length"])
115
+ year = int(meta["last_year"])
116
+ month = int(meta["last_month"])
117
+ seq = last_seq_scaled.copy()
118
+
119
+ results = []
120
+
121
+ for _ in range(n_months):
122
+ # 1. predict next step (scaled)
123
+ pred_scaled = model.predict(seq, verbose=0) # shape (1, 2)
124
+
125
+ # 2. convert back to real lat/lon
126
+ pred = scaler.inverse_transform(pred_scaled)[0] # shape (2,)
127
+
128
+ # 3. advance calendar by one month
129
+ month += 1
130
+ if month > 12:
131
+ month = 1
132
+ year += 1
133
+
134
+ results.append(
135
+ {
136
+ "year": int(year),
137
+ "month": int(month),
138
+ "latitude": float(pred[0]),
139
+ "longitude": float(pred[1]),
140
+ }
141
+ )
142
+
143
+ # 4. slide window: drop oldest, add new prediction
144
+ new_seq = np.vstack([seq[0][1:], pred_scaled[0]]) # (seq_len, 2)
145
+ seq = new_seq.reshape(1, seq_len, 2)
146
+
147
+ return results, seq_len
148
+
149
+
150
+ # ==========================
151
+ # ENDPOINTS
152
+ # ==========================
153
+
154
+ @app.get("/predict-migration", response_model=PredictionResponse)
155
+ def predict_migration(
156
+ species: str = Query("mackerel", description="Species ID (e.g., mackerel, sardinella)"),
157
+ months: int = Query(6, ge=1, le=24, description="Number of future months to predict"),
158
+ ):
159
+ """
160
+ Example:
161
+ GET /predict-migration?species=mackerel&months=12
162
+ """
163
+ try:
164
+ points, seq_len_used = predict_future_months(species, months)
165
+ except Exception as e:
166
+ raise HTTPException(status_code=400, detail=str(e))
167
+
168
+ return PredictionResponse(
169
+ species=species,
170
+ months_requested=months,
171
+ sequence_length_used=seq_len_used,
172
+ points=[PredictionPoint(**p) for p in points],
173
+ )
174
+
175
+
176
+ @app.get("/")
177
+ def root():
178
+ return {
179
+ "message": "Multi-Species Fish Migration LSTM API is running",
180
+ "available_species": SPECIES_LIST,
181
+ "example": "/predict-migration?species=mackerel&months=12",
182
+ }
metadata.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "species": "sardinella",
3
+ "last_year": 2012,
4
+ "last_month": 9,
5
+ "sequence_length": 3,
6
+ "last_sequence": [
7
+ [
8
+ 0.5544831090300122,
9
+ 0.3959002296999068
10
+ ],
11
+ [
12
+ 0.5473025885600646,
13
+ 0.39520740517616026
14
+ ],
15
+ [
16
+ 0.4158765115337282,
17
+ 0.3960623952468836
18
+ ]
19
+ ]
20
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+
4
+ tensorflow-cpu
5
+ numpy
6
+ pandas
7
+ scikit-learn
train.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # multi_species_pipeline.py
2
+
3
+ import os
4
+ import json
5
+ import pickle
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from sklearn.preprocessing import MinMaxScaler
10
+ from tensorflow.keras.models import Sequential
11
+ from tensorflow.keras.layers import LSTM, Dense
12
+
13
+ # -------- CONFIG --------
14
+
15
+ SPECIES_FILES = {
16
+ "mackerel": "migration_timeseries_mackerel.csv",
17
+ "sardinella": "migration_timeseries_sardinella.csv",
18
+ "scomber": "migration_timeseries_scomber.csv",
19
+ "skipjack": "migration_timeseries_skipjack.csv",
20
+ "tuna": "migration_timeseries_tuna.csv",
21
+ }
22
+
23
+ # ๐Ÿšจ This is ONLY a training hyperparameter (not exposed to frontend)
24
+ SEQUENCE_LENGTH = 3
25
+
26
+
27
+ def train_for_species(species_id: str, ts_csv: str):
28
+ if not os.path.exists(ts_csv):
29
+ print(f"[WARN] Timeseries CSV not found for {species_id}: {ts_csv}")
30
+ return
31
+
32
+ print(f"\n=== Training LSTM for {species_id} from {ts_csv} ===")
33
+
34
+ df = pd.read_csv(ts_csv)
35
+ df = df.sort_values(["year", "month"]).reset_index(drop=True)
36
+
37
+ required = {"year", "month", "decimalLatitude", "decimalLongitude"}
38
+ missing = required - set(df.columns)
39
+ if missing:
40
+ print(f"[ERROR] Missing columns {missing} in {ts_csv}")
41
+ return
42
+
43
+ coords = df[["decimalLatitude", "decimalLongitude"]].values
44
+
45
+ scaler = MinMaxScaler()
46
+ coords_scaled = scaler.fit_transform(coords)
47
+
48
+ X, y = [], []
49
+ for i in range(SEQUENCE_LENGTH, len(coords_scaled)):
50
+ X.append(coords_scaled[i - SEQUENCE_LENGTH:i])
51
+ y.append(coords_scaled[i])
52
+
53
+ X = np.array(X)
54
+ y = np.array(y)
55
+
56
+ if len(X) == 0:
57
+ print(f"[ERROR] Not enough data to train for {species_id}")
58
+ return
59
+
60
+ model = Sequential()
61
+ model.add(LSTM(64, activation="tanh", input_shape=(SEQUENCE_LENGTH, 2)))
62
+ model.add(Dense(32, activation="relu"))
63
+ model.add(Dense(2))
64
+ model.compile(optimizer="adam", loss="mse")
65
+
66
+ model.fit(X, y, epochs=50, batch_size=8, verbose=1)
67
+
68
+ out_dir = os.path.join("models", species_id)
69
+ os.makedirs(out_dir, exist_ok=True)
70
+
71
+ # ๐Ÿ”น Species-specific filenames
72
+ model_path = os.path.join(out_dir, f"{species_id}_model.h5")
73
+ scaler_path = os.path.join(out_dir, f"{species_id}_scaler.pkl")
74
+ meta_path = os.path.join(out_dir, f"{species_id}_metadata.json")
75
+
76
+ model.save(model_path)
77
+
78
+ with open(scaler_path, "wb") as f:
79
+ pickle.dump(scaler, f)
80
+
81
+ # ๐Ÿ‘‰ Store everything backend needs (no frontend involvement)
82
+ metadata = {
83
+ "species": species_id,
84
+ "sequence_length": SEQUENCE_LENGTH, # internal
85
+ "last_year": int(df["year"].iloc[-1]),
86
+ "last_month": int(df["month"].iloc[-1]),
87
+ "last_sequence": coords_scaled[-SEQUENCE_LENGTH:].tolist() # internal
88
+ }
89
+
90
+ with open(meta_path, "w") as f:
91
+ json.dump(metadata, f, indent=2)
92
+
93
+ print(f"[OK] Saved {model_path}, {scaler_path}, {meta_path}")
94
+
95
+
96
+ def main():
97
+ os.makedirs("models", exist_ok=True)
98
+ for species_id, ts_csv in SPECIES_FILES.items():
99
+ train_for_species(species_id, ts_csv)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()