space-DemandScape / demandscape.py
SYS2011's picture
Upload 5 files
89f848a verified
"""
DemandScape — Demand Forecasting module
All logic extracted from the original suite/app.py so it can be imported
by the unified DemandScape Suite app.
"""
import json
import os
from datetime import datetime
from typing import Any, Dict
import gradio as gr
import pandas as pd
import requests
# ---------------------------------------------------------------------------
# Sample file URLs
# ---------------------------------------------------------------------------
SAMPLE_CSV_URL = "https://github.com/diskover-diagnostics/demandscape-api/blob/main/samples/infer_sample_1.csv"
SAMPLE_SCENARIOS_URL = "https://github.com/diskover-diagnostics/demandscape-api/blob/main/samples/scenarios.json"
PRODUCT = "demandscape"
# ---------------------------------------------------------------------------
# Data helpers
# ---------------------------------------------------------------------------
def load_scenarios_from_file(scenarios_file_path: str) -> Dict[str, Any]:
"""Load scenarios from a JSON file, falling back to built-in defaults."""
try:
with open(scenarios_file_path, "r") as f:
return json.load(f)
except Exception as e:
print(f"Warning: Could not load scenarios file. Using defaults. Error: {e}")
return {
"base": {},
"price_up_2pct": {"planned_price_index": 1.02},
"price_up_5pct": {"planned_price_index": 1.05},
"price_up_10pct": {"planned_price_index": 1.1},
"price_down_2pct": {"planned_price_index": 0.98},
"price_down_5pct": {"planned_price_index": 0.95},
"price_down_10pct": {"planned_price_index": 0.9},
"discount_5pct": {"planned_discount_pct": 5},
"discount_10pct": {"planned_discount_pct": 10},
"discount_20pct": {"planned_discount_pct": 20},
"promo_on": {"planned_promo_flag": 1.0},
"promo_off": {"planned_promo_flag": 0.0},
"tender_on": {"planned_tender_flag": 1.0},
"regulatory_event": {"planned_regulatory_event_flag": 1.0},
"supply_risk_high": {"planned_supply_risk": 1.0},
"competitor_pressure_high": {"planned_competitor_pressure": 1.0},
"growth_push": {
"planned_price_index": 0.95,
"planned_discount_pct": 10,
"planned_promo_flag": 1.0,
},
"margin_push": {
"planned_price_index": 1.05,
"planned_discount_pct": 0.0,
"planned_promo_flag": 0.0,
},
"promo_plus_discount": {
"planned_discount_pct": 15,
"planned_promo_flag": 1.0,
},
"tender_plus_supply_risk": {
"planned_tender_flag": 1.0,
"planned_supply_risk": 1.0,
},
"worst_case": {
"planned_price_index": 1.1,
"planned_supply_risk": 1.0,
"planned_competitor_pressure": 1.0,
},
}
def csv_to_inference_json(csv_file_path: str, scenarios: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a demand-forecast CSV file to the inference request payload."""
df = pd.read_csv(csv_file_path)
df["month"] = pd.to_datetime(df["month"], format="mixed", errors="coerce").dt.strftime("%Y-%m")
df = df.replace([float("inf"), float("-inf")], None)
df = df.where(pd.notna(df), None)
def to_json_safe_list(series):
result = []
for val in series:
if pd.isna(val) or val is None:
result.append(None)
elif isinstance(val, (float, int)) and (val == float("inf") or val == float("-inf")):
result.append(None)
else:
result.append(val)
return result
return {
"product": PRODUCT,
"inputs": {
"data": {
"month": df["month"].tolist(),
"product_id": df["product_id"].tolist(),
"market_id": df["market_id"].tolist(),
"units_sold": to_json_safe_list(df["units_sold"]),
"planned_price_index": to_json_safe_list(df["planned_price_index"]),
"planned_discount_pct": to_json_safe_list(df["planned_discount_pct"]),
"planned_promo_flag": to_json_safe_list(df["planned_promo_flag"]),
"planned_tender_flag": to_json_safe_list(df["planned_tender_flag"]),
"planned_supply_risk": to_json_safe_list(df["planned_supply_risk"]),
"planned_competitor_pressure": to_json_safe_list(df["planned_competitor_pressure"]),
"planned_regulatory_event_flag": to_json_safe_list(df["planned_regulatory_event_flag"]),
},
"parameters": {
"encoder_length": 12,
"prediction_length": 6,
"batch_size": 256,
"n_samples": 1000,
"quantiles": [0.1, 0.5, 0.9],
"scenarios": scenarios,
"round_outputs": True,
},
},
}
def send_inference_request(payload: Dict[str, Any], endpoint: str, api_key: str) -> Dict[str, Any]:
"""Send an inference request to the fly.dev orchestrator."""
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
response = requests.post(endpoint, headers=headers, json=payload, timeout=120)
response.raise_for_status()
return response.json()
def parse_forecast_response(response_json: Dict[str, Any]) -> pd.DataFrame:
"""Parse forecast response JSON into a DataFrame."""
forecasts = response_json.get("forecasts", [])
if not forecasts:
return pd.DataFrame()
df = pd.DataFrame(forecasts)
column_order = [
"scenario", "product_id", "market_id", "month", "horizon_step",
"point_mean", "p10", "p50", "p90",
"confidence_label", "confidence_score", "requires_review",
"planned_price_index", "planned_discount_pct", "planned_promo_flag",
"planned_tender_flag", "planned_regulatory_event_flag",
"planned_supply_risk", "planned_competitor_pressure",
]
column_order = [c for c in column_order if c in df.columns]
return df[column_order]
def _format_error_response(response_json: Dict[str, Any]):
"""Convert an error response dict to a user-visible string."""
status = response_json.get("status", "")
detail = response_json.get("detail", response_json.get("error", "Unknown error"))
usage = response_json.get("usage", {})
if status == 401:
icon, heading = "🔒", "**Unauthorized**"
elif status == 429:
icon, heading = "🚦", "**Too Many Requests**"
else:
icon, heading = "❌", "**Error**"
usage_lines = ""
if usage:
usage_lines = (
f"\n\n**Your usage:**\n"
f"- Requests last minute: **{usage.get('requests_last_minute', '–')}** "
f"/ {usage.get('rate_limit_per_minute', '–')}\n"
f"- Requests today: **{usage.get('requests_today', '–')}** "
f"/ {usage.get('daily_quota', '–')}"
)
return f"{icon} {heading}: {detail}{usage_lines}"
# ---------------------------------------------------------------------------
# Gradio processing function
# ---------------------------------------------------------------------------
def process_forecast(csv_file, scenarios_file, api_key: str, endpoint: str):
"""Main processing function called by the Gradio submit button."""
try:
if not api_key or not api_key.strip():
return (
None,
"🔑 **API Key required**: Please enter your DemandScape Suite API Key. "
"To request one, email **info@diskoverdiagnostics.com**.",
None,
)
if csv_file is None:
return None, "❌ **Error**: Please upload a CSV file.", None
if scenarios_file is None:
return None, "❌ **Error**: Please upload a scenarios JSON file.", None
api_key = api_key.strip()
scenarios = load_scenarios_from_file(scenarios_file.name)
payload = csv_to_inference_json(csv_file.name, scenarios)
response_json = send_inference_request(payload, endpoint, api_key)
if "error" in response_json:
return None, _format_error_response(response_json), None
df_forecasts = parse_forecast_response(response_json)
if df_forecasts.empty:
return None, "⚠️ **Warning**: No forecasts returned from the model.", None
meta = response_json.get("metadata", {})
usage = meta.get("usage", {})
usage_section = ""
if usage:
partner = usage.get("partner", meta.get("partner", ""))
usage_section = (
f"\n\n**API Usage** ({partner}):\n"
f"- Requests last minute: **{usage.get('requests_last_minute', '–')}** "
f"/ {usage.get('rate_limit_per_minute', '–')}\n"
f"- Requests today: **{usage.get('requests_today', '–')}** "
f"/ {usage.get('daily_quota', '–')}"
)
summary_text = f"""
✅ **Forecast Generation Successful!**
**Summary:**
- 📊 Total forecasts: **{len(df_forecasts)}**
- 🎯 Unique scenarios: **{df_forecasts['scenario'].nunique()}**
- 📦 Products: **{df_forecasts['product_id'].nunique()}**
- 🌍 Markets: **{df_forecasts['market_id'].nunique()}**
- 📅 Date range: **{df_forecasts['month'].min()}** to **{df_forecasts['month'].max()}**
**Confidence Distribution:**
- 🟢 HIGH: **{(df_forecasts['confidence_label'] == 'HIGH').sum()}** forecasts
- 🟡 MEDIUM: **{(df_forecasts['confidence_label'] == 'MEDIUM').sum()}** forecasts
- 🔴 LOW: **{(df_forecasts['confidence_label'] == 'LOW').sum()}** forecasts
**Scenarios Processed:** {', '.join(df_forecasts['scenario'].unique()[:5])}{'...' if df_forecasts['scenario'].nunique() > 5 else ''}{usage_section}
"""
output_csv_path = f"demandscape_forecasts_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
df_forecasts.to_csv(output_csv_path, index=False)
return df_forecasts, summary_text, output_csv_path
except Exception as e:
return None, f"❌ **Error**: {str(e)}", None
# ---------------------------------------------------------------------------
# Gradio Tab builder
# ---------------------------------------------------------------------------
def build_tab(api_key_input: gr.Textbox, flyio_endpoint: str) -> None:
"""
Build the DemandScape UI inside the current gr.Tab / gr.Blocks context.
`api_key_input` is the shared Textbox component from the parent app.
"""
gr.Markdown("### Upload your data to generate demand forecasts for multiple scenarios")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📁 Input Files")
csv_input = gr.File(
label="1️⃣ Upload CSV File (Historical Data)",
file_types=[".csv"],
)
scenarios_input = gr.File(
label="2️⃣ Upload Scenarios JSON File",
file_types=[".json"],
)
gr.Markdown(
f"""
**Sample files:**
- 📄 [infer_sample_1.csv]({SAMPLE_CSV_URL})
- 📋 [scenarios.json]({SAMPLE_SCENARIOS_URL})
**CSV Format:**
```
month, product_id, market_id, units_sold,
planned_price_index, planned_discount_pct,
planned_promo_flag, planned_tender_flag,
planned_supply_risk, planned_competitor_pressure,
planned_regulatory_event_flag
```
⚠️ Last 6 rows: leave `units_sold` **empty** (forecast horizon)
**JSON Format:**
```json
{{
"base": {{}},
"price_up_10pct": {{"planned_price_index": 1.1}}
}}
```
"""
)
submit_btn = gr.Button("🚀 Generate Forecasts", variant="primary", size="lg")
with gr.Column(scale=2):
gr.Markdown("### 📈 Results")
summary_output = gr.Markdown()
with gr.Row():
forecast_table = gr.Dataframe(
label="Forecast Results",
interactive=False,
wrap=True,
)
with gr.Row():
download_btn = gr.File(label="📥 Download Forecasts CSV")
gr.Markdown(
"""
---
**Output Columns:** Point forecasts (mean) and prediction intervals (P10, P50, P90) ·
Confidence scores (HIGH / MEDIUM / LOW) · Review flags · All scenario parameters
"""
)
# Wire up
submit_btn.click(
fn=lambda csv_f, scen_f, key: process_forecast(csv_f, scen_f, key, flyio_endpoint),
inputs=[csv_input, scenarios_input, api_key_input],
outputs=[forecast_table, summary_output, download_btn],
)