NaseefNazrul commited on
Commit
23c8193
·
verified ·
1 Parent(s): 6e6068d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -57
app.py CHANGED
@@ -30,6 +30,8 @@ SCALER_FILE = Path("mil_scaler.joblib")
30
  FEATURES_FILE = Path("mil_features.joblib")
31
  PHENO_FILE = Path("phenologythingy.csv")
32
  SPECIES_STATS_FILE = Path("species_stats.csv")
 
 
33
 
34
  ELEV_IMAGE_ID = "USGS/SRTMGL1_003"
35
  BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200))
@@ -68,27 +70,31 @@ class BloomPredictionRequest(BaseModel):
68
  lon: float = Field(..., ge=-180, le=180)
69
  date: str = Field(..., description="YYYY-MM-DD")
70
 
71
- class MonthlyResult(BaseModel):
72
  month: int
73
- sample_date: str
74
- ml_bloom_probability: Optional[float] = None
75
- ml_prediction: Optional[str] = None
76
- ml_confidence: Optional[str] = None
77
- species_top: Optional[List[Tuple[str, float]]] = None
78
- species_probs: Optional[Dict[str, float]] = None
79
- elevation_m: Optional[float] = None
80
- data_quality: Optional[dict] = None
81
- satellite: Optional[str] = None
82
- note: Optional[str] = None
83
 
84
  class BloomPredictionResponse(BaseModel):
85
  success: bool
86
- analysis_date: str
87
  requested_date: str
88
- monthly_results: List[MonthlyResult]
89
- monthly_curve: Dict[int, float] # month -> percent (sums to ~100)
90
- bell_valid: Optional[bool] = None
91
- bell_diagnostics: Optional[Dict[str, float]] = None
 
 
 
 
 
 
 
 
92
  processing_time: float
93
 
94
  # ------------------------------
@@ -696,20 +702,25 @@ def process_month_task(lat, lon, year, month, elevation):
696
  @app.post("/predict", response_model=BloomPredictionResponse)
697
  async def predict_bloom(req: BloomPredictionRequest):
698
  start_time = time.time()
699
- # validate date
 
700
  try:
701
  req_dt = datetime.strptime(req.date, "%Y-%m-%d")
702
  except ValueError:
703
  raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD")
704
- # elevation once
 
705
  elevation = get_elevation_from_ee(req.lat, req.lon)
706
  year = req_dt.year
707
 
708
  monthly_results = [None] * 12
709
- # Run monthly tasks in parallel to speed up (bounded workers)
710
- tasks = []
711
  with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
712
- futures = {ex.submit(process_month_task, req.lat, req.lon, year, month, elevation): month for month in range(1, 13)}
 
 
 
713
  for fut in as_completed(futures):
714
  month = futures[fut]
715
  try:
@@ -719,9 +730,9 @@ async def predict_bloom(req: BloomPredictionRequest):
719
  res = {
720
  "month": month,
721
  "sample_date": date(year, month, 15).strftime("%Y-%m-%d"),
722
- "ml_bloom_probability": None,
723
- "ml_prediction": None,
724
- "ml_confidence": None,
725
  "species_top": [],
726
  "species_probs": {},
727
  "elevation_m": elevation,
@@ -729,47 +740,82 @@ async def predict_bloom(req: BloomPredictionRequest):
729
  "satellite": None,
730
  "note": "processing_error"
731
  }
732
- monthly_results[month - 1] = MonthlyResult(**res)
733
 
734
- # Build monthly bell curve: normalize ml_bloom_probability across months
735
- # After monthly_results list is filled (MonthlyResult objects)...
 
 
 
 
736
 
737
- # Build raw ML probs array (0-100)
738
- raw_probs = np.array([(mr.ml_bloom_probability or 0.0) for mr in monthly_results], dtype=float)
739
-
740
- # 1) Compute smoothed bell curve percentages
741
  monthly_perc = smooth_monthly_probs(raw_probs.tolist(), alpha=ALPHA, sigma=SMOOTH_SIGMA)
742
  monthly_curve = {i+1: float(monthly_perc[i]) for i in range(12)}
743
-
744
- # 2) Bell-shape verification
745
  bell_ok, bell_diag = is_bell_shaped(list(monthly_perc))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
- # 3) Trim species_probs in each monthly_result to top-K only
748
- for mr in monthly_results:
749
- if isinstance(mr.species_probs, dict) and len(mr.species_probs) > 0:
750
- # sort and keep top K
751
- items = sorted(mr.species_probs.items(), key=lambda x: -float(x[1]))[:TOP_K_SPECIES]
752
- mr.species_probs = {k: round(float(v), 6) for k, v in items}
753
- # also ensure species_top aligns (already top list)
754
- mr.species_top = [(s, float(p)) for s, p in mr.species_top[:TOP_K_SPECIES]]
755
- else:
756
- mr.species_probs = {}
757
- mr.species_top = []
758
-
759
- # include bell verification in response
760
  processing_time = round(time.time() - start_time, 2)
761
 
762
- response = {
763
- "success": True,
764
- "analysis_date": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
765
- "requested_date": req.date,
766
- "monthly_results": monthly_results,
767
- "monthly_curve": monthly_curve,
768
- "bell_valid": bell_ok,
769
- "bell_diagnostics": bell_diag,
770
- "processing_time": processing_time
771
- }
772
- return BloomPredictionResponse(**response)
 
 
773
 
774
  # ------------------------------
775
  # Local run
 
30
  FEATURES_FILE = Path("mil_features.joblib")
31
  PHENO_FILE = Path("phenologythingy.csv")
32
  SPECIES_STATS_FILE = Path("species_stats.csv")
33
+ MIN_BLOOM_THRESHOLD = float(os.environ.get("MIN_BLOOM_THRESHOLD", 40.0)) # minimum probability to predict species
34
+ MIN_PEAK_FOR_BELL = float(os.environ.get("MIN_PEAK_FOR_BELL", 60.0))
35
 
36
  ELEV_IMAGE_ID = "USGS/SRTMGL1_003"
37
  BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200))
 
70
  lon: float = Field(..., ge=-180, le=180)
71
  date: str = Field(..., description="YYYY-MM-DD")
72
 
73
+ class SimplifiedMonthlyResult(BaseModel):
74
  month: int
75
+ bloom_probability: float
76
+ prediction: str # "BLOOM" or "NO_BLOOM"
77
+
78
+ class SpeciesResult(BaseModel):
79
+ name: str
80
+ probability: float # as percentage
 
 
 
 
81
 
82
  class BloomPredictionResponse(BaseModel):
83
  success: bool
84
+ status: str # "BLOOM_DETECTED", "NO_BLOOM", "LOW_CONFIDENCE"
85
  requested_date: str
86
+
87
+ # Only include these if there's a valid bloom season
88
+ peak_month: Optional[int] = None
89
+ peak_probability: Optional[float] = None
90
+ bloom_window: Optional[List[int]] = None # months with >40% probability
91
+
92
+ # Only include species if peak > threshold
93
+ top_species: Optional[List[SpeciesResult]] = None
94
+
95
+ # Simplified monthly data (probabilities only)
96
+ monthly_probabilities: Dict[int, float]
97
+
98
  processing_time: float
99
 
100
  # ------------------------------
 
702
  @app.post("/predict", response_model=BloomPredictionResponse)
703
  async def predict_bloom(req: BloomPredictionRequest):
704
  start_time = time.time()
705
+
706
+ # Validate date
707
  try:
708
  req_dt = datetime.strptime(req.date, "%Y-%m-%d")
709
  except ValueError:
710
  raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD")
711
+
712
+ # Get elevation once
713
  elevation = get_elevation_from_ee(req.lat, req.lon)
714
  year = req_dt.year
715
 
716
  monthly_results = [None] * 12
717
+
718
+ # Run monthly tasks in parallel
719
  with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
720
+ futures = {
721
+ ex.submit(process_month_task, req.lat, req.lon, year, month, elevation): month
722
+ for month in range(1, 13)
723
+ }
724
  for fut in as_completed(futures):
725
  month = futures[fut]
726
  try:
 
730
  res = {
731
  "month": month,
732
  "sample_date": date(year, month, 15).strftime("%Y-%m-%d"),
733
+ "ml_bloom_probability": 0.0,
734
+ "ml_prediction": "NO_BLOOM",
735
+ "ml_confidence": "LOW",
736
  "species_top": [],
737
  "species_probs": {},
738
  "elevation_m": elevation,
 
740
  "satellite": None,
741
  "note": "processing_error"
742
  }
743
+ monthly_results[month - 1] = res
744
 
745
+ # Extract raw probabilities
746
+ raw_probs = np.array([
747
+ (mr.get("ml_bloom_probability") or 0.0) if isinstance(mr, dict)
748
+ else (mr.ml_bloom_probability or 0.0)
749
+ for mr in monthly_results
750
+ ], dtype=float)
751
 
752
+ # Compute smoothed curve
 
 
 
753
  monthly_perc = smooth_monthly_probs(raw_probs.tolist(), alpha=ALPHA, sigma=SMOOTH_SIGMA)
754
  monthly_curve = {i+1: float(monthly_perc[i]) for i in range(12)}
755
+
756
+ # Check bell shape
757
  bell_ok, bell_diag = is_bell_shaped(list(monthly_perc))
758
+
759
+ # Find peak
760
+ peak_idx = int(np.argmax(monthly_perc))
761
+ peak_month = peak_idx + 1
762
+ peak_prob = float(monthly_perc[peak_idx])
763
+
764
+ # Determine status and whether to include species
765
+ if peak_prob < MIN_PEAK_FOR_BELL or not bell_ok:
766
+ status = "NO_BLOOM"
767
+ top_species = None
768
+ bloom_window = None
769
+ peak_month_out = None
770
+ peak_prob_out = None
771
+ elif peak_prob < MIN_BLOOM_THRESHOLD:
772
+ status = "LOW_CONFIDENCE"
773
+ top_species = None
774
+ bloom_window = [i+1 for i, p in enumerate(monthly_perc) if p > 10.0]
775
+ peak_month_out = peak_month
776
+ peak_prob_out = peak_prob
777
+ else:
778
+ status = "BLOOM_DETECTED"
779
+ bloom_window = [i+1 for i, p in enumerate(monthly_perc) if p > MIN_BLOOM_THRESHOLD]
780
+ peak_month_out = peak_month
781
+ peak_prob_out = peak_prob
782
+
783
+ # Only predict species if we have a strong bloom signal
784
+ try:
785
+ # Use the peak month's data for species prediction
786
+ peak_result = monthly_results[peak_idx]
787
+ if isinstance(peak_result, dict):
788
+ doy = peak_result.get("day_of_year")
789
+ else:
790
+ # Estimate DOY from month
791
+ doy = date(year, peak_month, 15).timetuple().tm_yday
792
+
793
+ species_predictions = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES)
794
+
795
+ # Convert to response format (probabilities as percentages)
796
+ top_species = [
797
+ SpeciesResult(name=sp, probability=round(prob * 100.0, 2))
798
+ for sp, prob in species_predictions
799
+ ]
800
+ except Exception as e:
801
+ print(f"❌ species prediction error: {e}")
802
+ top_species = None
803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
  processing_time = round(time.time() - start_time, 2)
805
 
806
+ response = BloomPredictionResponse(
807
+ success=True,
808
+ status=status,
809
+ requested_date=req.date,
810
+ peak_month=peak_month_out,
811
+ peak_probability=peak_prob_out,
812
+ bloom_window=bloom_window,
813
+ top_species=top_species,
814
+ monthly_probabilities=monthly_curve,
815
+ processing_time=processing_time
816
+ )
817
+
818
+ return response
819
 
820
  # ------------------------------
821
  # Local run