Spaces:
Paused
Paused
| """ | |
| DemandScape — Demand Forecasting module | |
| All logic extracted from the original suite/app.py so it can be imported | |
| by the unified DemandScape Suite app. | |
| """ | |
| import json | |
| import os | |
| from datetime import datetime | |
| from typing import Any, Dict | |
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| # --------------------------------------------------------------------------- | |
| # Sample file URLs | |
| # --------------------------------------------------------------------------- | |
| SAMPLE_CSV_URL = "https://github.com/diskover-diagnostics/demandscape-api/blob/main/samples/infer_sample_1.csv" | |
| SAMPLE_SCENARIOS_URL = "https://github.com/diskover-diagnostics/demandscape-api/blob/main/samples/scenarios.json" | |
| PRODUCT = "demandscape" | |
| # --------------------------------------------------------------------------- | |
| # Data helpers | |
| # --------------------------------------------------------------------------- | |
| def load_scenarios_from_file(scenarios_file_path: str) -> Dict[str, Any]: | |
| """Load scenarios from a JSON file, falling back to built-in defaults.""" | |
| try: | |
| with open(scenarios_file_path, "r") as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Warning: Could not load scenarios file. Using defaults. Error: {e}") | |
| return { | |
| "base": {}, | |
| "price_up_2pct": {"planned_price_index": 1.02}, | |
| "price_up_5pct": {"planned_price_index": 1.05}, | |
| "price_up_10pct": {"planned_price_index": 1.1}, | |
| "price_down_2pct": {"planned_price_index": 0.98}, | |
| "price_down_5pct": {"planned_price_index": 0.95}, | |
| "price_down_10pct": {"planned_price_index": 0.9}, | |
| "discount_5pct": {"planned_discount_pct": 5}, | |
| "discount_10pct": {"planned_discount_pct": 10}, | |
| "discount_20pct": {"planned_discount_pct": 20}, | |
| "promo_on": {"planned_promo_flag": 1.0}, | |
| "promo_off": {"planned_promo_flag": 0.0}, | |
| "tender_on": {"planned_tender_flag": 1.0}, | |
| "regulatory_event": {"planned_regulatory_event_flag": 1.0}, | |
| "supply_risk_high": {"planned_supply_risk": 1.0}, | |
| "competitor_pressure_high": {"planned_competitor_pressure": 1.0}, | |
| "growth_push": { | |
| "planned_price_index": 0.95, | |
| "planned_discount_pct": 10, | |
| "planned_promo_flag": 1.0, | |
| }, | |
| "margin_push": { | |
| "planned_price_index": 1.05, | |
| "planned_discount_pct": 0.0, | |
| "planned_promo_flag": 0.0, | |
| }, | |
| "promo_plus_discount": { | |
| "planned_discount_pct": 15, | |
| "planned_promo_flag": 1.0, | |
| }, | |
| "tender_plus_supply_risk": { | |
| "planned_tender_flag": 1.0, | |
| "planned_supply_risk": 1.0, | |
| }, | |
| "worst_case": { | |
| "planned_price_index": 1.1, | |
| "planned_supply_risk": 1.0, | |
| "planned_competitor_pressure": 1.0, | |
| }, | |
| } | |
| def csv_to_inference_json(csv_file_path: str, scenarios: Dict[str, Any]) -> Dict[str, Any]: | |
| """Convert a demand-forecast CSV file to the inference request payload.""" | |
| df = pd.read_csv(csv_file_path) | |
| df["month"] = pd.to_datetime(df["month"], format="mixed", errors="coerce").dt.strftime("%Y-%m") | |
| df = df.replace([float("inf"), float("-inf")], None) | |
| df = df.where(pd.notna(df), None) | |
| def to_json_safe_list(series): | |
| result = [] | |
| for val in series: | |
| if pd.isna(val) or val is None: | |
| result.append(None) | |
| elif isinstance(val, (float, int)) and (val == float("inf") or val == float("-inf")): | |
| result.append(None) | |
| else: | |
| result.append(val) | |
| return result | |
| return { | |
| "product": PRODUCT, | |
| "inputs": { | |
| "data": { | |
| "month": df["month"].tolist(), | |
| "product_id": df["product_id"].tolist(), | |
| "market_id": df["market_id"].tolist(), | |
| "units_sold": to_json_safe_list(df["units_sold"]), | |
| "planned_price_index": to_json_safe_list(df["planned_price_index"]), | |
| "planned_discount_pct": to_json_safe_list(df["planned_discount_pct"]), | |
| "planned_promo_flag": to_json_safe_list(df["planned_promo_flag"]), | |
| "planned_tender_flag": to_json_safe_list(df["planned_tender_flag"]), | |
| "planned_supply_risk": to_json_safe_list(df["planned_supply_risk"]), | |
| "planned_competitor_pressure": to_json_safe_list(df["planned_competitor_pressure"]), | |
| "planned_regulatory_event_flag": to_json_safe_list(df["planned_regulatory_event_flag"]), | |
| }, | |
| "parameters": { | |
| "encoder_length": 12, | |
| "prediction_length": 6, | |
| "batch_size": 256, | |
| "n_samples": 1000, | |
| "quantiles": [0.1, 0.5, 0.9], | |
| "scenarios": scenarios, | |
| "round_outputs": True, | |
| }, | |
| }, | |
| } | |
| def send_inference_request(payload: Dict[str, Any], endpoint: str, api_key: str) -> Dict[str, Any]: | |
| """Send an inference request to the fly.dev orchestrator.""" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| response = requests.post(endpoint, headers=headers, json=payload, timeout=120) | |
| response.raise_for_status() | |
| return response.json() | |
| def parse_forecast_response(response_json: Dict[str, Any]) -> pd.DataFrame: | |
| """Parse forecast response JSON into a DataFrame.""" | |
| forecasts = response_json.get("forecasts", []) | |
| if not forecasts: | |
| return pd.DataFrame() | |
| df = pd.DataFrame(forecasts) | |
| column_order = [ | |
| "scenario", "product_id", "market_id", "month", "horizon_step", | |
| "point_mean", "p10", "p50", "p90", | |
| "confidence_label", "confidence_score", "requires_review", | |
| "planned_price_index", "planned_discount_pct", "planned_promo_flag", | |
| "planned_tender_flag", "planned_regulatory_event_flag", | |
| "planned_supply_risk", "planned_competitor_pressure", | |
| ] | |
| column_order = [c for c in column_order if c in df.columns] | |
| return df[column_order] | |
| def _format_error_response(response_json: Dict[str, Any]): | |
| """Convert an error response dict to a user-visible string.""" | |
| status = response_json.get("status", "") | |
| detail = response_json.get("detail", response_json.get("error", "Unknown error")) | |
| usage = response_json.get("usage", {}) | |
| if status == 401: | |
| icon, heading = "🔒", "**Unauthorized**" | |
| elif status == 429: | |
| icon, heading = "🚦", "**Too Many Requests**" | |
| else: | |
| icon, heading = "❌", "**Error**" | |
| usage_lines = "" | |
| if usage: | |
| usage_lines = ( | |
| f"\n\n**Your usage:**\n" | |
| f"- Requests last minute: **{usage.get('requests_last_minute', '–')}** " | |
| f"/ {usage.get('rate_limit_per_minute', '–')}\n" | |
| f"- Requests today: **{usage.get('requests_today', '–')}** " | |
| f"/ {usage.get('daily_quota', '–')}" | |
| ) | |
| return f"{icon} {heading}: {detail}{usage_lines}" | |
| # --------------------------------------------------------------------------- | |
| # Gradio processing function | |
| # --------------------------------------------------------------------------- | |
| def process_forecast(csv_file, scenarios_file, api_key: str, endpoint: str): | |
| """Main processing function called by the Gradio submit button.""" | |
| try: | |
| if not api_key or not api_key.strip(): | |
| return ( | |
| None, | |
| "🔑 **API Key required**: Please enter your DemandScape Suite API Key. " | |
| "To request one, email **info@diskoverdiagnostics.com**.", | |
| None, | |
| ) | |
| if csv_file is None: | |
| return None, "❌ **Error**: Please upload a CSV file.", None | |
| if scenarios_file is None: | |
| return None, "❌ **Error**: Please upload a scenarios JSON file.", None | |
| api_key = api_key.strip() | |
| scenarios = load_scenarios_from_file(scenarios_file.name) | |
| payload = csv_to_inference_json(csv_file.name, scenarios) | |
| response_json = send_inference_request(payload, endpoint, api_key) | |
| if "error" in response_json: | |
| return None, _format_error_response(response_json), None | |
| df_forecasts = parse_forecast_response(response_json) | |
| if df_forecasts.empty: | |
| return None, "⚠️ **Warning**: No forecasts returned from the model.", None | |
| meta = response_json.get("metadata", {}) | |
| usage = meta.get("usage", {}) | |
| usage_section = "" | |
| if usage: | |
| partner = usage.get("partner", meta.get("partner", "")) | |
| usage_section = ( | |
| f"\n\n**API Usage** ({partner}):\n" | |
| f"- Requests last minute: **{usage.get('requests_last_minute', '–')}** " | |
| f"/ {usage.get('rate_limit_per_minute', '–')}\n" | |
| f"- Requests today: **{usage.get('requests_today', '–')}** " | |
| f"/ {usage.get('daily_quota', '–')}" | |
| ) | |
| summary_text = f""" | |
| ✅ **Forecast Generation Successful!** | |
| **Summary:** | |
| - 📊 Total forecasts: **{len(df_forecasts)}** | |
| - 🎯 Unique scenarios: **{df_forecasts['scenario'].nunique()}** | |
| - 📦 Products: **{df_forecasts['product_id'].nunique()}** | |
| - 🌍 Markets: **{df_forecasts['market_id'].nunique()}** | |
| - 📅 Date range: **{df_forecasts['month'].min()}** to **{df_forecasts['month'].max()}** | |
| **Confidence Distribution:** | |
| - 🟢 HIGH: **{(df_forecasts['confidence_label'] == 'HIGH').sum()}** forecasts | |
| - 🟡 MEDIUM: **{(df_forecasts['confidence_label'] == 'MEDIUM').sum()}** forecasts | |
| - 🔴 LOW: **{(df_forecasts['confidence_label'] == 'LOW').sum()}** forecasts | |
| **Scenarios Processed:** {', '.join(df_forecasts['scenario'].unique()[:5])}{'...' if df_forecasts['scenario'].nunique() > 5 else ''}{usage_section} | |
| """ | |
| output_csv_path = f"demandscape_forecasts_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" | |
| df_forecasts.to_csv(output_csv_path, index=False) | |
| return df_forecasts, summary_text, output_csv_path | |
| except Exception as e: | |
| return None, f"❌ **Error**: {str(e)}", None | |
| # --------------------------------------------------------------------------- | |
| # Gradio Tab builder | |
| # --------------------------------------------------------------------------- | |
| def build_tab(api_key_input: gr.Textbox, flyio_endpoint: str) -> None: | |
| """ | |
| Build the DemandScape UI inside the current gr.Tab / gr.Blocks context. | |
| `api_key_input` is the shared Textbox component from the parent app. | |
| """ | |
| gr.Markdown("### Upload your data to generate demand forecasts for multiple scenarios") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📁 Input Files") | |
| csv_input = gr.File( | |
| label="1️⃣ Upload CSV File (Historical Data)", | |
| file_types=[".csv"], | |
| ) | |
| scenarios_input = gr.File( | |
| label="2️⃣ Upload Scenarios JSON File", | |
| file_types=[".json"], | |
| ) | |
| gr.Markdown( | |
| f""" | |
| **Sample files:** | |
| - 📄 [infer_sample_1.csv]({SAMPLE_CSV_URL}) | |
| - 📋 [scenarios.json]({SAMPLE_SCENARIOS_URL}) | |
| **CSV Format:** | |
| ``` | |
| month, product_id, market_id, units_sold, | |
| planned_price_index, planned_discount_pct, | |
| planned_promo_flag, planned_tender_flag, | |
| planned_supply_risk, planned_competitor_pressure, | |
| planned_regulatory_event_flag | |
| ``` | |
| ⚠️ Last 6 rows: leave `units_sold` **empty** (forecast horizon) | |
| **JSON Format:** | |
| ```json | |
| {{ | |
| "base": {{}}, | |
| "price_up_10pct": {{"planned_price_index": 1.1}} | |
| }} | |
| ``` | |
| """ | |
| ) | |
| submit_btn = gr.Button("🚀 Generate Forecasts", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📈 Results") | |
| summary_output = gr.Markdown() | |
| with gr.Row(): | |
| forecast_table = gr.Dataframe( | |
| label="Forecast Results", | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| with gr.Row(): | |
| download_btn = gr.File(label="📥 Download Forecasts CSV") | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Output Columns:** Point forecasts (mean) and prediction intervals (P10, P50, P90) · | |
| Confidence scores (HIGH / MEDIUM / LOW) · Review flags · All scenario parameters | |
| """ | |
| ) | |
| # Wire up | |
| submit_btn.click( | |
| fn=lambda csv_f, scen_f, key: process_forecast(csv_f, scen_f, key, flyio_endpoint), | |
| inputs=[csv_input, scenarios_input, api_key_input], | |
| outputs=[forecast_table, summary_output, download_btn], | |
| ) | |