Spaces:
Runtime error
Runtime error
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import os | |
| import json | |
| import math | |
| import traceback | |
| import uuid | |
| from typing import Tuple | |
| import numpy as np | |
| import pandas as pd | |
| import joblib | |
| import tensorflow as tf | |
| from tensorflow.keras.utils import load_img, img_to_array | |
| # Hybrid ARIMA | |
| from statsmodels.tsa.arima.model import ARIMA | |
| app = Flask(__name__) | |
| CORS(app) | |
| # --------------------------------------------------------------------- | |
| # BASE DIRS | |
| # --------------------------------------------------------------------- | |
| BASE_DIR = os.path.dirname(__file__) | |
| MODEL_DIR = os.path.join(BASE_DIR, "model") | |
| # --------------------------------------------------------------------- | |
| # (A) TEA PRICE (NEW HYBRID ARIMA + RF) | |
| # --------------------------------------------------------------------- | |
| TEA_ARTIFACT_DIR = os.getenv("TEA_ARTIFACT_DIR", os.path.join(BASE_DIR, "artifacts_tea_hybrid")) | |
| TEA_DATA_PATH = os.getenv("TEA_DATA_PATH", os.path.join(BASE_DIR, "tea_auction_advanced_dataset.csv")) | |
| TEA_MODEL_PATH = os.path.join(TEA_ARTIFACT_DIR, "hybrid_arima_rf_model.joblib") | |
| TEA_CFG_PATH = os.path.join(TEA_ARTIFACT_DIR, "hybrid_config.json") | |
| TEA_MIN_ARIMA_POINTS = int(os.getenv("TEA_MIN_ARIMA_POINTS", "60")) | |
| tea_model = None | |
| tea_cfg = None | |
| tea_df_all = None | |
| tea_load_error = None | |
| tea_data_error = None | |
| # derived | |
| TEA_TARGET_COL = "auction_price_rs_per_kg" | |
| TEA_DATE_COL = "date_week" | |
| tea_cat_cols = ["elevation", "grade"] | |
| tea_num_cols = [] | |
| TEA_ARIMA_ORDER = (2, 1, 2) | |
| TEA_GROUP_COLS = ["elevation", "grade"] | |
| tea_arima_models = {} # key: (elevation, grade) -> fitted ARIMA | |
| tea_ref_values = {} | |
| tea_fallback_col = None | |
| tea_global_median = None | |
| def month_sin_cos(month_num: int): | |
| angle = 2.0 * np.pi * (month_num - 1) / 12.0 | |
| return float(np.sin(angle)), float(np.cos(angle)) | |
| def _tea_safe_load(): | |
| global tea_model, tea_cfg, tea_df_all | |
| global TEA_TARGET_COL, TEA_DATE_COL, tea_cat_cols, tea_num_cols, TEA_ARIMA_ORDER, TEA_GROUP_COLS | |
| global tea_load_error, tea_data_error | |
| global tea_fallback_col, tea_global_median, tea_ref_values | |
| # load artifacts | |
| try: | |
| if not os.path.exists(TEA_MODEL_PATH): | |
| raise FileNotFoundError(f"Missing tea model file: {TEA_MODEL_PATH}") | |
| if not os.path.exists(TEA_CFG_PATH): | |
| raise FileNotFoundError(f"Missing tea config file: {TEA_CFG_PATH}") | |
| tea_model = joblib.load(TEA_MODEL_PATH) | |
| with open(TEA_CFG_PATH, "r", encoding="utf-8") as f: | |
| tea_cfg = json.load(f) | |
| TEA_TARGET_COL = tea_cfg.get("TARGET_COL", TEA_TARGET_COL) | |
| TEA_DATE_COL = tea_cfg.get("DATE_COL", TEA_DATE_COL) | |
| tea_cat_cols = tea_cfg.get("cat_cols", tea_cat_cols) | |
| tea_num_cols = tea_cfg.get("num_cols", tea_num_cols) | |
| TEA_ARIMA_ORDER = tuple(tea_cfg.get("arima_order", list(TEA_ARIMA_ORDER))) | |
| TEA_GROUP_COLS = tea_cfg.get("group_cols", TEA_GROUP_COLS) | |
| except Exception as e: | |
| tea_load_error = f"Failed to load tea hybrid artifacts: {e}" | |
| tea_model = None | |
| tea_cfg = None | |
| # load data | |
| try: | |
| if not os.path.exists(TEA_DATA_PATH): | |
| raise FileNotFoundError(f"Missing tea dataset CSV: {TEA_DATA_PATH}") | |
| tea_df_all = pd.read_csv(TEA_DATA_PATH) | |
| tea_df_all[TEA_DATE_COL] = pd.to_datetime(tea_df_all[TEA_DATE_COL], errors="coerce") | |
| tea_df_all = tea_df_all.dropna(subset=[TEA_DATE_COL, TEA_TARGET_COL]).sort_values(TEA_DATE_COL).reset_index(drop=True) | |
| tea_fallback_col = "price_lag_1w_rs" if "price_lag_1w_rs" in tea_df_all.columns else None | |
| tea_global_median = float(tea_df_all[TEA_TARGET_COL].median()) | |
| # typical values for local explanation | |
| tea_ref_values = {} | |
| for c in (tea_num_cols or []): | |
| if c in tea_df_all.columns and pd.api.types.is_numeric_dtype(tea_df_all[c]): | |
| tea_ref_values[c] = float(tea_df_all[c].median()) | |
| for c in (tea_cat_cols or []): | |
| if c in tea_df_all.columns: | |
| mode = tea_df_all[c].dropna().mode() | |
| tea_ref_values[c] = str(mode.iloc[0]) if len(mode) else "" | |
| tea_ref_values["arima_pred"] = tea_global_median | |
| except Exception as e: | |
| tea_data_error = f"Failed to load tea dataset: {e}" | |
| tea_df_all = None | |
| def _tea_fit_arima_models(): | |
| if tea_df_all is None: | |
| return | |
| if not all(c in tea_df_all.columns for c in TEA_GROUP_COLS): | |
| return | |
| tea_arima_models.clear() | |
| for key, g in tea_df_all.groupby(TEA_GROUP_COLS): | |
| g = g.sort_values(TEA_DATE_COL) | |
| y = g[TEA_TARGET_COL].astype(float).values | |
| if len(y) < TEA_MIN_ARIMA_POINTS: | |
| continue | |
| try: | |
| tea_arima_models[tuple(key)] = ARIMA(y, order=TEA_ARIMA_ORDER).fit() | |
| except Exception: | |
| continue | |
| def tea_build_next_week_input(elevation: str, grade: str, overrides=None): | |
| overrides = overrides or {} | |
| if tea_df_all is None: | |
| raise ValueError(f"Tea dataset not loaded. {tea_data_error or ''}".strip()) | |
| if "elevation" not in tea_df_all.columns or "grade" not in tea_df_all.columns: | |
| raise ValueError("Tea dataset missing elevation/grade columns.") | |
| seg = tea_df_all[(tea_df_all["elevation"] == elevation) & (tea_df_all["grade"] == grade)].sort_values(TEA_DATE_COL) | |
| if len(seg) < 10: | |
| raise ValueError("Not enough history for this (elevation, grade). Need >= 10 rows.") | |
| last = seg.iloc[-1].copy() | |
| next_row = last.copy() | |
| # next week (+7 days) | |
| next_date = pd.to_datetime(last[TEA_DATE_COL]) + pd.Timedelta(days=7) | |
| next_row[TEA_DATE_COL] = next_date | |
| # calendar fields if present | |
| if "year" in tea_df_all.columns: | |
| next_row["year"] = int(next_date.year) | |
| if "month" in tea_df_all.columns: | |
| next_row["month"] = int(next_date.month) | |
| if "month_sin" in tea_df_all.columns and "month_cos" in tea_df_all.columns: | |
| s, c = month_sin_cos(int(next_date.month)) | |
| next_row["month_sin"] = s | |
| next_row["month_cos"] = c | |
| # lag/rolling features if present | |
| if "price_lag_1w_rs" in tea_df_all.columns: | |
| next_row["price_lag_1w_rs"] = float(last[TEA_TARGET_COL]) | |
| if "price_lag_4w_rs" in tea_df_all.columns and len(seg) >= 4: | |
| next_row["price_lag_4w_rs"] = float(seg.iloc[-4][TEA_TARGET_COL]) | |
| if "price_lag_12w_rs" in tea_df_all.columns and len(seg) >= 12: | |
| next_row["price_lag_12w_rs"] = float(seg.iloc[-12][TEA_TARGET_COL]) | |
| if "price_lag_48w_rs" in tea_df_all.columns and len(seg) >= 48: | |
| next_row["price_lag_48w_rs"] = float(seg.iloc[-48][TEA_TARGET_COL]) | |
| if "price_rollmean_4w_rs" in tea_df_all.columns and len(seg) >= 4: | |
| next_row["price_rollmean_4w_rs"] = float(seg[TEA_TARGET_COL].tail(4).mean()) | |
| if "price_rollmean_12w_rs" in tea_df_all.columns and len(seg) >= 12: | |
| next_row["price_rollmean_12w_rs"] = float(seg[TEA_TARGET_COL].tail(12).mean()) | |
| if "price_rollmean_48w_rs" in tea_df_all.columns and len(seg) >= 48: | |
| next_row["price_rollmean_48w_rs"] = float(seg[TEA_TARGET_COL].tail(48).mean()) | |
| # apply overrides | |
| for k, v in (overrides or {}).items(): | |
| if k not in next_row.index: | |
| raise KeyError(f"Unknown override column: {k}") | |
| next_row[k] = v | |
| # target unknown | |
| next_row[TEA_TARGET_COL] = np.nan | |
| return next_row.to_frame().T | |
| def tea_get_arima_pred(elevation: str, grade: str, built_row: pd.DataFrame): | |
| key = (elevation, grade) | |
| if key in tea_arima_models: | |
| try: | |
| return float(tea_arima_models[key].forecast(steps=1)[0]) | |
| except Exception: | |
| pass | |
| if tea_fallback_col and tea_fallback_col in built_row.columns and not pd.isna(built_row[tea_fallback_col].iloc[0]): | |
| return float(built_row[tea_fallback_col].iloc[0]) | |
| return float(tea_global_median) if tea_global_median is not None else 0.0 | |
| def tea_local_sensitivity_explain(model, X: pd.DataFrame, pred: float, ref_values: dict, top_k: int = 6): | |
| impacts = [] | |
| for col in X.columns: | |
| if col not in ref_values: | |
| continue | |
| x_tmp = X.copy() | |
| original_val = x_tmp[col].iloc[0] | |
| typical_val = ref_values[col] | |
| try: | |
| if pd.isna(original_val) and pd.isna(typical_val): | |
| continue | |
| if str(original_val) == str(typical_val): | |
| continue | |
| except Exception: | |
| pass | |
| x_tmp[col] = typical_val | |
| try: | |
| pred_typical = float(model.predict(x_tmp)[0]) | |
| except Exception: | |
| continue | |
| impact = pred - pred_typical | |
| impacts.append({ | |
| "feature": col, | |
| "value": None if pd.isna(original_val) else (float(original_val) if isinstance(original_val, (int, float, np.number)) else str(original_val)), | |
| "typical": typical_val, | |
| "impact": float(impact) | |
| }) | |
| impacts.sort(key=lambda d: abs(d["impact"]), reverse=True) | |
| return impacts[:top_k] | |
| def tea_segment_context(elevation: str, grade: str): | |
| if tea_df_all is None: | |
| return None | |
| seg = tea_df_all[(tea_df_all["elevation"] == elevation) & (tea_df_all["grade"] == grade)].sort_values(TEA_DATE_COL) | |
| if len(seg) == 0: | |
| return None | |
| last_price = float(seg.iloc[-1][TEA_TARGET_COL]) | |
| mean_4w = float(seg[TEA_TARGET_COL].tail(4).mean()) if len(seg) >= 4 else None | |
| mean_12w = float(seg[TEA_TARGET_COL].tail(12).mean()) if len(seg) >= 12 else None | |
| trend = None | |
| if mean_4w is not None: | |
| trend = "up" if last_price > mean_4w else ("down" if last_price < mean_4w else "flat") | |
| return { | |
| "last_price": last_price, | |
| "avg_4w": mean_4w, | |
| "avg_12w": mean_12w, | |
| "trend_vs_4w_avg": trend, | |
| "history_points": int(len(seg)) | |
| } | |
| def tea_describe_direction(val, typical): | |
| try: | |
| v = float(val); t = float(typical) | |
| if np.isfinite(v) and np.isfinite(t): | |
| if abs(v - t) <= (0.02 * (abs(t) + 1e-6)): | |
| return "close to usual" | |
| return "higher than usual" if v > t else "lower than usual" | |
| except Exception: | |
| pass | |
| return "different from usual" | |
| def tea_feature_display_name(f): | |
| nice = { | |
| "fx_lkr_per_usd": "USD→LKR exchange rate", | |
| "rainfall_mm": "rainfall", | |
| "temperature_c": "temperature", | |
| "arima_pred": "recent price trend (time-series)", | |
| "price_lag_1w_rs": "last week price", | |
| "price_rollmean_4w_rs": "last 4-week average price", | |
| "price_rollmean_12w_rs": "last 12-week average price", | |
| } | |
| return nice.get(f, f.replace("_", " ")) | |
| def tea_build_explanation_text(pred, factors, segment_ctx=None, top_k=5): | |
| top = factors[:top_k] | |
| bullets = [] | |
| for item in top: | |
| f = item["feature"] | |
| val = item["value"] | |
| typical = item["typical"] | |
| impact = item["impact"] | |
| if abs(impact) < 0.5: | |
| continue | |
| if f == "arima_pred": | |
| if segment_ctx and segment_ctx.get("trend_vs_4w_avg"): | |
| trend = segment_ctx["trend_vs_4w_avg"] | |
| bullets.append( | |
| f"Recent segment trend looks **{trend}**, which {'pushes up' if impact > 0 else 'pulls down'} the prediction (time-series effect)." | |
| ) | |
| else: | |
| bullets.append("Recent price pattern in this segment influences the forecast (time-series effect).") | |
| continue | |
| name = tea_feature_display_name(f) | |
| direction = tea_describe_direction(val, typical) | |
| if impact > 0: | |
| bullets.append(f"{name} is **{direction}** ({val} vs typical {typical}), so the model expects price to be **higher**.") | |
| else: | |
| bullets.append(f"{name} is **{direction}** ({val} vs typical {typical}), so the model expects price to be **lower**.") | |
| seg_line = None | |
| if segment_ctx: | |
| lp = segment_ctx.get("last_price") | |
| a4 = segment_ctx.get("avg_4w") | |
| if lp is not None and a4 is not None: | |
| seg_line = f"Last recorded price was **{lp:.2f}** and the 4-week average is **{a4:.2f}**." | |
| if bullets: | |
| main_push = "higher" if sum([f["impact"] for f in top]) > 0 else "lower" | |
| summary = f"Predicted price is **{pred:.2f}** mainly because the strongest inputs/trend signals push the model **{main_push}** compared to typical conditions." | |
| else: | |
| summary = f"Predicted price is **{pred:.2f}** based on learned patterns from history for this segment and the provided inputs." | |
| if seg_line: | |
| summary = summary + " " + seg_line | |
| return summary, bullets | |
| # init tea | |
| _tea_safe_load() | |
| if tea_model is not None and tea_df_all is not None: | |
| _tea_fit_arima_models() | |
| # --------------------------------------------------------------------- | |
| # (B) YIELD MODEL (KEEP EXISTING) | |
| # --------------------------------------------------------------------- | |
| YIELD_MODEL_PATH = os.getenv("YIELD_MODEL_PATH", os.path.join(MODEL_DIR, "smarttea_yield_model.joblib")) | |
| YIELD_DATA_PATH = os.getenv("YIELD_DATA_PATH", os.path.join(BASE_DIR, "data/smarttea_monthly_yield_dataset_sri_lanka_synthetic_2000_2025.csv")) | |
| YIELD_DATE_COL = os.getenv("YIELD_DATE_COL", "date") | |
| YIELD_TARGET_COL = os.getenv("YIELD_TARGET_COL", "yield_kg_per_ha") | |
| REGION_DEFAULTS = { | |
| "Nuwara_Eliya": {"elevation_band": "high", "elevation_m": 1850, "country": "Sri_Lanka"}, | |
| "Uva": {"elevation_band": "mid", "elevation_m": 1200, "country": "Sri_Lanka"}, | |
| "Kandy": {"elevation_band": "mid", "elevation_m": 900, "country": "Sri_Lanka"}, | |
| "Sabaragamuwa": {"elevation_band": "low", "elevation_m": 300, "country": "Sri_Lanka"}, | |
| "Galle": {"elevation_band": "low", "elevation_m": 50, "country": "Sri_Lanka"}, | |
| } | |
| yield_model = None | |
| yield_feature_cols = None | |
| yield_load_error = None | |
| yield_df = None | |
| yield_data_error = None | |
| def unwrap_model(obj): | |
| if isinstance(obj, dict): | |
| model = obj.get("model", obj) | |
| feature_cols = obj.get("feature_cols") | |
| target = obj.get("target") | |
| return model, feature_cols, target | |
| model = obj | |
| feature_cols = getattr(model, "feature_names_in_", None) | |
| target = None | |
| return model, (list(feature_cols) if feature_cols is not None else None), target | |
| try: | |
| if os.path.exists(YIELD_MODEL_PATH): | |
| yield_model_raw = joblib.load(YIELD_MODEL_PATH) | |
| yield_model, yield_feature_cols, _ = unwrap_model(yield_model_raw) | |
| else: | |
| yield_load_error = f"Yield model file not found at: {YIELD_MODEL_PATH}" | |
| except Exception as e: | |
| yield_load_error = f"Failed to load YIELD model: {e}" | |
| try: | |
| if os.path.exists(YIELD_DATA_PATH): | |
| yield_df = pd.read_csv(YIELD_DATA_PATH) | |
| yield_df[YIELD_DATE_COL] = pd.to_datetime(yield_df[YIELD_DATE_COL], errors="coerce") | |
| yield_df = yield_df.dropna(subset=[YIELD_DATE_COL]).sort_values([YIELD_DATE_COL]).reset_index(drop=True) | |
| else: | |
| yield_data_error = f"Yield dataset not found at: {YIELD_DATA_PATH}" | |
| except Exception as e: | |
| yield_data_error = f"Failed to load YIELD dataset: {e}" | |
| YIELD_REQUIRED_INPUTS = [ | |
| "region", "year", "month", | |
| "rainfall_mm", "temp_avg_c", "temp_min_c", "temp_max_c", | |
| "humidity_pct", "soil_ph", "soil_ec_ds_m", | |
| "fertilizer_kg_per_ha", "disease_index", | |
| ] | |
| def get_region_history(region: str, current_date: pd.Timestamp) -> pd.DataFrame: | |
| if yield_df is None or "region" not in yield_df.columns: | |
| return pd.DataFrame() | |
| h = yield_df[(yield_df["region"] == region) & (yield_df[YIELD_DATE_COL] < current_date)].copy() | |
| return h.sort_values(YIELD_DATE_COL) | |
| def compute_yield_lags_rolls(region_hist: pd.DataFrame): | |
| if region_hist is None or len(region_hist) == 0 or YIELD_TARGET_COL not in region_hist.columns: | |
| return { | |
| "yield_lag_1": None, "yield_lag_3": None, "yield_lag_12": None, | |
| "yield_rollmean_3": None, "yield_rollmean_6": None, "yield_rollmean_12": None, | |
| } | |
| y = region_hist[YIELD_TARGET_COL].astype(float).values | |
| def lag(k): | |
| if len(y) >= k: | |
| return float(y[-k]) | |
| return float(y[0]) | |
| def roll(k): | |
| k = min(k, len(y)) | |
| return float(np.mean(y[-k:])) | |
| return { | |
| "yield_lag_1": lag(1), | |
| "yield_lag_3": lag(3) if len(y) >= 3 else lag(1), | |
| "yield_lag_12": lag(12) if len(y) >= 12 else lag(1), | |
| "yield_rollmean_3": roll(3), | |
| "yield_rollmean_6": roll(6), | |
| "yield_rollmean_12": roll(12), | |
| } | |
| def compute_exog_rolls(region_hist: pd.DataFrame): | |
| def rmean(col, n): | |
| if region_hist is None or len(region_hist) == 0 or col not in region_hist.columns: | |
| return None | |
| vals = region_hist[col].astype(float).values | |
| if len(vals) >= n: | |
| return float(np.mean(vals[-n:])) | |
| return float(np.mean(vals)) if len(vals) else None | |
| return { | |
| "rain_rollmean_3": rmean("rainfall_mm", 3), | |
| "rain_rollmean_6": rmean("rainfall_mm", 6), | |
| "temp_rollmean_3": rmean("temp_avg_c", 3), | |
| "temp_rollmean_6": rmean("temp_avg_c", 6), | |
| "fert_rollmean_3": rmean("fertilizer_kg_per_ha", 3), | |
| "fert_rollmean_6": rmean("fertilizer_kg_per_ha", 6), | |
| "disease_rollmean_3": rmean("disease_index", 3), | |
| "disease_rollmean_6": rmean("disease_index", 6), | |
| } | |
| def local_feature_impact(model_pipeline, X_row: pd.DataFrame, numeric_features, steps=0.03, top_n=6): | |
| base = float(model_pipeline.predict(X_row)[0]) | |
| impacts = [] | |
| for f in numeric_features: | |
| if f not in X_row.columns: | |
| continue | |
| v = X_row.iloc[0][f] | |
| if pd.isna(v): | |
| continue | |
| delta = max(abs(float(v)) * steps, 0.01) | |
| X_up = X_row.copy() | |
| X_dn = X_row.copy() | |
| X_up.loc[X_up.index[0], f] = float(v) + delta | |
| X_dn.loc[X_dn.index[0], f] = float(v) - delta | |
| p_up = float(model_pipeline.predict(X_up)[0]) | |
| p_dn = float(model_pipeline.predict(X_dn)[0]) | |
| effect = (p_up - p_dn) / 2.0 | |
| impacts.append({ | |
| "feature": f, | |
| "impact_kg_per_ha": round(float(effect), 3), | |
| "direction": "increases" if effect > 0 else "decreases" | |
| }) | |
| impacts.sort(key=lambda x: abs(x["impact_kg_per_ha"]), reverse=True) | |
| return base, impacts[:top_n] | |
| def build_yield_row(payload: dict): | |
| if yield_model is None: | |
| raise ValueError("Yield model is not loaded. Check YIELD_MODEL_PATH.") | |
| missing = [k for k in YIELD_REQUIRED_INPUTS if k not in payload] | |
| if missing: | |
| raise ValueError(f"Missing required fields: {missing}") | |
| region = str(payload["region"]) | |
| year = int(payload["year"]) | |
| month = int(payload["month"]) | |
| current_date = pd.Timestamp(f"{year}-{month:02d}-01") | |
| defaults = REGION_DEFAULTS.get(region) | |
| if not defaults: | |
| raise ValueError(f"Unknown region '{region}'. Allowed: {list(REGION_DEFAULTS.keys())}") | |
| ms, mc = month_sin_cos(month) | |
| region_hist = get_region_history(region, current_date) | |
| lag_feats = compute_yield_lags_rolls(region_hist) | |
| exog_rolls = compute_exog_rolls(region_hist) | |
| row = { | |
| "region": region, | |
| "country": defaults["country"], | |
| "elevation_band": defaults["elevation_band"], | |
| "elevation_m": defaults["elevation_m"], | |
| "year": year, | |
| "month": month, | |
| "month_sin": ms, | |
| "month_cos": mc, | |
| "rainfall_mm": float(payload["rainfall_mm"]), | |
| "temp_avg_c": float(payload["temp_avg_c"]), | |
| "temp_min_c": float(payload["temp_min_c"]), | |
| "temp_max_c": float(payload["temp_max_c"]), | |
| "humidity_pct": float(payload["humidity_pct"]), | |
| "soil_ph": float(payload["soil_ph"]), | |
| "soil_ec_ds_m": float(payload["soil_ec_ds_m"]), | |
| "fertilizer_kg_per_ha": float(payload["fertilizer_kg_per_ha"]), | |
| "disease_index": float(payload["disease_index"]), | |
| } | |
| row.update(lag_feats) | |
| row.update(exog_rolls) | |
| X = pd.DataFrame([row]) | |
| if yield_feature_cols: | |
| for c in yield_feature_cols: | |
| if c not in X.columns: | |
| X[c] = np.nan | |
| X = X[yield_feature_cols] | |
| return X, str(current_date.date()), int(len(region_hist)) | |
| # --------------------------------------------------------------------- | |
| # (C) LEAF DISEASE MODEL (KEEP EXISTING) | |
| # --------------------------------------------------------------------- | |
| LEAF_WEIGHTS_PATH = os.getenv("LEAF_WEIGHTS_PATH", os.path.join(MODEL_DIR, "tea_mobilenet_v2.weights.h5")) | |
| LEAF_LABELS_PATH = os.getenv("LEAF_LABELS_PATH", os.path.join(MODEL_DIR, "labels.json")) | |
| UPLOAD_DIR = os.getenv("UPLOAD_DIR", os.path.join(BASE_DIR, "uploads")) | |
| IMG_SIZE: Tuple[int, int] = (224, 224) | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| leaf_class_names = None | |
| leaf_model = None | |
| leaf_load_error = None | |
| def build_leaf_model(num_classes: int) -> tf.keras.Model: | |
| base_model = tf.keras.applications.MobileNetV2( | |
| input_shape=IMG_SIZE + (3,), | |
| include_top=False, | |
| weights="imagenet", | |
| ) | |
| base_model.trainable = False | |
| inputs = tf.keras.Input(shape=IMG_SIZE + (3,), name="input_layer_1") | |
| x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs) | |
| x = base_model(x, training=False) | |
| x = tf.keras.layers.GlobalAveragePooling2D(name="global_avg_pool")(x) | |
| x = tf.keras.layers.Dropout(0.2, name="dropout")(x) | |
| outputs = tf.keras.layers.Dense(num_classes, activation="softmax", name="dense")(x) | |
| return tf.keras.Model(inputs, outputs, name="tea_mobilenet_v2_inference") | |
| try: | |
| if os.path.exists(LEAF_LABELS_PATH): | |
| with open(LEAF_LABELS_PATH, "r", encoding="utf-8") as f: | |
| leaf_class_names = json.load(f) | |
| else: | |
| raise FileNotFoundError(f"Leaf labels not found at: {LEAF_LABELS_PATH}") | |
| leaf_model = build_leaf_model(num_classes=len(leaf_class_names)) | |
| if not os.path.exists(LEAF_WEIGHTS_PATH): | |
| raise FileNotFoundError(f"Leaf weights not found at: {LEAF_WEIGHTS_PATH}") | |
| leaf_model.load_weights(LEAF_WEIGHTS_PATH) | |
| except Exception as e: | |
| leaf_load_error = f"Failed to load LEAF model: {e}" | |
| leaf_model = None | |
| leaf_class_names = None | |
| def predict_leaf_image(image_path: str): | |
| if leaf_model is None or leaf_class_names is None: | |
| raise RuntimeError(leaf_load_error or "Leaf model not loaded.") | |
| img = load_img(image_path, target_size=IMG_SIZE) | |
| img_array = img_to_array(img) | |
| img_batch = np.expand_dims(img_array, axis=0) | |
| probs = leaf_model.predict(img_batch, verbose=0)[0] | |
| pred_index = int(np.argmax(probs)) | |
| pred_label = leaf_class_names[pred_index] | |
| confidence = float(probs[pred_index]) | |
| probs_list = [float(p) for p in probs] | |
| probs_dict = {leaf_class_names[i]: probs_list[i] for i in range(len(leaf_class_names))} | |
| return pred_label, confidence, probs_dict | |
| # --------------------------------------------------------------------- | |
| # ROUTES | |
| # --------------------------------------------------------------------- | |
| def health(): | |
| return jsonify({ | |
| "status": "ok", | |
| "tea_price_hybrid_loaded": tea_model is not None, | |
| "tea_price_segments_with_arima": int(len(tea_arima_models)) if tea_model is not None else 0, | |
| "yield_model_loaded": yield_model is not None, | |
| "leaf_model_loaded": leaf_model is not None, | |
| "paths": { | |
| "tea_artifact_dir": TEA_ARTIFACT_DIR, | |
| "tea_model_path": TEA_MODEL_PATH, | |
| "tea_cfg_path": TEA_CFG_PATH, | |
| "tea_data_path": TEA_DATA_PATH, | |
| "yield_model_path": YIELD_MODEL_PATH, | |
| "yield_data_path": YIELD_DATA_PATH, | |
| "leaf_weights_path": LEAF_WEIGHTS_PATH, | |
| "leaf_labels_path": LEAF_LABELS_PATH, | |
| }, | |
| "errors": { | |
| "tea_load_error": tea_load_error, | |
| "tea_data_error": tea_data_error, | |
| "yield_load_error": yield_load_error, | |
| "yield_data_error": yield_data_error, | |
| "leaf_load_error": leaf_load_error, | |
| }, | |
| "endpoints": { | |
| "GET /tea-price/meta": "Tea price metadata (elevations, grades, override keys)", | |
| "POST /tea-price/predict-next-week": "Tea price next-week forecast (elevation+grade + overrides + explain)", | |
| "POST /predict/yield-simple": "Yield prediction", | |
| "POST /predict/leaf": "Leaf disease prediction (image upload)", | |
| } | |
| }) | |
| # ------------------------- | |
| # TEA PRICE: health/meta/predict-next-week | |
| # ------------------------- | |
| def tea_price_health(): | |
| return jsonify({ | |
| "ok": True, | |
| "model_loaded": tea_model is not None, | |
| "cfg_loaded": tea_cfg is not None, | |
| "rows_in_history": int(len(tea_df_all)) if tea_df_all is not None else 0, | |
| "segments_with_arima": int(len(tea_arima_models)), | |
| "target": TEA_TARGET_COL, | |
| "date_col": TEA_DATE_COL, | |
| "error": tea_load_error or tea_data_error | |
| }) | |
| def tea_price_meta(): | |
| if tea_df_all is None: | |
| return jsonify({"ok": False, "error": tea_data_error or "Tea dataset not loaded"}), 500 | |
| return jsonify({ | |
| "ok": True, | |
| "target": TEA_TARGET_COL, | |
| "date_col": TEA_DATE_COL, | |
| "cat_cols": tea_cat_cols, | |
| "num_cols": tea_num_cols, | |
| "example_override_keys": [c for c in tea_df_all.columns if c not in [TEA_TARGET_COL]], | |
| "unique_elevations": sorted(tea_df_all["elevation"].dropna().unique().tolist()) if "elevation" in tea_df_all.columns else [], | |
| "unique_grades": sorted(tea_df_all["grade"].dropna().unique().tolist()) if "grade" in tea_df_all.columns else [], | |
| }) | |
| def tea_price_predict_next_week(): | |
| if tea_model is None: | |
| return jsonify({"ok": False, "error": tea_load_error or "Tea hybrid model not loaded"}), 500 | |
| if tea_df_all is None: | |
| return jsonify({"ok": False, "error": tea_data_error or "Tea dataset not loaded"}), 500 | |
| body = request.get_json(silent=True) or {} | |
| elevation = str(body.get("elevation", "")).strip() | |
| grade = str(body.get("grade", "")).strip() | |
| overrides = body.get("overrides") or {} | |
| if not elevation or not grade: | |
| return jsonify({"ok": False, "error": "elevation and grade are required"}), 400 | |
| if not isinstance(overrides, dict): | |
| return jsonify({"ok": False, "error": "overrides must be an object/dict"}), 400 | |
| try: | |
| row = tea_build_next_week_input(elevation, grade, overrides=overrides) | |
| arima_pred = tea_get_arima_pred(elevation, grade, row) | |
| # build X exactly like notebook expects | |
| needed_cols = (tea_cat_cols or []) + (tea_num_cols or []) | |
| X = row.copy() | |
| # ensure required cols exist | |
| for c in needed_cols: | |
| if c not in X.columns: | |
| X[c] = np.nan | |
| X = X[needed_cols].copy() | |
| X["arima_pred"] = arima_pred | |
| pred = float(tea_model.predict(X)[0]) | |
| want_explain = bool(body.get("explain", False)) | |
| explain_payload = None | |
| if want_explain: | |
| factors = tea_local_sensitivity_explain(tea_model, X, pred, tea_ref_values, top_k=8) | |
| seg_ctx = tea_segment_context(elevation, grade) | |
| summary, bullets = tea_build_explanation_text(pred, factors, seg_ctx, top_k=5) | |
| explain_payload = { | |
| "summary": summary, | |
| "reasons": bullets, | |
| "top_factors": factors, | |
| "segment_context": seg_ctx, | |
| "disclaimer": "These reasons explain what the model learned from data (correlations), not guaranteed real-world causation." | |
| } | |
| return jsonify({ | |
| "ok": True, | |
| "elevation": elevation, | |
| "grade": grade, | |
| "predicted_price": pred, | |
| "arima_pred": arima_pred, | |
| "next_date": str(pd.to_datetime(row[TEA_DATE_COL].iloc[0]).date()), | |
| "explanation": explain_payload | |
| }) | |
| except KeyError as e: | |
| return jsonify({"ok": False, "error": str(e)}), 400 | |
| except Exception as e: | |
| return jsonify({"ok": False, "error": str(e), "trace": traceback.format_exc()}), 500 | |
| # ------------------------- | |
| # YIELD | |
| # ------------------------- | |
| def debug_yield_model(): | |
| try: | |
| obj = joblib.load(YIELD_MODEL_PATH) | |
| return jsonify({ | |
| "ok": True, | |
| "path": YIELD_MODEL_PATH, | |
| "type": str(type(obj)), | |
| "keys": list(obj.keys()) if isinstance(obj, dict) else None | |
| }) | |
| except Exception as e: | |
| return jsonify({"ok": False, "error": str(e), "trace": traceback.format_exc()}), 500 | |
| def predict_yield(): | |
| try: | |
| if yield_model is None: | |
| return jsonify({ | |
| "success": False, | |
| "error": "Yield model not loaded", | |
| "details": yield_load_error, | |
| "hint": "Put your yield .joblib file in the model folder and set YIELD_MODEL_PATH if needed." | |
| }), 500 | |
| payload = request.get_json(silent=True) or {} | |
| X, pred_date, history_months = build_yield_row(payload) | |
| pred = float(yield_model.predict(X)[0]) | |
| numeric_for_explain = [ | |
| "rainfall_mm", "temp_avg_c", "humidity_pct", | |
| "soil_ph", "soil_ec_ds_m", "fertilizer_kg_per_ha", "disease_index", | |
| "yield_lag_1", "yield_lag_3", "yield_lag_12", | |
| "rain_rollmean_3", "rain_rollmean_6", | |
| "temp_rollmean_3", "temp_rollmean_6", | |
| "fert_rollmean_3", "fert_rollmean_6", | |
| "disease_rollmean_3", "disease_rollmean_6", | |
| ] | |
| base_pred, top_impacts = local_feature_impact(yield_model, X, numeric_for_explain) | |
| pos = [i for i in top_impacts if i["impact_kg_per_ha"] > 0][:2] | |
| neg = [i for i in top_impacts if i["impact_kg_per_ha"] < 0][:2] | |
| parts = [] | |
| if pos: | |
| parts.append("higher " + " & ".join([p["feature"] for p in pos])) | |
| if neg: | |
| parts.append("lower " + " & ".join([n["feature"] for n in neg])) | |
| explain_sentence = "Prediction is mainly influenced by " + (", and ".join(parts) if parts else "the input factors.") | |
| # labour estimation | |
| area_ha = float(payload.get("area_ha", 1.0)) | |
| plucking_days = int(payload.get("plucking_days", 22)) | |
| productivity = float(payload.get("productivity_kg_per_worker_day", 20.0)) | |
| efficiency = float(payload.get("efficiency", 0.9)) | |
| total_harvest_kg = pred * area_ha | |
| den = productivity * plucking_days * max(efficiency, 0.01) | |
| labourers_needed = int(math.ceil(total_harvest_kg / den)) | |
| warnings = [] | |
| if history_months < 12 and yield_df is not None: | |
| warnings.append( | |
| f"Only {history_months} months of history were available before {pred_date}; some lag/rolling features may be weak." | |
| ) | |
| return jsonify({ | |
| "success": True, | |
| "prediction": { | |
| "yield_kg_per_ha": round(pred, 2), | |
| "for_month": pred_date, | |
| "area_ha": area_ha, | |
| "total_harvest_kg": round(total_harvest_kg, 2), | |
| "labourers_needed": labourers_needed, | |
| "assumptions": { | |
| "plucking_days": plucking_days, | |
| "productivity_kg_per_worker_day": productivity, | |
| "efficiency": efficiency | |
| } | |
| }, | |
| "explainability": { | |
| "summary": explain_sentence, | |
| "top_factors": top_impacts | |
| }, | |
| "meta": { | |
| "region": payload.get("region"), | |
| "history_months_used": history_months, | |
| "warnings": warnings | |
| } | |
| }) | |
| except Exception as e: | |
| return jsonify({"success": False, "error": str(e), "trace": traceback.format_exc()}), 400 | |
| # ------------------------- | |
| # LEAF | |
| # ------------------------- | |
| def predict_leaf(): | |
| if leaf_model is None or leaf_class_names is None: | |
| return jsonify({ | |
| "ok": False, | |
| "error": "Leaf model not loaded", | |
| "details": leaf_load_error, | |
| "hint": "Make sure model/labels.json and model/tea_mobilenet_v2.weights.h5 exist." | |
| }), 500 | |
| if "image" not in request.files: | |
| return jsonify({"ok": False, "error": "No file part 'image' in the request"}), 400 | |
| file = request.files["image"] | |
| if file.filename == "": | |
| return jsonify({"ok": False, "error": "No file selected"}), 400 | |
| allowed_ext = (".jpg", ".jpeg", ".png") | |
| if not file.filename.lower().endswith(allowed_ext): | |
| return jsonify({"ok": False, "error": "Unsupported file type. Use JPG or PNG."}), 400 | |
| temp_filename = f"{uuid.uuid4().hex}_{file.filename}" | |
| temp_path = os.path.join(UPLOAD_DIR, temp_filename) | |
| file.save(temp_path) | |
| try: | |
| label, confidence, probs_dict = predict_leaf_image(temp_path) | |
| return jsonify({ | |
| "ok": True, | |
| "prediction": label, | |
| "confidence": confidence, | |
| "probabilities": probs_dict | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| "ok": False, | |
| "error": "Failed to process image", | |
| "details": str(e), | |
| "trace": traceback.format_exc() | |
| }), 500 | |
| finally: | |
| try: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------- | |
| # MAIN | |
| # --------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", "7860")) | |
| app.run(host="0.0.0.0", port=port) |