Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,6 +30,8 @@ SCALER_FILE = Path("mil_scaler.joblib")
|
|
| 30 |
FEATURES_FILE = Path("mil_features.joblib")
|
| 31 |
PHENO_FILE = Path("phenologythingy.csv")
|
| 32 |
SPECIES_STATS_FILE = Path("species_stats.csv")
|
|
|
|
|
|
|
| 33 |
|
| 34 |
ELEV_IMAGE_ID = "USGS/SRTMGL1_003"
|
| 35 |
BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200))
|
|
@@ -68,27 +70,31 @@ class BloomPredictionRequest(BaseModel):
|
|
| 68 |
lon: float = Field(..., ge=-180, le=180)
|
| 69 |
date: str = Field(..., description="YYYY-MM-DD")
|
| 70 |
|
| 71 |
-
class
|
| 72 |
month: int
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
elevation_m: Optional[float] = None
|
| 80 |
-
data_quality: Optional[dict] = None
|
| 81 |
-
satellite: Optional[str] = None
|
| 82 |
-
note: Optional[str] = None
|
| 83 |
|
| 84 |
class BloomPredictionResponse(BaseModel):
|
| 85 |
success: bool
|
| 86 |
-
|
| 87 |
requested_date: str
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
processing_time: float
|
| 93 |
|
| 94 |
# ------------------------------
|
|
@@ -696,20 +702,25 @@ def process_month_task(lat, lon, year, month, elevation):
|
|
| 696 |
@app.post("/predict", response_model=BloomPredictionResponse)
|
| 697 |
async def predict_bloom(req: BloomPredictionRequest):
|
| 698 |
start_time = time.time()
|
| 699 |
-
|
|
|
|
| 700 |
try:
|
| 701 |
req_dt = datetime.strptime(req.date, "%Y-%m-%d")
|
| 702 |
except ValueError:
|
| 703 |
raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD")
|
| 704 |
-
|
|
|
|
| 705 |
elevation = get_elevation_from_ee(req.lat, req.lon)
|
| 706 |
year = req_dt.year
|
| 707 |
|
| 708 |
monthly_results = [None] * 12
|
| 709 |
-
|
| 710 |
-
tasks
|
| 711 |
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
|
| 712 |
-
futures = {
|
|
|
|
|
|
|
|
|
|
| 713 |
for fut in as_completed(futures):
|
| 714 |
month = futures[fut]
|
| 715 |
try:
|
|
@@ -719,9 +730,9 @@ async def predict_bloom(req: BloomPredictionRequest):
|
|
| 719 |
res = {
|
| 720 |
"month": month,
|
| 721 |
"sample_date": date(year, month, 15).strftime("%Y-%m-%d"),
|
| 722 |
-
"ml_bloom_probability":
|
| 723 |
-
"ml_prediction":
|
| 724 |
-
"ml_confidence":
|
| 725 |
"species_top": [],
|
| 726 |
"species_probs": {},
|
| 727 |
"elevation_m": elevation,
|
|
@@ -729,47 +740,82 @@ async def predict_bloom(req: BloomPredictionRequest):
|
|
| 729 |
"satellite": None,
|
| 730 |
"note": "processing_error"
|
| 731 |
}
|
| 732 |
-
monthly_results[month - 1] =
|
| 733 |
|
| 734 |
-
#
|
| 735 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
|
| 737 |
-
#
|
| 738 |
-
raw_probs = np.array([(mr.ml_bloom_probability or 0.0) for mr in monthly_results], dtype=float)
|
| 739 |
-
|
| 740 |
-
# 1) Compute smoothed bell curve percentages
|
| 741 |
monthly_perc = smooth_monthly_probs(raw_probs.tolist(), alpha=ALPHA, sigma=SMOOTH_SIGMA)
|
| 742 |
monthly_curve = {i+1: float(monthly_perc[i]) for i in range(12)}
|
| 743 |
-
|
| 744 |
-
#
|
| 745 |
bell_ok, bell_diag = is_bell_shaped(list(monthly_perc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
|
| 747 |
-
# 3) Trim species_probs in each monthly_result to top-K only
|
| 748 |
-
for mr in monthly_results:
|
| 749 |
-
if isinstance(mr.species_probs, dict) and len(mr.species_probs) > 0:
|
| 750 |
-
# sort and keep top K
|
| 751 |
-
items = sorted(mr.species_probs.items(), key=lambda x: -float(x[1]))[:TOP_K_SPECIES]
|
| 752 |
-
mr.species_probs = {k: round(float(v), 6) for k, v in items}
|
| 753 |
-
# also ensure species_top aligns (already top list)
|
| 754 |
-
mr.species_top = [(s, float(p)) for s, p in mr.species_top[:TOP_K_SPECIES]]
|
| 755 |
-
else:
|
| 756 |
-
mr.species_probs = {}
|
| 757 |
-
mr.species_top = []
|
| 758 |
-
|
| 759 |
-
# include bell verification in response
|
| 760 |
processing_time = round(time.time() - start_time, 2)
|
| 761 |
|
| 762 |
-
response =
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
|
|
|
|
|
|
| 773 |
|
| 774 |
# ------------------------------
|
| 775 |
# Local run
|
|
|
|
| 30 |
FEATURES_FILE = Path("mil_features.joblib")
|
| 31 |
PHENO_FILE = Path("phenologythingy.csv")
|
| 32 |
SPECIES_STATS_FILE = Path("species_stats.csv")
|
| 33 |
+
MIN_BLOOM_THRESHOLD = float(os.environ.get("MIN_BLOOM_THRESHOLD", 40.0)) # minimum probability to predict species
|
| 34 |
+
MIN_PEAK_FOR_BELL = float(os.environ.get("MIN_PEAK_FOR_BELL", 60.0))
|
| 35 |
|
| 36 |
ELEV_IMAGE_ID = "USGS/SRTMGL1_003"
|
| 37 |
BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200))
|
|
|
|
| 70 |
lon: float = Field(..., ge=-180, le=180)
|
| 71 |
date: str = Field(..., description="YYYY-MM-DD")
|
| 72 |
|
| 73 |
+
class SimplifiedMonthlyResult(BaseModel):
|
| 74 |
month: int
|
| 75 |
+
bloom_probability: float
|
| 76 |
+
prediction: str # "BLOOM" or "NO_BLOOM"
|
| 77 |
+
|
| 78 |
+
class SpeciesResult(BaseModel):
|
| 79 |
+
name: str
|
| 80 |
+
probability: float # as percentage
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
class BloomPredictionResponse(BaseModel):
|
| 83 |
success: bool
|
| 84 |
+
status: str # "BLOOM_DETECTED", "NO_BLOOM", "LOW_CONFIDENCE"
|
| 85 |
requested_date: str
|
| 86 |
+
|
| 87 |
+
# Only include these if there's a valid bloom season
|
| 88 |
+
peak_month: Optional[int] = None
|
| 89 |
+
peak_probability: Optional[float] = None
|
| 90 |
+
bloom_window: Optional[List[int]] = None # months with >40% probability
|
| 91 |
+
|
| 92 |
+
# Only include species if peak > threshold
|
| 93 |
+
top_species: Optional[List[SpeciesResult]] = None
|
| 94 |
+
|
| 95 |
+
# Simplified monthly data (probabilities only)
|
| 96 |
+
monthly_probabilities: Dict[int, float]
|
| 97 |
+
|
| 98 |
processing_time: float
|
| 99 |
|
| 100 |
# ------------------------------
|
|
|
|
| 702 |
@app.post("/predict", response_model=BloomPredictionResponse)
|
| 703 |
async def predict_bloom(req: BloomPredictionRequest):
|
| 704 |
start_time = time.time()
|
| 705 |
+
|
| 706 |
+
# Validate date
|
| 707 |
try:
|
| 708 |
req_dt = datetime.strptime(req.date, "%Y-%m-%d")
|
| 709 |
except ValueError:
|
| 710 |
raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD")
|
| 711 |
+
|
| 712 |
+
# Get elevation once
|
| 713 |
elevation = get_elevation_from_ee(req.lat, req.lon)
|
| 714 |
year = req_dt.year
|
| 715 |
|
| 716 |
monthly_results = [None] * 12
|
| 717 |
+
|
| 718 |
+
# Run monthly tasks in parallel
|
| 719 |
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
|
| 720 |
+
futures = {
|
| 721 |
+
ex.submit(process_month_task, req.lat, req.lon, year, month, elevation): month
|
| 722 |
+
for month in range(1, 13)
|
| 723 |
+
}
|
| 724 |
for fut in as_completed(futures):
|
| 725 |
month = futures[fut]
|
| 726 |
try:
|
|
|
|
| 730 |
res = {
|
| 731 |
"month": month,
|
| 732 |
"sample_date": date(year, month, 15).strftime("%Y-%m-%d"),
|
| 733 |
+
"ml_bloom_probability": 0.0,
|
| 734 |
+
"ml_prediction": "NO_BLOOM",
|
| 735 |
+
"ml_confidence": "LOW",
|
| 736 |
"species_top": [],
|
| 737 |
"species_probs": {},
|
| 738 |
"elevation_m": elevation,
|
|
|
|
| 740 |
"satellite": None,
|
| 741 |
"note": "processing_error"
|
| 742 |
}
|
| 743 |
+
monthly_results[month - 1] = res
|
| 744 |
|
| 745 |
+
# Extract raw probabilities
|
| 746 |
+
raw_probs = np.array([
|
| 747 |
+
(mr.get("ml_bloom_probability") or 0.0) if isinstance(mr, dict)
|
| 748 |
+
else (mr.ml_bloom_probability or 0.0)
|
| 749 |
+
for mr in monthly_results
|
| 750 |
+
], dtype=float)
|
| 751 |
|
| 752 |
+
# Compute smoothed curve
|
|
|
|
|
|
|
|
|
|
| 753 |
monthly_perc = smooth_monthly_probs(raw_probs.tolist(), alpha=ALPHA, sigma=SMOOTH_SIGMA)
|
| 754 |
monthly_curve = {i+1: float(monthly_perc[i]) for i in range(12)}
|
| 755 |
+
|
| 756 |
+
# Check bell shape
|
| 757 |
bell_ok, bell_diag = is_bell_shaped(list(monthly_perc))
|
| 758 |
+
|
| 759 |
+
# Find peak
|
| 760 |
+
peak_idx = int(np.argmax(monthly_perc))
|
| 761 |
+
peak_month = peak_idx + 1
|
| 762 |
+
peak_prob = float(monthly_perc[peak_idx])
|
| 763 |
+
|
| 764 |
+
# Determine status and whether to include species
|
| 765 |
+
if peak_prob < MIN_PEAK_FOR_BELL or not bell_ok:
|
| 766 |
+
status = "NO_BLOOM"
|
| 767 |
+
top_species = None
|
| 768 |
+
bloom_window = None
|
| 769 |
+
peak_month_out = None
|
| 770 |
+
peak_prob_out = None
|
| 771 |
+
elif peak_prob < MIN_BLOOM_THRESHOLD:
|
| 772 |
+
status = "LOW_CONFIDENCE"
|
| 773 |
+
top_species = None
|
| 774 |
+
bloom_window = [i+1 for i, p in enumerate(monthly_perc) if p > 10.0]
|
| 775 |
+
peak_month_out = peak_month
|
| 776 |
+
peak_prob_out = peak_prob
|
| 777 |
+
else:
|
| 778 |
+
status = "BLOOM_DETECTED"
|
| 779 |
+
bloom_window = [i+1 for i, p in enumerate(monthly_perc) if p > MIN_BLOOM_THRESHOLD]
|
| 780 |
+
peak_month_out = peak_month
|
| 781 |
+
peak_prob_out = peak_prob
|
| 782 |
+
|
| 783 |
+
# Only predict species if we have a strong bloom signal
|
| 784 |
+
try:
|
| 785 |
+
# Use the peak month's data for species prediction
|
| 786 |
+
peak_result = monthly_results[peak_idx]
|
| 787 |
+
if isinstance(peak_result, dict):
|
| 788 |
+
doy = peak_result.get("day_of_year")
|
| 789 |
+
else:
|
| 790 |
+
# Estimate DOY from month
|
| 791 |
+
doy = date(year, peak_month, 15).timetuple().tm_yday
|
| 792 |
+
|
| 793 |
+
species_predictions = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES)
|
| 794 |
+
|
| 795 |
+
# Convert to response format (probabilities as percentages)
|
| 796 |
+
top_species = [
|
| 797 |
+
SpeciesResult(name=sp, probability=round(prob * 100.0, 2))
|
| 798 |
+
for sp, prob in species_predictions
|
| 799 |
+
]
|
| 800 |
+
except Exception as e:
|
| 801 |
+
print(f"❌ species prediction error: {e}")
|
| 802 |
+
top_species = None
|
| 803 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
processing_time = round(time.time() - start_time, 2)
|
| 805 |
|
| 806 |
+
response = BloomPredictionResponse(
|
| 807 |
+
success=True,
|
| 808 |
+
status=status,
|
| 809 |
+
requested_date=req.date,
|
| 810 |
+
peak_month=peak_month_out,
|
| 811 |
+
peak_probability=peak_prob_out,
|
| 812 |
+
bloom_window=bloom_window,
|
| 813 |
+
top_species=top_species,
|
| 814 |
+
monthly_probabilities=monthly_curve,
|
| 815 |
+
processing_time=processing_time
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
return response
|
| 819 |
|
| 820 |
# ------------------------------
|
| 821 |
# Local run
|