NaseefNazrul commited on
Commit
5ee657c
Β·
verified Β·
1 Parent(s): 9fa0552

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +523 -131
app.py CHANGED
@@ -1,159 +1,551 @@
 
1
  import os
 
 
2
  import joblib
 
 
3
  import numpy as np
 
4
  from fastapi import FastAPI, HTTPException
5
- from pydantic import BaseModel
6
- import logging
7
- import sys
 
 
 
 
8
 
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
 
 
 
 
 
12
 
13
- app = FastAPI(title="Bloom Prediction ML API")
 
 
 
 
 
 
 
14
 
15
- # ML Model artifacts (upload these to your Hugging Face Space)
16
- MODEL_PATH = "mil_bloom_model.joblib"
17
- SCALER_PATH = "mil_scaler.joblib"
18
- FEATURES_PATH = "mil_features.joblib"
 
19
 
20
- # Global variables for ML model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  ML_MODEL = None
22
  SCALER = None
23
  FEATURE_COLUMNS = None
 
 
24
 
25
- class PredictionRequest(BaseModel):
26
- features: dict
27
- parameters: dict = {}
 
 
 
 
 
28
 
29
- class PredictionResponse(BaseModel):
30
- success: bool
31
- bloom_probability: float
32
- prediction: str
33
- confidence: str
34
- message: str = ""
35
-
36
- def load_ml_model():
37
- """Load the ML model and artifacts"""
38
- global ML_MODEL, SCALER, FEATURE_COLUMNS
39
-
 
 
 
 
 
40
  try:
41
- ML_MODEL = joblib.load(MODEL_PATH)
42
- SCALER = joblib.load(SCALER_PATH)
43
- FEATURE_COLUMNS = joblib.load(FEATURES_PATH)
44
- logger.info("βœ… ML model loaded successfully in Hugging Face Space")
45
- logger.info(f"βœ… Features: {FEATURE_COLUMNS}")
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
- logger.error(f"❌ Failed to load ML model: {e}")
48
- raise
49
-
50
- def predict_bloom(features_dict: dict):
51
- """
52
- ML prediction logic - same as your original but now runs on Hugging Face
53
- """
54
- if ML_MODEL is None:
55
- raise ValueError("ML model not loaded")
56
-
57
- # Extract features in correct order
58
  try:
59
- features_array = np.array([[
60
- float(features_dict['ndvi']),
61
- float(features_dict['ndwi']),
62
- float(features_dict['evi']),
63
- float(features_dict['lst']),
64
- float(features_dict['cloud_cover']),
65
- float(features_dict['month']),
66
- float(features_dict['day_of_year'])
67
- ]])
68
-
69
- # Scale features
70
- features_scaled = SCALER.transform(features_array)
71
-
72
- # Get prediction
73
- probabilities = ML_MODEL.predict_proba(features_scaled)
74
-
75
- if probabilities.shape[1] == 2:
76
- bloom_probability = probabilities[0, 1]
77
- else:
78
- bloom_probability = probabilities[0, 0]
79
-
80
- prediction = ML_MODEL.predict(features_scaled)[0]
81
-
82
- # Apply your business logic (winter adjustments, etc.)
83
- ndvi = features_dict['ndvi']
84
- evi = features_dict['evi']
85
- month = features_dict['month']
86
-
87
- # Winter adjustment
88
- if month in [11, 12, 1, 2] and evi < 0.8 and ndvi < 0.3:
89
- bloom_probability = bloom_probability * 0.5
90
- logger.info("❄️ Applied winter adjustment")
91
-
92
- # Confidence calculation
93
- if bloom_probability > 0.75 or bloom_probability < 0.25:
94
- confidence = 'HIGH'
95
- elif bloom_probability > 0.6 or bloom_probability < 0.4:
96
- confidence = 'MEDIUM'
97
- else:
98
- confidence = 'LOW'
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  return {
101
- 'bloom_probability': round(float(bloom_probability * 100), 2),
102
- 'prediction': 'BLOOM' if prediction == 1 else 'NO_BLOOM',
103
- 'confidence': confidence,
 
 
 
 
 
 
 
104
  }
105
-
106
  except Exception as e:
107
- logger.error(f"❌ Prediction error: {e}")
108
- raise
109
 
110
- @app.on_event("startup")
111
- async def startup_event():
112
- """Load ML model when the app starts"""
113
- load_ml_model()
 
 
 
 
 
 
114
 
115
- @app.get("/")
116
- async def root():
117
- return {
118
- "message": "Bloom Prediction ML API",
119
- "status": "active",
120
- "model_loaded": ML_MODEL is not None
121
- }
 
122
 
123
- @app.get("/health")
124
- async def health():
125
- return {
126
- "status": "healthy",
127
- "model_loaded": ML_MODEL is not None
128
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- @app.post("/predict")
131
- async def predict(request: PredictionRequest):
132
- """
133
- Main prediction endpoint called by the backend
134
- """
135
  try:
136
- logger.info(f"πŸ“Š Received prediction request with features: {request.features}")
137
-
138
- # Perform ML prediction
139
- prediction_result = predict_bloom(request.features)
140
-
141
- response = PredictionResponse(
142
- success=True,
143
- bloom_probability=prediction_result['bloom_probability'],
144
- prediction=prediction_result['prediction'],
145
- confidence=prediction_result['confidence'],
146
- message="Prediction completed successfully"
147
- )
148
-
149
- logger.info(f"βœ… Prediction completed: {prediction_result['bloom_probability']}%")
150
- return response
151
-
152
  except Exception as e:
153
- logger.error(f"❌ Prediction failed: {e}")
154
- raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- # For Hugging Face Spaces deployment
157
  if __name__ == "__main__":
158
  import uvicorn
159
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ # app.py
2
  import os
3
+ import time
4
+ import math
5
  import joblib
6
+ import ee
7
+ import pandas as pd
8
  import numpy as np
9
+ from datetime import datetime, date, timedelta
10
  from fastapi import FastAPI, HTTPException
11
+ from pydantic import BaseModel, Field
12
+ from contextlib import asynccontextmanager
13
+ from pathlib import Path
14
+ from typing import Optional, List, Dict, Tuple
15
+ # google oauth helpers
16
+ from google.oauth2.credentials import Credentials
17
+ from google.auth.transport.requests import Request as GoogleRequest
18
 
19
+ # ------------------------------
20
+ # CONFIG / FILENAMES
21
+ # ------------------------------
22
+ MODEL_FILE = Path("mil_bloom_model.joblib")
23
+ SCALER_FILE = Path("mil_scaler.joblib")
24
+ FEATURES_FILE = Path("mil_features.joblib")
25
+ PHENO_FILE = Path("phenologythingy.csv")
26
+ SPECIES_STATS_FILE = Path("species_stats.csv")
27
 
28
+ ELEV_IMAGE_ID = "USGS/SRTMGL1_003"
29
+ BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200))
30
+ MAX_DAYS_BACK = int(os.environ.get("MAX_DAYS_BACK", 30))
31
+ MIN_COUNT_FOR_SPECIES = int(os.environ.get("MIN_COUNT_FOR_SPECIES", 20))
32
+ TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5))
33
+ DOY_BINS = 366
34
+ DOY_SMOOTH = 15
35
+ EPS_STD = 1.0
36
 
37
+ # EE OAuth env vars expected to be set in HF Space secrets
38
+ CLIENT_ID = os.environ.get("CLIENT_ID")
39
+ CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
40
+ REFRESH_TOKEN = os.environ.get("REFRESH_TOKEN")
41
+ EE_PROJECT = os.environ.get("PROJECT") or os.environ.get("EE_PROJECT") or None
42
 
43
+ EE_SCOPES = [
44
+ "https://www.googleapis.com/auth/earthengine",
45
+ "https://www.googleapis.com/auth/cloud-platform",
46
+ "https://www.googleapis.com/auth/drive",
47
+ "https://www.googleapis.com/auth/devstorage.full_control",
48
+ ]
49
+
50
+ # ------------------------------
51
+ # Pydantic models
52
+ # ------------------------------
53
+ class BloomPredictionRequest(BaseModel):
54
+ lat: float = Field(..., ge=-90, le=90)
55
+ lon: float = Field(..., ge=-180, le=180)
56
+ date: str = Field(..., description="YYYY-MM-DD")
57
+
58
+ class MonthlyResult(BaseModel):
59
+ month: int
60
+ sample_date: str
61
+ ml_bloom_probability: Optional[float] = None
62
+ ml_prediction: Optional[str] = None
63
+ ml_confidence: Optional[str] = None
64
+ species_top: Optional[List[Tuple[str, float]]] = None
65
+ species_probs: Optional[Dict[str, float]] = None
66
+ elevation_m: Optional[float] = None
67
+ data_quality: Optional[dict] = None
68
+ satellite: Optional[str] = None
69
+ note: Optional[str] = None
70
+
71
+ class BloomPredictionResponse(BaseModel):
72
+ success: bool
73
+ analysis_date: str
74
+ requested_date: str
75
+ monthly_results: List[MonthlyResult]
76
+ processing_time: float
77
+
78
+ # ------------------------------
79
+ # Globals
80
+ # ------------------------------
81
  ML_MODEL = None
82
  SCALER = None
83
  FEATURE_COLUMNS = None
84
+ SPECIES_STATS_DF = None
85
+ DOY_HIST_MAP: Dict[str, np.ndarray] = {}
86
 
87
+ # ------------------------------
88
+ # Helpers
89
+ # ------------------------------
90
+ def gaussian_pdf(x, mean, std):
91
+ std = max(std, 1e-6)
92
+ coef = 1.0 / (std * math.sqrt(2 * math.pi))
93
+ z = (x - mean) / std
94
+ return coef * math.exp(-0.5 * z * z)
95
 
96
+ def circular_histogram(doys, bins=DOY_BINS, smooth_window=DOY_SMOOTH):
97
+ if len(doys) == 0:
98
+ return np.ones(bins) / bins
99
+ counts = np.bincount(doys.astype(int), minlength=bins+1)[1:]
100
+ window = np.ones(smooth_window) / smooth_window
101
+ doubled = np.concatenate([counts, counts])
102
+ smoothed = np.convolve(doubled, window, mode='same')[:bins]
103
+ total = smoothed.sum()
104
+ if total <= 0:
105
+ return np.ones(bins) / bins
106
+ return smoothed / total
107
+
108
+ # ------------------------------
109
+ # Earth Engine init (OAuth refresh-token or fallback)
110
+ # ------------------------------
111
+ def initialize_ee_from_env():
112
  try:
113
+ if CLIENT_ID and CLIENT_SECRET and REFRESH_TOKEN:
114
+ creds = Credentials(
115
+ token=None,
116
+ refresh_token=REFRESH_TOKEN,
117
+ client_id=CLIENT_ID,
118
+ client_secret=CLIENT_SECRET,
119
+ token_uri="https://oauth2.googleapis.com/token",
120
+ scopes=EE_SCOPES
121
+ )
122
+ request = GoogleRequest()
123
+ creds.refresh(request)
124
+ ee.Initialize(credentials=creds, project=EE_PROJECT)
125
+ print("βœ… Earth Engine initialized with OAuth credentials")
126
+ return True
127
+ else:
128
+ ee.Initialize(project=EE_PROJECT) if EE_PROJECT else ee.Initialize()
129
+ print("βœ… Earth Engine initialized (default)")
130
+ return True
131
  except Exception as e:
132
+ print("❌ Earth Engine initialization failed:", e)
133
+ return False
134
+
135
+ def get_elevation_from_ee(lat, lon):
 
 
 
 
 
 
 
136
  try:
137
+ img = ee.Image(ELEV_IMAGE_ID)
138
+ pt = ee.Geometry.Point([float(lon), float(lat)])
139
+ rr = img.reduceRegion(ee.Reducer.first(), pt, scale=30, maxPixels=1e6)
140
+ if rr is None:
141
+ return None
142
+ try:
143
+ val = rr.get("elevation").getInfo()
144
+ return float(val) if val is not None else None
145
+ except Exception:
146
+ keys = rr.keys().getInfo()
147
+ for k in keys:
148
+ v = rr.get(k).getInfo()
149
+ if isinstance(v, (int, float)):
150
+ return float(v)
151
+ return None
152
+ except Exception as e:
153
+ print("❌ get_elevation_from_ee error:", e)
154
+ return None
155
+
156
+ # ------------------------------
157
+ # Satellite retrieval (Landsat L2)
158
+ # ------------------------------
159
+ def get_single_date_satellite_data(lat, lon, date_str, satellite, buffer_meters, area):
160
+ collection_id = "LANDSAT/LC09/C02/T1_L2" if satellite == "Landsat-9" else "LANDSAT/LC08/C02/T1_L2"
161
+ try:
162
+ filtered = (ee.ImageCollection(collection_id)
163
+ .filterBounds(area)
164
+ .filterDate(date_str, f"{date_str}T23:59:59")
165
+ .sort("CLOUD_COVER")
166
+ .limit(1))
167
+ size = int(filtered.size().getInfo())
168
+ if size == 0:
169
+ return None
170
+ image = ee.Image(filtered.first())
171
+ info = image.getInfo().get("properties", {})
172
+ cloud_cover = float(info.get("CLOUD_COVER", 100.0))
173
+ if cloud_cover > 80:
174
+ return None
175
+
176
+ ndvi = image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI")
177
+ ndwi = image.normalizedDifference(["SR_B3", "SR_B5"]).rename("NDWI")
178
+ evi = image.expression(
179
+ "2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))",
180
+ {"NIR": image.select("SR_B5"), "RED": image.select("SR_B4"), "BLUE": image.select("SR_B2")},
181
+ ).rename("EVI")
182
+ lst = image.select("ST_B10").multiply(0.00341802).add(149.0).subtract(273.15).rename("LST")
183
+
184
+ composite = ndvi.addBands([ndwi, evi, lst])
185
+ stats = composite.reduceRegion(
186
+ reducer=ee.Reducer.mean(), geometry=area, scale=30, maxPixels=1e6, bestEffort=True
187
+ ).getInfo()
188
+ ndvi_val = stats.get("NDVI")
189
+ if ndvi_val is None:
190
+ return None
191
+ ndwi_val = stats.get("NDWI")
192
+ evi_val = stats.get("EVI")
193
+ lst_val = stats.get("LST")
194
+ current_dt = datetime.strptime(date_str, "%Y-%m-%d")
195
  return {
196
+ "ndvi": float(ndvi_val),
197
+ "ndwi": float(ndwi_val) if ndwi_val is not None else None,
198
+ "evi": float(evi_val) if evi_val is not None else None,
199
+ "lst": float(lst_val) if lst_val is not None else None,
200
+ "cloud_cover": float(cloud_cover),
201
+ "month": current_dt.month,
202
+ "day_of_year": current_dt.timetuple().tm_yday,
203
+ "satellite": satellite,
204
+ "date": date_str,
205
+ "buffer_size": buffer_meters,
206
  }
 
207
  except Exception as e:
208
+ print("❌ get_single_date_satellite_data error:", e)
209
+ return None
210
 
211
+ def get_satellite_data_with_fallback(lat, lon, target_dt, satellite, buffer_meters, area, max_days_back=MAX_DAYS_BACK):
212
+ for days_back in range(0, max_days_back + 1):
213
+ current_date = (target_dt - timedelta(days=days_back)).strftime("%Y-%m-%d")
214
+ data = get_single_date_satellite_data(lat, lon, current_date, satellite, buffer_meters, area)
215
+ if data and data.get("ndvi") is not None:
216
+ data["original_request_date"] = target_dt.strftime("%Y-%m-%d")
217
+ data["actual_data_date"] = current_date
218
+ data["days_offset"] = days_back
219
+ return data
220
+ return None
221
 
222
+ def get_essential_vegetation_data(lat, lon, target_date, buffer_meters=BUFFER_METERS, max_days_back=MAX_DAYS_BACK):
223
+ point = ee.Geometry.Point([float(lon), float(lat)])
224
+ area = point.buffer(buffer_meters)
225
+ target_dt = datetime.strptime(target_date, "%Y-%m-%d")
226
+ data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-9", buffer_meters, area, max_days_back)
227
+ if not data:
228
+ data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-8", buffer_meters, area, max_days_back)
229
+ return data
230
 
231
+ # ------------------------------
232
+ # ML prediction wrapper
233
+ # ------------------------------
234
+ def predict_bloom_with_ml(features_dict):
235
+ ndvi = features_dict.get("ndvi", 0.0) or 0.0
236
+ evi = features_dict.get("evi", 0.0) or 0.0
237
+ if ndvi < 0.05:
238
+ return {"bloom_probability": 8.0, "prediction": "NO_BLOOM", "confidence": "HIGH"}
239
+ if evi < 0.1 and ndvi < 0.1:
240
+ return {"bloom_probability": 10.0, "prediction": "NO_BLOOM", "confidence": "HIGH"}
241
+
242
+ if ML_MODEL is not None and SCALER is not None:
243
+ try:
244
+ features_array = np.array(
245
+ [
246
+ [
247
+ float(features_dict.get("ndvi", 0.0)),
248
+ float(features_dict.get("ndwi", 0.0) or 0.0),
249
+ float(features_dict.get("evi", 0.0) or 0.0),
250
+ float(features_dict.get("lst", 0.0) or 0.0),
251
+ float(features_dict.get("cloud_cover", 0.0) or 0.0),
252
+ float(features_dict.get("month", 0) or 0),
253
+ float(features_dict.get("day_of_year", 0) or 0),
254
+ ]
255
+ ],
256
+ dtype=np.float64,
257
+ )
258
+ features_scaled = SCALER.transform(features_array)
259
+ probabilities = ML_MODEL.predict_proba(features_scaled)
260
+ bloom_prob = probabilities[0, 1] if probabilities.shape[1] == 2 else probabilities[0, 0]
261
+ prediction = ML_MODEL.predict(features_scaled)[0]
262
+ bloom_prob_pct = round(float(bloom_prob * 100.0), 2)
263
+ if bloom_prob_pct > 75 or bloom_prob_pct < 25:
264
+ conf = "HIGH"
265
+ elif bloom_prob_pct > 60 or bloom_prob_pct < 40:
266
+ conf = "MEDIUM"
267
+ else:
268
+ conf = "LOW"
269
+ return {"bloom_probability": bloom_prob_pct, "prediction": "BLOOM" if prediction == 1 else "NO_BLOOM", "confidence": conf}
270
+ except Exception as e:
271
+ print("❌ ML model error:", e)
272
+ return predict_bloom_fallback(features_dict)
273
+
274
+ def predict_bloom_fallback(features_dict):
275
+ ndvi = float(features_dict.get("ndvi") or 0.0)
276
+ ndwi = float(features_dict.get("ndwi") or 0.0)
277
+ evi = float(features_dict.get("evi") or 0.0)
278
+ lst = float(features_dict.get("lst") or 0.0)
279
+ month = int(features_dict.get("month") or 1)
280
+ score = 0.0
281
+ if evi > 0.7:
282
+ score += 50
283
+ elif evi > 0.5:
284
+ score += 35
285
+ elif evi > 0.3:
286
+ score += 20
287
+ if ndvi > 0.5:
288
+ score += 25
289
+ elif ndvi > 0.3:
290
+ score += 15
291
+ if -0.2 < ndwi < 0.05:
292
+ score += 15
293
+ if 12 < lst < 32:
294
+ score += 12
295
+ if month in [3, 4, 5]:
296
+ score += 15
297
+ if month in [11, 12, 1, 2]:
298
+ score -= 3
299
+ prob = min(90, max(8, score))
300
+ if prob > 52:
301
+ pred = "BLOOM"
302
+ conf = "MEDIUM" if prob > 65 else "LOW"
303
+ else:
304
+ pred = "NO_BLOOM"
305
+ conf = "MEDIUM" if prob < 25 else "LOW"
306
+ return {"bloom_probability": round(prob, 2), "prediction": pred, "confidence": conf}
307
+
308
+ # ------------------------------
309
+ # Species stats builder / predictor
310
+ # ------------------------------
311
+ def load_or_build_species_stats():
312
+ global PHENO_FILE, SPECIES_STATS_FILE
313
+ if SPECIES_STATS_FILE.exists():
314
+ df = pd.read_csv(SPECIES_STATS_FILE)
315
+ doy_map = {}
316
+ for s in df["species"].tolist():
317
+ doy_map[s] = np.ones(DOY_BINS) / DOY_BINS
318
+ return df, doy_map
319
+ if PHENO_FILE.exists():
320
+ ph = pd.read_csv(PHENO_FILE, low_memory=False)
321
+ if "phenophaseStatus" in ph.columns:
322
+ ph["phenophaseStatus"] = ph["phenophaseStatus"].astype(str).str.strip().str.lower()
323
+ ph_yes = ph[ph["phenophaseStatus"].str.startswith("y")].copy()
324
+ else:
325
+ ph_yes = ph.copy()
326
+ ph_yes = ph_yes.dropna(subset=["elevation"])
327
+ if "dayOfYear" in ph_yes.columns:
328
+ ph_yes["dayOfYear"] = pd.to_numeric(ph_yes["dayOfYear"], errors="coerce").dropna().astype(int).clip(1, 366)
329
+ rows = []
330
+ doy_map = {}
331
+ grouped = ph_yes.groupby("scientificName")
332
+ for name, g in grouped:
333
+ cnt = len(g)
334
+ mean_elev = float(g["elevation"].dropna().mean()) if cnt > 0 else np.nan
335
+ std_elev = float(g["elevation"].dropna().std(ddof=0)) if cnt > 0 else EPS_STD
336
+ std_elev = max(std_elev if not np.isnan(std_elev) else 0.0, EPS_STD)
337
+ rows.append({"species": name, "count": cnt, "mean_elev": mean_elev, "std_elev": std_elev})
338
+ if "dayOfYear" in g.columns:
339
+ doy_map[name] = circular_histogram(g["dayOfYear"].to_numpy(dtype=int))
340
+ else:
341
+ doy_map[name] = np.ones(DOY_BINS) / DOY_BINS
342
+ species_df = pd.DataFrame(rows)
343
+ total = species_df["count"].sum()
344
+ species_df["prior"] = species_df["count"] / total if total > 0 else 1.0 / max(1, len(species_df))
345
+ rare = species_df[species_df["count"] < MIN_COUNT_FOR_SPECIES]
346
+ frequent = species_df[species_df["count"] >= MIN_COUNT_FOR_SPECIES]
347
+ final_rows = frequent.to_dict("records")
348
+ if len(rare) > 0:
349
+ rare_names = rare["species"].tolist()
350
+ rare_obs = ph_yes[ph_yes["scientificName"].isin(rare_names)]
351
+ total_rare = len(rare_obs)
352
+ if total_rare > 0:
353
+ mean_other = float(rare_obs["elevation"].dropna().mean())
354
+ std_other = float(rare_obs["elevation"].dropna().std(ddof=0)) if total_rare > 1 else EPS_STD
355
+ std_other = max(std_other if not np.isnan(std_other) else 0.0, EPS_STD)
356
+ final_rows.append(
357
+ {
358
+ "species": "OTHER",
359
+ "count": int(total_rare),
360
+ "mean_elev": mean_other,
361
+ "std_elev": std_other,
362
+ "prior": int(total_rare) / total if total > 0 else int(total_rare),
363
+ }
364
+ )
365
+ doy_map["OTHER"] = circular_histogram(rare_obs["dayOfYear"].to_numpy(dtype=int)) if "dayOfYear" in rare_obs.columns else np.ones(DOY_BINS) / DOY_BINS
366
+ final_df = pd.DataFrame(final_rows).fillna(0)
367
+ if "prior" not in final_df.columns:
368
+ t2 = final_df["count"].sum()
369
+ final_df["prior"] = final_df["count"] / t2 if t2 > 0 else 1.0 / len(final_df)
370
+ return final_df, doy_map
371
+ return pd.DataFrame(columns=["species", "count", "mean_elev", "std_elev", "prior"]), {}
372
+
373
+ def predict_species_by_elevation(elevation, doy=None, top_k=TOP_K_SPECIES):
374
+ global SPECIES_STATS_DF, DOY_HIST_MAP
375
+ if SPECIES_STATS_DF is None or SPECIES_STATS_DF.empty:
376
+ return []
377
+ species = SPECIES_STATS_DF["species"].tolist()
378
+ priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float)
379
+ means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float)
380
+ stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float)
381
+ x = np.array([float(elevation)]) if elevation is not None else np.array([np.nan])
382
+ like = np.array([gaussian_pdf(x, means[i], stds[i])[0] for i in range(len(species))])
383
+ post = priors * like
384
+ if post.sum() == 0:
385
+ post = np.ones(len(species)) / len(species)
386
+ else:
387
+ post = post / post.sum()
388
+ if doy is not None and not np.isnan(doy):
389
+ doy_idx = int(doy) - 1
390
+ doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in species])
391
+ combined = post * doy_probs
392
+ if combined.sum() > 0:
393
+ combined = combined / combined.sum()
394
+ post = combined
395
+ order = np.argsort(-post)
396
+ top = []
397
+ for i in order[:top_k]:
398
+ top.append((species[i], float(post[i])))
399
+ return top
400
+
401
+ # ------------------------------
402
+ # Lifespan to load models and init EE
403
+ # ------------------------------
404
+ @asynccontextmanager
405
+ async def lifespan(app):
406
+ global ML_MODEL, SCALER, FEATURE_COLUMNS, SPECIES_STATS_DF, DOY_HIST_MAP
407
+ if MODEL_FILE.exists():
408
+ try:
409
+ ML_MODEL = joblib.load(MODEL_FILE)
410
+ print("βœ… MIL model loaded.")
411
+ except Exception as e:
412
+ print("❌ MIL model load error:", e)
413
+ if SCALER_FILE.exists():
414
+ try:
415
+ SCALER = joblib.load(SCALER_FILE)
416
+ print("βœ… Scaler loaded.")
417
+ except Exception as e:
418
+ print("❌ Scaler load error:", e)
419
+ if FEATURES_FILE.exists():
420
+ try:
421
+ FEATURE_COLUMNS = joblib.load(FEATURES_FILE)
422
+ print("βœ… Features list loaded.")
423
+ except Exception as e:
424
+ print("❌ Features list load error:", e)
425
+
426
+ ok = initialize_ee_from_env()
427
+ if not ok:
428
+ raise RuntimeError("Earth Engine initialization failed. Set CLIENT_ID, CLIENT_SECRET, REFRESH_TOKEN env vars in Space secrets.")
429
 
 
 
 
 
 
430
  try:
431
+ SPECIES_STATS_DF, DOY_HIST_MAP = load_or_build_species_stats()
432
+ print("βœ… Species stats ready. species count:", len(SPECIES_STATS_DF))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  except Exception as e:
434
+ print("⚠️ Species stats build error:", e)
435
+ SPECIES_STATS_DF = pd.DataFrame()
436
+ DOY_HIST_MAP = {}
437
+
438
+ yield
439
+ print("πŸ”„ Shutting down")
440
+
441
+ # ------------------------------
442
+ # App + endpoints
443
+ # ------------------------------
444
+ app = FastAPI(title="Bloom Prediction (HF Space)", lifespan=lifespan)
445
+
446
+ @app.get("/")
447
+ async def root():
448
+ return {"message": "Bloom Prediction API (HF Space)", "model_loaded": ML_MODEL is not None}
449
+
450
+ @app.post("/predict", response_model=BloomPredictionResponse)
451
+ async def predict_bloom(req: BloomPredictionRequest):
452
+ start = time.time()
453
+ try:
454
+ req_dt = datetime.strptime(req.date, "%Y-%m-%d")
455
+ except ValueError:
456
+ raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD")
457
+
458
+ elevation = get_elevation_from_ee(req.lat, req.lon)
459
+ year = req_dt.year
460
+ monthly_results = []
461
+ for month in range(1, 13):
462
+ sample_dt = date(year, month, 15)
463
+ sample_date_str = sample_dt.strftime("%Y-%m-%d")
464
+ point = ee.Geometry.Point([float(req.lon), float(req.lat)])
465
+ area = point.buffer(BUFFER_METERS)
466
+ sat_data = get_essential_vegetation_data(req.lat, req.lon, sample_date_str)
467
+ result = {
468
+ "month": month,
469
+ "sample_date": sample_date_str,
470
+ "ml_bloom_probability": None,
471
+ "ml_prediction": None,
472
+ "ml_confidence": None,
473
+ "species_top": None,
474
+ "species_probs": None,
475
+ "elevation_m": elevation,
476
+ "data_quality": None,
477
+ "satellite": None,
478
+ "note": None,
479
+ }
480
+ if sat_data is None:
481
+ result["note"] = f"No satellite data within {MAX_DAYS_BACK} days for {sample_date_str}"
482
+ monthly_results.append(MonthlyResult(**result))
483
+ continue
484
+
485
+ ml_out = predict_bloom_with_ml(sat_data)
486
+ result["ml_bloom_probability"] = float(ml_out.get("bloom_probability", 0.0))
487
+ result["ml_prediction"] = ml_out.get("prediction")
488
+ result["ml_confidence"] = ml_out.get("confidence")
489
+ result["data_quality"] = {
490
+ "satellite": sat_data.get("satellite"),
491
+ "cloud_cover": sat_data.get("cloud_cover"),
492
+ "days_offset": sat_data.get("days_offset"),
493
+ "buffer_radius_m": sat_data.get("buffer_size"),
494
+ }
495
+ result["satellite"] = sat_data.get("satellite")
496
+
497
+ try:
498
+ bloom_bool = (result["ml_prediction"] == "BLOOM") or (result["ml_bloom_probability"] >= 50.0)
499
+ if bloom_bool:
500
+ doy = sat_data.get("day_of_year", None)
501
+ top_species = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES)
502
+ result["species_top"] = [(s, round(p * 100.0, 2)) for s, p in top_species]
503
+
504
+ species_probs = {}
505
+ if (SPECIES_STATS_DF is not None) and (not SPECIES_STATS_DF.empty):
506
+ all_species = SPECIES_STATS_DF["species"].tolist()
507
+ priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float)
508
+ means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float)
509
+ stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float)
510
+ x = np.array([float(elevation)]) if elevation is not None else np.array([np.nan])
511
+ like = np.array([gaussian_pdf(x, means[i], stds[i])[0] for i in range(len(all_species))])
512
+ post = priors * like
513
+ if post.sum() == 0:
514
+ post = np.ones(len(all_species)) / len(all_species)
515
+ else:
516
+ post = post / post.sum()
517
+ if doy is not None and not np.isnan(doy):
518
+ doy_idx = int(doy) - 1
519
+ doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in all_species])
520
+ combined = post * doy_probs
521
+ if combined.sum() > 0:
522
+ combined = combined / combined.sum()
523
+ post = combined
524
+ for s, p in zip(all_species, post):
525
+ species_probs[s] = round(float(p * 100.0), 6)
526
+ result["species_probs"] = species_probs
527
+ else:
528
+ result["species_top"] = []
529
+ result["species_probs"] = {}
530
+ except Exception as e:
531
+ print("❌ species prediction error:", e)
532
+ result["species_top"] = []
533
+ result["species_probs"] = {}
534
+ result["note"] = (result.get("note", "") + " ; species_pred_error") if result.get("note") else "species_pred_error"
535
+
536
+ monthly_results.append(MonthlyResult(**result))
537
+
538
+ proc_time = round(time.time() - start, 2)
539
+ resp = {
540
+ "success": True,
541
+ "analysis_date": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
542
+ "requested_date": req.date,
543
+ "monthly_results": monthly_results,
544
+ "processing_time": proc_time,
545
+ }
546
+ return BloomPredictionResponse(**resp)
547
 
548
+ # Run locally if invoked directly (not used by Docker CMD)
549
  if __name__ == "__main__":
550
  import uvicorn
551
+ uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))