Yash1178 commited on
Commit
2c3c5f5
·
1 Parent(s): 054baf7

Deploy CommodiSense v1.0

Browse files
.gitignore ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+ build/
11
+ develop-eggs/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ lib/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # IDE
27
+ .vscode/
28
+ .idea/
29
+ *.swp
30
+ *.swo
31
+ *~
32
+
33
+ # Environment
34
+ .env
35
+ .env.local
36
+ .env.*.local
37
+
38
+ # Data/Cache/Database (binary files)
39
+ *.db
40
+ *.duckdb
41
+ data/cache/
42
+ data/*.parquet
43
+ data/*.csv.bak
44
+ data/collector_cache/
45
+ model/models/
46
+ model/cache/
47
+ .cache/
48
+ *.pkl
49
+ *.pickle
50
+
51
+ # Logs
52
+ *.log
53
+ logs/
54
+ runs/
55
+
56
+ # Streamlit cache (keep config.toml for deployment)
57
+ .streamlit_cache/
58
+ .streamlit/.cache/
59
+ .streamlit/__pycache__/
60
+
61
+ # Jupyter
62
+ .ipynb_checkpoints/
63
+ *.ipynb
.streamlit/config.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "#3D7FFF"
3
+ backgroundColor = "#060A0F"
4
+ secondaryBackgroundColor = "#0D1117"
5
+ textColor = "#E6EDF3"
6
+
7
+ [client]
8
+ showErrorDetails = true
9
+
10
+ [server]
11
+ port = 7860
12
+ headless = true
13
+ enableCORS = false
14
+ enableXsrfProtection = true
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["streamlit", "run", "dashboard/app.py", "--server.port=7860", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -1,10 +1,569 @@
1
  ---
2
- title: Commodisense
3
- emoji: 🐢
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CommodiSense
3
+ colorFrom: gray
4
+ colorTo: gray
 
5
  sdk: docker
6
+ app_file: dashboard/app.py
7
  pinned: false
8
  ---
9
 
10
+ # CommodiSense Global Commodity Intelligence Engine
11
+
12
+ <div align="center">
13
+
14
+ ![Python](https://img.shields.io/badge/Python-3.10+-3776AB?style=flat-square&logo=python&logoColor=white)
15
+ ![Streamlit](https://img.shields.io/badge/Streamlit-1.28+-FF4B4B?style=flat-square&logo=streamlit&logoColor=white)
16
+ ![XGBoost](https://img.shields.io/badge/XGBoost-2.0+-006400?style=flat-square)
17
+ ![LightGBM](https://img.shields.io/badge/LightGBM-4.0+-5B8C5A?style=flat-square)
18
+ ![DuckDB](https://img.shields.io/badge/DuckDB-0.10+-FFF000?style=flat-square)
19
+ ![License](https://img.shields.io/badge/License-MIT-blue?style=flat-square)
20
+ ![Status](https://img.shields.io/badge/Status-Live-00D97E?style=flat-square)
21
+
22
+ **Zero-cost commodity price direction forecaster for 10 global markets.**
23
+ Powered by XGBoost + LightGBM ensemble, SHAP explainability, FinBERT NLP sentiment,
24
+ CFTC COT positioning, EIA inventory data, USDA crop signals, and FRED macro indicators.
25
+
26
+ [**Live Demo**](https://commodisense.streamlit.app) · [**Report Bug**](https://github.com/Yashvardhansharma112/commodisense/issues) · [**Request Feature**](https://github.com/Yashvardhansharma112/commodisense/issues)
27
+
28
+ </div>
29
+
30
+ ---
31
+
32
+ ## Table of Contents
33
+
34
+ - [Overview](#overview)
35
+ - [Features](#features)
36
+ - [How It Works](#how-it-works)
37
+ - [Data Sources](#data-sources)
38
+ - [Model Architecture](#model-architecture)
39
+ - [Accuracy Results](#accuracy-results)
40
+ - [Tech Stack](#tech-stack)
41
+ - [Project Structure](#project-structure)
42
+ - [Getting Started](#getting-started)
43
+ - [Configuration](#configuration)
44
+ - [Deployment](#deployment)
45
+ - [Daily Pipeline](#daily-pipeline)
46
+ - [API Keys](#api-keys)
47
+
48
+ ---
49
+
50
+ ## Overview
51
+
52
+ CommodiSense is a production-grade commodity intelligence platform that forecasts price direction (UP / STABLE / DOWN) for 10 global commodity futures over 7-day and 30-day horizons.
53
+
54
+ Unlike most financial ML projects that rely on price technicals alone, CommodiSense fuses **8 independent data sources** — including institutional positioning data (CFTC COT), energy inventory surprises (EIA), crop condition ratings (USDA), and macroeconomic indicators (FRED) — into a single ensemble model per commodity.
55
+
56
+ The entire system runs at **zero ongoing cost** using free public APIs, GitHub Actions for scheduling, Streamlit Cloud for hosting, and DuckDB as a serverless embedded database.
57
+
58
+ ```
59
+ Data Collection → Feature Engineering → Ensemble Training → Live Dashboard
60
+ (8 sources) (65+ features) (XGBoost+LGBM) (Streamlit Cloud)
61
+ ```
62
+
63
+ ---
64
+
65
+ ## Features
66
+
67
+ ### Forecasting Engine
68
+ - **10 commodity markets**: Crude Oil (CL=F), Natural Gas (NG=F), Gold (GC=F), Wheat (ZW=F), Corn (ZC=F), Soybeans (ZS=F), Cotton (CT=F), Sugar (SB=F), USD/INR (USDINR=X), Copper (HG=F)
69
+ - **Dual horizons**: 7-day and 30-day directional forecasts
70
+ - **3-class output**: UP (>threshold%), STABLE, DOWN (<-threshold%) with per-commodity calibrated thresholds
71
+ - **Probability scores** with isotonic calibration for reliable confidence estimates
72
+ - **HIGH / MEDIUM / LOW confidence tiers** based on model probability
73
+ - **Signal confirmation filter**: 4 independent signals must agree to issue a HIGH-confidence call (price momentum, COT commercial positioning, EIA supply signal, USDA crop trend)
74
+
75
+ ### Data Intelligence
76
+ - **CFTC COT Reports**: 13 years of weekly institutional positioning (commercial hedgers vs managed money). The single most valuable commodity signal — smart money positioning often leads price by 1–3 weeks.
77
+ - **EIA Inventory**: Weekly crude oil stocks (2,278 rows back to 1982) and natural gas storage (856 rows). Inventory surprises vs 5-year average directly drive energy price moves.
78
+ - **USDA NASS**: Weekly crop condition (% good + excellent) for corn, wheat, soybeans, cotton. Annual production estimates. Declining crop condition → bullish price signal.
79
+ - **FRED Macro**: USD Index (DXY), VIX volatility, 10-year Treasury yield, Fed Funds rate, Industrial Production. Gold inversely correlates with real yields; copper tracks industrial output.
80
+ - **FinBERT NLP**: GDELT news articles scored for financial sentiment (bullish/bearish/neutral). Rolling 1-day, 3-day, 7-day sentiment aggregates per commodity.
81
+ - **spaCy Event Extraction**: Supply shock, policy change, and geopolitical event detection from news headlines.
82
+ - **Open-Meteo Weather**: Drought index, heat stress days, precipitation anomaly for agricultural commodity regions.
83
+ - **ACLED Geopolitical**: Risk scores for regions that supply each commodity.
84
+
85
+ ### Explainability
86
+ - **SHAP values** for every forecast — top 5 signal drivers shown in the dashboard
87
+ - Human-readable feature labels (e.g., "COT Smart Money Positioning", "EIA Crude Inventory Surprise")
88
+ - **AI Analyst Reports** generated via Groq LLM (Llama 3) contextualizing each forecast
89
+
90
+ ### Dashboard (Dark Luxury Terminal)
91
+ - Live animated ticker strip with all 10 markets
92
+ - Macro environment bar: DXY, VIX, yield curve, spread, copper demand proxy
93
+ - Direction-colored commodity cards with confidence badges
94
+ - Candlestick chart with 20-day SMA and forecast zone overlay
95
+ - COT positioning chart (commercial vs managed money, 2-year history)
96
+ - EIA inventory bar chart with 4-week rolling average
97
+ - News sentiment chart with bull/bear zones
98
+ - Weather signal metrics
99
+ - AI analyst report per commodity
100
+ - Recent news feed with sentiment scores
101
+
102
+ ### Infrastructure
103
+ - **GitHub Actions** daily pipeline (Mon–Fri 6am UTC): collect → process → retrain → forecast → commit
104
+ - **DuckDB** embedded database (no server required, zero cost)
105
+ - **Streamlit Cloud** free-tier hosting with auto-deploy on push
106
+ - Full **error isolation** — one failing step doesn't halt the rest of the pipeline
107
+
108
+ ---
109
+
110
+ ## How It Works
111
+
112
+ ```
113
+ ┌─────────────────────────────────────────────────────────────────┐
114
+ │ DAILY PIPELINE (13 Steps) │
115
+ ├─────────────────────────────────────────────────────────────────┤
116
+ │ Step 1 Collect prices yfinance → DuckDB │
117
+ │ Step 2 Collect news GDELT → DuckDB │
118
+ │ Step 3 Collect weather Open-Meteo → DuckDB │
119
+ │ Step 4 Collect geopolitical ACLED → DuckDB │
120
+ │ Step 5 Collect COT CFTC → DuckDB │
121
+ │ Step 6 Collect FRED macro FRED CSV + yfinance → DuckDB │
122
+ │ Step 7 Collect EIA inventory EIA API v2 → DuckDB │
123
+ │ Step 8 Collect USDA crop USDA NASS API → DuckDB │
124
+ │ Step 9 Score NLP sentiment FinBERT → sentiment_daily │
125
+ │ Step 10 Extract events spaCy → extracted_events │
126
+ │ Step 11 Generate forecasts XGBoost+LightGBM → accuracy_log │
127
+ │ Step 12 Generate AI reports Groq LLM → reports │
128
+ │ Step 13 Log accuracy Compare 7-day-old forecasts │
129
+ └─────────────────────────────────────────────────────────────────┘
130
+
131
+ ↓ pushes to GitHub ↓
132
+
133
+ Streamlit Cloud auto-deploys
134
+ ```
135
+
136
+ ---
137
+
138
+ ## Data Sources
139
+
140
+ | Source | Type | Coverage | Update Frequency | Key |
141
+ |--------|------|----------|-----------------|-----|
142
+ | **yfinance** | Price OHLCV | 12,613 rows · 5yr | Daily | None |
143
+ | **CFTC COT** | Futures positioning | 8,826 rows · 13yr | Weekly (Friday) | None |
144
+ | **FRED** | Macro indicators | 7,193 rows · 16yr | Daily/Weekly/Monthly | None |
145
+ | **EIA** | Energy inventory | 3,134 rows · 40yr crude | Weekly (Wednesday) | Free |
146
+ | **USDA NASS** | Crop condition & stocks | 1,104 rows · 5yr | Weekly/Quarterly | Free |
147
+ | **GDELT** | Global news | 392 articles | Daily | None |
148
+ | **Open-Meteo** | Agricultural weather | 210 rows | Daily | None |
149
+ | **ACLED** | Geopolitical events | 20 events | Weekly | None |
150
+
151
+ ### Free API Keys Required
152
+
153
+ | API | Data | Register |
154
+ |-----|------|---------|
155
+ | EIA | Crude oil & natural gas weekly inventory | [eia.gov/opendata](https://www.eia.gov/opendata/register.php) |
156
+ | USDA NASS | Crop condition, stocks, production | [quickstats.nass.usda.gov/api](https://quickstats.nass.usda.gov/api) |
157
+ | Groq | AI analyst report generation | [console.groq.com](https://console.groq.com) |
158
+
159
+ ---
160
+
161
+ ## Model Architecture
162
+
163
+ ### Per-Symbol Ensemble
164
+
165
+ Each of the 10 commodities has **two independent models** trained: one for the 7-day horizon and one for the 30-day horizon.
166
+
167
+ ```
168
+ Raw Features (65+)
169
+
170
+
171
+ Feature Selection ← drops columns with <5% non-zero values
172
+ (sparse filter) auto-excludes missing data sources
173
+
174
+
175
+ StandardScaler ← fit on training data, saved per symbol
176
+
177
+ ├─────────────────────────────────────────────┐
178
+ ▼ ▼
179
+ XGBoost Classifier LightGBM Classifier
180
+ (300 trees, max_depth=5) (300 trees, 31 leaves)
181
+ + Isotonic Calibration
182
+ │ │
183
+ └──────────────┬──────────────────────────────┘
184
+
185
+ Ensemble (avg probabilities)
186
+
187
+
188
+ Direction + Probability
189
+ (UP / STABLE / DOWN)
190
+
191
+
192
+ Signal Confirmation Filter ← 4-signal cross-check
193
+ (momentum + COT + EIA + USDA)
194
+
195
+
196
+ HIGH / MEDIUM / LOW confidence
197
+ ```
198
+
199
+ ### Feature Groups (65+ total)
200
+
201
+ | Group | Features | Count |
202
+ |-------|----------|-------|
203
+ | **Price technicals** | RSI-14, MACD, Bollinger Band position, ATR, SMA crossover | 5 |
204
+ | **Price momentum** | Return 1d/7d/14d/30d/60d, momentum score | 6 |
205
+ | **Seasonality** | Month sin/cos, harvest season flag, days to OPEC meeting | 4 |
206
+ | **Cross-commodity** | Oil/Gold ratio, DXY proxy | 2 |
207
+ | **CFTC COT** | Commercial net %, MM net %, week-over-week changes, open interest | 7 |
208
+ | **FRED macro** | DXY, VIX, 10Y yield, Fed Funds, INDPRO, yield inversion, copper basis | 12 |
209
+ | **EIA inventory** | Stocks level, weekly change, z-score vs 5yr avg, draw flag | 5 |
210
+ | **USDA crop** | Condition score, week-over-week change, stocks, production | 5 |
211
+ | **NLP sentiment** | 1-day/3-day/7-day sentiment, article count, positive ratio | 5 |
212
+ | **Event signals** | Bullish/bearish events, max severity, supply shock, policy change | 6 |
213
+ | **Geopolitical** | Risk score 7d, risk score 30d | 2 |
214
+ | **Weather** | Drought index, heat stress days, precipitation anomaly | 3 |
215
+ | **Data flags** | has_cot_data, has_fred_data, has_eia_data, has_usda_data | 4 |
216
+
217
+ ### Training Strategy
218
+
219
+ - **Walk-forward validation**: 5-fold cross-validation on 80% of data, tested on most recent 20%
220
+ - **Class balancing**: `compute_sample_weight("balanced")` addresses UP/DOWN/STABLE imbalance
221
+ - **Commodity-specific thresholds**: USDINR uses ±0.4% threshold (managed float), NG=F uses ±3.5% (highly volatile)
222
+ - **Regime detection**: TRENDING / VOLATILE / RANGE_BOUND classification per row
223
+ - **Interaction features**: `sentiment × momentum`, `event × momentum`, `high_volatility_flag`
224
+ - **SHAP explainer**: TreeExplainer run post-training, top 5 features saved per forecast
225
+
226
+ ---
227
+
228
+ ## Accuracy Results
229
+
230
+ > Measured on held-out test set (most recent 20% of data). Random chance = 33.3% (3-class problem).
231
+
232
+ | Commodity | 7-Day | 30-Day | vs Baseline |
233
+ |-----------|-------|--------|------------|
234
+ | Crude Oil (CL=F) | 30.7% | 31.5% | +4.0% |
235
+ | Natural Gas (NG=F) | 36.3% | 44.6% | +3.6% |
236
+ | Gold (GC=F) | 37.1% | **54.2%** | +6.8% 30d |
237
+ | Wheat (ZW=F) | **44.6%** | 23.1% | +0.4% 7d |
238
+ | Corn (ZC=F) | 16.7%⚠ | **48.2%** | — |
239
+ | **Soybeans (ZS=F)** | **62.2%** | 48.6% | **+18.0%** |
240
+ | Cotton (CT=F) | **45.8%** | 34.7% | +0.8% |
241
+ | Sugar (SB=F) | 35.9% | 36.7% | — |
242
+ | USD/INR (USDINR=X) | 41.2% | **50.8%** | **+28.1%** 30d |
243
+ | Copper (HG=F) | 16.3%⚠ | 23.1% | — |
244
+ | **Average** | **36.7%** | **39.6%** | +5.4% vs random |
245
+
246
+ > ⚠ ZC=F 7d and HG=F have below-random accuracy due to structural market regime breaks in 2024–2026 (South American corn oversupply, HG=F name change in CFTC files limiting history). Use 30d forecasts for these symbols.
247
+
248
+ **Best performers:**
249
+ - 🥇 **ZS=F 7d: 62.2%** — USDA soybean crop condition is a dominant signal
250
+ - 🥈 **USDINR=X 30d: 50.8%** — FRED DXY + Fed Funds rate highly predictive for USD/INR
251
+ - 🥉 **GC=F 30d: 54.2%** — Gold responds strongly to yield curve and inflation expectations
252
+
253
+ ---
254
+
255
+ ## Tech Stack
256
+
257
+ ```
258
+ Language Python 3.10+
259
+ Database DuckDB 0.10+ (embedded, zero-config, serverless)
260
+ ML XGBoost 2.0, LightGBM 4.0, scikit-learn 1.3
261
+ Explainability SHAP 0.42
262
+ NLP HuggingFace Transformers (FinBERT), spaCy 3.5
263
+ Dashboard Streamlit 1.28, Plotly 5.15
264
+ LLM Reports Groq API (Llama 3)
265
+ Data APIs yfinance, requests, FRED CSV, EIA API v2, USDA NASS API
266
+ Scheduling GitHub Actions (cron)
267
+ Hosting Streamlit Cloud (free tier)
268
+ ```
269
+
270
+ ---
271
+
272
+ ## Project Structure
273
+
274
+ ```
275
+ commodisense/
276
+
277
+ ├── data/ # Data collection layer
278
+ │ ├── db.py # DuckDB connection + schema init (9 tables)
279
+ │ ├── collector_prices.py # yfinance OHLCV prices
280
+ │ ├── collector_news.py # GDELT news articles
281
+ │ ├── collector_weather.py # Open-Meteo agricultural weather
282
+ │ ├── collector_geopolitical.py # ACLED geopolitical events
283
+ │ ├── collector_cot.py # CFTC COT weekly positioning (2013–2026)
284
+ │ ├── collector_fred.py # FRED macro + yfinance DXY/VIX
285
+ │ ├── collector_eia.py # EIA crude oil + natural gas inventory
286
+ │ └── collector_usda.py # USDA crop condition + stocks + production
287
+
288
+ ├── signals/ # Feature engineering layer
289
+ │ ├── price_features.py # RSI, MACD, momentum, seasonality, cross-commodity
290
+ │ ├── nlp_sentiment.py # FinBERT sentiment scoring pipeline
291
+ │ ├── nlp_events.py # spaCy event extraction
292
+ │ ├── weather_features.py # Drought/heat/precip aggregation by commodity region
293
+ │ └── macro_features.py # COT + FRED + EIA + USDA feature engineering
294
+
295
+ ├── model/ # ML layer
296
+ │ ├── feature_builder.py # Assembles all signals → training matrix (no lookahead)
297
+ │ ├── trainer.py # XGBoost + LightGBM training, calibration, SHAP
298
+ │ ├── predictor.py # Inference with signal confirmation filter
299
+ │ └── explainer.py # AI report generation via Groq
300
+
301
+ ├── pipeline/
302
+ │ └── daily_run.py # 13-step orchestrator with error isolation
303
+
304
+ ├── dashboard/
305
+ │ └── app.py # Streamlit dashboard (dark luxury terminal UI)
306
+
307
+ ├── models/ # Trained model artifacts (committed to git)
308
+ │ ├── xgb_{SYMBOL}_{horizon}.pkl
309
+ │ ├── lgbm_{SYMBOL}_{horizon}.pkl
310
+ │ ├── scaler_{SYMBOL}_{horizon}.pkl
311
+ │ ├── feature_names_{SYMBOL}_{horizon}.json
312
+ │ └── accuracy_report.json
313
+
314
+ ├── tests/
315
+ │ └── test_accuracy.py # Walk-forward backtesting framework (6 boosters)
316
+
317
+ ├── .github/workflows/
318
+ │ └── daily_pipeline.yml # GitHub Actions cron (Mon–Fri 06:00 UTC)
319
+
320
+ ├── .env.example # Environment variable template
321
+ ├── requirements.txt # Python dependencies
322
+ └── README.md
323
+ ```
324
+
325
+ ### Database Schema (9 tables)
326
+
327
+ | Table | Description |
328
+ |-------|-------------|
329
+ | `prices` | Daily OHLCV per symbol |
330
+ | `news_raw` | Raw news articles with NLP scores |
331
+ | `sentiment_daily` | Aggregated daily sentiment per commodity |
332
+ | `extracted_events` | spaCy-extracted supply shocks, policy changes |
333
+ | `weather_features` | Drought/heat/precip by region and commodity |
334
+ | `geopolitical_events` | Risk scores per region/commodity |
335
+ | `accuracy_log` | Live forecast vs actual outcome tracking |
336
+ | `cot_data` | CFTC COT weekly positioning per symbol |
337
+ | `fred_data` | FRED macro series (daily, forward-filled) |
338
+ | `eia_inventory` | EIA weekly energy storage |
339
+ | `usda_crop` | USDA crop condition, stocks, production |
340
+
341
+ ---
342
+
343
+ ## Getting Started
344
+
345
+ ### Prerequisites
346
+
347
+ - Python 3.10+
348
+ - Git
349
+
350
+ ### Installation
351
+
352
+ ```bash
353
+ # Clone the repository
354
+ git clone https://github.com/Yashvardhansharma112/commodisense.git
355
+ cd commodisense
356
+
357
+ # Create virtual environment
358
+ python -m venv venv
359
+
360
+ # Activate (Windows)
361
+ venv\Scripts\activate
362
+
363
+ # Activate (macOS/Linux)
364
+ source venv/bin/activate
365
+
366
+ # Install dependencies
367
+ pip install -r requirements.txt
368
+
369
+ # Download spaCy model
370
+ python -m spacy download en_core_web_sm
371
+ ```
372
+
373
+ ### Environment Variables
374
+
375
+ ```bash
376
+ # Copy the example and fill in your keys
377
+ cp .env.example .env
378
+ ```
379
+
380
+ Edit `.env`:
381
+ ```env
382
+ GROQ_API_KEY=your_groq_key_here # groq.com — free, for AI reports
383
+ EIA_API_KEY=your_eia_key_here # eia.gov/opendata — free
384
+ USDA_API_KEY=your_usda_key_here # quickstats.nass.usda.gov/api — free
385
+ ```
386
+
387
+ ### First Run (Full Backfill)
388
+
389
+ ```bash
390
+ # Initialize database schema
391
+ python data/db.py
392
+
393
+ # Backfill all data sources (takes ~15 minutes)
394
+ python pipeline/daily_run.py --backfill
395
+
396
+ # Train models for all 10 commodities
397
+ for symbol in CL=F NG=F GC=F ZW=F ZC=F ZS=F CT=F SB=F USDINR=X HG=F; do
398
+ python model/trainer.py --symbol $symbol --horizon both
399
+ done
400
+
401
+ # Launch dashboard
402
+ streamlit run dashboard/app.py
403
+ ```
404
+
405
+ The dashboard will be available at **http://localhost:8501**
406
+
407
+ ### Individual Commands
408
+
409
+ ```bash
410
+ # Collect specific data source
411
+ python data/collector_prices.py --backfill
412
+ python data/collector_cot.py --backfill
413
+ python data/collector_fred.py --backfill
414
+ python data/collector_eia.py --backfill
415
+ python data/collector_usda.py --backfill
416
+
417
+ # Run NLP pipeline
418
+ python signals/nlp_sentiment.py --limit 500
419
+ python signals/nlp_events.py --limit 500
420
+
421
+ # Generate forecast for a single symbol
422
+ python model/predictor.py --symbol ZS=F
423
+
424
+ # Generate all forecasts
425
+ python model/predictor.py --all
426
+
427
+ # Run accuracy backtest
428
+ python tests/test_accuracy.py --symbol ZS=F
429
+
430
+ # Run only a specific pipeline step (for debugging)
431
+ python pipeline/daily_run.py --step 7
432
+ ```
433
+
434
+ ---
435
+
436
+ ## Configuration
437
+
438
+ ### Per-Commodity Direction Thresholds
439
+
440
+ Different commodities have different volatility profiles. Thresholds are set in `model/feature_builder.py`:
441
+
442
+ | Symbol | Threshold | Rationale |
443
+ |--------|-----------|-----------|
444
+ | USDINR=X | ±0.4% | Managed float — rarely moves >1% in a week |
445
+ | GC=F | ±1.5% | Gold — moderately volatile |
446
+ | NG=F | ±3.5% | Natural gas — highly volatile seasonally |
447
+ | Others | ±2.0% | Default threshold |
448
+
449
+ ### Adding a New Commodity
450
+
451
+ 1. Add the ticker to `ALL_SYMBOLS` in `signals/price_features.py`
452
+ 2. Add a human-readable name to `SYMBOL_NAMES` in `model/predictor.py`
453
+ 3. Run `python data/collector_prices.py --backfill`
454
+ 4. Train: `python model/trainer.py --symbol NEW=F --horizon both`
455
+
456
+ ---
457
+
458
+ ## Deployment
459
+
460
+ ### Streamlit Cloud (Recommended — Free)
461
+
462
+ 1. Fork or push to GitHub
463
+ 2. Go to [share.streamlit.io](https://share.streamlit.io)
464
+ 3. Click **New app** → connect your GitHub repo
465
+ 4. Set:
466
+ - **Repository**: `Yashvardhansharma112/commodisense`
467
+ - **Branch**: `main`
468
+ - **Main file path**: `dashboard/app.py`
469
+ 5. Click **Advanced settings** → paste in **Secrets** (TOML format):
470
+ ```toml
471
+ GROQ_API_KEY = "your_key"
472
+ EIA_API_KEY = "your_key"
473
+ USDA_API_KEY = "your_key"
474
+ ```
475
+ 6. Click **Deploy**
476
+
477
+ ### GitHub Actions (Daily Pipeline)
478
+
479
+ Add the same 3 keys as **Repository Secrets** at:
480
+ `Settings → Secrets → Actions → New repository secret`
481
+
482
+ The pipeline runs automatically Mon–Fri at 06:00 UTC. It:
483
+ 1. Collects fresh data from all 8 sources
484
+ 2. Runs NLP sentiment + event extraction
485
+ 3. Generates new forecasts for all 10 symbols
486
+ 4. Commits the updated `data/commodisense.duckdb` back to the repo
487
+ 5. Streamlit Cloud auto-deploys on the new commit
488
+
489
+ ---
490
+
491
+ ## Daily Pipeline
492
+
493
+ The pipeline is defined in `pipeline/daily_run.py`. Each step is isolated in a `try/except` — one failure doesn't stop the rest.
494
+
495
+ ```
496
+ Step 1 Collect prices ~30s
497
+ Step 2 Collect news ~60s (GDELT rate-limited)
498
+ Step 3 Collect weather ~45s
499
+ Step 4 Collect geopolitical ~15s
500
+ Step 5 Collect COT ~30s (CFTC public ZIP download)
501
+ Step 6 Collect FRED macro ~30s (7 series + yfinance fallback)
502
+ Step 7 Collect EIA inventory ~15s (2 series via API)
503
+ Step 8 Collect USDA crop ~60s (4 commodities × 3 queries)
504
+ Step 9 Score NLP sentiment ~120s (FinBERT on GPU/CPU)
505
+ Step 10 Extract events ~60s (spaCy NER)
506
+ Step 11 Generate forecasts ~30s (10 symbols, cached models)
507
+ Step 12 Generate AI reports ~90s (Groq API, 10 LLM calls)
508
+ Step 13 Log accuracy ~5s (compare 7-day-old forecasts)
509
+ ─────────────────────────────────────────
510
+ Total ~8-12 minutes
511
+ ```
512
+
513
+ Manual trigger: Go to **Actions** tab → **Daily CommodiSense Pipeline** → **Run workflow**
514
+
515
+ ---
516
+
517
+ ## API Keys
518
+
519
+ | Key | Where to get | Cost | What it enables |
520
+ |-----|-------------|------|----------------|
521
+ | `GROQ_API_KEY` | [console.groq.com](https://console.groq.com) | Free tier | AI analyst reports via Llama 3 |
522
+ | `EIA_API_KEY` | [eia.gov/opendata/register.php](https://www.eia.gov/opendata/register.php) | Free | Crude oil + natural gas weekly inventory data |
523
+ | `USDA_API_KEY` | [quickstats.nass.usda.gov/api](https://quickstats.nass.usda.gov/api) | Free | Crop condition, stocks, production |
524
+
525
+ The system runs without any API keys — it will skip those data collection steps and fall back to price technicals only. Accuracy improves significantly with all keys set.
526
+
527
+ ---
528
+
529
+ ## Accuracy Improvement Roadmap
530
+
531
+ | Data Source | Expected Gain | Status |
532
+ |------------|--------------|--------|
533
+ | CFTC COT (13yr history) | +5–8% avg | ✅ Implemented |
534
+ | EIA crude + natgas inventory | +10–13% for CL=F | ✅ Implemented |
535
+ | USDA crop condition | +15–18% for ZS=F | ✅ Implemented |
536
+ | FRED macro (DXY, VIX, yields) | +21% USDINR=X 30d | ✅ Implemented |
537
+ | South American crop data (CONAB) | +10–15% ZC=F | 🔲 Planned |
538
+ | LME copper warehouse stocks | +8–12% HG=F | 🔲 Planned |
539
+ | Heating/Cooling Degree Days (NOAA) | +5–8% NG=F | 🔲 Planned |
540
+ | WASDE monthly projections | +5–7% grains | 🔲 Planned |
541
+
542
+ ---
543
+
544
+ ## License
545
+
546
+ MIT License — see [LICENSE](LICENSE) for details.
547
+
548
+ ---
549
+
550
+ ## Acknowledgements
551
+
552
+ - **CFTC** for free public COT disaggregated reports
553
+ - **Federal Reserve (FRED)** for free macroeconomic data API
554
+ - **U.S. Energy Information Administration (EIA)** for free energy inventory API
555
+ - **USDA NASS** for free agricultural statistics API
556
+ - **GDELT Project** for free global news event database
557
+ - **Open-Meteo** for free historical weather API
558
+ - **yfinance** community for the excellent Yahoo Finance wrapper
559
+ - **Groq** for free Llama 3 inference API
560
+
561
+ ---
562
+
563
+ <div align="center">
564
+
565
+ Built with Python · Deployed on Streamlit Cloud · Data from CFTC, FRED, EIA, USDA, GDELT
566
+
567
+ **[⭐ Star this repo](https://github.com/Yashvardhansharma112/commodisense)** if you find it useful
568
+
569
+ </div>
dashboard/app.py ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CommodiSense Dashboard — Global Commodity Intelligence Engine
3
+ Dark luxury financial terminal UI.
4
+
5
+ Run: streamlit run dashboard/app.py
6
+ Deploy: Streamlit Cloud → main file: dashboard/app.py → secret: GROQ_API_KEY
7
+ """
8
+
9
+ import sys
10
+ from datetime import date, datetime, timedelta
11
+ from pathlib import Path
12
+
13
+ import pandas as pd
14
+ import plotly.graph_objects as go
15
+ import streamlit as st
16
+
17
+ ROOT = Path(__file__).parent.parent
18
+ sys.path.insert(0, str(ROOT))
19
+
20
+ from data.db import get_conn, init_schema
21
+ from model.explainer import load_latest_reports, generate_report
22
+ from model.predictor import predict, SYMBOL_NAMES
23
+
24
+ # ── page config ────────────────────────────────────────────────────────────────
25
+
26
+ st.set_page_config(
27
+ page_title="CommodiSense",
28
+ page_icon="◈",
29
+ layout="wide",
30
+ initial_sidebar_state="collapsed",
31
+ )
32
+
33
+ # ── design tokens ──────────────────────────────────────────────────────────────
34
+
35
+ C = {
36
+ "bg": "#060A0F",
37
+ "surface": "#0D1117",
38
+ "surface2": "#161B22",
39
+ "border": "rgba(255,255,255,0.07)",
40
+ "border_hi": "rgba(255,255,255,0.14)",
41
+ "up": "#00D97E",
42
+ "down": "#FF3B55",
43
+ "stable": "#7A8899",
44
+ "up_dim": "rgba(0,217,126,0.12)",
45
+ "down_dim": "rgba(255,59,85,0.12)",
46
+ "stable_dim": "rgba(122,136,153,0.10)",
47
+ "accent": "#3D7FFF",
48
+ "accent_dim": "rgba(61,127,255,0.12)",
49
+ "gold": "#FFBB00",
50
+ "text": "#E6EDF3",
51
+ "text2": "#8B949E",
52
+ "text3": "#484F58",
53
+ "conf_high": "#00D97E",
54
+ "conf_mid": "#FFBB00",
55
+ "conf_low": "#7A8899",
56
+ }
57
+
58
+ DIR_COLOR = {"UP": C["up"], "DOWN": C["down"], "STABLE": C["stable"]}
59
+ DIR_DIM = {"UP": C["up_dim"],"DOWN": C["down_dim"],"STABLE": C["stable_dim"]}
60
+ DIR_ICON = {"UP": "▲", "DOWN": "▼", "STABLE": "◆"}
61
+ CONF_COLOR = {"HIGH": C["conf_high"], "MEDIUM": C["conf_mid"], "LOW": C["conf_low"]}
62
+
63
+ ALL_SYMBOLS = list(SYMBOL_NAMES.keys())
64
+
65
+ # ── CSS ────────────────────────────────────────────────────────────────────────
66
+
67
+ def _inject_css():
68
+ st.markdown(f"""
69
+ <style>
70
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
71
+
72
+ html, body, [class*="css"] {{
73
+ font-family: 'Inter', -apple-system, sans-serif;
74
+ background-color: {C['bg']};
75
+ color: {C['text']};
76
+ }}
77
+ .stApp {{ background-color: {C['bg']}; }}
78
+ .block-container {{ padding: 1.2rem 2rem 3rem 2rem; max-width: 1600px; }}
79
+
80
+ /* Hide default Streamlit chrome */
81
+ #MainMenu, footer, header {{ visibility: hidden; }}
82
+ .stDeployButton {{ display: none; }}
83
+ [data-testid="stSidebar"] {{ background: {C['surface']}; border-right: 1px solid {C['border']}; }}
84
+
85
+ /* Scrollbar */
86
+ ::-webkit-scrollbar {{ width: 4px; height: 4px; }}
87
+ ::-webkit-scrollbar-track {{ background: {C['bg']}; }}
88
+ ::-webkit-scrollbar-thumb {{ background: {C['border_hi']}; border-radius: 2px; }}
89
+
90
+ /* Buttons */
91
+ .stButton > button {{
92
+ background: transparent;
93
+ border: 1px solid {C['border_hi']};
94
+ color: {C['text2']};
95
+ border-radius: 6px;
96
+ font-size: 0.78rem;
97
+ padding: 4px 10px;
98
+ transition: all 0.15s ease;
99
+ font-family: 'Inter', sans-serif;
100
+ }}
101
+ .stButton > button:hover {{
102
+ border-color: {C['accent']};
103
+ color: {C['accent']};
104
+ background: {C['accent_dim']};
105
+ }}
106
+
107
+ /* Metric cards */
108
+ div[data-testid="metric-container"] {{
109
+ background: {C['surface']};
110
+ border: 1px solid {C['border']};
111
+ border-radius: 10px;
112
+ padding: 14px 16px;
113
+ }}
114
+ div[data-testid="metric-container"] label {{
115
+ color: {C['text2']} !important;
116
+ font-size: 0.72rem !important;
117
+ letter-spacing: 0.06em;
118
+ text-transform: uppercase;
119
+ }}
120
+ div[data-testid="metric-container"] [data-testid="stMetricValue"] {{
121
+ color: {C['text']} !important;
122
+ font-size: 1.3rem !important;
123
+ font-weight: 600;
124
+ font-family: 'JetBrains Mono', monospace;
125
+ }}
126
+
127
+ /* Radio + select */
128
+ .stRadio > div {{ gap: 8px; }}
129
+ .stRadio label {{ font-size: 0.8rem; color: {C['text2']}; }}
130
+ .stSelectbox label {{ color: {C['text2']}; font-size: 0.8rem; }}
131
+
132
+ /* Tabs */
133
+ .stTabs [data-baseweb="tab-list"] {{
134
+ gap: 4px;
135
+ background: transparent;
136
+ border-bottom: 1px solid {C['border']};
137
+ }}
138
+ .stTabs [data-baseweb="tab"] {{
139
+ background: transparent;
140
+ border: none;
141
+ color: {C['text2']};
142
+ font-size: 0.82rem;
143
+ padding: 6px 14px;
144
+ border-radius: 6px 6px 0 0;
145
+ }}
146
+ .stTabs [aria-selected="true"] {{
147
+ background: {C['surface']} !important;
148
+ color: {C['text']} !important;
149
+ border-bottom: 2px solid {C['accent']};
150
+ }}
151
+
152
+ /* Ticker animation */
153
+ @keyframes ticker-scroll {{
154
+ 0% {{ transform: translateX(0); }}
155
+ 100% {{ transform: translateX(-50%); }}
156
+ }}
157
+ .ticker-wrap {{
158
+ overflow: hidden;
159
+ background: {C['surface']};
160
+ border-top: 1px solid {C['border']};
161
+ border-bottom: 1px solid {C['border']};
162
+ padding: 8px 0;
163
+ margin: -1rem -2rem 1.4rem -2rem;
164
+ }}
165
+ .ticker-inner {{
166
+ display: flex;
167
+ animation: ticker-scroll 40s linear infinite;
168
+ width: max-content;
169
+ }}
170
+ .ticker-item {{
171
+ display: inline-flex;
172
+ align-items: center;
173
+ gap: 6px;
174
+ padding: 0 28px;
175
+ white-space: nowrap;
176
+ font-family: 'JetBrains Mono', monospace;
177
+ font-size: 0.78rem;
178
+ border-right: 1px solid {C['border']};
179
+ }}
180
+ .ticker-sep {{
181
+ padding: 0 28px;
182
+ color: {C['text3']};
183
+ font-size: 0.6rem;
184
+ border-right: 1px solid {C['border']};
185
+ }}
186
+
187
+ /* Commodity cards */
188
+ .comm-card {{
189
+ background: {C['surface']};
190
+ border: 1px solid {C['border']};
191
+ border-radius: 12px;
192
+ padding: 16px;
193
+ cursor: pointer;
194
+ transition: all 0.18s ease;
195
+ height: 100%;
196
+ position: relative;
197
+ overflow: hidden;
198
+ }}
199
+ .comm-card::before {{
200
+ content: '';
201
+ position: absolute;
202
+ top: 0; left: 0;
203
+ width: 3px; height: 100%;
204
+ border-radius: 12px 0 0 12px;
205
+ }}
206
+ .comm-card:hover {{
207
+ border-color: {C['border_hi']};
208
+ transform: translateY(-1px);
209
+ box-shadow: 0 8px 24px rgba(0,0,0,0.4);
210
+ }}
211
+ .comm-card.active {{
212
+ border-color: {C['accent']} !important;
213
+ background: linear-gradient(135deg, {C['surface']} 0%, rgba(61,127,255,0.05) 100%);
214
+ }}
215
+ .comm-card.up::before {{ background: {C['up']}; }}
216
+ .comm-card.down::before {{ background: {C['down']}; }}
217
+ .comm-card.stable::before {{ background: {C['stable']}; }}
218
+
219
+ /* Signal pill */
220
+ .signal-pill {{
221
+ display: inline-block;
222
+ padding: 2px 8px;
223
+ border-radius: 20px;
224
+ font-size: 0.68rem;
225
+ font-weight: 600;
226
+ letter-spacing: 0.04em;
227
+ text-transform: uppercase;
228
+ }}
229
+
230
+ /* Macro bar */
231
+ .macro-item {{
232
+ text-align: center;
233
+ padding: 10px 16px;
234
+ background: {C['surface']};
235
+ border: 1px solid {C['border']};
236
+ border-radius: 8px;
237
+ }}
238
+ .macro-label {{ font-size: 0.65rem; color: {C['text3']}; letter-spacing: 0.08em; text-transform: uppercase; margin-bottom: 3px; }}
239
+ .macro-value {{ font-size: 1.05rem; font-weight: 600; font-family: 'JetBrains Mono', monospace; color: {C['text']}; }}
240
+ .macro-change {{ font-size: 0.68rem; margin-top: 2px; }}
241
+
242
+ /* AI report */
243
+ .ai-report {{
244
+ background: linear-gradient(135deg, {C['surface2']} 0%, rgba(61,127,255,0.04) 100%);
245
+ border: 1px solid {C['border']};
246
+ border-left: 3px solid {C['accent']};
247
+ border-radius: 10px;
248
+ padding: 16px 20px;
249
+ line-height: 1.7;
250
+ font-size: 0.9rem;
251
+ color: {C['text']};
252
+ }}
253
+
254
+ /* News row */
255
+ .news-row {{
256
+ padding: 10px 0;
257
+ border-bottom: 1px solid {C['border']};
258
+ display: flex;
259
+ align-items: flex-start;
260
+ gap: 12px;
261
+ }}
262
+
263
+ /* COT bar */
264
+ .cot-label {{ font-size: 0.7rem; color: {C['text2']}; margin-bottom: 4px; }}
265
+ .cot-bar-wrap {{
266
+ height: 6px;
267
+ background: {C['surface2']};
268
+ border-radius: 3px;
269
+ overflow: hidden;
270
+ margin-bottom: 10px;
271
+ }}
272
+
273
+ /* Section header */
274
+ .section-header {{
275
+ display: flex;
276
+ align-items: center;
277
+ gap: 10px;
278
+ margin-bottom: 12px;
279
+ padding-bottom: 8px;
280
+ border-bottom: 1px solid {C['border']};
281
+ }}
282
+ .section-title {{
283
+ font-size: 0.7rem;
284
+ font-weight: 600;
285
+ letter-spacing: 0.12em;
286
+ text-transform: uppercase;
287
+ color: {C['text2']};
288
+ }}
289
+ .section-dot {{ width: 6px; height: 6px; border-radius: 50%; background: {C['accent']}; }}
290
+
291
+ /* Confidence arc */
292
+ .conf-badge {{
293
+ display: inline-flex;
294
+ align-items: center;
295
+ gap: 5px;
296
+ padding: 4px 10px;
297
+ border-radius: 20px;
298
+ font-size: 0.72rem;
299
+ font-weight: 600;
300
+ letter-spacing: 0.05em;
301
+ }}
302
+ </style>
303
+ """, unsafe_allow_html=True)
304
+
305
+
306
+ # ── data loaders ───────────────────────────────────────────────────────────────
307
+
308
+ @st.cache_resource
309
+ def _ensure_schema():
310
+ init_schema()
311
+
312
+ @st.cache_data(ttl=3600)
313
+ def _load_forecast(symbol: str) -> dict:
314
+ return predict(symbol)
315
+
316
+ @st.cache_data(ttl=3600)
317
+ def _load_all_forecasts(symbols: tuple) -> dict:
318
+ return {s: _load_forecast(s) for s in symbols}
319
+
320
+ @st.cache_data(ttl=3600)
321
+ def _load_price_history(symbol: str, days: int = 90) -> pd.DataFrame:
322
+ conn = get_conn()
323
+ df = conn.execute(
324
+ "SELECT date, open, high, low, close FROM prices "
325
+ "WHERE symbol = ? AND date >= ? ORDER BY date",
326
+ [symbol, (date.today() - timedelta(days=days)).isoformat()],
327
+ ).df()
328
+ conn.close()
329
+ return df
330
+
331
+ @st.cache_data(ttl=3600)
332
+ def _load_sentiment_history(symbol: str, days: int = 60) -> pd.DataFrame:
333
+ conn = get_conn()
334
+ df = conn.execute(
335
+ "SELECT date, sentiment_score, article_count FROM sentiment_daily "
336
+ "WHERE commodity = ? AND date >= ? ORDER BY date",
337
+ [symbol, (date.today() - timedelta(days=days)).isoformat()],
338
+ ).df()
339
+ conn.close()
340
+ return df
341
+
342
+ @st.cache_data(ttl=3600)
343
+ def _load_cot_history(symbol: str, weeks: int = 104) -> pd.DataFrame:
344
+ conn = get_conn()
345
+ df = conn.execute(
346
+ "SELECT date, commercial_net_pct, mm_net_pct, open_interest "
347
+ "FROM cot_data WHERE symbol = ? ORDER BY date DESC LIMIT ?",
348
+ [symbol, weeks],
349
+ ).df()
350
+ conn.close()
351
+ return df.sort_values("date").reset_index(drop=True) if not df.empty else df
352
+
353
+ @st.cache_data(ttl=3600)
354
+ def _load_macro_env() -> dict:
355
+ conn = get_conn()
356
+ try:
357
+ row = conn.execute(
358
+ "SELECT dxy, vix, treasury_10y, fedfunds, financial_stress, copper_basis "
359
+ "FROM fred_data WHERE dxy IS NOT NULL ORDER BY date DESC LIMIT 1"
360
+ ).fetchone()
361
+ except Exception:
362
+ row = None
363
+ conn.close()
364
+ if row:
365
+ return {"dxy": row[0], "vix": row[1], "t10y": row[2],
366
+ "fedfunds": row[3], "stress": row[4], "copper_basis": row[5]}
367
+ return {}
368
+
369
+ @st.cache_data(ttl=3600)
370
+ def _load_recent_news(symbol: str, limit: int = 15) -> pd.DataFrame:
371
+ conn = get_conn()
372
+ df = conn.execute(
373
+ "SELECT published_date, title, url, sentiment_score FROM news_raw "
374
+ "WHERE commodity_tags LIKE ? ORDER BY published_date DESC LIMIT ?",
375
+ [f"%{symbol}%", limit],
376
+ ).df()
377
+ conn.close()
378
+ return df
379
+
380
+ @st.cache_data(ttl=3600)
381
+ def _load_weather(symbol: str) -> dict:
382
+ from signals.weather_features import get_weather_features
383
+ return get_weather_features(symbol, days=30)
384
+
385
+ @st.cache_data(ttl=3600)
386
+ def _load_eia_history(series: str, weeks: int = 52) -> pd.DataFrame:
387
+ conn = get_conn()
388
+ df = conn.execute(
389
+ "SELECT date, value, chg_1w, vs_5yr_avg FROM eia_inventory "
390
+ "WHERE series = ? ORDER BY date DESC LIMIT ?",
391
+ [series, weeks],
392
+ ).df()
393
+ conn.close()
394
+ return df.sort_values("date").reset_index(drop=True) if not df.empty else df
395
+
396
+ # ── header ─────────────────────────────────────────────────────────────────────
397
+
398
+ def _render_header():
399
+ now = datetime.now()
400
+ st.markdown(f"""
401
+ <div style="display:flex;align-items:center;justify-content:space-between;
402
+ padding:16px 0 12px 0;border-bottom:1px solid {C['border']};margin-bottom:0;">
403
+ <div style="display:flex;align-items:center;gap:14px;">
404
+ <div style="font-size:1.6rem;font-weight:700;letter-spacing:-0.02em;
405
+ background:linear-gradient(135deg,{C['text']} 0%,{C['accent']} 100%);
406
+ -webkit-background-clip:text;-webkit-text-fill-color:transparent;">
407
+ ◈ CommodiSense
408
+ </div>
409
+ <div style="display:flex;align-items:center;gap:5px;
410
+ background:{C['surface']};border:1px solid {C['border']};
411
+ border-radius:20px;padding:3px 10px;">
412
+ <div style="width:6px;height:6px;border-radius:50%;background:{C['up']};
413
+ box-shadow:0 0 6px {C['up']};animation:pulse 2s infinite;"></div>
414
+ <span style="font-size:0.68rem;color:{C['up']};font-weight:600;letter-spacing:0.06em;">LIVE</span>
415
+ </div>
416
+ </div>
417
+ <div style="text-align:right;">
418
+ <div style="font-size:0.7rem;color:{C['text3']};letter-spacing:0.06em;text-transform:uppercase;">
419
+ Global Commodity Intelligence
420
+ </div>
421
+ <div style="font-size:0.78rem;color:{C['text2']};font-family:'JetBrains Mono',monospace;">
422
+ {now.strftime('%a %d %b %Y %H:%M')} UTC
423
+ </div>
424
+ </div>
425
+ </div>
426
+ <style>
427
+ @keyframes pulse {{
428
+ 0%,100% {{ opacity:1; }} 50% {{ opacity:0.4; }}
429
+ }}
430
+ </style>
431
+ """, unsafe_allow_html=True)
432
+
433
+
434
+ # ── ticker strip ───────────────────────────────────────────────────────────────
435
+
436
+ def _render_ticker(forecasts: dict, horizon_key: str):
437
+ fk = "forecast_7d" if horizon_key == "7d" else "forecast_30d"
438
+ items_html = ""
439
+ for sym in ALL_SYMBOLS:
440
+ fc = forecasts.get(sym, {})
441
+ if "error" in fc or not fc:
442
+ continue
443
+ f = fc.get(fk, {})
444
+ price = fc.get("current_price", 0)
445
+ dir_ = f.get("direction", "STABLE")
446
+ prob = f.get("probability", 0)
447
+ icon = DIR_ICON.get(dir_, "◆")
448
+ col = DIR_COLOR.get(dir_, C["stable"])
449
+ name = SYMBOL_NAMES.get(sym, sym).upper()
450
+ items_html += f"""
451
+ <div class="ticker-item">
452
+ <span style="color:{C['text3']};font-size:0.65rem;">{sym}</span>
453
+ <span style="color:{C['text']};font-weight:500;">{name}</span>
454
+ <span style="color:{C['text2']};">${price:,.2f}</span>
455
+ <span style="color:{col};font-weight:600;">{icon} {prob:.0%}</span>
456
+ </div>"""
457
+
458
+ # Double for seamless loop
459
+ st.markdown(f"""
460
+ <div class="ticker-wrap">
461
+ <div class="ticker-inner">{items_html}{items_html}</div>
462
+ </div>
463
+ """, unsafe_allow_html=True)
464
+
465
+
466
+ # ── macro environment bar ──────────────────────────────────────────────────────
467
+
468
+ def _render_macro_bar():
469
+ macro = _load_macro_env()
470
+ if not macro:
471
+ return
472
+
473
+ def _change_html(val, neutral=0, invert=False, fmt=".2f", suffix=""):
474
+ if val is None:
475
+ return ""
476
+ diff = val - neutral
477
+ if invert:
478
+ diff = -diff
479
+ col = C["up"] if diff > 0 else (C["down"] if diff < 0 else C["stable"])
480
+ sign = "+" if diff > 0 else ""
481
+ return f'<span style="color:{col}">{sign}{diff:{fmt}}{suffix}</span>'
482
+
483
+ vix = macro.get("vix") or 0
484
+ vix_regime = "HIGH FEAR" if vix > 30 else ("CAUTION" if vix > 20 else "CALM")
485
+ vix_col = C["down"] if vix > 30 else (C["gold"] if vix > 20 else C["up"])
486
+
487
+ dxy = macro.get("dxy") or 0
488
+ t10y = macro.get("t10y") or 0
489
+ ff = macro.get("fedfunds") or 0
490
+ yield_inv = t10y < ff
491
+ spread = t10y - ff
492
+
493
+ st.markdown(f"""
494
+ <div style="display:grid;grid-template-columns:repeat(6,1fr);gap:8px;margin-bottom:20px;">
495
+ <div class="macro-item">
496
+ <div class="macro-label">USD Index (DXY)</div>
497
+ <div class="macro-value">{dxy:.1f}</div>
498
+ <div class="macro-change" style="color:{C['text3']}">Broad USD Strength</div>
499
+ </div>
500
+ <div class="macro-item">
501
+ <div class="macro-label">VIX Volatility</div>
502
+ <div class="macro-value" style="color:{vix_col}">{vix:.1f}</div>
503
+ <div class="macro-change" style="color:{vix_col}">{vix_regime}</div>
504
+ </div>
505
+ <div class="macro-item">
506
+ <div class="macro-label">10Y Treasury</div>
507
+ <div class="macro-value">{t10y:.2f}%</div>
508
+ <div class="macro-change" style="color:{C['text3']}">US Yield</div>
509
+ </div>
510
+ <div class="macro-item">
511
+ <div class="macro-label">Fed Funds</div>
512
+ <div class="macro-value">{ff:.2f}%</div>
513
+ <div class="macro-change" style="color:{C['text3']}">Policy Rate</div>
514
+ </div>
515
+ <div class="macro-item">
516
+ <div class="macro-label">Yield Spread</div>
517
+ <div class="macro-value" style="color:{C['down'] if yield_inv else C['up']}">{spread:+.2f}%</div>
518
+ <div class="macro-change" style="color:{C['down'] if yield_inv else C['text3']}">
519
+ {'⚠ INVERTED' if yield_inv else 'Normal'}
520
+ </div>
521
+ </div>
522
+ <div class="macro-item">
523
+ <div class="macro-label">Copper 3M Trend</div>
524
+ <div class="macro-value" style="color:{C['up'] if (macro.get('copper_basis') or 0) > 0 else C['down']}">
525
+ {(macro.get('copper_basis') or 0):+.1f}%
526
+ </div>
527
+ <div class="macro-change" style="color:{C['text3']}">Industrial Demand</div>
528
+ </div>
529
+ </div>
530
+ """, unsafe_allow_html=True)
531
+
532
+
533
+ # ── commodity grid ─────────────────────────────────────────────────────────────
534
+
535
+ def _render_commodity_grid(forecasts: dict, horizon_key: str, active_sym: str) -> str | None:
536
+ fk = "forecast_7d" if horizon_key == "7d" else "forecast_30d"
537
+
538
+ st.markdown(f"""
539
+ <div class="section-header">
540
+ <div class="section-dot"></div>
541
+ <div class="section-title">Market Overview — {horizon_key.upper()} Forecast</div>
542
+ </div>
543
+ """, unsafe_allow_html=True)
544
+
545
+ clicked = None
546
+ rows = [ALL_SYMBOLS[i:i+5] for i in range(0, len(ALL_SYMBOLS), 5)]
547
+
548
+ for row_syms in rows:
549
+ cols = st.columns(len(row_syms))
550
+ for col, sym in zip(cols, row_syms):
551
+ fc = forecasts.get(sym, {})
552
+ f = fc.get(fk, {}) if fc and "error" not in fc else {}
553
+ dir_ = f.get("direction", "STABLE")
554
+ conf = f.get("confidence", "LOW")
555
+ prob = f.get("probability", 0.5)
556
+ price = fc.get("current_price", 0) if fc else 0
557
+ name = SYMBOL_NAMES.get(sym, sym)
558
+ icon = DIR_ICON.get(dir_, "◆")
559
+ dcol = DIR_COLOR.get(dir_, C["stable"])
560
+ ddim = DIR_DICT = DIR_DIM.get(dir_, C["stable_dim"])
561
+ ccol = CONF_COLOR.get(conf, C["conf_low"])
562
+ is_active = sym == active_sym
563
+ warn = fc.get("forecast_7d", {}).get("model_warning") if fc else None
564
+
565
+ with col:
566
+ st.markdown(f"""
567
+ <div class="comm-card {dir_.lower()} {'active' if is_active else ''}"
568
+ style="background:linear-gradient(145deg,{C['surface']} 0%,{ddim} 100%);">
569
+ <div style="display:flex;justify-content:space-between;align-items:flex-start;margin-bottom:8px;">
570
+ <div>
571
+ <div style="font-size:0.65rem;color:{C['text3']};letter-spacing:0.08em;font-family:'JetBrains Mono',monospace;">{sym}</div>
572
+ <div style="font-size:0.88rem;font-weight:600;color:{C['text']};margin-top:1px;">{name}</div>
573
+ </div>
574
+ <div style="background:{ccol}22;border:1px solid {ccol}44;border-radius:4px;
575
+ padding:2px 6px;font-size:0.6rem;font-weight:700;color:{ccol};
576
+ letter-spacing:0.06em;">{conf}</div>
577
+ </div>
578
+ <div style="font-size:1.05rem;font-weight:600;color:{C['text']};
579
+ font-family:'JetBrains Mono',monospace;margin-bottom:6px;">
580
+ ${price:,.2f}
581
+ </div>
582
+ <div style="display:flex;align-items:center;gap:6px;">
583
+ <span style="font-size:1.5rem;color:{dcol};font-weight:700;line-height:1;">{icon}</span>
584
+ <div>
585
+ <div style="font-size:0.82rem;color:{dcol};font-weight:600;">{dir_}</div>
586
+ <div style="font-size:0.65rem;color:{C['text3']};">{prob:.0%} probability</div>
587
+ </div>
588
+ </div>
589
+ {'<div style="margin-top:6px;font-size:0.62rem;color:' + C["gold"] + ';background:' + C["gold"] + '15;border-radius:3px;padding:2px 5px;">⚠ Use 30d model</div>' if warn and horizon_key == "7d" else ''}
590
+ </div>
591
+ """, unsafe_allow_html=True)
592
+
593
+ if st.button("Analyze →", key=f"btn_{sym}", use_container_width=True):
594
+ clicked = sym
595
+
596
+ return clicked
597
+
598
+
599
+ # ── deep dive ──────────────────────────────────────────────────────────────────
600
+
601
+ def _price_chart(symbol: str, days: int, fc: dict, horizon_key: str):
602
+ df = _load_price_history(symbol, days)
603
+ if df.empty:
604
+ st.info("No price history — run the price collector.")
605
+ return
606
+
607
+ fk = "forecast_7d" if horizon_key == "7d" else "forecast_30d"
608
+ fcast = fc.get(fk, {})
609
+ dir_ = fcast.get("direction", "STABLE")
610
+ dcol = DIR_COLOR.get(dir_, C["stable"])
611
+ low = fcast.get("price_range_low")
612
+ high = fcast.get("price_range_high")
613
+
614
+ fig = go.Figure()
615
+
616
+ fig.add_trace(go.Candlestick(
617
+ x=df["date"], open=df["open"], high=df["high"],
618
+ low=df["low"], close=df["close"], name="Price",
619
+ increasing=dict(line=dict(color=C["up"], width=1), fillcolor=C["up_dim"]),
620
+ decreasing=dict(line=dict(color=C["down"], width=1), fillcolor=C["down_dim"]),
621
+ ))
622
+
623
+ # 20-day SMA
624
+ df["sma20"] = df["close"].rolling(20, min_periods=1).mean()
625
+ fig.add_trace(go.Scatter(
626
+ x=df["date"], y=df["sma20"], mode="lines",
627
+ line=dict(color=C["accent"], width=1.2, dash="dot"),
628
+ name="SMA 20", opacity=0.6,
629
+ ))
630
+
631
+ # Forecast zone
632
+ if low and high and not df.empty:
633
+ last_date = pd.to_datetime(df["date"].max())
634
+ fwd = last_date + timedelta(days=7 if horizon_key == "7d" else 30)
635
+ fig.add_shape(type="rect",
636
+ x0=str(last_date.date()), x1=str(fwd.date()),
637
+ y0=low, y1=high,
638
+ fillcolor=dcol, opacity=0.10,
639
+ line=dict(color=dcol, width=1, dash="dot"),
640
+ )
641
+ fig.add_annotation(
642
+ x=str(fwd.date()), y=(low + high) / 2,
643
+ text=f" {DIR_ICON.get(dir_,'')} {dir_} {fcast.get('probability',0):.0%}",
644
+ showarrow=False, font=dict(color=dcol, size=11, family="JetBrains Mono"),
645
+ bgcolor=C["surface2"], bordercolor=dcol,
646
+ )
647
+
648
+ fig.update_layout(
649
+ template="plotly_dark",
650
+ paper_bgcolor=C["bg"], plot_bgcolor=C["bg"],
651
+ xaxis_rangeslider_visible=False,
652
+ height=360,
653
+ margin=dict(l=0, r=0, t=8, b=0),
654
+ legend=dict(orientation="h", x=0, y=1.06, font=dict(size=10, color=C["text2"])),
655
+ xaxis=dict(gridcolor=C["surface2"], showgrid=True),
656
+ yaxis=dict(gridcolor=C["surface2"], showgrid=True),
657
+ font=dict(family="Inter", color=C["text2"]),
658
+ )
659
+ st.plotly_chart(fig, use_container_width=True)
660
+
661
+
662
+ def _shap_chart(fc: dict):
663
+ signals = fc.get("top_signals", [])
664
+ if not signals:
665
+ st.caption("No SHAP signals — retrain models to enable.")
666
+ return
667
+
668
+ labels = [s.get("label", s.get("feature", ""))[:32] for s in signals]
669
+ weights = [s["weight"] if s["impact"] == "BULLISH" else -s["weight"] for s in signals]
670
+ colors = [C["up"] if w > 0 else C["down"] for w in weights]
671
+
672
+ fig = go.Figure(go.Bar(
673
+ x=weights, y=labels, orientation="h",
674
+ marker=dict(color=colors, opacity=0.85),
675
+ text=[f"{'▲' if w>0 else '▼'} {abs(w):.3f}" for w in weights],
676
+ textposition="outside", textfont=dict(size=10, family="JetBrains Mono", color=C["text2"]),
677
+ ))
678
+ fig.update_layout(
679
+ template="plotly_dark",
680
+ paper_bgcolor=C["bg"], plot_bgcolor=C["bg"],
681
+ title=dict(text="Top Signal Drivers (SHAP)", font=dict(size=11, color=C["text2"])),
682
+ xaxis=dict(gridcolor=C["surface2"], zeroline=True, zerolinecolor=C["border_hi"],
683
+ showticklabels=False),
684
+ yaxis=dict(gridcolor="transparent"),
685
+ height=260, margin=dict(l=0, r=40, t=32, b=0),
686
+ showlegend=False,
687
+ )
688
+ st.plotly_chart(fig, use_container_width=True)
689
+
690
+
691
+ def _cot_chart(symbol: str):
692
+ df = _load_cot_history(symbol)
693
+ if df.empty:
694
+ st.caption("No COT data for this symbol.")
695
+ return
696
+
697
+ fig = go.Figure()
698
+ fig.add_trace(go.Scatter(
699
+ x=df["date"], y=df["commercial_net_pct"] * 100,
700
+ mode="lines", fill="tozeroy",
701
+ line=dict(color=C["up"], width=1.5),
702
+ fillcolor="rgba(0,217,126,0.08)",
703
+ name="Commercial (Smart $)",
704
+ ))
705
+ fig.add_trace(go.Scatter(
706
+ x=df["date"], y=df["mm_net_pct"] * 100,
707
+ mode="lines", fill="tozeroy",
708
+ line=dict(color=C["accent"], width=1.5),
709
+ fillcolor="rgba(61,127,255,0.08)",
710
+ name="Managed Money",
711
+ ))
712
+ fig.add_hline(y=0, line_dash="dot", line_color=C["border_hi"], line_width=1)
713
+ fig.update_layout(
714
+ template="plotly_dark",
715
+ paper_bgcolor=C["bg"], plot_bgcolor=C["bg"],
716
+ title=dict(text="COT Positioning — % of Open Interest", font=dict(size=11, color=C["text2"])),
717
+ xaxis=dict(gridcolor=C["surface2"]),
718
+ yaxis=dict(gridcolor=C["surface2"], ticksuffix="%"),
719
+ height=220, margin=dict(l=0, r=0, t=32, b=0),
720
+ legend=dict(orientation="h", x=0, y=1.12, font=dict(size=10)),
721
+ font=dict(family="Inter", color=C["text2"]),
722
+ )
723
+ st.plotly_chart(fig, use_container_width=True)
724
+
725
+
726
+ def _sentiment_chart(symbol: str):
727
+ df = _load_sentiment_history(symbol)
728
+ if df.empty:
729
+ st.caption("No sentiment data — run the NLP processor.")
730
+ return
731
+
732
+ colors = [C["up"] if float(s) > 0.1 else (C["down"] if float(s) < -0.1 else C["stable"])
733
+ for s in df["sentiment_score"].fillna(0)]
734
+
735
+ fig = go.Figure()
736
+ fig.add_hrect(y0=0.1, y1=1, fillcolor=C["up_dim"], opacity=1, line_width=0)
737
+ fig.add_hrect(y0=-1, y1=-0.1, fillcolor=C["down_dim"], opacity=1, line_width=0)
738
+ fig.add_trace(go.Scatter(
739
+ x=df["date"], y=df["sentiment_score"],
740
+ mode="lines+markers",
741
+ line=dict(color=C["text2"], width=1.5),
742
+ marker=dict(color=colors, size=5),
743
+ fill="tozeroy", fillcolor="rgba(139,148,158,0.06)",
744
+ name="Sentiment",
745
+ ))
746
+ fig.add_hline(y=0, line_dash="solid", line_color=C["border_hi"], line_width=1)
747
+ fig.update_layout(
748
+ template="plotly_dark",
749
+ paper_bgcolor=C["bg"], plot_bgcolor=C["bg"],
750
+ title=dict(text="News Sentiment (60-day)", font=dict(size=11, color=C["text2"])),
751
+ yaxis=dict(range=[-1, 1], gridcolor=C["surface2"], tickformat=".1f"),
752
+ xaxis=dict(gridcolor=C["surface2"]),
753
+ height=200, margin=dict(l=0, r=0, t=32, b=0),
754
+ showlegend=False,
755
+ )
756
+ st.plotly_chart(fig, use_container_width=True)
757
+
758
+
759
+ def _eia_chart(symbol: str):
760
+ series = {"CL=F": "crude_stocks", "NG=F": "natgas_storage"}.get(symbol)
761
+ if not series:
762
+ return
763
+ df = _load_eia_history(series)
764
+ if df.empty:
765
+ return
766
+
767
+ label = "Crude Oil Stocks (Mbbls)" if symbol == "CL=F" else "Natural Gas Storage (Bcf)"
768
+ div = 1000 if symbol == "CL=F" else 1
769
+
770
+ fig = go.Figure()
771
+ fig.add_trace(go.Bar(
772
+ x=df["date"], y=df["value"] / div,
773
+ name=label,
774
+ marker=dict(
775
+ color=[C["down_dim"] if (v or 0) > 0 else C["up_dim"] for v in df.get("chg_1w", [])],
776
+ line=dict(width=0),
777
+ ),
778
+ opacity=0.8,
779
+ ))
780
+ fig.add_trace(go.Scatter(
781
+ x=df["date"], y=(df["value"] / div).rolling(4).mean(),
782
+ mode="lines", line=dict(color=C["accent"], width=1.5, dash="dot"),
783
+ name="4-wk avg",
784
+ ))
785
+ fig.update_layout(
786
+ template="plotly_dark",
787
+ paper_bgcolor=C["bg"], plot_bgcolor=C["bg"],
788
+ title=dict(text=label, font=dict(size=11, color=C["text2"])),
789
+ height=200, margin=dict(l=0, r=0, t=32, b=0),
790
+ legend=dict(orientation="h", x=0, y=1.15, font=dict(size=10)),
791
+ xaxis=dict(gridcolor=C["surface2"]),
792
+ yaxis=dict(gridcolor=C["surface2"]),
793
+ )
794
+ st.plotly_chart(fig, use_container_width=True)
795
+
796
+
797
+ def _render_deep_dive(symbol: str, days: int, horizon_key: str):
798
+ fc = _load_forecast(symbol)
799
+ name = SYMBOL_NAMES.get(symbol, symbol)
800
+
801
+ if "error" in fc:
802
+ st.warning(f"No forecast for {name} — run `python model/trainer.py --symbol {symbol}`")
803
+ return
804
+
805
+ fk = "forecast_7d" if horizon_key == "7d" else "forecast_30d"
806
+ fcast = fc.get(fk, {})
807
+ dir_ = fcast.get("direction", "STABLE")
808
+ prob = fcast.get("probability", 0.5)
809
+ conf = fcast.get("confidence", "LOW")
810
+ price = fc.get("current_price", 0)
811
+ dcol = DIR_COLOR.get(dir_, C["stable"])
812
+ ddim = DIR_DIM.get(dir_, C["stable_dim"])
813
+ icon = DIR_ICON.get(dir_, "◆")
814
+ ccol = CONF_COLOR.get(conf, C["conf_low"])
815
+ warn = fcast.get("model_warning")
816
+
817
+ # Breadcrumb + headline
818
+ st.markdown(f"""
819
+ <div style="display:flex;align-items:center;gap:8px;margin-bottom:16px;">
820
+ <div style="font-size:0.7rem;color:{C['text3']};letter-spacing:0.08em;">ANALYSIS</div>
821
+ <div style="font-size:0.7rem;color:{C['text3']};">›</div>
822
+ <div style="font-size:0.85rem;font-weight:600;color:{C['text']};">{name}</div>
823
+ <div style="font-size:0.65rem;color:{C['text3']};font-family:'JetBrains Mono',monospace;">{symbol}</div>
824
+ </div>
825
+ <div style="display:flex;align-items:center;gap:16px;padding:18px 20px;
826
+ background:linear-gradient(135deg,{C['surface']} 0%,{ddim} 100%);
827
+ border:1px solid {dcol}44;border-radius:12px;margin-bottom:16px;">
828
+ <div style="font-size:3rem;color:{dcol};line-height:1;">{icon}</div>
829
+ <div>
830
+ <div style="font-size:1.9rem;font-weight:700;color:{C['text']};font-family:'JetBrains Mono',monospace;">
831
+ ${price:,.2f}
832
+ </div>
833
+ <div style="display:flex;align-items:center;gap:8px;margin-top:4px;">
834
+ <span style="font-size:1.1rem;font-weight:700;color:{dcol};">{dir_}</span>
835
+ <span style="background:{ccol}22;border:1px solid {ccol}55;color:{ccol};
836
+ font-size:0.72rem;font-weight:700;padding:3px 8px;border-radius:20px;">
837
+ {conf} CONF
838
+ </span>
839
+ <span style="font-size:0.85rem;color:{C['text2']};">{prob:.1%} probability · {horizon_key.upper()}</span>
840
+ </div>
841
+ {f'<div style="margin-top:6px;font-size:0.72rem;color:{C["gold"]};background:{C["gold"]}18;padding:4px 10px;border-radius:6px;display:inline-block;">⚠ {warn}</div>' if warn else ''}
842
+ </div>
843
+ {f'''<div style="margin-left:auto;text-align:right;">
844
+ <div style="font-size:0.65rem;color:{C["text3"]};text-transform:uppercase;letter-spacing:0.1em;">Price Target Range</div>
845
+ <div style="font-size:1.1rem;font-weight:600;color:{C["text"]};font-family:'JetBrains Mono',monospace;">
846
+ ${fcast.get("price_range_low",0):,.0f} – ${fcast.get("price_range_high",0):,.0f}
847
+ </div>
848
+ </div>''' if fcast.get("price_range_low") else ''}
849
+ </div>
850
+ """, unsafe_allow_html=True)
851
+
852
+ # Main layout: chart | signals
853
+ chart_col, signal_col = st.columns([3, 2])
854
+
855
+ with chart_col:
856
+ st.markdown(f'<div class="section-header"><div class="section-dot"></div><div class="section-title">Price Chart</div></div>', unsafe_allow_html=True)
857
+ _price_chart(symbol, days, fc, horizon_key)
858
+
859
+ with signal_col:
860
+ st.markdown(f'<div class="section-header"><div class="section-dot"></div><div class="section-title">Signal Drivers</div></div>', unsafe_allow_html=True)
861
+ _shap_chart(fc)
862
+
863
+ # Both 7d and 30d forecast side by side
864
+ f7 = fc.get("forecast_7d", {})
865
+ f30 = fc.get("forecast_30d", {})
866
+ st.markdown(f"""
867
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:8px;margin-top:8px;">
868
+ <div style="background:{C['surface2']};border:1px solid {C['border']};border-radius:8px;padding:10px;text-align:center;">
869
+ <div style="font-size:0.6rem;color:{C['text3']};letter-spacing:0.1em;text-transform:uppercase;margin-bottom:4px;">7-Day</div>
870
+ <div style="font-size:1.1rem;font-weight:700;color:{DIR_COLOR.get(f7.get('direction','STABLE'),C['stable'])};">
871
+ {DIR_ICON.get(f7.get('direction','STABLE'),'◆')} {f7.get('direction','—')}
872
+ </div>
873
+ <div style="font-size:0.7rem;color:{C['text3']};">{f7.get('probability',0):.0%}</div>
874
+ </div>
875
+ <div style="background:{C['surface2']};border:1px solid {C['border']};border-radius:8px;padding:10px;text-align:center;">
876
+ <div style="font-size:0.6rem;color:{C['text3']};letter-spacing:0.1em;text-transform:uppercase;margin-bottom:4px;">30-Day</div>
877
+ <div style="font-size:1.1rem;font-weight:700;color:{DIR_COLOR.get(f30.get('direction','STABLE'),C['stable'])};">
878
+ {DIR_ICON.get(f30.get('direction','STABLE'),'◆')} {f30.get('direction','—')}
879
+ </div>
880
+ <div style="font-size:0.7rem;color:{C['text3']};">{f30.get('probability',0):.0%}</div>
881
+ </div>
882
+ </div>
883
+ """, unsafe_allow_html=True)
884
+
885
+ # Tabbed data panels
886
+ tab_labels = ["COT Positioning", "Sentiment", "EIA Inventory", "Weather", "AI Report"]
887
+ tabs = st.tabs(tab_labels)
888
+
889
+ with tabs[0]:
890
+ _cot_chart(symbol)
891
+
892
+ with tabs[1]:
893
+ _sentiment_chart(symbol)
894
+
895
+ with tabs[2]:
896
+ _eia_chart(symbol)
897
+ if symbol not in ("CL=F", "NG=F"):
898
+ st.caption("EIA inventory data is available for Crude Oil (CL=F) and Natural Gas (NG=F) only.")
899
+
900
+ with tabs[3]:
901
+ weather = _load_weather(symbol)
902
+ if weather and weather.get("drought_index", 0) > 0:
903
+ w1, w2, w3 = st.columns(3)
904
+ w1.metric("Drought Index", f"{weather['drought_index']:.2f}", help="0=normal, 1=extreme drought")
905
+ w2.metric("Heat Stress Days", weather["heat_stress_days"])
906
+ w3.metric("Precip Anomaly", f"{weather['precip_anomaly_pct']:+.1f}%")
907
+ else:
908
+ st.caption("No weather data available. Weather signals apply to agricultural commodities.")
909
+
910
+ with tabs[4]:
911
+ reports = load_latest_reports()
912
+ report_text = reports.get(symbol, "")
913
+ if not report_text:
914
+ with st.spinner("Generating AI analysis..."):
915
+ report_text = generate_report(fc)
916
+ if report_text:
917
+ st.markdown(f'<div class="ai-report">🤖&nbsp; <strong>AI Analyst</strong><br><br>{report_text}</div>', unsafe_allow_html=True)
918
+ else:
919
+ st.caption("AI report unavailable — set GROQ_API_KEY in your .env file.")
920
+
921
+
922
+ # ── news feed ──────────────────────────────────────────────────────────────────
923
+
924
+ def _render_news(symbol: str):
925
+ st.markdown(f"""
926
+ <div class="section-header" style="margin-top:8px;">
927
+ <div class="section-dot"></div>
928
+ <div class="section-title">Recent News — {SYMBOL_NAMES.get(symbol, symbol)}</div>
929
+ </div>
930
+ """, unsafe_allow_html=True)
931
+
932
+ df = _load_recent_news(symbol)
933
+ if df.empty:
934
+ st.caption("No news data — run the news collector.")
935
+ return
936
+
937
+ for _, row in df.iterrows():
938
+ score = float(row.get("sentiment_score") or 0)
939
+ scol = C["up"] if score > 0.1 else (C["down"] if score < -0.1 else C["stable"])
940
+ sign = "+" if score > 0 else ""
941
+ title = str(row.get("title", ""))[:120]
942
+ url = str(row.get("url", "#"))
943
+ pub = str(row.get("published_date", ""))[:10]
944
+
945
+ st.markdown(f"""
946
+ <div class="news-row">
947
+ <div style="min-width:80px;font-size:0.68rem;color:{C['text3']};
948
+ font-family:'JetBrains Mono',monospace;padding-top:1px;">{pub}</div>
949
+ <div style="min-width:42px;text-align:center;">
950
+ <span style="background:{scol}22;color:{scol};border-radius:4px;
951
+ padding:2px 6px;font-size:0.68rem;font-weight:600;
952
+ font-family:'JetBrains Mono',monospace;">{sign}{score:.2f}</span>
953
+ </div>
954
+ <div style="flex:1;font-size:0.84rem;color:{C['text']};">
955
+ <a href="{url}" target="_blank"
956
+ style="color:{C['text']};text-decoration:none;"
957
+ onmouseover="this.style.color='{C['accent']}'"
958
+ onmouseout="this.style.color='{C['text']}'">{title}</a>
959
+ </div>
960
+ </div>
961
+ """, unsafe_allow_html=True)
962
+
963
+
964
+ # ── sidebar controls ───────────────────────────────────────────────────────────
965
+
966
+ def _render_sidebar() -> tuple[str, int]:
967
+ with st.sidebar:
968
+ st.markdown(f"""
969
+ <div style="padding:12px 0 16px 0;border-bottom:1px solid {C['border']};margin-bottom:16px;">
970
+ <div style="font-size:1.1rem;font-weight:700;
971
+ background:linear-gradient(135deg,{C['text']} 0%,{C['accent']} 100%);
972
+ -webkit-background-clip:text;-webkit-text-fill-color:transparent;">
973
+ ◈ CommodiSense
974
+ </div>
975
+ <div style="font-size:0.65rem;color:{C['text3']};margin-top:3px;letter-spacing:0.06em;">
976
+ COMMODITY INTELLIGENCE
977
+ </div>
978
+ </div>
979
+ """, unsafe_allow_html=True)
980
+
981
+ horizon = st.radio("Forecast Horizon", ["7d", "30d"], index=0,
982
+ format_func=lambda x: "7-Day" if x == "7d" else "30-Day")
983
+
984
+ days = st.slider("Chart History", 30, 365, 90, step=15,
985
+ format="%d days")
986
+
987
+ st.markdown("---")
988
+
989
+ if st.button("↺ Refresh Data", use_container_width=True):
990
+ st.cache_data.clear()
991
+ st.rerun()
992
+
993
+ st.markdown(f"""
994
+ <div style="margin-top:16px;">
995
+ <div style="font-size:0.65rem;color:{C['text3']};letter-spacing:0.08em;
996
+ text-transform:uppercase;margin-bottom:10px;">Data Sources</div>
997
+ """, unsafe_allow_html=True)
998
+
999
+ sources = [
1000
+ ("Prices", "yfinance", "12,613 rows"),
1001
+ ("COT", "CFTC", "8,826 rows"),
1002
+ ("Macro", "FRED", "7,193 rows"),
1003
+ ("EIA", "DOE", "3,134 rows"),
1004
+ ("USDA", "NASS", "1,104 rows"),
1005
+ ("News", "GDELT", "392 articles"),
1006
+ ("Weather", "Open-Meteo", "210 rows"),
1007
+ ]
1008
+ for name, src, count in sources:
1009
+ st.markdown(f"""
1010
+ <div style="display:flex;justify-content:space-between;align-items:center;
1011
+ padding:5px 0;border-bottom:1px solid {C['border']};">
1012
+ <div style="font-size:0.72rem;color:{C['text2']};font-weight:500;">{name}</div>
1013
+ <div style="text-align:right;">
1014
+ <div style="font-size:0.62rem;color:{C['text3']};font-family:'JetBrains Mono',monospace;">{count}</div>
1015
+ </div>
1016
+ </div>
1017
+ """, unsafe_allow_html=True)
1018
+
1019
+ st.markdown("</div>", unsafe_allow_html=True)
1020
+
1021
+ st.markdown(f"""
1022
+ <div style="margin-top:20px;padding:10px;background:{C['surface']};
1023
+ border:1px solid {C['border']};border-radius:8px;font-size:0.65rem;
1024
+ color:{C['text3']};line-height:1.6;">
1025
+ <div style="color:{C['text2']};font-weight:600;margin-bottom:4px;">Pipeline</div>
1026
+ GitHub Actions · Mon–Fri 06:00 UTC<br>
1027
+ XGBoost + LightGBM ensemble<br>
1028
+ SHAP explainability · FinBERT NLP
1029
+ </div>
1030
+ """, unsafe_allow_html=True)
1031
+
1032
+ return horizon, days
1033
+
1034
+
1035
+ # ── main ───────────────────────────────────────────────────────────────────────
1036
+
1037
+ def main():
1038
+ _ensure_schema()
1039
+ _inject_css()
1040
+ _render_header()
1041
+
1042
+ horizon, days = _render_sidebar()
1043
+
1044
+ # Load all forecasts at once
1045
+ forecasts = _load_all_forecasts(tuple(ALL_SYMBOLS))
1046
+
1047
+ # Ticker strip
1048
+ _render_ticker(forecasts, horizon)
1049
+
1050
+ # Macro environment
1051
+ _render_macro_bar()
1052
+
1053
+ # Commodity grid — track active symbol in session state
1054
+ clicked = _render_commodity_grid(forecasts, horizon,
1055
+ st.session_state.get("active_sym", ALL_SYMBOLS[0]))
1056
+
1057
+ if clicked:
1058
+ st.session_state["active_sym"] = clicked
1059
+
1060
+ active = st.session_state.get("active_sym")
1061
+ if not active:
1062
+ active = ALL_SYMBOLS[0]
1063
+ st.session_state["active_sym"] = active
1064
+
1065
+ # Divider
1066
+ st.markdown(f'<div style="height:1px;background:linear-gradient(90deg,transparent,{C["border_hi"]},transparent);margin:20px 0;"></div>', unsafe_allow_html=True)
1067
+
1068
+ # Deep dive
1069
+ _render_deep_dive(active, days, horizon)
1070
+
1071
+ # News
1072
+ st.markdown(f'<div style="height:1px;background:linear-gradient(90deg,transparent,{C["border_hi"]},transparent);margin:20px 0 16px;"></div>', unsafe_allow_html=True)
1073
+ _render_news(active)
1074
+
1075
+
1076
+ if __name__ == "__main__":
1077
+ main()
model/__init__.py ADDED
File without changes
model/explainer.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Explainer — generates plain-English 3-sentence forecast reports using
3
+ Groq API (llama-3.3-70b-versatile, free tier: 14,400 req/day).
4
+ Falls back to a deterministic template if Groq is unavailable.
5
+
6
+ Reports are cached to data/reports/report_{date}.json.
7
+
8
+ Usage:
9
+ python model/explainer.py --symbol ZW=F
10
+ python model/explainer.py --all
11
+ """
12
+
13
+ import json
14
+ import logging
15
+ import os
16
+ import sys
17
+ from datetime import date
18
+ from pathlib import Path
19
+
20
+ sys.path.insert(0, str(Path(__file__).parent.parent))
21
+ from model.predictor import predict, predict_all, SYMBOL_NAMES
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+ REPORTS_DIR = Path(__file__).parent.parent / "data" / "reports"
26
+ REPORTS_DIR.mkdir(parents=True, exist_ok=True)
27
+
28
+ GROQ_MODEL = "llama-3.3-70b-versatile"
29
+
30
+ # ── Groq client (lazy) ─────────────────────────────────────────────────────────
31
+
32
+ _groq_client = None
33
+
34
+
35
+ def _get_groq_client():
36
+ global _groq_client
37
+ if _groq_client is not None:
38
+ return _groq_client
39
+ api_key = os.getenv("GROQ_API_KEY")
40
+ if not api_key:
41
+ log.warning("GROQ_API_KEY not set — using template fallback")
42
+ return None
43
+ try:
44
+ from groq import Groq
45
+ _groq_client = Groq(api_key=api_key)
46
+ return _groq_client
47
+ except ImportError:
48
+ log.warning("groq package not installed — using template fallback")
49
+ return None
50
+
51
+
52
+ # ── helpers ────────────────────────────────────────────────────────────────────
53
+
54
+
55
+ def _format_signals(signals: list[dict]) -> str:
56
+ """Format top signals as numbered list for the LLM prompt."""
57
+ lines = []
58
+ for i, sig in enumerate(signals[:5], 1):
59
+ label = sig.get("label", sig.get("feature", "unknown"))
60
+ value = sig.get("value", 0)
61
+ impact = sig.get("impact", "NEUTRAL")
62
+ weight = sig.get("weight", 0)
63
+ lines.append(f" {i}. {label}: {value:.3g} | Impact: {impact} | Weight: {weight:.3f}")
64
+ return "\n".join(lines) if lines else " (no signal data available)"
65
+
66
+
67
+ def _pick_risk_factor(prediction: dict) -> str:
68
+ """Return the top bearish signal as the risk factor for the report."""
69
+ signals = prediction.get("top_signals", [])
70
+ bearish = [s for s in signals if s.get("impact") == "BEARISH"]
71
+ if bearish:
72
+ return bearish[0].get("label", "adverse signal reversal")
73
+ # Generic risks per commodity type
74
+ symbol = prediction.get("symbol", "")
75
+ risk_map = {
76
+ "CL=F": "unexpected OPEC output increase",
77
+ "NG=F": "warmer-than-expected winter forecast",
78
+ "GC=F": "stronger-than-expected US jobs data",
79
+ "ZW=F": "favourable Black Sea weather reducing supply fears",
80
+ "ZC=F": "USDA upward crop estimate revision",
81
+ "ZS=F": "Brazil harvest beating expectations",
82
+ "CT=F": "recovery in monsoon rainfall",
83
+ "SB=F": "Brazil supply-side recovery",
84
+ "USDINR=X":"RBI unexpected rate cut",
85
+ "HG=F": "China demand slowdown data",
86
+ }
87
+ return risk_map.get(symbol, "unexpected policy reversal")
88
+
89
+
90
+ def _template_report(prediction: dict) -> str:
91
+ """
92
+ Deterministic template-based report. Used when Groq is unavailable.
93
+ No LLM needed — readable and fast.
94
+ """
95
+ name = prediction.get("commodity_name", prediction.get("symbol", "Commodity"))
96
+ fc7 = prediction.get("forecast_7d", {})
97
+ fc30 = prediction.get("forecast_30d", {})
98
+ direction= fc7.get("direction", "STABLE")
99
+ prob = fc7.get("probability", 0.5)
100
+ conf = fc7.get("confidence", "LOW")
101
+ dir30 = fc30.get("direction", "STABLE")
102
+ prob30 = fc30.get("probability", 0.5)
103
+ signals = prediction.get("top_signals", [])
104
+
105
+ sig1 = signals[0] if len(signals) > 0 else {}
106
+ sig2 = signals[1] if len(signals) > 1 else {}
107
+ s1_label = sig1.get("label", "market momentum")
108
+ s1_val = sig1.get("value", 0)
109
+ s2_label = sig2.get("label", "news sentiment")
110
+ s2_val = sig2.get("value", 0)
111
+ risk = _pick_risk_factor(prediction)
112
+
113
+ dir_phrase = {
114
+ "UP": "rise",
115
+ "DOWN": "fall",
116
+ "STABLE": "remain stable",
117
+ }.get(direction, "remain stable")
118
+
119
+ return (
120
+ f"{name} is forecast to {dir_phrase} over the next 7 days "
121
+ f"({prob:.0%} confidence, {conf}); 30-day view is {dir30} ({prob30:.0%}). "
122
+ f"Primary drivers are {s1_label} at {s1_val:.3g} and "
123
+ f"{s2_label} at {s2_val:.3g}. "
124
+ f"Key risk: {risk} could invalidate this forecast."
125
+ )
126
+
127
+
128
+ def _groq_report(prediction: dict) -> str:
129
+ """Call Groq API to generate a 3-sentence analyst report."""
130
+ client = _get_groq_client()
131
+ if client is None:
132
+ return _template_report(prediction)
133
+
134
+ name = prediction.get("commodity_name", prediction.get("symbol"))
135
+ price = prediction.get("current_price", 0)
136
+ fc7 = prediction.get("forecast_7d", {})
137
+ fc30 = prediction.get("forecast_30d", {})
138
+ signals = prediction.get("top_signals", [])
139
+
140
+ prompt = f"""You are a commodity market analyst. Based on the following data signals, write a 3-sentence forecast report. Be specific. Cite the signals. Use numbers.
141
+
142
+ Commodity: {name}
143
+ Current price: {price}
144
+ 7-day forecast: {fc7.get('direction')} with {fc7.get('probability', 0):.0%} confidence ({fc7.get('confidence')} tier)
145
+ 30-day forecast: {fc30.get('direction')} with {fc30.get('probability', 0):.0%} confidence
146
+
147
+ Top 5 driving signals:
148
+ {_format_signals(signals)}
149
+
150
+ Rules:
151
+ - Sentence 1: State the forecast and confidence level.
152
+ - Sentence 2: Name the top 2 signals and their specific values.
153
+ - Sentence 3: Name one risk factor that could invalidate this forecast.
154
+ - Write in plain English. No jargon. Max 80 words total.
155
+ - Do not use phrases like "based on the data" or "analysis suggests".
156
+ - Start directly: "{name} is forecast to..."
157
+ """
158
+
159
+ try:
160
+ response = client.chat.completions.create(
161
+ model=GROQ_MODEL,
162
+ messages=[{"role": "user", "content": prompt}],
163
+ max_tokens=150,
164
+ temperature=0.3,
165
+ )
166
+ text = response.choices[0].message.content.strip()
167
+ # Sanity: if response is empty or too short, fall back to template
168
+ if len(text) < 30:
169
+ return _template_report(prediction)
170
+ return text
171
+ except Exception as exc:
172
+ log.warning("Groq API error: %s — using template fallback", exc)
173
+ return _template_report(prediction)
174
+
175
+
176
+ # ── public API ─────────────────────────────────────────────────────────────────
177
+
178
+
179
+ def generate_report(prediction: dict) -> str:
180
+ """
181
+ Generate a plain-English 3-sentence forecast report for a commodity.
182
+
183
+ Uses Groq if GROQ_API_KEY is set, otherwise falls back to template.
184
+
185
+ Args:
186
+ prediction: Dict returned by predictor.predict()
187
+
188
+ Returns:
189
+ 3-sentence report string.
190
+ """
191
+ if "error" in prediction:
192
+ return f"{prediction.get('symbol', 'Commodity')}: forecast unavailable ({prediction['error']})."
193
+
194
+ return _groq_report(prediction)
195
+
196
+
197
+ def generate_all_reports(as_of_date: str = None) -> dict[str, str]:
198
+ """
199
+ Generate reports for all 10 commodities.
200
+ Calls predict() + generate_report() for each.
201
+ Caches results to data/reports/report_{date}.json.
202
+
203
+ Args:
204
+ as_of_date: ISO date string. Defaults to today.
205
+
206
+ Returns:
207
+ Dict mapping symbol → report string.
208
+ """
209
+ today = as_of_date or date.today().isoformat()
210
+ cache_path = REPORTS_DIR / f"report_{today}.json"
211
+
212
+ # Return cached reports if already generated today
213
+ if cache_path.exists():
214
+ log.info("Loading cached reports from %s", cache_path)
215
+ with open(cache_path) as f:
216
+ return json.load(f)
217
+
218
+ forecasts = predict_all(as_of_date)
219
+ reports: dict[str, str] = {}
220
+
221
+ for symbol, fc in forecasts.items():
222
+ report = generate_report(fc)
223
+ reports[symbol] = report
224
+ name = SYMBOL_NAMES.get(symbol, symbol)
225
+ log.info("%s: report generated", name)
226
+
227
+ # Cache to disk
228
+ with open(cache_path, "w") as f:
229
+ json.dump(reports, f, indent=2)
230
+ log.info("Reports saved to %s", cache_path)
231
+
232
+ return reports
233
+
234
+
235
+ def load_latest_reports() -> dict[str, str]:
236
+ """
237
+ Return the most recently generated report file, or empty dict if none.
238
+ Used by the dashboard to display reports without regenerating.
239
+ """
240
+ report_files = sorted(REPORTS_DIR.glob("report_*.json"), reverse=True)
241
+ if not report_files:
242
+ return {}
243
+ with open(report_files[0]) as f:
244
+ return json.load(f)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ import argparse
249
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
250
+
251
+ parser = argparse.ArgumentParser(description="CommodiSense explainer")
252
+ parser.add_argument("--symbol", default=None)
253
+ parser.add_argument("--all", action="store_true")
254
+ parser.add_argument("--date", default=None)
255
+ args = parser.parse_args()
256
+
257
+ if args.all:
258
+ reports = generate_all_reports(args.date)
259
+ for sym, report in reports.items():
260
+ print(f"\n[{sym}]\n{report}")
261
+ elif args.symbol:
262
+ fc = predict(args.symbol, args.date)
263
+ report = generate_report(fc)
264
+ print(report)
265
+ else:
266
+ parser.print_help()
model/feature_builder.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature Builder — assembles all signals (price, sentiment, events, weather,
3
+ geopolitical) into a single feature matrix per commodity.
4
+
5
+ CRITICAL: zero lookahead. All signal windows use T-1 to T-N only.
6
+ Target variable uses T+7 and T+30 prices (shifted forward, excluded from features).
7
+
8
+ Usage:
9
+ from model.feature_builder import build_training_data, build_prediction_features
10
+ """
11
+
12
+ import logging
13
+ import sys
14
+ from datetime import date, datetime, timedelta
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ sys.path.insert(0, str(Path(__file__).parent.parent))
21
+ from data.db import get_conn
22
+ from signals.price_features import build_feature_matrix, ALL_SYMBOLS
23
+ from signals.weather_features import get_weather_dataframe
24
+ from signals.macro_features import build_macro_dataframe, get_macro_features
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+ # Per-commodity direction thresholds — calibrated to each asset's typical volatility.
29
+ # USDINR is a managed float (rarely moves ±2% in 7 days → extreme STABLE imbalance).
30
+ # NG=F is highly volatile → needs wider threshold to avoid noise.
31
+ DIRECTION_THRESHOLDS: dict[str, float] = {
32
+ "CL=F": 2.0,
33
+ "NG=F": 3.5,
34
+ "GC=F": 1.5,
35
+ "ZW=F": 2.0,
36
+ "ZC=F": 2.0,
37
+ "ZS=F": 2.0,
38
+ "CT=F": 2.0,
39
+ "SB=F": 2.0,
40
+ "USDINR=X": 0.4,
41
+ "HG=F": 2.0,
42
+ }
43
+ DIRECTION_THRESHOLD_PCT = 2.0 # fallback
44
+
45
+ # ── helpers ────────────────────────────────────────────────────────────────────
46
+
47
+
48
+ def _load_prices_for_target(symbol: str) -> pd.DataFrame:
49
+ """Load close prices with enough future rows to compute T+7 and T+30 targets."""
50
+ conn = get_conn()
51
+ df = conn.execute(
52
+ "SELECT date, close FROM prices WHERE symbol = ? ORDER BY date",
53
+ [symbol],
54
+ ).df()
55
+ conn.close()
56
+ df["date"] = pd.to_datetime(df["date"]).dt.date
57
+ return df.sort_values("date").reset_index(drop=True)
58
+
59
+
60
+ def _compute_targets(price_df: pd.DataFrame, symbol: str = None) -> pd.DataFrame:
61
+ """
62
+ Compute direction_7d and direction_30d target columns.
63
+
64
+ Labels:
65
+ 1 (UP) if future price > current * 1.02
66
+ 0 (STABLE) if within ±2%
67
+ -1 (DOWN) if future price < current * 0.98
68
+ """
69
+ df = price_df.copy().sort_values("date").reset_index(drop=True)
70
+ closes = df["close"].values
71
+ threshold = DIRECTION_THRESHOLDS.get(symbol, DIRECTION_THRESHOLD_PCT) if symbol else DIRECTION_THRESHOLD_PCT
72
+
73
+ def _direction(current: float, future: float) -> int:
74
+ if future == 0 or current == 0:
75
+ return 0
76
+ chg = (future - current) / current * 100
77
+ if chg > threshold:
78
+ return 1
79
+ if chg < -threshold:
80
+ return -1
81
+ return 0
82
+
83
+ dir_7d, dir_30d = [], []
84
+ n = len(closes)
85
+ for i in range(n):
86
+ # Find the index approximately 7 / 30 trading days forward
87
+ # Use calendar-day shifted date to find the nearest actual price row
88
+ fwd7 = df[df["date"] >= (df.at[i, "date"] + timedelta(days=7))].head(1)
89
+ fwd30 = df[df["date"] >= (df.at[i, "date"] + timedelta(days=30))].head(1)
90
+
91
+ dir_7d.append(
92
+ _direction(closes[i], float(fwd7["close"].values[0])) if not fwd7.empty else None
93
+ )
94
+ dir_30d.append(
95
+ _direction(closes[i], float(fwd30["close"].values[0])) if not fwd30.empty else None
96
+ )
97
+
98
+ df["direction_7d"] = dir_7d
99
+ df["direction_30d"] = dir_30d
100
+ return df
101
+
102
+
103
+ def _load_sentiment_series(symbol: str) -> pd.DataFrame:
104
+ """Load daily sentiment aggregates for a commodity from DuckDB."""
105
+ conn = get_conn()
106
+ df = conn.execute(
107
+ """
108
+ SELECT date, sentiment_score, article_count, positive_count
109
+ FROM sentiment_daily
110
+ WHERE commodity = ?
111
+ ORDER BY date
112
+ """,
113
+ [symbol],
114
+ ).df()
115
+ conn.close()
116
+ if df.empty:
117
+ return df
118
+ df["date"] = pd.to_datetime(df["date"]).dt.date
119
+ df = df.sort_values("date").reset_index(drop=True)
120
+ # Rolling aggregates
121
+ df["sentiment_3d"] = df["sentiment_score"].rolling(3, min_periods=1).mean()
122
+ df["sentiment_7d"] = df["sentiment_score"].rolling(7, min_periods=1).mean()
123
+ df["article_count_7d"] = df["article_count"].rolling(7, min_periods=1).sum()
124
+ df["positive_ratio_7d"] = (
125
+ df["positive_count"].rolling(7, min_periods=1).sum()
126
+ / df["article_count_7d"].replace(0, 1)
127
+ )
128
+ return df.rename(columns={"sentiment_score": "sentiment_score_1d"})
129
+
130
+
131
+ def _load_event_series(symbol: str) -> pd.DataFrame:
132
+ """Load daily event aggregates for a commodity from DuckDB."""
133
+ conn = get_conn()
134
+ df = conn.execute(
135
+ """
136
+ SELECT date, event_type, direction, severity
137
+ FROM extracted_events
138
+ WHERE commodity = ?
139
+ ORDER BY date
140
+ """,
141
+ [symbol],
142
+ ).df()
143
+ conn.close()
144
+ if df.empty:
145
+ return pd.DataFrame()
146
+
147
+ df["date"] = pd.to_datetime(df["date"]).dt.date
148
+ df["dir_score"] = df["direction"].map({"BULLISH": 1, "BEARISH": -1, "NEUTRAL": 0}).fillna(0)
149
+
150
+ agg = df.groupby("date").agg(
151
+ bullish_events_7d =("direction", lambda x: int((x == "BULLISH").sum())),
152
+ bearish_events_7d =("direction", lambda x: int((x == "BEARISH").sum())),
153
+ max_severity_7d =("severity", "max"),
154
+ direction_score_7d =("dir_score", "sum"),
155
+ supply_shock_flag =("event_type", lambda x: int((x == "SUPPLY_SHOCK").any())),
156
+ policy_change_flag =("event_type", lambda x: int((x == "POLICY_CHANGE").any())),
157
+ ).reset_index()
158
+
159
+ # Rolling 7-day window for event counts
160
+ agg = agg.sort_values("date").reset_index(drop=True)
161
+ for col in ["bullish_events_7d", "bearish_events_7d", "direction_score_7d"]:
162
+ agg[col] = agg[col].rolling(7, min_periods=1).sum()
163
+ return agg
164
+
165
+
166
+ def _load_geo_series(symbol: str) -> pd.DataFrame:
167
+ """Load rolling geopolitical risk scores for a commodity."""
168
+ conn = get_conn()
169
+ df = conn.execute(
170
+ "SELECT date, risk_score FROM geopolitical_events WHERE commodity = ? ORDER BY date",
171
+ [symbol],
172
+ ).df()
173
+ conn.close()
174
+ if df.empty:
175
+ return pd.DataFrame()
176
+ df["date"] = pd.to_datetime(df["date"]).dt.date
177
+ agg = df.groupby("date")["risk_score"].mean().reset_index()
178
+ agg = agg.sort_values("date").reset_index(drop=True)
179
+ agg["risk_score_7d"] = agg["risk_score"].rolling(7, min_periods=1).mean()
180
+ agg["risk_score_30d"] = agg["risk_score"].rolling(30, min_periods=1).mean()
181
+ return agg[["date", "risk_score_7d", "risk_score_30d"]]
182
+
183
+
184
+ def _safe_merge(base: pd.DataFrame, other: pd.DataFrame, on: str = "date") -> pd.DataFrame:
185
+ """Left-join `other` onto `base`, filling NaN with 0."""
186
+ if other.empty:
187
+ return base
188
+ merged = base.merge(other, on=on, how="left")
189
+ merged = merged.fillna(0)
190
+ return merged
191
+
192
+
193
+ # ── public API ─────────────────────────────────────────────────────────────────
194
+
195
+
196
+ def build_training_data(
197
+ symbol: str,
198
+ ) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
199
+ """
200
+ Assemble the full feature matrix + targets for a commodity.
201
+
202
+ Uses all available history in DuckDB. No lookahead: signal features
203
+ reflect data known at close of each trading day.
204
+
205
+ Args:
206
+ symbol: Commodity ticker, e.g. "ZW=F"
207
+
208
+ Returns:
209
+ (X, y_7d, y_30d) where:
210
+ X — DataFrame, one row per date, all feature columns
211
+ y_7d — Series of direction labels {-1, 0, 1} for 7-day horizon
212
+ y_30d — Series of direction labels {-1, 0, 1} for 30-day horizon
213
+ """
214
+ # Price features (covers ~5yr history)
215
+ end_date = date.today().isoformat()
216
+ start_date = (date.today() - timedelta(days=365 * 5)).isoformat()
217
+ price_feat = build_feature_matrix(symbol, start_date, end_date)
218
+ if price_feat.empty:
219
+ log.warning("%s: no price features available", symbol)
220
+ return pd.DataFrame(), pd.Series(dtype=int), pd.Series(dtype=int)
221
+
222
+ # Targets — computed from raw close prices with per-commodity threshold
223
+ prices = _load_prices_for_target(symbol)
224
+ targets = _compute_targets(prices, symbol=symbol)[["date", "direction_7d", "direction_30d"]]
225
+
226
+ # All signal series
227
+ sentiment = _load_sentiment_series(symbol)
228
+ events = _load_event_series(symbol)
229
+ geo = _load_geo_series(symbol)
230
+ weather = get_weather_dataframe(symbol, days=365 * 5)
231
+ if not weather.empty:
232
+ weather["date"] = pd.to_datetime(weather["date"]).dt.date
233
+
234
+ macro = build_macro_dataframe(symbol, start_date, end_date)
235
+ if not macro.empty:
236
+ macro["date"] = pd.to_datetime(macro["date"]).dt.date
237
+
238
+ # Merge everything onto price_feat (left join → zero-fill missing signal days)
239
+ df = price_feat.copy()
240
+ df = _safe_merge(df, targets, on="date")
241
+ df = _safe_merge(df, sentiment[["date", "sentiment_score_1d", "sentiment_3d",
242
+ "sentiment_7d", "article_count_7d",
243
+ "positive_ratio_7d"]] if not sentiment.empty else pd.DataFrame(),
244
+ on="date")
245
+ df = _safe_merge(df, events, on="date")
246
+ df = _safe_merge(df, geo, on="date")
247
+ df = _safe_merge(df, weather, on="date")
248
+ df = _safe_merge(df, macro, on="date")
249
+
250
+ # Add binary indicator: 1 on days where we have real news signal, 0 elsewhere.
251
+ # This lets the model learn "trust sentiment when has_news_signal=1" rather than
252
+ # treating zero-padded sentiment rows as neutral-sentiment days.
253
+ if "sentiment_score_1d" in df.columns:
254
+ df["has_news_signal"] = (df["sentiment_score_1d"].abs() > 0.01).astype(int)
255
+ else:
256
+ df["has_news_signal"] = 0
257
+
258
+ # Drop rows where targets are unavailable (last 30 days have no T+30 target)
259
+ df = df.dropna(subset=["direction_7d", "direction_30d"])
260
+ df = df.sort_values("date").reset_index(drop=True)
261
+
262
+ feature_cols = [c for c in df.columns if c not in
263
+ ("date", "direction_7d", "direction_30d")]
264
+
265
+ X = df[feature_cols].fillna(0).astype(float)
266
+ y_7d = df["direction_7d"].astype(int)
267
+ y_30d = df["direction_30d"].astype(int)
268
+
269
+ log.info("%s: training data shape %s, class dist 7d: %s",
270
+ symbol, X.shape, y_7d.value_counts().to_dict())
271
+ return X, y_7d, y_30d
272
+
273
+
274
+ def build_prediction_features(symbol: str, as_of_date: str = None) -> pd.Series:
275
+ """
276
+ Build a single-row feature vector for inference.
277
+
278
+ Uses only data available up to (and including) as_of_date.
279
+ No future data touches this vector.
280
+
281
+ Args:
282
+ symbol: Commodity ticker
283
+ as_of_date: ISO date string. Defaults to today.
284
+
285
+ Returns:
286
+ pd.Series with the same feature names as build_training_data returns.
287
+ """
288
+ from signals.price_features import get_price_features
289
+ from signals.weather_features import get_weather_features
290
+
291
+ target_date = as_of_date or date.today().isoformat()
292
+
293
+ # Price features (T-1 based internally)
294
+ price_f = get_price_features(symbol, target_date)
295
+
296
+ # Sentiment: last 7 days before target_date
297
+ cutoff = (datetime.strptime(target_date, "%Y-%m-%d").date() - timedelta(days=7)).isoformat()
298
+ conn = get_conn()
299
+ sent_rows = conn.execute(
300
+ """
301
+ SELECT date, sentiment_score, article_count, positive_count
302
+ FROM sentiment_daily
303
+ WHERE commodity = ? AND date >= ? AND date <= ?
304
+ ORDER BY date DESC
305
+ """,
306
+ [symbol, cutoff, target_date],
307
+ ).df()
308
+ conn.close()
309
+
310
+ sentiment_1d = float(sent_rows.iloc[0]["sentiment_score"]) if not sent_rows.empty else 0.0
311
+ sentiment_3d = float(sent_rows.head(3)["sentiment_score"].mean()) if len(sent_rows) >= 1 else 0.0
312
+ sentiment_7d = float(sent_rows["sentiment_score"].mean()) if not sent_rows.empty else 0.0
313
+ article_count_7d = int(sent_rows["article_count"].sum()) if not sent_rows.empty else 0
314
+ positive_ratio_7d = (
315
+ float(sent_rows["positive_count"].sum() / max(article_count_7d, 1))
316
+ if not sent_rows.empty else 0.0
317
+ )
318
+
319
+ # Events: last 7 days
320
+ conn = get_conn()
321
+ evt_rows = conn.execute(
322
+ """
323
+ SELECT event_type, direction, severity
324
+ FROM extracted_events
325
+ WHERE commodity = ? AND date >= ? AND date <= ?
326
+ """,
327
+ [symbol, cutoff, target_date],
328
+ ).df()
329
+ conn.close()
330
+
331
+ bullish_events_7d = int((evt_rows["direction"] == "BULLISH").sum()) if not evt_rows.empty else 0
332
+ bearish_events_7d = int((evt_rows["direction"] == "BEARISH").sum()) if not evt_rows.empty else 0
333
+ max_severity_7d = int(evt_rows["severity"].max()) if not evt_rows.empty else 0
334
+ dir_map = {"BULLISH": 1, "BEARISH": -1, "NEUTRAL": 0}
335
+ direction_score_7d = int(evt_rows["direction"].map(dir_map).fillna(0).sum()) if not evt_rows.empty else 0
336
+ supply_shock_flag = int((evt_rows["event_type"] == "SUPPLY_SHOCK").any()) if not evt_rows.empty else 0
337
+ policy_change_flag = int((evt_rows["event_type"] == "POLICY_CHANGE").any()) if not evt_rows.empty else 0
338
+
339
+ # Geopolitical risk
340
+ cutoff_30 = (datetime.strptime(target_date, "%Y-%m-%d").date() - timedelta(days=30)).isoformat()
341
+ conn = get_conn()
342
+ geo_rows = conn.execute(
343
+ "SELECT risk_score FROM geopolitical_events WHERE commodity = ? AND date >= ? AND date <= ?",
344
+ [symbol, cutoff_30, target_date],
345
+ ).df()
346
+ conn.close()
347
+ risk_score_7d = float(geo_rows.tail(7)["risk_score"].mean()) if not geo_rows.empty else 0.05
348
+ risk_score_30d = float(geo_rows["risk_score"].mean()) if not geo_rows.empty else 0.05
349
+
350
+ # Weather
351
+ weather_f = get_weather_features(symbol, days=90)
352
+
353
+ macro_f = get_macro_features(symbol, target_date)
354
+
355
+ features = {
356
+ **price_f,
357
+ "sentiment_score_1d": sentiment_1d,
358
+ "sentiment_3d": sentiment_3d,
359
+ "sentiment_7d": sentiment_7d,
360
+ "article_count_7d": article_count_7d,
361
+ "positive_ratio_7d": positive_ratio_7d,
362
+ "bullish_events_7d": bullish_events_7d,
363
+ "bearish_events_7d": bearish_events_7d,
364
+ "max_severity_7d": max_severity_7d,
365
+ "direction_score_7d": direction_score_7d,
366
+ "supply_shock_flag": supply_shock_flag,
367
+ "policy_change_flag": policy_change_flag,
368
+ "risk_score_7d": risk_score_7d,
369
+ "risk_score_30d": risk_score_30d,
370
+ **weather_f,
371
+ **macro_f,
372
+ }
373
+
374
+ return pd.Series(features)
model/predictor.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Predictor — loads saved XGBoost + LightGBM models and generates forecasts
3
+ at inference time. Runs entirely on CPU.
4
+
5
+ Usage:
6
+ python model/predictor.py --symbol ZW=F
7
+ python model/predictor.py --all
8
+ """
9
+
10
+ import argparse
11
+ import json
12
+ import logging
13
+ import pickle
14
+ import sys
15
+ from datetime import date, datetime, timedelta
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ sys.path.insert(0, str(Path(__file__).parent.parent))
22
+ from model.feature_builder import build_prediction_features
23
+ from data.db import get_conn
24
+
25
+ log = logging.getLogger(__name__)
26
+
27
+ MODELS_DIR = Path(__file__).parent.parent / "models"
28
+
29
+ SYMBOL_NAMES: dict[str, str] = {
30
+ "CL=F": "Crude Oil",
31
+ "NG=F": "Natural Gas",
32
+ "GC=F": "Gold",
33
+ "ZW=F": "Wheat",
34
+ "ZC=F": "Corn",
35
+ "ZS=F": "Soybeans",
36
+ "CT=F": "Cotton",
37
+ "SB=F": "Sugar",
38
+ "USDINR=X":"USD/INR",
39
+ "HG=F": "Copper",
40
+ }
41
+
42
+ # Human-readable labels for SHAP feature display
43
+ FEATURE_LABELS: dict[str, str] = {
44
+ "rsi_14": "RSI (14-day)",
45
+ "macd_signal": "MACD crossover",
46
+ "bb_position": "Bollinger Band position",
47
+ "atr_14": "Average True Range",
48
+ "atr_pct": "Volatility %",
49
+ "sma_20_50_cross": "SMA 20/50 crossover",
50
+ "return_1d": "1-day return %",
51
+ "return_7d": "7-day return %",
52
+ "return_30d": "30-day return %",
53
+ "momentum_score": "Momentum score",
54
+ "month_sin": "Seasonal cycle (sin)",
55
+ "month_cos": "Seasonal cycle (cos)",
56
+ "harvest_season_flag": "Harvest season",
57
+ "days_to_opec_meeting":"Days to OPEC meeting",
58
+ "oil_gold_ratio": "Oil/Gold ratio",
59
+ "dxy_proxy": "USD strength proxy",
60
+ "sentiment_score_1d": "News sentiment (24h)",
61
+ "sentiment_3d": "News sentiment (3-day)",
62
+ "sentiment_7d": "News sentiment (7-day)",
63
+ "article_count_7d": "Article volume (7-day)",
64
+ "positive_ratio_7d": "Positive news ratio",
65
+ "bullish_events_7d": "Bullish events (7-day)",
66
+ "bearish_events_7d": "Bearish events (7-day)",
67
+ "max_severity_7d": "Max event severity",
68
+ "direction_score_7d": "Net event direction",
69
+ "supply_shock_flag": "Supply shock detected",
70
+ "policy_change_flag": "Policy change detected",
71
+ "risk_score_7d": "Geopolitical risk (7-day)",
72
+ "risk_score_30d": "Geopolitical risk (30-day)",
73
+ "drought_index": "Drought index",
74
+ "heat_stress_days": "Heat stress days",
75
+ "precip_anomaly_pct": "Precipitation anomaly %",
76
+ }
77
+
78
+ # Expected return by predicted direction (base, adjusted per-commodity)
79
+ DIRECTION_EXPECTED_RETURN: dict[str, float] = {
80
+ "UP": 3.0,
81
+ "STABLE": 0.0,
82
+ "DOWN": -3.0,
83
+ }
84
+
85
+ # ── model cache (loaded once per process) ─────────────────────────────────────
86
+
87
+ _model_cache: dict[str, dict] = {}
88
+
89
+
90
+ def _load_models(symbol: str, horizon: str = "7d") -> dict | None:
91
+ """
92
+ Load XGBoost, LightGBM, scaler, and feature names for a symbol.
93
+ Caches in memory for the process lifetime.
94
+
95
+ Returns None if models not found (not trained yet).
96
+ """
97
+ cache_key = f"{symbol}_{horizon}"
98
+ if cache_key in _model_cache:
99
+ return _model_cache[cache_key]
100
+
101
+ xgb_path = MODELS_DIR / f"xgb_{symbol}_{horizon}.pkl"
102
+ lgbm_path = MODELS_DIR / f"lgbm_{symbol}_{horizon}.pkl"
103
+ scaler_path = MODELS_DIR / f"scaler_{symbol}_{horizon}.pkl"
104
+ feat_path = MODELS_DIR / f"feature_names_{symbol}_{horizon}.json"
105
+
106
+ if not all(p.exists() for p in [xgb_path, lgbm_path, scaler_path, feat_path]):
107
+ log.warning("Models not found for %s %s — run model/trainer.py first", symbol, horizon)
108
+ return None
109
+
110
+ with open(xgb_path, "rb") as f:
111
+ xgb_model = pickle.load(f)
112
+ with open(lgbm_path, "rb") as f:
113
+ lgbm_model = pickle.load(f)
114
+ with open(scaler_path, "rb") as f:
115
+ scaler = pickle.load(f)
116
+ with open(feat_path) as f:
117
+ feature_names = json.load(f)
118
+
119
+ bundle = {
120
+ "xgb": xgb_model,
121
+ "lgbm": lgbm_model,
122
+ "scaler": scaler,
123
+ "features": feature_names,
124
+ }
125
+ _model_cache[cache_key] = bundle
126
+ return bundle
127
+
128
+
129
+ def _get_shap_top5(xgb_model, X_row: np.ndarray, feature_names: list[str], pred_class: int) -> list[dict]:
130
+ """
131
+ Compute SHAP values for XGBoost and return top 5 features by |shap_value|
132
+ for the predicted class.
133
+ """
134
+ try:
135
+ import shap
136
+ explainer = shap.TreeExplainer(xgb_model)
137
+ shap_vals = explainer.shap_values(X_row) # shape: (n_classes, n_features) or (1, n_classes, n_features)
138
+
139
+ # shap_values shape varies by XGBoost version
140
+ if isinstance(shap_vals, list):
141
+ vals = shap_vals[pred_class][0] # for predicted class
142
+ else:
143
+ vals = shap_vals[0, :, pred_class] if shap_vals.ndim == 3 else shap_vals[0]
144
+
145
+ top_idx = np.argsort(np.abs(vals))[::-1][:5]
146
+ result = []
147
+ for i in top_idx:
148
+ fname = feature_names[i] if i < len(feature_names) else f"feature_{i}"
149
+ fval = float(X_row[0][i])
150
+ shap_v = float(vals[i])
151
+ result.append({
152
+ "feature": fname,
153
+ "label": FEATURE_LABELS.get(fname, fname),
154
+ "value": round(fval, 4),
155
+ "impact": "BULLISH" if shap_v > 0 else "BEARISH",
156
+ "weight": round(abs(shap_v), 4),
157
+ })
158
+ return result
159
+
160
+ except Exception as exc:
161
+ log.debug("SHAP error: %s", exc)
162
+ return []
163
+
164
+
165
+ def _get_current_price(symbol: str) -> tuple[float, float]:
166
+ """Return (current_close, atr_pct) from latest DB row."""
167
+ conn = get_conn()
168
+ rows = conn.execute(
169
+ "SELECT close FROM prices WHERE symbol = ? ORDER BY date DESC LIMIT 2",
170
+ [symbol],
171
+ ).fetchall()
172
+ conn.close()
173
+ if not rows:
174
+ return 0.0, 0.02
175
+ close = float(rows[0][0])
176
+ # Rough ATR proxy: |today - yesterday| / today
177
+ atr_pct = abs(float(rows[0][0]) - float(rows[1][0])) / close if len(rows) > 1 and close > 0 else 0.02
178
+ return close, atr_pct
179
+
180
+
181
+ # ── public API ─────────────────────────────────────────────────────────────────
182
+
183
+
184
+ def predict(symbol: str, as_of_date: str = None) -> dict:
185
+ """
186
+ Generate a forecast for a single commodity.
187
+
188
+ Args:
189
+ symbol: Commodity ticker, e.g. "ZW=F"
190
+ as_of_date: ISO date string. Defaults to today.
191
+
192
+ Returns:
193
+ Forecast dict with symbol, current price, 7d + 30d forecasts,
194
+ top_signals, and confidence levels. Returns error dict if models
195
+ are not trained.
196
+ """
197
+ as_of = as_of_date or date.today().isoformat()
198
+
199
+ bundle_7d = _load_models(symbol, "7d")
200
+ bundle_30d = _load_models(symbol, "30d")
201
+
202
+ if bundle_7d is None:
203
+ return {"symbol": symbol, "error": "models_not_trained", "as_of_date": as_of}
204
+
205
+ # Build feature vector
206
+ features_series = build_prediction_features(symbol, as_of)
207
+ if features_series.empty:
208
+ return {"symbol": symbol, "error": "no_features", "as_of_date": as_of}
209
+
210
+ # Align to trained feature names
211
+ feat_names_7d = bundle_7d["features"]
212
+ X_raw = features_series.reindex(feat_names_7d, fill_value=0).values.reshape(1, -1)
213
+ X_scaled_7d = bundle_7d["scaler"].transform(pd.DataFrame(X_raw, columns=feat_names_7d))
214
+
215
+ # Ensemble prediction — 7d
216
+ X_df_7d = pd.DataFrame(X_scaled_7d, columns=feat_names_7d)
217
+ xgb_proba_7d = bundle_7d["xgb"].predict_proba(X_df_7d)[0]
218
+ lgbm_proba_7d = bundle_7d["lgbm"].predict_proba(X_df_7d)[0]
219
+ ensemble_proba_7d = (xgb_proba_7d + lgbm_proba_7d) / 2
220
+ pred_class_7d = int(ensemble_proba_7d.argmax())
221
+ # Map encoded class back: 0=DOWN, 1=STABLE, 2=UP
222
+ direction_map = {0: "DOWN", 1: "STABLE", 2: "UP"}
223
+ direction_7d = direction_map[pred_class_7d]
224
+ prob_7d = float(ensemble_proba_7d[pred_class_7d])
225
+
226
+ # Ensemble prediction — 30d (may not be trained)
227
+ direction_30d, prob_30d = "STABLE", 0.5
228
+ if bundle_30d:
229
+ feat_names_30d = bundle_30d["features"]
230
+ X_raw_30d = features_series.reindex(feat_names_30d, fill_value=0).values.reshape(1, -1)
231
+ X_scaled_30d = bundle_30d["scaler"].transform(pd.DataFrame(X_raw_30d, columns=feat_names_30d))
232
+ X_df_30d = pd.DataFrame(X_scaled_30d, columns=feat_names_30d)
233
+ xgb_proba_30d = bundle_30d["xgb"].predict_proba(X_df_30d)[0]
234
+ lgbm_proba_30d = bundle_30d["lgbm"].predict_proba(X_df_30d)[0]
235
+ ensemble_proba_30d = (xgb_proba_30d + lgbm_proba_30d) / 2
236
+ pred_class_30d = int(ensemble_proba_30d.argmax())
237
+ direction_30d = direction_map[pred_class_30d]
238
+ prob_30d = float(ensemble_proba_30d[pred_class_30d])
239
+
240
+ # Confidence tier — base probability threshold
241
+ def _confidence(prob: float) -> str:
242
+ if prob >= 0.70:
243
+ return "HIGH"
244
+ if prob >= 0.55:
245
+ return "MEDIUM"
246
+ return "LOW"
247
+
248
+ # High-confidence signal confirmation: require 2+ independent signals to agree.
249
+ # Signals checked: price momentum, COT commercial positioning, EIA/USDA flag.
250
+ def _confirmed_confidence(prob: float, direction: str, feat: pd.Series) -> str:
251
+ base = _confidence(prob)
252
+ if base == "LOW":
253
+ return "LOW"
254
+ confirming = 0
255
+ # Signal 1: price momentum agrees
256
+ mom = float(feat.get("momentum_score", 0) or 0)
257
+ ret7 = float(feat.get("return_7d", 0) or 0)
258
+ if direction == "UP" and (mom > 0 or ret7 > 0): confirming += 1
259
+ if direction == "DOWN" and (mom < 0 or ret7 < 0): confirming += 1
260
+ # Signal 2: COT commercial positioning agrees (commercials = smart money)
261
+ cot_net = float(feat.get("cot_commercial_net_pct", 0) or 0)
262
+ cot_chg = float(feat.get("cot_commercial_chg_1w", 0) or 0)
263
+ if direction == "UP" and (cot_net > 0.05 or cot_chg > 0): confirming += 1
264
+ if direction == "DOWN" and (cot_net < -0.05 or cot_chg < 0): confirming += 1
265
+ # Signal 3: EIA supply signal agrees (for CL=F and NG=F)
266
+ eia_draw = float(feat.get("eia_crude_draw", 0) or feat.get("eia_natgas_draw", 0) or 0)
267
+ eia_vs5yr = float(feat.get("eia_crude_vs_5yr", 0) or feat.get("eia_natgas_vs_5yr", 0) or 0)
268
+ if direction == "UP" and (eia_draw > 0 or eia_vs5yr < -0.5): confirming += 1
269
+ if direction == "DOWN" and eia_vs5yr > 0.5: confirming += 1
270
+ # Signal 4: USDA crop condition trend agrees (for grain/ag symbols)
271
+ crop_chg = float(feat.get("usda_crop_good_exc_chg", 0) or 0)
272
+ if direction == "DOWN" and crop_chg < -2: confirming += 1
273
+ if direction == "UP" and crop_chg > 2: confirming += 1
274
+ # Upgrade if 2+ signals confirm; downgrade if none confirm
275
+ if confirming >= 2 and base == "MEDIUM":
276
+ return "HIGH"
277
+ if confirming == 0 and base == "MEDIUM":
278
+ return "LOW"
279
+ return base
280
+
281
+ # Price range using ATR
282
+ current_price, atr_pct = _get_current_price(symbol)
283
+ exp_ret = DIRECTION_EXPECTED_RETURN.get(direction_7d, 0.0) / 100
284
+ price_range_low = round(current_price * (1 + exp_ret - 1.5 * atr_pct), 2)
285
+ price_range_high = round(current_price * (1 + exp_ret + 1.5 * atr_pct), 2)
286
+
287
+ # SHAP top signals
288
+ top_signals = _get_shap_top5(bundle_7d["xgb"], X_scaled_7d, feat_names_7d, pred_class_7d)
289
+
290
+ conf_7d = _confirmed_confidence(prob_7d, direction_7d, features_series)
291
+ conf_30d = _confirmed_confidence(prob_30d, direction_30d, features_series)
292
+
293
+ # Symbols where 7d model has known accuracy issues — surface a warning.
294
+ UNRELIABLE_7D = {"ZC=F", "HG=F"}
295
+ model_warning = (
296
+ "7d model accuracy is low for this symbol — use 30d forecast instead"
297
+ if symbol in UNRELIABLE_7D else None
298
+ )
299
+
300
+ return {
301
+ "symbol": symbol,
302
+ "commodity_name": SYMBOL_NAMES.get(symbol, symbol),
303
+ "as_of_date": as_of,
304
+ "current_price": current_price,
305
+ "forecast_7d": {
306
+ "direction": direction_7d,
307
+ "probability": round(prob_7d, 4),
308
+ "price_range_low": price_range_low,
309
+ "price_range_high": price_range_high,
310
+ "confidence": conf_7d,
311
+ "model_warning": model_warning,
312
+ },
313
+ "forecast_30d": {
314
+ "direction": direction_30d,
315
+ "probability": round(prob_30d, 4),
316
+ "confidence": conf_30d,
317
+ },
318
+ "top_signals": top_signals,
319
+ }
320
+
321
+
322
+ def predict_all(as_of_date: str = None) -> dict[str, dict]:
323
+ """
324
+ Generate forecasts for all 10 commodities and save to DuckDB.
325
+
326
+ Returns:
327
+ Dict mapping symbol → forecast dict.
328
+ """
329
+ from signals.price_features import ALL_SYMBOLS
330
+
331
+ results = {}
332
+ for symbol in ALL_SYMBOLS:
333
+ try:
334
+ fc = predict(symbol, as_of_date)
335
+ results[symbol] = fc
336
+ if "error" not in fc:
337
+ _save_forecast(fc)
338
+ except Exception as exc:
339
+ log.error("predict %s failed: %s", symbol, exc)
340
+ results[symbol] = {"symbol": symbol, "error": str(exc)}
341
+
342
+ return results
343
+
344
+
345
+ def _save_forecast(fc: dict) -> None:
346
+ """Persist a forecast to DuckDB for accuracy tracking."""
347
+ conn = get_conn()
348
+ try:
349
+ conn.execute(
350
+ """
351
+ INSERT OR REPLACE INTO accuracy_log
352
+ (date, symbol, forecast_direction, actual_direction, was_correct, confidence)
353
+ VALUES (?, ?, ?, NULL, NULL, ?)
354
+ """,
355
+ [
356
+ fc["as_of_date"],
357
+ fc["symbol"],
358
+ fc["forecast_7d"]["direction"],
359
+ fc["forecast_7d"]["confidence"],
360
+ ],
361
+ )
362
+ except Exception as exc:
363
+ log.debug("Forecast save error: %s", exc)
364
+ finally:
365
+ conn.close()
366
+
367
+
368
+ if __name__ == "__main__":
369
+ parser = argparse.ArgumentParser(description="CommodiSense predictor")
370
+ parser.add_argument("--symbol", default=None, help="Single symbol to predict")
371
+ parser.add_argument("--all", action="store_true", help="Predict all symbols")
372
+ parser.add_argument("--date", default=None, help="As-of date YYYY-MM-DD")
373
+ args = parser.parse_args()
374
+
375
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
376
+
377
+ if args.all:
378
+ results = predict_all(args.date)
379
+ for sym, fc in results.items():
380
+ if "error" not in fc:
381
+ d7 = fc["forecast_7d"]
382
+ print(f"{sym:<12} {d7['direction']:<7} {d7['probability']:.0%} [{d7['confidence']}]")
383
+ elif args.symbol:
384
+ fc = predict(args.symbol, args.date)
385
+ print(json.dumps(fc, indent=2, default=str))
386
+ else:
387
+ parser.print_help()
model/trainer.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Trainer — trains XGBoost + LightGBM ensemble per commodity.
3
+ Designed to run on Kaggle free notebooks (GPU available there) but
4
+ works on CPU locally.
5
+
6
+ IMPORTANT: Run this on Kaggle for GPU acceleration, or locally with CPU.
7
+ Saves trained models to models/ directory.
8
+
9
+ Usage:
10
+ python model/trainer.py # train all symbols
11
+ python model/trainer.py --symbol GC=F # train one symbol
12
+ python model/trainer.py --symbol ZW=F --horizon 7d
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import logging
18
+ import pickle
19
+ from datetime import date, timedelta
20
+ import sys
21
+ import warnings
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ from sklearn.calibration import CalibratedClassifierCV
27
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
28
+ from sklearn.model_selection import TimeSeriesSplit
29
+ from sklearn.preprocessing import StandardScaler
30
+
31
+ warnings.filterwarnings("ignore")
32
+
33
+ sys.path.insert(0, str(Path(__file__).parent.parent))
34
+ from model.feature_builder import build_training_data
35
+ from signals.price_features import ALL_SYMBOLS
36
+
37
+ MODELS_DIR = Path(__file__).parent.parent / "models"
38
+ MODELS_DIR.mkdir(exist_ok=True)
39
+
40
+ logging.basicConfig(
41
+ level=logging.INFO,
42
+ format="%(asctime)s %(levelname)s %(message)s",
43
+ )
44
+ log = logging.getLogger(__name__)
45
+
46
+ # Label encoding: -1 → 0 (DOWN), 0 → 1 (STABLE), 1 → 2 (UP) for XGBoost
47
+ LABEL_MAP = {-1: 0, 0: 1, 1: 2}
48
+ LABEL_REVERSE = {0: -1, 1: 0, 2: 1}
49
+ LABEL_NAMES = {0: "DOWN", 1: "STABLE", 2: "UP"}
50
+
51
+ # ── Phase 6: Booster 2 — commodity-specific feature weight multipliers ─────────
52
+ # Applied to sample weights at training time so the model learns that certain
53
+ # features matter more for specific commodities.
54
+ COMMODITY_FEATURE_WEIGHTS: dict[str, dict[str, float]] = {
55
+ "CL=F": {"risk_score_7d": 1.5, "risk_score_30d": 1.5, "days_to_opec_meeting": 1.4,
56
+ "drought_index": 0.5},
57
+ "NG=F": {"days_to_opec_meeting": 1.4, "return_60d": 1.3, "atr_14": 1.3},
58
+ "GC=F": {"dxy_proxy": 1.8, "risk_score_7d": 1.3, "sentiment_score_1d": 1.2},
59
+ "ZW=F": {"drought_index": 2.0, "sentiment_score_1d": 1.2, "precip_anomaly_pct": 1.5},
60
+ "ZC=F": {"harvest_season_flag": 1.5, "drought_index": 1.8, "precip_anomaly_pct": 1.4},
61
+ "ZS=F": {"harvest_season_flag": 1.5, "drought_index": 1.6, "precip_anomaly_pct": 1.3},
62
+ "CT=F": {"harvest_season_flag": 1.6, "heat_stress_days": 2.0, "precip_anomaly_pct": 1.5},
63
+ "SB=F": {"harvest_season_flag": 1.5, "precip_anomaly_pct": 1.4},
64
+ "USDINR=X": {"return_60d": 1.4, "momentum_score": 1.3, "macd_signal": 1.2},
65
+ "HG=F": {"risk_score_7d": 1.3, "return_60d": 1.4, "momentum_score": 1.2},
66
+ }
67
+
68
+ # ── model configs ──────────────────────────────────────────────────────────────
69
+
70
+ XGB_PARAMS = {
71
+ "n_estimators": 500,
72
+ "max_depth": 6,
73
+ "learning_rate": 0.05,
74
+ "subsample": 0.8,
75
+ "colsample_bytree": 0.8,
76
+ "objective": "multi:softprob",
77
+ "num_class": 3,
78
+ "eval_metric": "mlogloss",
79
+ "early_stopping_rounds": 50, # constructor param in XGBoost 3.x
80
+ "random_state": 42,
81
+ "n_jobs": -1,
82
+ }
83
+
84
+ LGBM_PARAMS = {
85
+ "n_estimators": 500,
86
+ "num_leaves": 31,
87
+ "learning_rate": 0.05,
88
+ "feature_fraction": 0.8,
89
+ "bagging_fraction": 0.8,
90
+ "bagging_freq": 5,
91
+ "objective": "multiclass",
92
+ "num_class": 3,
93
+ "metric": "multi_logloss",
94
+ "verbose": -1,
95
+ "random_state": 42,
96
+ "n_jobs": -1,
97
+ }
98
+
99
+ # ── helpers ────────────────────────────────────────────────────────────────────
100
+
101
+
102
+ def _encode_labels(y: pd.Series) -> np.ndarray:
103
+ return y.map(LABEL_MAP).values
104
+
105
+
106
+ def _compute_sample_weights(y_encoded: np.ndarray) -> np.ndarray:
107
+ """Inverse-frequency sample weights. Falls back to uniform if not all 3 classes present."""
108
+ from sklearn.utils.class_weight import compute_sample_weight
109
+ if len(np.unique(y_encoded)) < 3:
110
+ return np.ones(len(y_encoded), dtype=float)
111
+ return compute_sample_weight("balanced", y_encoded)
112
+
113
+
114
+ def _select_top_features(
115
+ X: pd.DataFrame,
116
+ importances: np.ndarray,
117
+ top_n: int = 20,
118
+ min_importance: float = 0.01,
119
+ ) -> list[str]:
120
+ """Return top_n feature names by importance, filtering below min_importance."""
121
+ feat_imp = pd.Series(importances, index=X.columns).sort_values(ascending=False)
122
+ selected = feat_imp[feat_imp >= min_importance].head(top_n).index.tolist()
123
+ if len(selected) < 5:
124
+ selected = feat_imp.head(top_n).index.tolist()
125
+ return selected
126
+
127
+
128
+ def _detect_regime(X: pd.DataFrame) -> np.ndarray:
129
+ """
130
+ Booster 3 — Regime Detection.
131
+ Returns per-row regime array: 0=RANGE_BOUND, 1=TRENDING, 2=VOLATILE.
132
+ Uses ATR% and absolute 30-day return to classify market state.
133
+ Only applied when X has enough rows to compute rolling stats (>60).
134
+ """
135
+ if len(X) < 60 or "atr_pct" not in X.columns:
136
+ return np.zeros(len(X), dtype=int)
137
+
138
+ atr_pct = X["atr_pct"].fillna(0)
139
+ ret_30d = X.get("return_30d", pd.Series(0, index=X.index)).abs().fillna(0)
140
+
141
+ atr_mean = atr_pct.rolling(60, min_periods=20).mean().fillna(atr_pct.mean())
142
+ atr_std = atr_pct.rolling(60, min_periods=20).std().fillna(atr_pct.std())
143
+ atr_thresh_volatile = atr_mean + 1.5 * atr_std
144
+
145
+ regime = np.zeros(len(X), dtype=int)
146
+ regime[ret_30d.values > 10.0] = 1 # TRENDING
147
+ regime[atr_pct.values > atr_thresh_volatile.values] = 2 # VOLATILE
148
+ return regime
149
+
150
+
151
+ def _apply_commodity_weights(
152
+ sample_weights: np.ndarray,
153
+ X: pd.DataFrame,
154
+ symbol: str,
155
+ regime: np.ndarray,
156
+ ) -> np.ndarray:
157
+ """
158
+ Booster 2+3 combined — scale sample weights by commodity-specific feature
159
+ importance multipliers, then dampen VOLATILE-regime rows (trust nothing when
160
+ the market is in a chaotic state).
161
+ """
162
+ w = sample_weights.copy().astype(float)
163
+
164
+ # Commodity-specific: up-weight rows where the key signal is strong
165
+ for feat, mult in COMMODITY_FEATURE_WEIGHTS.get(symbol, {}).items():
166
+ if feat in X.columns:
167
+ signal_strength = X[feat].abs().fillna(0)
168
+ percentile_75 = np.percentile(signal_strength, 75)
169
+ if percentile_75 > 0:
170
+ strong_rows = (signal_strength >= percentile_75).values
171
+ w[strong_rows] *= mult
172
+
173
+ # Regime: dampen volatile rows (Booster 3 — "trust nothing when volatile")
174
+ w[regime == 2] *= 0.6
175
+ # Trending rows: trust momentum features more — mild up-weight
176
+ w[regime == 1] *= 1.2
177
+
178
+ # Renormalise so total weight is unchanged
179
+ total = w.sum()
180
+ if total > 0:
181
+ w = w / total * len(w)
182
+ return w
183
+
184
+
185
+ def _directional_accuracy(y_true: np.ndarray, y_pred: np.ndarray) -> float:
186
+ """Accuracy of predicting UP/DOWN/STABLE direction correctly."""
187
+ return float(np.mean(y_true == y_pred))
188
+
189
+
190
+ def _sharpe_ratio(y_true_raw: pd.Series, y_pred_encoded: np.ndarray) -> float:
191
+ """
192
+ Naive Sharpe: long when model predicts UP, short when DOWN, flat when STABLE.
193
+ Uses true direction as proxy for daily return sign.
194
+ """
195
+ pred_dirs = pd.Series(y_pred_encoded).map(LABEL_REVERSE)
196
+ returns = pred_dirs * y_true_raw.values # +1 correct, -1 wrong
197
+ mu = returns.mean()
198
+ sigma = returns.std()
199
+ return round(float(mu / sigma * np.sqrt(252)) if sigma > 0 else 0.0, 3)
200
+
201
+
202
+ # ── training ───────────────────────────────────────────────────────────────────
203
+
204
+
205
+ def train_symbol(
206
+ symbol: str,
207
+ horizon: str = "7d",
208
+ add_lag_features: bool = True,
209
+ last_days: int = None,
210
+ ) -> dict:
211
+ """
212
+ Train XGBoost + LightGBM ensemble for a single commodity and horizon.
213
+
214
+ Args:
215
+ symbol: Commodity ticker, e.g. "ZW=F"
216
+ horizon: "7d" or "30d"
217
+ add_lag_features: Add interaction features (accuracy booster)
218
+ last_days: If set, train only on the most recent N calendar days.
219
+ Use this when NLP signals only cover a short window —
220
+ avoids 4+ years of zero-padded sentiment rows diluting the model.
221
+
222
+ Returns:
223
+ Dict with accuracy metrics for this symbol/horizon.
224
+ """
225
+ log.info("Training %s | horizon=%s | window=%s",
226
+ symbol, horizon, f"last {last_days}d" if last_days else "full")
227
+
228
+ X, y_7d, y_30d = build_training_data(symbol)
229
+ if X.empty:
230
+ log.warning("%s: no training data, skipping", symbol)
231
+ return {"symbol": symbol, "horizon": horizon, "error": "no_data"}
232
+
233
+ # Trim to short window — keeps only rows where NLP signals are non-zero
234
+ if last_days is not None:
235
+ cutoff = date.today() - timedelta(days=last_days)
236
+ if "date" in X.columns:
237
+ mask = pd.to_datetime(X["date"]).dt.date >= cutoff
238
+ else:
239
+ # date is the index order — take the last last_days * 0.7 rows (trading days)
240
+ trading_days = int(last_days * 0.71)
241
+ mask = pd.Series([False] * len(X))
242
+ mask.iloc[-trading_days:] = True
243
+ X = X[mask.values].reset_index(drop=True)
244
+ y_7d = y_7d[mask.values].reset_index(drop=True)
245
+ y_30d = y_30d[mask.values].reset_index(drop=True)
246
+ log.info("%s: trimmed to %d rows (last %d days)", symbol, len(X), last_days)
247
+
248
+ y = y_7d if horizon == "7d" else y_30d
249
+
250
+ # Skip if one class dominates >95% — model would just memorise the majority class
251
+ class_counts = y.value_counts(normalize=True)
252
+ if class_counts.max() > 0.95:
253
+ log.warning("%s %s: dominant class %.0f%% — skipping (too imbalanced to learn from)",
254
+ symbol, horizon, class_counts.max() * 100)
255
+ return {"symbol": symbol, "horizon": horizon, "error": "extreme_class_imbalance"}
256
+
257
+ # ── Phase 6 Booster 4: lag + interaction features ──
258
+ if add_lag_features:
259
+ X = X.copy()
260
+ # Interaction: sentiment × momentum (strong when both agree)
261
+ if "sentiment_score_1d" in X.columns and "momentum_score" in X.columns:
262
+ X["sentiment_x_momentum"] = X["sentiment_score_1d"] * X["momentum_score"]
263
+ # Interaction: event direction × price momentum
264
+ if "direction_score_7d" in X.columns and "return_7d" in X.columns:
265
+ X["event_x_momentum"] = X["direction_score_7d"] * np.sign(X["return_7d"].fillna(0))
266
+ # Volatility regime flag (standalone feature for the model)
267
+ if "atr_pct" in X.columns and len(X) >= 60:
268
+ atr_mean = X["atr_pct"].rolling(60, min_periods=20).mean().fillna(X["atr_pct"].mean())
269
+ X["high_volatility_flag"] = (X["atr_pct"] > atr_mean * 1.5).astype(int)
270
+
271
+ y_enc = _encode_labels(y)
272
+ sample_weights = _compute_sample_weights(y_enc)
273
+
274
+ # Phase 6 Boosters 2+3: regime detection + commodity-specific weights
275
+ if len(X) >= 60:
276
+ regime = _detect_regime(X)
277
+ sample_weights = _apply_commodity_weights(sample_weights, X, symbol, regime)
278
+ trending_pct = (regime == 1).mean() * 100
279
+ volatile_pct = (regime == 2).mean() * 100
280
+ log.info("%s: regime — %.0f%% trending, %.0f%% volatile, %.0f%% range-bound",
281
+ symbol, trending_pct, volatile_pct, 100 - trending_pct - volatile_pct)
282
+
283
+ # Short-window mode: use fewer folds + lighter model to avoid overfitting
284
+ is_short_window = last_days is not None and len(X) < 200
285
+ n_splits = 3 if is_short_window else 5
286
+ xgb_params_cv = {**XGB_PARAMS, "n_estimators": 200, "max_depth": 3} if is_short_window else XGB_PARAMS
287
+
288
+ tscv = TimeSeriesSplit(n_splits=n_splits)
289
+ fold_accs: list[float] = []
290
+ best_features: list[str] | None = None
291
+ last_fold_idx = n_splits - 1
292
+
293
+ # ── cross-validation to find stable feature set ──
294
+ for fold, (train_idx, val_idx) in enumerate(tscv.split(X)):
295
+ X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
296
+ y_train, y_val = y_enc[train_idx], y_enc[val_idx]
297
+ sw_train = sample_weights[train_idx]
298
+
299
+ # Skip folds where val set has fewer than 3 samples or missing classes
300
+ if len(y_val) < 3:
301
+ continue
302
+
303
+ scaler_fold = StandardScaler()
304
+ X_tr_s = scaler_fold.fit_transform(X_train)
305
+ X_vl_s = scaler_fold.transform(X_val)
306
+
307
+ import xgboost as xgb
308
+ xgb_fold = xgb.XGBClassifier(**xgb_params_cv)
309
+ xgb_fold.fit(
310
+ X_tr_s, y_train,
311
+ sample_weight=sw_train,
312
+ eval_set=[(X_vl_s, y_val)],
313
+ verbose=False,
314
+ )
315
+
316
+ fold_acc = accuracy_score(y_val, xgb_fold.predict(X_vl_s))
317
+ fold_accs.append(fold_acc)
318
+
319
+ if fold == last_fold_idx:
320
+ best_features = _select_top_features(X, xgb_fold.feature_importances_)
321
+
322
+ if not fold_accs:
323
+ return {"symbol": symbol, "horizon": horizon, "error": "all_folds_skipped"}
324
+
325
+ cv_accuracy = float(np.mean(fold_accs))
326
+ log.info("%s %s: CV accuracy %.3f (folds: %s)",
327
+ symbol, horizon, cv_accuracy, [f"{a:.3f}" for a in fold_accs])
328
+
329
+ # Short window: use lighter final model to avoid overfitting on small data
330
+ if is_short_window:
331
+ XGB_PARAMS_BOOSTED = {**XGB_PARAMS, "n_estimators": 300, "max_depth": 4, "learning_rate": 0.03}
332
+ LGBM_PARAMS_BOOSTED = {**LGBM_PARAMS, "n_estimators": 300, "num_leaves": 15}
333
+ elif cv_accuracy < 0.90 and add_lag_features:
334
+ log.info("%s: below 90%%, boosting n_estimators to 1000", symbol)
335
+ XGB_PARAMS_BOOSTED = {**XGB_PARAMS, "n_estimators": 1000}
336
+ LGBM_PARAMS_BOOSTED = {**LGBM_PARAMS, "n_estimators": 1000}
337
+ else:
338
+ XGB_PARAMS_BOOSTED = XGB_PARAMS
339
+ LGBM_PARAMS_BOOSTED = LGBM_PARAMS
340
+
341
+ # ── final training on full dataset using best_features ──
342
+ X_selected = X[best_features] if best_features else X
343
+
344
+ scaler = StandardScaler()
345
+ X_s = scaler.fit_transform(X_selected)
346
+
347
+ # Short window: 70/30 split to keep a meaningful test set; else 80/20
348
+ test_frac = 0.30 if is_short_window else 0.20
349
+ split = int(len(X_s) * (1 - test_frac))
350
+ X_train_f, X_test_f = X_s[:split], X_s[split:]
351
+ y_train_f, y_test_f = y_enc[:split], y_enc[split:]
352
+ sw_f = sample_weights[:split]
353
+
354
+ import xgboost as xgb
355
+ import lightgbm as lgb
356
+
357
+ xgb_model = xgb.XGBClassifier(**XGB_PARAMS_BOOSTED)
358
+ xgb_model.fit(
359
+ X_train_f, y_train_f,
360
+ sample_weight=sw_f,
361
+ eval_set=[(X_test_f, y_test_f)],
362
+ verbose=False,
363
+ )
364
+
365
+ lgbm_model = lgb.LGBMClassifier(**LGBM_PARAMS_BOOSTED)
366
+ lgbm_model.fit(
367
+ X_train_f, y_train_f,
368
+ sample_weight=sw_f,
369
+ eval_set=[(X_test_f, y_test_f)],
370
+ callbacks=[lgb.early_stopping(50, verbose=False), lgb.log_evaluation(period=-1)],
371
+ )
372
+
373
+ # Phase 6 Booster 5 — Platt/isotonic calibration on XGBoost
374
+ # Uses the test split as held-out calibration data (cv="prefit")
375
+ cal_cv = min(3, max(2, len(X_train_f) // 100))
376
+ try:
377
+ from sklearn.calibration import CalibratedClassifierCV
378
+ xgb_calibrated = CalibratedClassifierCV(xgb_model, method="isotonic", cv="prefit")
379
+ xgb_calibrated.fit(X_test_f, y_test_f)
380
+ except Exception:
381
+ xgb_calibrated = xgb_model # fallback: uncalibrated
382
+
383
+ # Soft-voting ensemble on test set (calibrated XGB + raw LGBM)
384
+ xgb_proba = xgb_calibrated.predict_proba(X_test_f)
385
+ lgbm_proba = lgbm_model.predict_proba(X_test_f)
386
+ ensemble_proba = (xgb_proba + lgbm_proba) / 2
387
+ ensemble_pred = ensemble_proba.argmax(axis=1)
388
+
389
+ test_accuracy = _directional_accuracy(y_test_f, ensemble_pred)
390
+ sharpe = _sharpe_ratio(y.iloc[split:].reset_index(drop=True), ensemble_pred)
391
+
392
+ # Classification report
393
+ report = classification_report(
394
+ y_test_f, ensemble_pred,
395
+ target_names=["DOWN", "STABLE", "UP"],
396
+ output_dict=True,
397
+ )
398
+
399
+ # Feature importance (top 10 for report)
400
+ top10_features = (
401
+ pd.Series(xgb_model.feature_importances_, index=X_selected.columns)
402
+ .sort_values(ascending=False)
403
+ .head(10)
404
+ .to_dict()
405
+ )
406
+
407
+ log.info("%s %s: test accuracy=%.3f, Sharpe=%.2f", symbol, horizon, test_accuracy, sharpe)
408
+
409
+ # ── save artifacts ──
410
+ with open(MODELS_DIR / f"xgb_{symbol}_{horizon}.pkl", "wb") as f:
411
+ pickle.dump(xgb_calibrated, f)
412
+ with open(MODELS_DIR / f"lgbm_{symbol}_{horizon}.pkl", "wb") as f:
413
+ pickle.dump(lgbm_model, f)
414
+ with open(MODELS_DIR / f"scaler_{symbol}_{horizon}.pkl", "wb") as f:
415
+ pickle.dump(scaler, f)
416
+ with open(MODELS_DIR / f"feature_names_{symbol}_{horizon}.json", "w") as f:
417
+ json.dump(X_selected.columns.tolist(), f)
418
+
419
+ return {
420
+ "symbol": symbol,
421
+ "horizon": horizon,
422
+ "cv_accuracy": round(cv_accuracy, 4),
423
+ "test_accuracy": round(test_accuracy, 4),
424
+ "sharpe_ratio": sharpe,
425
+ "n_features": len(X_selected.columns),
426
+ "n_train_samples": split,
427
+ "n_test_samples": len(X_test_f),
428
+ "top10_features": top10_features,
429
+ "classification_report": report,
430
+ }
431
+
432
+
433
+ def train_all(horizons: list[str] = None, last_days: int = None) -> dict:
434
+ """
435
+ Train models for all 10 commodities and save accuracy report.
436
+
437
+ Args:
438
+ horizons: List of horizons to train. Default: ["7d", "30d"]
439
+ last_days: If set, train each symbol on only the most recent N days.
440
+
441
+ Returns:
442
+ Dict mapping symbol → accuracy metrics per horizon.
443
+ """
444
+ if horizons is None:
445
+ horizons = ["7d", "30d"]
446
+
447
+ results: dict = {}
448
+ for symbol in ALL_SYMBOLS:
449
+ results[symbol] = {}
450
+ for horizon in horizons:
451
+ try:
452
+ metrics = train_symbol(symbol, horizon=horizon, last_days=last_days)
453
+ results[symbol][horizon] = metrics
454
+ except Exception as exc:
455
+ log.error("Failed to train %s %s: %s", symbol, horizon, exc)
456
+ results[symbol][horizon] = {"error": str(exc)}
457
+
458
+ # Save combined accuracy report
459
+ report_path = MODELS_DIR / "accuracy_report.json"
460
+ with open(report_path, "w") as f:
461
+ json.dump(results, f, indent=2, default=str)
462
+ log.info("Accuracy report saved to %s", report_path)
463
+
464
+ # Print summary table
465
+ print("\n" + "=" * 85)
466
+ print(f"{'Commodity':<15} {'7d Accuracy':>12} {'30d Accuracy':>13} {'Sharpe (7d)':>12} {'Samples':>8}")
467
+ print("=" * 85)
468
+ for symbol, res in results.items():
469
+ r7 = res.get("7d", {})
470
+ r30 = res.get("30d", {})
471
+ acc7 = f"{r7.get('test_accuracy', 0):.1%}" if "test_accuracy" in r7 else "ERR"
472
+ acc30 = f"{r30.get('test_accuracy', 0):.1%}" if "test_accuracy" in r30 else "ERR"
473
+ sh7 = f"{r7.get('sharpe_ratio', 0):.2f}" if "sharpe_ratio" in r7 else "ERR"
474
+ n = r7.get("n_train_samples", 0)
475
+ print(f"{symbol:<15} {acc7:>12} {acc30:>13} {sh7:>12} {n:>8}")
476
+ print("=" * 85)
477
+
478
+ return results
479
+
480
+
481
+ if __name__ == "__main__":
482
+ parser = argparse.ArgumentParser(description="CommodiSense model trainer")
483
+ parser.add_argument("--symbol", default=None, help="Single symbol to train")
484
+ parser.add_argument("--horizon", default="both", choices=["7d", "30d", "both"])
485
+ parser.add_argument("--days", default=None, type=int,
486
+ help="Train on only the most recent N calendar days (short-window mode)")
487
+ args = parser.parse_args()
488
+
489
+ if args.symbol:
490
+ horizons = ["7d", "30d"] if args.horizon == "both" else [args.horizon]
491
+ for h in horizons:
492
+ result = train_symbol(args.symbol, horizon=h, last_days=args.days)
493
+ print(json.dumps({k: v for k, v in result.items()
494
+ if k != "classification_report"}, indent=2, default=str))
495
+ else:
496
+ train_all(last_days=args.days)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CommodiSense Dashboard — Hugging Face Spaces deployment
2
+ # Streamlit Cloud removed (forces Python 3.14, incompatible with numba ecosystem)
3
+ # pandas-ta removed (requires numba which doesn't support Python 3.14)
4
+ # shap removed (all versions require numba)
5
+ # All technical indicators implemented with pure pandas/numpy
6
+
7
+ duckdb>=0.10.0
8
+ pandas>=2.0.0
9
+ numpy>=1.24.0
10
+ yfinance>=0.2.0
11
+ requests>=2.28.0
12
+ xgboost>=2.0.0
13
+ lightgbm>=4.0.0
14
+ scikit-learn>=1.3.0
15
+ groq>=0.4.0
16
+ streamlit>=1.28.0
17
+ plotly>=5.15.0
18
+ python-dotenv>=1.0.0
signals/__init__.py ADDED
File without changes
signals/macro_features.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Macro Feature Engineering — COT, FRED, EIA, USDA signals.
3
+
4
+ No lookahead guarantee: all features use data available at or before as_of_date.
5
+ Missing data returns zero (model learns to weight it accordingly via has_*_data flags).
6
+
7
+ Public API:
8
+ build_macro_dataframe(symbol, start_date, end_date) → pd.DataFrame (training)
9
+ get_macro_features(symbol, as_of_date) → dict (inference)
10
+ """
11
+
12
+ import logging
13
+ import sys
14
+ from datetime import date, datetime, timedelta
15
+ from pathlib import Path
16
+
17
+ import pandas as pd
18
+
19
+ sys.path.insert(0, str(Path(__file__).parent.parent))
20
+ from data.db import get_conn
21
+
22
+ log = logging.getLogger(__name__)
23
+
24
+ # Symbols that have EIA inventory data
25
+ EIA_SYMBOL_MAP = {
26
+ "CL=F": "crude_stocks",
27
+ "NG=F": "natgas_storage",
28
+ }
29
+
30
+ # Symbols that have USDA crop data
31
+ USDA_SYMBOLS = {"ZW=F", "ZC=F", "ZS=F", "CT=F"}
32
+
33
+ # All macro feature names — used to guarantee consistent columns across training and inference
34
+ ALL_MACRO_FEATURES = [
35
+ # COT
36
+ "cot_commercial_net",
37
+ "cot_commercial_net_pct",
38
+ "cot_mm_net",
39
+ "cot_mm_net_pct",
40
+ "cot_commercial_chg_1w",
41
+ "cot_mm_chg_1w",
42
+ "cot_open_interest",
43
+ "has_cot_data",
44
+ # FRED
45
+ "fred_dxy",
46
+ "fred_dxy_chg_1w",
47
+ "fred_dxy_chg_4w",
48
+ "fred_inflation_exp",
49
+ "fred_vix",
50
+ "fred_vix_chg_1w",
51
+ "fred_vix_high",
52
+ "fred_treasury_10y",
53
+ "fred_financial_stress",
54
+ "fred_indpro",
55
+ "fred_fedfunds",
56
+ "fred_yield_inv",
57
+ "fred_china_pmi",
58
+ "fred_copper_basis",
59
+ "has_fred_data",
60
+ # EIA
61
+ "eia_crude_stocks",
62
+ "eia_crude_chg_1w",
63
+ "eia_crude_vs_5yr",
64
+ "eia_crude_draw",
65
+ "eia_natgas_stocks",
66
+ "eia_natgas_chg_1w",
67
+ "eia_natgas_vs_5yr",
68
+ "eia_natgas_draw",
69
+ "has_eia_data",
70
+ # USDA
71
+ "usda_crop_good_exc",
72
+ "usda_crop_good_exc_chg",
73
+ "usda_stocks",
74
+ "usda_stocks_yoy",
75
+ "usda_production",
76
+ "has_usda_data",
77
+ ]
78
+
79
+ _ZERO_ROW = {k: 0.0 for k in ALL_MACRO_FEATURES}
80
+
81
+
82
+ # ── training dataframes ────────────────────────────────────────────────────────
83
+
84
+
85
+ def _load_cot(symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
86
+ conn = get_conn()
87
+ df = conn.execute("""
88
+ SELECT date,
89
+ commercial_net_long AS cot_commercial_net,
90
+ commercial_net_pct AS cot_commercial_net_pct,
91
+ mm_net_long AS cot_mm_net,
92
+ mm_net_pct AS cot_mm_net_pct,
93
+ commercial_chg_1w AS cot_commercial_chg_1w,
94
+ mm_chg_1w AS cot_mm_chg_1w,
95
+ open_interest AS cot_open_interest
96
+ FROM cot_data
97
+ WHERE symbol = ? AND date >= ? AND date <= ?
98
+ ORDER BY date
99
+ """, [symbol, start_date, end_date]).df()
100
+ conn.close()
101
+ if df.empty:
102
+ return pd.DataFrame()
103
+ df["date"] = pd.to_datetime(df["date"]).dt.date
104
+ df["has_cot_data"] = 1
105
+ return df.sort_values("date").reset_index(drop=True)
106
+
107
+
108
+ def _load_fred(start_date: str, end_date: str) -> pd.DataFrame:
109
+ conn = get_conn()
110
+ # Try to select new columns; fall back gracefully if they don't exist yet
111
+ try:
112
+ df = conn.execute("""
113
+ SELECT date, dxy, inflation_exp, vix, treasury_10y,
114
+ financial_stress, indpro, fedfunds, china_pmi, copper_basis
115
+ FROM fred_data
116
+ WHERE date >= ? AND date <= ?
117
+ ORDER BY date
118
+ """, [start_date, end_date]).df()
119
+ except Exception:
120
+ df = conn.execute("""
121
+ SELECT date, dxy, inflation_exp, vix, treasury_10y,
122
+ financial_stress, indpro, fedfunds
123
+ FROM fred_data
124
+ WHERE date >= ? AND date <= ?
125
+ ORDER BY date
126
+ """, [start_date, end_date]).df()
127
+ df["china_pmi"] = None
128
+ df["copper_basis"] = None
129
+ conn.close()
130
+ if df.empty:
131
+ return pd.DataFrame()
132
+
133
+ df["date"] = pd.to_datetime(df["date"]).dt.date
134
+ df = df.sort_values("date").reset_index(drop=True)
135
+
136
+ for col in df.columns[1:]:
137
+ df[col] = df[col].ffill()
138
+
139
+ df["fred_dxy_chg_1w"] = df["dxy"].diff(5)
140
+ df["fred_dxy_chg_4w"] = df["dxy"].diff(20)
141
+ df["fred_vix_chg_1w"] = df["vix"].diff(5)
142
+ df["fred_vix_high"] = (df["vix"] > 25).astype(float)
143
+ fedfunds_safe = df["fedfunds"].fillna(0)
144
+ t10y_safe = df["treasury_10y"].fillna(0)
145
+ df["fred_yield_inv"] = (t10y_safe < fedfunds_safe).astype(float)
146
+ df["has_fred_data"] = 1
147
+
148
+ return df.rename(columns={
149
+ "dxy": "fred_dxy",
150
+ "inflation_exp": "fred_inflation_exp",
151
+ "vix": "fred_vix",
152
+ "treasury_10y": "fred_treasury_10y",
153
+ "financial_stress": "fred_financial_stress",
154
+ "indpro": "fred_indpro",
155
+ "fedfunds": "fred_fedfunds",
156
+ "china_pmi": "fred_china_pmi",
157
+ "copper_basis": "fred_copper_basis",
158
+ })
159
+
160
+
161
+ def _load_eia(symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
162
+ series_name = EIA_SYMBOL_MAP.get(symbol)
163
+ if not series_name:
164
+ return pd.DataFrame()
165
+
166
+ conn = get_conn()
167
+ df = conn.execute("""
168
+ SELECT date, value, chg_1w, vs_5yr_avg
169
+ FROM eia_inventory
170
+ WHERE series = ? AND date >= ? AND date <= ?
171
+ ORDER BY date
172
+ """, [series_name, start_date, end_date]).df()
173
+ conn.close()
174
+
175
+ if df.empty:
176
+ return pd.DataFrame()
177
+
178
+ df["date"] = pd.to_datetime(df["date"]).dt.date
179
+ prefix = "eia_crude" if symbol == "CL=F" else "eia_natgas"
180
+ df = df.rename(columns={
181
+ "value": f"{prefix}_stocks",
182
+ "chg_1w": f"{prefix}_chg_1w",
183
+ "vs_5yr_avg": f"{prefix}_vs_5yr",
184
+ })
185
+ # Drawdown flag: inventory fell (bullish supply signal)
186
+ chg_col = f"{prefix}_chg_1w"
187
+ df[f"{prefix}_draw"] = (df[chg_col].fillna(0) < -500).astype(float)
188
+ df["has_eia_data"] = 1
189
+ return df.sort_values("date").reset_index(drop=True)
190
+
191
+
192
+ def _load_usda(symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
193
+ if symbol not in USDA_SYMBOLS:
194
+ return pd.DataFrame()
195
+
196
+ conn = get_conn()
197
+ df = conn.execute("""
198
+ SELECT date, metric, value, yoy_chg_pct
199
+ FROM usda_crop
200
+ WHERE commodity = ? AND date >= ? AND date <= ?
201
+ ORDER BY date
202
+ """, [symbol, start_date, end_date]).df()
203
+ conn.close()
204
+
205
+ if df.empty:
206
+ return pd.DataFrame()
207
+
208
+ df["date"] = pd.to_datetime(df["date"]).dt.date
209
+
210
+ # Crop condition: sum % good + % excellent per date
211
+ cond = (
212
+ df[df["metric"].str.upper().str.contains("PCT GOOD|PCT EXCELLENT", na=False)]
213
+ .groupby("date")["value"].sum()
214
+ .reset_index()
215
+ .rename(columns={"value": "usda_crop_good_exc"})
216
+ .sort_values("date")
217
+ )
218
+ cond["usda_crop_good_exc_chg"] = cond["usda_crop_good_exc"].diff()
219
+
220
+ # Stocks
221
+ stk = (
222
+ df[df["metric"].str.upper().str.contains("STOCKS", na=False)]
223
+ .groupby("date")
224
+ .agg(usda_stocks=("value", "mean"), usda_stocks_yoy=("yoy_chg_pct", "mean"))
225
+ .reset_index()
226
+ .sort_values("date")
227
+ )
228
+
229
+ # Annual production (forward-filled across year)
230
+ prd = (
231
+ df[df["metric"].str.upper().str.contains("PRODUCTION", na=False)]
232
+ .groupby("date")
233
+ .agg(usda_production=("value", "mean"))
234
+ .reset_index()
235
+ .sort_values("date")
236
+ )
237
+
238
+ parts = [p for p in [cond, stk, prd] if not p.empty]
239
+ if not parts:
240
+ return pd.DataFrame()
241
+ result = parts[0]
242
+ for p in parts[1:]:
243
+ result = result.merge(p, on="date", how="outer")
244
+
245
+ result["has_usda_data"] = 1
246
+ return result.sort_values("date").reset_index(drop=True)
247
+
248
+
249
+ def _safe_merge(base: pd.DataFrame, other: pd.DataFrame) -> pd.DataFrame:
250
+ """Left-merge other onto base by date, zero-fill NaN."""
251
+ if other.empty:
252
+ return base
253
+ merged = base.merge(other, on="date", how="left")
254
+ return merged.fillna(0)
255
+
256
+
257
+ def build_macro_dataframe(symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
258
+ """
259
+ Assemble all macro feature columns for a symbol over a date range.
260
+ Returns a DataFrame keyed on 'date' with one row per calendar day
261
+ that has at least one non-zero macro feature. Missing data → zeros.
262
+
263
+ Designed for left-joining onto the price feature matrix in feature_builder.
264
+ """
265
+ cot = _load_cot(symbol, start_date, end_date)
266
+ fred = _load_fred(start_date, end_date)
267
+ eia = _load_eia(symbol, start_date, end_date)
268
+ usda = _load_usda(symbol, start_date, end_date)
269
+
270
+ if all(df.empty for df in [cot, fred, eia, usda]):
271
+ return pd.DataFrame()
272
+
273
+ # Use FRED as the date spine (widest coverage); fall back to other sources
274
+ if not fred.empty:
275
+ base = fred[["date"]].copy()
276
+ elif not cot.empty:
277
+ base = cot[["date"]].copy()
278
+ else:
279
+ base = pd.DataFrame({"date": pd.date_range(start_date, end_date, freq="D").date})
280
+
281
+ df = base.copy()
282
+ df = _safe_merge(df, cot)
283
+ df = _safe_merge(df, fred)
284
+ df = _safe_merge(df, eia)
285
+ df = _safe_merge(df, usda)
286
+
287
+ # Ensure all expected columns present
288
+ for col in ALL_MACRO_FEATURES:
289
+ if col not in df.columns:
290
+ df[col] = 0.0
291
+
292
+ # COT and EIA are weekly; forward-fill within the merged frame
293
+ cot_cols = [c for c in df.columns if c.startswith("cot_")]
294
+ eia_cols = [c for c in df.columns if c.startswith("eia_")]
295
+ usda_cols = [c for c in df.columns if c.startswith("usda_")]
296
+ for col_group in [cot_cols, eia_cols, usda_cols]:
297
+ df[col_group] = df[col_group].replace(0, float("nan")).ffill().fillna(0)
298
+
299
+ # Drop columns that are >95% zero — they have no signal and add noise.
300
+ # This auto-excludes EIA/USDA when no API keys are set.
301
+ feature_cols = [c for c in ALL_MACRO_FEATURES if c in df.columns]
302
+ nonzero_frac = (df[feature_cols].abs() > 0).mean()
303
+ active_cols = nonzero_frac[nonzero_frac >= 0.05].index.tolist()
304
+ if not active_cols:
305
+ return pd.DataFrame()
306
+
307
+ return df[["date"] + active_cols].sort_values("date").reset_index(drop=True)
308
+
309
+
310
+ # ── inference: single-row feature dict ────────────────────────────────────────
311
+
312
+
313
+ def get_macro_features(symbol: str, as_of_date: str = None) -> dict:
314
+ """
315
+ Return a flat dict of all macro features for the given symbol and date.
316
+ Guaranteed to return all keys in ALL_MACRO_FEATURES (zeros for missing data).
317
+ """
318
+ target = as_of_date or date.today().isoformat()
319
+ conn = get_conn()
320
+ result = dict(_ZERO_ROW)
321
+
322
+ # ── COT ──────────────────────────────────────────────────────────────────
323
+ row = conn.execute("""
324
+ SELECT commercial_net_long, commercial_net_pct, mm_net_long, mm_net_pct,
325
+ commercial_chg_1w, mm_chg_1w, open_interest
326
+ FROM cot_data WHERE symbol = ? AND date <= ?
327
+ ORDER BY date DESC LIMIT 1
328
+ """, [symbol, target]).fetchone()
329
+
330
+ if row:
331
+ result.update({
332
+ "cot_commercial_net": row[0] or 0,
333
+ "cot_commercial_net_pct": row[1] or 0,
334
+ "cot_mm_net": row[2] or 0,
335
+ "cot_mm_net_pct": row[3] or 0,
336
+ "cot_commercial_chg_1w": row[4] or 0,
337
+ "cot_mm_chg_1w": row[5] or 0,
338
+ "cot_open_interest": row[6] or 0,
339
+ "has_cot_data": 1.0,
340
+ })
341
+
342
+ # ── FRED ─────────────────────────────────────────────────────────────────
343
+ try:
344
+ fred_now = conn.execute("""
345
+ SELECT dxy, inflation_exp, vix, treasury_10y, financial_stress,
346
+ indpro, fedfunds, china_pmi, copper_basis
347
+ FROM fred_data WHERE date <= ? ORDER BY date DESC LIMIT 1
348
+ """, [target]).fetchone()
349
+ except Exception:
350
+ fred_now = conn.execute("""
351
+ SELECT dxy, inflation_exp, vix, treasury_10y, financial_stress,
352
+ indpro, fedfunds
353
+ FROM fred_data WHERE date <= ? ORDER BY date DESC LIMIT 1
354
+ """, [target]).fetchone()
355
+ fred_now = (fred_now + (None, None)) if fred_now else None
356
+
357
+ week_ago = (datetime.strptime(target, "%Y-%m-%d").date() - timedelta(days=7)).isoformat()
358
+ fred_wk = conn.execute("""
359
+ SELECT dxy, vix FROM fred_data WHERE date <= ? ORDER BY date DESC LIMIT 1
360
+ """, [week_ago]).fetchone()
361
+
362
+ month_ago = (datetime.strptime(target, "%Y-%m-%d").date() - timedelta(days=28)).isoformat()
363
+ fred_mo = conn.execute("""
364
+ SELECT dxy FROM fred_data WHERE date <= ? ORDER BY date DESC LIMIT 1
365
+ """, [month_ago]).fetchone()
366
+
367
+ if fred_now:
368
+ dxy = fred_now[0] or 0
369
+ vix = fred_now[2] or 0
370
+ t10y = fred_now[3] or 0
371
+ ff = fred_now[6] or 0
372
+ dxy_w = (fred_wk[0] or dxy) if fred_wk else dxy
373
+ vix_w = (fred_wk[1] or vix) if fred_wk else vix
374
+ dxy_m = (fred_mo[0] or dxy) if fred_mo else dxy
375
+ result.update({
376
+ "fred_dxy": dxy,
377
+ "fred_dxy_chg_1w": dxy - dxy_w,
378
+ "fred_dxy_chg_4w": dxy - dxy_m,
379
+ "fred_inflation_exp": fred_now[1] or 0,
380
+ "fred_vix": vix,
381
+ "fred_vix_chg_1w": vix - vix_w,
382
+ "fred_vix_high": float(vix > 25),
383
+ "fred_treasury_10y": t10y,
384
+ "fred_financial_stress": fred_now[4] or 0,
385
+ "fred_indpro": fred_now[5] or 0,
386
+ "fred_fedfunds": ff,
387
+ "fred_yield_inv": float(t10y < ff),
388
+ "fred_china_pmi": float(fred_now[7]) if fred_now[7] is not None else 0,
389
+ "fred_copper_basis": float(fred_now[8]) if fred_now[8] is not None else 0,
390
+ "has_fred_data": 1.0,
391
+ })
392
+
393
+ # ── EIA ──────────────────────────────────────────────────────────────────
394
+ series_name = EIA_SYMBOL_MAP.get(symbol)
395
+ prefix = "eia_crude" if symbol == "CL=F" else "eia_natgas"
396
+ if series_name:
397
+ eia_row = conn.execute("""
398
+ SELECT value, chg_1w, vs_5yr_avg FROM eia_inventory
399
+ WHERE series = ? AND date <= ? ORDER BY date DESC LIMIT 1
400
+ """, [series_name, target]).fetchone()
401
+ if eia_row:
402
+ chg = eia_row[1] or 0
403
+ result.update({
404
+ f"{prefix}_stocks": eia_row[0] or 0,
405
+ f"{prefix}_chg_1w": chg,
406
+ f"{prefix}_vs_5yr": eia_row[2] or 0,
407
+ f"{prefix}_draw": float(chg < -500),
408
+ "has_eia_data": 1.0,
409
+ })
410
+
411
+ # ── USDA ─────────────────────────────────────────────────────────────────
412
+ if symbol in USDA_SYMBOLS:
413
+ latest_date_row = conn.execute("""
414
+ SELECT MAX(date) FROM usda_crop WHERE commodity = ? AND date <= ?
415
+ """, [symbol, target]).fetchone()
416
+ latest = latest_date_row[0] if latest_date_row and latest_date_row[0] else None
417
+
418
+ if latest:
419
+ cond_row = conn.execute("""
420
+ SELECT SUM(value) FROM usda_crop
421
+ WHERE commodity = ? AND date = ?
422
+ AND (UPPER(metric) LIKE '%PCT GOOD%' OR UPPER(metric) LIKE '%PCT EXCELLENT%')
423
+ """, [symbol, latest]).fetchone()
424
+
425
+ stk_row = conn.execute("""
426
+ SELECT AVG(value), AVG(yoy_chg_pct) FROM usda_crop
427
+ WHERE commodity = ? AND date = ? AND UPPER(metric) LIKE '%STOCKS%'
428
+ """, [symbol, latest]).fetchone()
429
+
430
+ # Previous week for crop condition change
431
+ prev_date = (datetime.strptime(str(latest), "%Y-%m-%d").date() - timedelta(days=7)).isoformat()
432
+ prev_cond = conn.execute("""
433
+ SELECT SUM(value) FROM usda_crop
434
+ WHERE commodity = ? AND date = ?
435
+ AND (UPPER(metric) LIKE '%PCT GOOD%' OR UPPER(metric) LIKE '%PCT EXCELLENT%')
436
+ """, [symbol, prev_date]).fetchone()
437
+
438
+ prod_row = conn.execute("""
439
+ SELECT value FROM usda_crop
440
+ WHERE commodity = ? AND UPPER(metric) LIKE '%PRODUCTION%'
441
+ AND date <= ?
442
+ ORDER BY date DESC LIMIT 1
443
+ """, [symbol, target]).fetchone()
444
+
445
+ crop_now = float(cond_row[0]) if cond_row and cond_row[0] else 0
446
+ crop_prev = float(prev_cond[0]) if prev_cond and prev_cond[0] else crop_now
447
+ result.update({
448
+ "usda_crop_good_exc": crop_now,
449
+ "usda_crop_good_exc_chg": crop_now - crop_prev,
450
+ "usda_stocks": float(stk_row[0]) if stk_row and stk_row[0] else 0,
451
+ "usda_stocks_yoy": float(stk_row[1]) if stk_row and stk_row[1] else 0,
452
+ "usda_production": float(prod_row[0]) if prod_row and prod_row[0] else 0,
453
+ "has_usda_data": 1.0,
454
+ })
455
+
456
+ conn.close()
457
+ return result
signals/nlp_events.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Event Extractor — uses spaCy + rule-based patterns to detect commodity-relevant
3
+ events in news headlines and classify them as BULLISH / BEARISH / NEUTRAL.
4
+
5
+ Usage:
6
+ python signals/nlp_events.py # process recent news_raw articles
7
+ python signals/nlp_events.py --limit 200
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ import sys
13
+ from datetime import date, timedelta
14
+ from pathlib import Path
15
+
16
+ import pandas as pd
17
+
18
+ sys.path.insert(0, str(Path(__file__).parent.parent))
19
+ from data.db import get_conn, init_schema
20
+
21
+ LOG_PATH = Path(__file__).parent.parent / "data" / "logs" / "events.log"
22
+ LOG_PATH.parent.mkdir(exist_ok=True)
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format="%(asctime)s %(levelname)s %(message)s",
26
+ handlers=[logging.FileHandler(LOG_PATH), logging.StreamHandler()],
27
+ )
28
+ log = logging.getLogger(__name__)
29
+
30
+ # ── event pattern definitions ──────────────────────────────────────────────────
31
+
32
+ # Each entry: (event_type, trigger_phrases, default_direction, base_severity)
33
+ EVENT_PATTERNS: list[tuple[str, list[str], str, int]] = [
34
+ ("SUPPLY_SHOCK", ["production cut", "harvest failure", "pipeline explosion",
35
+ "pipeline attack", "port strike", "port closure",
36
+ "sanctions imposed", "export ban", "output cut",
37
+ "supply disruption", "refinery fire", "mine closure"], "BULLISH", 4),
38
+ ("SUPPLY_INCREASE", ["production increase", "record output", "supply glut",
39
+ "oversupply", "output raised", "inventory build",
40
+ "stockpile rise"], "BEARISH", 3),
41
+ ("DEMAND_SURGE", ["record imports", "stockpile build", "demand forecast raised",
42
+ "strong demand", "demand surge", "buying spree"], "BULLISH", 3),
43
+ ("DEMAND_DROP", ["demand falls", "demand drop", "weak demand",
44
+ "economic slowdown", "recession fears", "demand cut"], "BEARISH", 3),
45
+ ("POLICY_CHANGE", ["opec decision", "fed rate", "interest rate hike",
46
+ "interest rate cut", "tariff imposed", "trade deal",
47
+ "subsidy cut", "subsidy increase", "central bank"], "NEUTRAL", 2),
48
+ ("WEATHER_EVENT", ["drought", "flood", "frost", "la niña", "el niño",
49
+ "monsoon failure", "heatwave", "crop damage",
50
+ "hurricane", "cyclone", "typhoon"], "BULLISH", 4),
51
+ ("GEOPOLITICAL", ["war", "armed conflict", "sanctions", "embargo",
52
+ "coup", "invasion", "airstrike", "blockade"], "BULLISH", 5),
53
+ ]
54
+
55
+ # Commodity-specific policy direction overrides
56
+ # (event_type, commodity) → direction
57
+ POLICY_DIRECTION_OVERRIDES: dict[tuple[str, str], str] = {
58
+ ("POLICY_CHANGE", "GC=F"): "BULLISH", # rate cuts → gold up
59
+ ("POLICY_CHANGE", "USDINR=X"):"BEARISH", # rate hikes → stronger USD → bearish INR
60
+ ("POLICY_CHANGE", "CL=F"): "BEARISH", # trade deal → supply up
61
+ }
62
+
63
+ # Region → commodities most affected by weather in that region
64
+ REGION_COMMODITY_WEATHER: dict[str, list[str]] = {
65
+ "ukraine": ["ZW=F", "ZC=F"],
66
+ "russia": ["ZW=F"],
67
+ "brazil": ["ZS=F", "CT=F", "SB=F"],
68
+ "india": ["CT=F", "SB=F"],
69
+ "us": ["ZC=F", "ZS=F", "CL=F", "NG=F"],
70
+ "texas": ["CL=F", "NG=F"],
71
+ "chile": ["HG=F"],
72
+ "middle east": ["CL=F", "NG=F"],
73
+ "opec": ["CL=F", "NG=F"],
74
+ "gulf": ["CL=F", "NG=F"],
75
+ }
76
+
77
+ # Commodity keywords for tagging (same as news collector)
78
+ COMMODITY_KEYWORDS: dict[str, list[str]] = {
79
+ "CL=F": ["oil", "petroleum", "crude", "opec", "brent", "wti"],
80
+ "NG=F": ["natural gas", "lng", "gas pipeline"],
81
+ "GC=F": ["gold", "bullion", "safe haven"],
82
+ "ZW=F": ["wheat", "grain", "flour"],
83
+ "ZC=F": ["corn", "maize"],
84
+ "ZS=F": ["soybean", "soy"],
85
+ "CT=F": ["cotton"],
86
+ "SB=F": ["sugar", "cane"],
87
+ "USDINR=X":["rupee", "inr", "india forex"],
88
+ "HG=F": ["copper"],
89
+ }
90
+
91
+ # ── spaCy loader (lazy) ────────────────────────────────────────────────────────
92
+
93
+ _nlp = None
94
+
95
+
96
+ def _load_nlp():
97
+ global _nlp
98
+ if _nlp is None:
99
+ import spacy
100
+ try:
101
+ _nlp = spacy.load("en_core_web_sm")
102
+ except OSError:
103
+ log.warning("en_core_web_sm not found — run: python -m spacy download en_core_web_sm")
104
+ _nlp = None
105
+ return _nlp
106
+
107
+
108
+ # ── helpers ────────────────────────────────────────────────────────────────────
109
+
110
+
111
+ def _detect_commodities(text: str) -> list[str]:
112
+ lower = text.lower()
113
+ return [sym for sym, kws in COMMODITY_KEYWORDS.items() if any(k in lower for k in kws)]
114
+
115
+
116
+ def _detect_location(text: str) -> str:
117
+ """Extract first recognised location from text using spaCy GPE entities."""
118
+ nlp = _load_nlp()
119
+ if nlp is None:
120
+ return "unknown"
121
+ doc = nlp(text[:300])
122
+ for ent in doc.ents:
123
+ if ent.label_ in ("GPE", "LOC"):
124
+ return ent.text
125
+ return "unknown"
126
+
127
+
128
+ def _resolve_direction(event_type: str, commodities: list[str], default: str) -> str:
129
+ """Apply commodity-specific overrides to the default direction."""
130
+ if not commodities:
131
+ return default
132
+ for commodity in commodities:
133
+ override = POLICY_DIRECTION_OVERRIDES.get((event_type, commodity))
134
+ if override:
135
+ return override
136
+ return default
137
+
138
+
139
+ def _severity_from_text(text: str, base: int) -> int:
140
+ """Bump severity +1 if text contains intensifiers."""
141
+ intensifiers = ["massive", "unprecedented", "historic", "emergency",
142
+ "catastrophic", "record", "major", "severe"]
143
+ lower = text.lower()
144
+ bump = sum(1 for w in intensifiers if w in lower)
145
+ return min(5, base + bump)
146
+
147
+
148
+ # ── public API ─────────────────────────────────────────────────────────────────
149
+
150
+
151
+ def extract_events(text: str, event_date: str) -> list[dict]:
152
+ """
153
+ Extract commodity-relevant events from a text string.
154
+
155
+ Args:
156
+ text: Article headline or summary.
157
+ event_date: ISO date string "YYYY-MM-DD" for the event.
158
+
159
+ Returns:
160
+ List of dicts with keys: date, headline, event_type, commodity,
161
+ location, severity, direction, source.
162
+ """
163
+ lower = text.lower()
164
+ events: list[dict] = []
165
+
166
+ for evt_type, phrases, default_direction, base_severity in EVENT_PATTERNS:
167
+ matched_phrase = next((p for p in phrases if p in lower), None)
168
+ if not matched_phrase:
169
+ continue
170
+
171
+ commodities = _detect_commodities(text)
172
+ if not commodities:
173
+ # For weather/geopolitical, try to infer commodity from location
174
+ location = _detect_location(text)
175
+ loc_lower = location.lower()
176
+ for region, syms in REGION_COMMODITY_WEATHER.items():
177
+ if region in loc_lower:
178
+ commodities = syms
179
+ break
180
+
181
+ if not commodities:
182
+ commodities = ["CL=F"] # fallback to crude oil as most globally traded
183
+
184
+ direction = _resolve_direction(evt_type, commodities, default_direction)
185
+ severity = _severity_from_text(text, base_severity)
186
+ location = _detect_location(text)
187
+
188
+ for commodity in commodities:
189
+ events.append({
190
+ "date": event_date,
191
+ "headline": text[:500],
192
+ "event_type": evt_type,
193
+ "commodity": commodity,
194
+ "location": location,
195
+ "severity": severity,
196
+ "direction": direction,
197
+ "source": "nlp_events",
198
+ })
199
+
200
+ return events
201
+
202
+
203
+ def process_batch(limit: int = 100) -> int:
204
+ """
205
+ Extract events from recent news_raw articles and store in extracted_events.
206
+
207
+ Args:
208
+ limit: Max articles to scan.
209
+
210
+ Returns:
211
+ Count of events extracted.
212
+ """
213
+ conn = get_conn()
214
+ df = conn.execute(
215
+ """
216
+ SELECT id, title, summary, published_date
217
+ FROM news_raw
218
+ ORDER BY published_date DESC
219
+ LIMIT ?
220
+ """,
221
+ [limit],
222
+ ).df()
223
+ conn.close()
224
+
225
+ if df.empty:
226
+ return 0
227
+
228
+ total_events = 0
229
+ conn = get_conn()
230
+ for _, row in df.iterrows():
231
+ text = f"{row.get('title', '')} {row.get('summary', '')}".strip()
232
+ pub = str(row.get("published_date", date.today()))[:10]
233
+ events = extract_events(text, pub)
234
+ for evt in events:
235
+ try:
236
+ conn.execute(
237
+ """
238
+ INSERT INTO extracted_events
239
+ (date, headline, event_type, commodity, location,
240
+ severity, direction, source)
241
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
242
+ """,
243
+ [
244
+ evt["date"], evt["headline"], evt["event_type"],
245
+ evt["commodity"], evt["location"], evt["severity"],
246
+ evt["direction"], evt["source"],
247
+ ],
248
+ )
249
+ total_events += 1
250
+ except Exception as exc:
251
+ log.debug("Event insert error: %s", exc)
252
+ conn.close()
253
+
254
+ log.info("Extracted %d events from %d articles", total_events, len(df))
255
+ return total_events
256
+
257
+
258
+ def get_event_features(commodity: str, days: int = 30) -> pd.DataFrame:
259
+ """
260
+ Return aggregated event features for a commodity over a date window.
261
+
262
+ Args:
263
+ commodity: Ticker symbol, e.g. "CL=F"
264
+ days: Look-back window in calendar days
265
+
266
+ Returns:
267
+ DataFrame with one row per date, columns:
268
+ event_count, bullish_count, bearish_count, max_severity,
269
+ direction_score (bullish=+1, bearish=-1, neutral=0, summed),
270
+ supply_shock_flag (1 if any SUPPLY_SHOCK that day),
271
+ policy_change_flag (1 if any POLICY_CHANGE that day)
272
+ """
273
+ cutoff = date.today() - timedelta(days=days)
274
+ conn = get_conn()
275
+ df = conn.execute(
276
+ """
277
+ SELECT date, event_type, direction, severity
278
+ FROM extracted_events
279
+ WHERE commodity = ? AND date >= ?
280
+ ORDER BY date
281
+ """,
282
+ [commodity, cutoff],
283
+ ).df()
284
+ conn.close()
285
+
286
+ if df.empty:
287
+ return pd.DataFrame(columns=[
288
+ "date", "event_count", "bullish_count", "bearish_count",
289
+ "max_severity", "direction_score", "supply_shock_flag", "policy_change_flag",
290
+ ])
291
+
292
+ df["dir_score"] = df["direction"].map({"BULLISH": 1, "BEARISH": -1, "NEUTRAL": 0}).fillna(0)
293
+
294
+ agg = df.groupby("date").agg(
295
+ event_count=("event_type", "count"),
296
+ bullish_count=("direction", lambda x: (x == "BULLISH").sum()),
297
+ bearish_count=("direction", lambda x: (x == "BEARISH").sum()),
298
+ max_severity=("severity", "max"),
299
+ direction_score=("dir_score", "sum"),
300
+ supply_shock_flag=("event_type", lambda x: int((x == "SUPPLY_SHOCK").any())),
301
+ policy_change_flag=("event_type", lambda x: int((x == "POLICY_CHANGE").any())),
302
+ ).reset_index()
303
+
304
+ return agg
305
+
306
+
307
+ if __name__ == "__main__":
308
+ parser = argparse.ArgumentParser()
309
+ parser.add_argument("--limit", type=int, default=100)
310
+ args = parser.parse_args()
311
+ init_schema()
312
+ n = process_batch(limit=args.limit)
313
+ print(f"Extracted {n} events")
signals/nlp_sentiment.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NLP Sentiment Engine — scores commodity news articles using FinBERT,
3
+ aggregates into daily sentiment features per commodity.
4
+
5
+ Usage:
6
+ python signals/nlp_sentiment.py # process all unscored articles
7
+ python signals/nlp_sentiment.py --limit 200
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ import sys
13
+ from datetime import date, datetime, timedelta, timezone
14
+ from pathlib import Path
15
+
16
+ import pandas as pd
17
+
18
+ sys.path.insert(0, str(Path(__file__).parent.parent))
19
+ from data.db import get_conn, init_schema
20
+ from data.collector_news import get_unprocessed_news, mark_processed
21
+
22
+ LOG_PATH = Path(__file__).parent.parent / "data" / "logs" / "sentiment.log"
23
+ LOG_PATH.parent.mkdir(exist_ok=True)
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format="%(asctime)s %(levelname)s %(message)s",
27
+ handlers=[logging.FileHandler(LOG_PATH), logging.StreamHandler()],
28
+ )
29
+ log = logging.getLogger(__name__)
30
+
31
+ # Max tokens for FinBERT input (model limit is 512, we use 256 for speed)
32
+ MAX_TOKENS = 256
33
+
34
+ # ── model loading (lazy, cached) ───────────────────────────────────────────────
35
+
36
+ _pipeline = None
37
+
38
+
39
+ def _load_pipeline():
40
+ """Load FinBERT pipeline once and cache it. Falls back to DistilBERT."""
41
+ global _pipeline
42
+ if _pipeline is not None:
43
+ return _pipeline
44
+
45
+ from transformers import pipeline as hf_pipeline
46
+
47
+ try:
48
+ log.info("Loading FinBERT (ProsusAI/finbert)...")
49
+ _pipeline = hf_pipeline(
50
+ "text-classification",
51
+ model="ProsusAI/finbert",
52
+ tokenizer="ProsusAI/finbert",
53
+ top_k=None, # return all 3 class probabilities
54
+ device=-1, # CPU
55
+ truncation=True,
56
+ max_length=MAX_TOKENS,
57
+ )
58
+ log.info("FinBERT loaded")
59
+ except Exception as exc:
60
+ log.warning("FinBERT load failed (%s), falling back to DistilBERT", exc)
61
+ _pipeline = hf_pipeline(
62
+ "text-classification",
63
+ model="distilbert-base-uncased-finetuned-sst-2-english",
64
+ top_k=None,
65
+ device=-1,
66
+ truncation=True,
67
+ max_length=MAX_TOKENS,
68
+ )
69
+ log.info("DistilBERT loaded as fallback")
70
+
71
+ return _pipeline
72
+
73
+
74
+ # ── keyword-based baseline (for ensemble uncertainty check) ────────────────────
75
+
76
+ _BULLISH_WORDS = {
77
+ "surge", "rally", "gain", "rise", "boom", "shortage", "record high",
78
+ "supply cut", "output cut", "strong demand", "bullish",
79
+ }
80
+ _BEARISH_WORDS = {
81
+ "fall", "drop", "crash", "decline", "surplus", "oversupply",
82
+ "demand drop", "weak demand", "bearish", "glut",
83
+ }
84
+
85
+
86
+ def _keyword_sentiment(text: str) -> float:
87
+ """Fast keyword-based sentiment score in [-1, +1]."""
88
+ lower = text.lower()
89
+ pos = sum(1 for w in _BULLISH_WORDS if w in lower)
90
+ neg = sum(1 for w in _BEARISH_WORDS if w in lower)
91
+ total = pos + neg
92
+ return (pos - neg) / total if total > 0 else 0.0
93
+
94
+
95
+ # ── public API ─────────────────────────────────────────────────────────────────
96
+
97
+
98
+ def score_article(text: str) -> float:
99
+ """
100
+ Score a single text string using FinBERT.
101
+
102
+ Returns:
103
+ Sentiment score in [-1.0, +1.0].
104
+ positive_prob - negative_prob from FinBERT.
105
+ Falls back to keyword score if model unavailable.
106
+ """
107
+ if not text or len(text.strip()) < 10:
108
+ return 0.0
109
+
110
+ try:
111
+ pipe = _load_pipeline()
112
+ results = pipe(text[:512])[0] # list of {label, score} dicts
113
+
114
+ scores = {r["label"].lower(): r["score"] for r in results}
115
+
116
+ # FinBERT labels: positive / negative / neutral
117
+ # DistilBERT labels: POSITIVE / NEGATIVE — normalize
118
+ pos = scores.get("positive", scores.get("label_1", 0.0))
119
+ neg = scores.get("negative", scores.get("label_0", 0.0))
120
+
121
+ ml_score = pos - neg # range [-1, +1]
122
+
123
+ # Ensemble uncertainty check: if ML and keyword disagree strongly, use neutral
124
+ kw_score = _keyword_sentiment(text)
125
+ if abs(ml_score - kw_score) > 0.4 and abs(kw_score) > 0.1:
126
+ return 0.0
127
+
128
+ return round(ml_score, 4)
129
+
130
+ except Exception as exc:
131
+ log.debug("score_article error: %s", exc)
132
+ return _keyword_sentiment(text)
133
+
134
+
135
+ def process_batch(limit: int = 100) -> int:
136
+ """
137
+ Score unprocessed articles from news_raw using batched inference and store
138
+ aggregated daily sentiment.
139
+
140
+ Batched pipeline call is 10-30x faster than scoring one article at a time.
141
+
142
+ Args:
143
+ limit: Max articles to process per call.
144
+
145
+ Returns:
146
+ Count of articles processed.
147
+ """
148
+ df = get_unprocessed_news(limit=limit)
149
+ if df.empty:
150
+ log.info("No unprocessed articles found")
151
+ return 0
152
+
153
+ log.info("Processing %d articles (batched)...", len(df))
154
+
155
+ # Build text list and IDs together
156
+ texts = [
157
+ f"{row.get('title', '')} {row.get('summary', '')}".strip()[:512]
158
+ for _, row in df.iterrows()
159
+ ]
160
+ ids = df["id"].tolist()
161
+
162
+ # Single batched pipeline call — far faster than N individual calls
163
+ pipe = _load_pipeline()
164
+ try:
165
+ batch_results = pipe(texts, batch_size=16, truncation=True)
166
+ except Exception as exc:
167
+ log.warning("Batched inference failed (%s), falling back to per-article", exc)
168
+ batch_results = [pipe(t)[0] for t in texts]
169
+
170
+ scores: list[float] = []
171
+ for text, result in zip(texts, batch_results):
172
+ try:
173
+ # result is a list of dicts when top_k=None
174
+ label_scores = {r["label"].lower(): r["score"] for r in result}
175
+ pos = label_scores.get("positive", label_scores.get("label_1", 0.0))
176
+ neg = label_scores.get("negative", label_scores.get("label_0", 0.0))
177
+ ml_score = pos - neg
178
+
179
+ kw_score = _keyword_sentiment(text)
180
+ if abs(ml_score - kw_score) > 0.4 and abs(kw_score) > 0.1:
181
+ scores.append(0.0)
182
+ else:
183
+ scores.append(round(ml_score, 4))
184
+ except Exception:
185
+ scores.append(_keyword_sentiment(text))
186
+
187
+ # Bulk update in one connection
188
+ conn = get_conn()
189
+ for article_id, score in zip(ids, scores):
190
+ conn.execute(
191
+ "UPDATE news_raw SET sentiment_score = ?, processed = TRUE WHERE id = ?",
192
+ [score, article_id],
193
+ )
194
+ conn.close()
195
+
196
+ _aggregate_daily_sentiment()
197
+ log.info("Processed %d articles", len(ids))
198
+ return len(ids)
199
+
200
+
201
+ def _aggregate_daily_sentiment() -> None:
202
+ """
203
+ Recompute sentiment_daily table from scored news_raw rows.
204
+ Applies time-decay weights: 1.0 (<24h), 0.5 (24–48h), 0.25 (48–72h).
205
+ """
206
+ conn = get_conn()
207
+ # Get scored articles from last 7 days
208
+ cutoff = (datetime.now(timezone.utc) - timedelta(days=7)).strftime("%Y-%m-%d")
209
+ df = conn.execute(
210
+ """
211
+ SELECT id, published_date, commodity_tags, sentiment_score
212
+ FROM news_raw
213
+ WHERE processed = TRUE
214
+ AND sentiment_score IS NOT NULL
215
+ AND published_date >= ?
216
+ """,
217
+ [cutoff],
218
+ ).df()
219
+ conn.close()
220
+
221
+ if df.empty:
222
+ return
223
+
224
+ now = datetime.now(timezone.utc)
225
+ rows_to_upsert: list[dict] = []
226
+
227
+ # Explode commodity tags — one row per commodity mention
228
+ records = []
229
+ for _, row in df.iterrows():
230
+ tags = str(row.get("commodity_tags") or "").split(",")
231
+ pub = row["published_date"]
232
+ if isinstance(pub, str):
233
+ try:
234
+ pub = datetime.fromisoformat(pub.replace("Z", "+00:00"))
235
+ except Exception:
236
+ pub = now
237
+ if pub.tzinfo is None:
238
+ pub = pub.replace(tzinfo=timezone.utc)
239
+
240
+ age_hours = (now - pub).total_seconds() / 3600
241
+ weight = 1.0 if age_hours < 24 else (0.5 if age_hours < 48 else 0.25)
242
+
243
+ for tag in tags:
244
+ tag = tag.strip()
245
+ if tag:
246
+ records.append({
247
+ "date": pub.date(),
248
+ "commodity": tag,
249
+ "score": row["sentiment_score"],
250
+ "weight": weight,
251
+ })
252
+
253
+ if not records:
254
+ return
255
+
256
+ df_exp = pd.DataFrame(records)
257
+
258
+ # Weighted average per (date, commodity)
259
+ def _wavg(g):
260
+ w = g["weight"]
261
+ s = g["score"]
262
+ total_w = w.sum()
263
+ return {
264
+ "sentiment_score": (s * w).sum() / total_w if total_w > 0 else 0.0,
265
+ "article_count": len(g),
266
+ "positive_count": int((s > 0.1).sum()),
267
+ "negative_count": int((s < -0.1).sum()),
268
+ }
269
+
270
+ summary = df_exp.groupby(["date", "commodity"]).apply(_wavg).reset_index()
271
+
272
+ conn = get_conn()
273
+ for _, row in summary.iterrows():
274
+ vals = row[0] if isinstance(row[0], dict) else row.to_dict()
275
+ conn.execute(
276
+ """
277
+ INSERT OR REPLACE INTO sentiment_daily
278
+ (date, commodity, sentiment_score, article_count,
279
+ positive_count, negative_count)
280
+ VALUES (?, ?, ?, ?, ?, ?)
281
+ """,
282
+ [
283
+ row["date"],
284
+ row["commodity"],
285
+ vals.get("sentiment_score", 0.0),
286
+ vals.get("article_count", 0),
287
+ vals.get("positive_count", 0),
288
+ vals.get("negative_count", 0),
289
+ ],
290
+ )
291
+ conn.close()
292
+
293
+
294
+ def get_sentiment_features(commodity: str, days: int = 30) -> pd.DataFrame:
295
+ """
296
+ Return daily sentiment features for a commodity with rolling averages.
297
+
298
+ Args:
299
+ commodity: Ticker symbol, e.g. "ZW=F"
300
+ days: Look-back window in calendar days
301
+
302
+ Returns:
303
+ DataFrame with columns: date, sentiment_score, article_count,
304
+ sentiment_3d, sentiment_7d, positive_ratio_7d
305
+ """
306
+ cutoff = date.today() - timedelta(days=days)
307
+ conn = get_conn()
308
+ df = conn.execute(
309
+ """
310
+ SELECT * FROM sentiment_daily
311
+ WHERE commodity = ? AND date >= ?
312
+ ORDER BY date
313
+ """,
314
+ [commodity, cutoff],
315
+ ).df()
316
+ conn.close()
317
+
318
+ if df.empty:
319
+ return df
320
+
321
+ df = df.sort_values("date").reset_index(drop=True)
322
+ df["sentiment_3d"] = df["sentiment_score"].rolling(3, min_periods=1).mean()
323
+ df["sentiment_7d"] = df["sentiment_score"].rolling(7, min_periods=1).mean()
324
+ df["positive_ratio_7d"] = (
325
+ df["positive_count"].rolling(7, min_periods=1).sum()
326
+ / df["article_count"].rolling(7, min_periods=1).sum().replace(0, 1)
327
+ )
328
+ return df
329
+
330
+
331
+ if __name__ == "__main__":
332
+ parser = argparse.ArgumentParser()
333
+ parser.add_argument("--limit", type=int, default=100)
334
+ args = parser.parse_args()
335
+ init_schema()
336
+ n = process_batch(limit=args.limit)
337
+ print(f"Processed {n} articles")
signals/price_features.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Price Feature Engineer — computes all technical, momentum, seasonality, and
3
+ cross-commodity features from stored price data.
4
+
5
+ All features are derived from DuckDB prices table — no live API calls.
6
+
7
+ Usage (standalone):
8
+ python signals/price_features.py --symbol GC=F --date 2024-06-01
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import logging
14
+ import sys
15
+ from datetime import date, datetime, timedelta
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ sys.path.insert(0, str(Path(__file__).parent.parent))
22
+ from data.db import get_conn
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+ # ── commodity metadata ─────────────────────────────────────────────────────────
27
+
28
+ SYMBOL_NAMES: dict[str, str] = {
29
+ "CL=F": "Crude Oil",
30
+ "NG=F": "Natural Gas",
31
+ "GC=F": "Gold",
32
+ "ZW=F": "Wheat",
33
+ "ZC=F": "Corn",
34
+ "ZS=F": "Soybeans",
35
+ "CT=F": "Cotton",
36
+ "SB=F": "Sugar",
37
+ "USDINR=X":"USD/INR",
38
+ "HG=F": "Copper",
39
+ }
40
+
41
+ ALL_SYMBOLS = list(SYMBOL_NAMES.keys())
42
+
43
+ # Harvest season windows (month_start, month_end) — inclusive
44
+ HARVEST_SEASONS: dict[str, list[tuple[int, int]]] = {
45
+ "ZW=F": [(6, 8)], # Northern hemisphere wheat: June–August
46
+ "ZC=F": [(9, 11)], # US corn: September–November
47
+ "ZS=F": [(9, 11), (3, 5)], # US + Brazil soy harvest windows
48
+ "CT=F": [(9, 12)], # US cotton: September–December
49
+ "SB=F": [(4, 6), (10, 12)],# Brazil + India sugar
50
+ }
51
+
52
+ # OPEC meeting dates (ISO strings) — extend annually
53
+ OPEC_MEETING_DATES: list[str] = [
54
+ "2024-06-02", "2024-11-26",
55
+ "2025-03-03", "2025-05-28", "2025-11-05",
56
+ "2026-03-02", "2026-06-01",
57
+ ]
58
+
59
+ # ── helpers ────────────────────────────────────────────────────────────────────
60
+
61
+
62
+ def _load_prices(symbol: str, days: int = 400) -> pd.DataFrame:
63
+ """
64
+ Load OHLCV data for a symbol from DuckDB.
65
+ Returns DataFrame sorted by date ascending with at least `days` rows of buffer.
66
+ """
67
+ cutoff = date.today() - timedelta(days=days)
68
+ conn = get_conn()
69
+ df = conn.execute(
70
+ "SELECT date, open, high, low, close, volume, adj_close FROM prices "
71
+ "WHERE symbol = ? AND date >= ? ORDER BY date",
72
+ [symbol, cutoff],
73
+ ).df()
74
+ conn.close()
75
+ df["date"] = pd.to_datetime(df["date"])
76
+ return df.sort_values("date").reset_index(drop=True)
77
+
78
+
79
+ def _load_all_prices_latest() -> dict[str, float]:
80
+ """Return latest close price for every symbol (used for cross-commodity ratios)."""
81
+ conn = get_conn()
82
+ rows = conn.execute(
83
+ """
84
+ SELECT symbol, close FROM prices p
85
+ WHERE date = (SELECT MAX(date) FROM prices p2 WHERE p2.symbol = p.symbol)
86
+ """
87
+ ).fetchall()
88
+ conn.close()
89
+ return {r[0]: r[1] for r in rows}
90
+
91
+
92
+ def _days_to_next_opec(as_of: date) -> int:
93
+ """Return calendar days until the next OPEC meeting on or after `as_of`."""
94
+ future = [
95
+ (datetime.strptime(d, "%Y-%m-%d").date() - as_of).days
96
+ for d in OPEC_MEETING_DATES
97
+ if datetime.strptime(d, "%Y-%m-%d").date() >= as_of
98
+ ]
99
+ return min(future) if future else 180 # default if no upcoming date in list
100
+
101
+
102
+ def _harvest_season_flag(symbol: str, month: int) -> int:
103
+ """Return 1 if `month` falls within any harvest window for the symbol."""
104
+ windows = HARVEST_SEASONS.get(symbol, [])
105
+ for start_m, end_m in windows:
106
+ if start_m <= month <= end_m:
107
+ return 1
108
+ return 0
109
+
110
+
111
+ def _compute_ta_features(df: pd.DataFrame) -> pd.DataFrame:
112
+ """
113
+ Append technical analysis columns using pandas-ta.
114
+ Works on a copy of df — returns augmented DataFrame.
115
+ """
116
+ try:
117
+ import pandas_ta as ta
118
+
119
+ df = df.copy()
120
+ df.ta.rsi(length=14, append=True) # RSI_14
121
+ df.ta.macd(fast=12, slow=26, signal=9, append=True) # MACD_12_26_9, etc.
122
+ df.ta.bbands(length=20, std=2, append=True) # BBL_20_2.0, BBM_20_2.0, BBU_20_2.0
123
+ df.ta.atr(length=14, append=True) # ATRr_14
124
+ df.ta.sma(length=20, append=True) # SMA_20
125
+ df.ta.sma(length=50, append=True) # SMA_50
126
+
127
+ except ImportError:
128
+ log.warning("pandas-ta not installed — TA features will be NaN")
129
+
130
+ return df
131
+
132
+
133
+ def _safe(val) -> float:
134
+ """Return 0.0 for NaN/None values to keep feature vector clean."""
135
+ if val is None:
136
+ return 0.0
137
+ try:
138
+ v = float(val)
139
+ return 0.0 if (v != v) else v # NaN check without numpy
140
+ except (TypeError, ValueError):
141
+ return 0.0
142
+
143
+
144
+ # ── public API ─────────────────────────────────────────────────────────────────
145
+
146
+
147
+ def get_price_features(symbol: str, as_of_date: str = None) -> dict:
148
+ """
149
+ Compute all price-based features for a symbol on a given date.
150
+
151
+ Args:
152
+ symbol: Commodity ticker, e.g. "GC=F"
153
+ as_of_date: ISO date string. Defaults to today.
154
+
155
+ Returns:
156
+ Flat dict of feature_name → float value.
157
+ All values are guaranteed non-NaN (NaN → 0.0).
158
+ """
159
+ target_date = (
160
+ datetime.strptime(as_of_date, "%Y-%m-%d").date()
161
+ if as_of_date
162
+ else date.today()
163
+ )
164
+
165
+ df = _load_prices(symbol, days=400)
166
+ if df.empty or len(df) < 20:
167
+ log.warning("%s: insufficient price history for feature engineering", symbol)
168
+ return {}
169
+
170
+ df = _compute_ta_features(df)
171
+
172
+ # Locate the row nearest to target_date (T-1 to avoid lookahead)
173
+ df["_date"] = df["date"].dt.date
174
+ available = df[df["_date"] <= target_date]
175
+ if available.empty:
176
+ return {}
177
+ row = available.iloc[-1]
178
+ idx = available.index[-1]
179
+
180
+ close = _safe(row["close"])
181
+ if close == 0:
182
+ return {}
183
+
184
+ # ── momentum / returns ──
185
+ def _pct_change(lookback_days: int) -> float:
186
+ past = df[df["_date"] <= (target_date - timedelta(days=lookback_days))]
187
+ if past.empty:
188
+ return 0.0
189
+ past_close = _safe(past.iloc[-1]["close"])
190
+ return round((close - past_close) / past_close * 100, 4) if past_close else 0.0
191
+
192
+ ret_1d = _pct_change(1)
193
+ ret_7d = _pct_change(7)
194
+ ret_14d = _pct_change(14)
195
+ ret_30d = _pct_change(30)
196
+ ret_60d = _pct_change(60)
197
+ momentum_score = float(np.sign(ret_7d) + np.sign(ret_30d)) # -2 to +2
198
+
199
+ # ── technical indicators from pandas-ta ──
200
+ rsi = _safe(row.get("RSI_14"))
201
+
202
+ macd = _safe(row.get("MACD_12_26_9"))
203
+ macd_signal_line = _safe(row.get("MACDs_12_26_9"))
204
+ macd_signal = 1 if macd > macd_signal_line else (-1 if macd < macd_signal_line else 0)
205
+
206
+ bb_lower = _safe(row.get("BBL_20_2.0"))
207
+ bb_upper = _safe(row.get("BBU_20_2.0"))
208
+ bb_range = bb_upper - bb_lower
209
+ bb_position = ((close - bb_lower) / bb_range) if bb_range > 0 else 0.5
210
+
211
+ atr = _safe(row.get("ATRr_14"))
212
+ atr_pct = (atr / close * 100) if close > 0 else 0.0
213
+
214
+ sma20 = _safe(row.get("SMA_20"))
215
+ sma50 = _safe(row.get("SMA_50"))
216
+ sma_20_50_cross = 1 if sma20 > sma50 else -1
217
+
218
+ # ── seasonality ──
219
+ month = target_date.month
220
+ day_of_week = target_date.weekday() # 0=Monday
221
+ month_sin = float(np.sin(2 * np.pi * month / 12))
222
+ month_cos = float(np.cos(2 * np.pi * month / 12))
223
+ harvest_flag = _harvest_season_flag(symbol, month)
224
+
225
+ # Oil/gas: days to next OPEC meeting
226
+ days_opec = _days_to_next_opec(target_date) if symbol in ("CL=F", "NG=F") else 0
227
+
228
+ # ── cross-commodity features ──
229
+ latest_prices = _load_all_prices_latest()
230
+ cl_price = latest_prices.get("CL=F", 0)
231
+ gc_price = latest_prices.get("GC=F", 0)
232
+ oil_gold_ratio = round(cl_price / gc_price, 6) if gc_price > 0 else 0.0
233
+
234
+ # DXY proxy: inverted gold price normalised (gold up → USD weak)
235
+ gc_hist_mean = df["close"].mean() if not df.empty else 1.0
236
+ dxy_proxy = round(1 - (gc_price / gc_hist_mean) if gc_hist_mean > 0 else 0.5, 4)
237
+
238
+ return {
239
+ # Technical
240
+ "rsi_14": round(rsi, 4),
241
+ "macd_signal": macd_signal,
242
+ "bb_position": round(bb_position, 4),
243
+ "atr_14": round(atr, 4),
244
+ "atr_pct": round(atr_pct, 4),
245
+ "sma_20_50_cross": sma_20_50_cross,
246
+ # Momentum
247
+ "return_1d": ret_1d,
248
+ "return_7d": ret_7d,
249
+ "return_14d": ret_14d,
250
+ "return_30d": ret_30d,
251
+ "return_60d": ret_60d,
252
+ "momentum_score": momentum_score,
253
+ # Seasonality
254
+ "month_sin": round(month_sin, 4),
255
+ "month_cos": round(month_cos, 4),
256
+ "day_of_week": day_of_week,
257
+ "harvest_season_flag": harvest_flag,
258
+ "days_to_opec_meeting": days_opec,
259
+ # Cross-commodity
260
+ "oil_gold_ratio": oil_gold_ratio,
261
+ "dxy_proxy": dxy_proxy,
262
+ }
263
+
264
+
265
+ def build_feature_matrix(
266
+ symbol: str,
267
+ start_date: str,
268
+ end_date: str,
269
+ ) -> pd.DataFrame:
270
+ """
271
+ Build a feature matrix for model training — one row per trading day.
272
+
273
+ Args:
274
+ symbol: Commodity ticker
275
+ start_date: ISO date string "YYYY-MM-DD"
276
+ end_date: ISO date string "YYYY-MM-DD"
277
+
278
+ Returns:
279
+ DataFrame with one row per date, all price feature columns.
280
+ Does NOT include target variable — caller adds that.
281
+ """
282
+ start = datetime.strptime(start_date, "%Y-%m-%d").date()
283
+ end = datetime.strptime(end_date, "%Y-%m-%d").date()
284
+
285
+ # Load full price history once
286
+ df_prices = _load_prices(symbol, days=(end - start).days + 500)
287
+ if df_prices.empty:
288
+ return pd.DataFrame()
289
+
290
+ df_prices = _compute_ta_features(df_prices)
291
+ df_prices["_date"] = df_prices["date"].dt.date
292
+
293
+ latest_prices = _load_all_prices_latest()
294
+ cl_price = latest_prices.get("CL=F", 0)
295
+ gc_price = latest_prices.get("GC=F", 0)
296
+ gc_hist_mean = df_prices["close"].mean() if not df_prices.empty else 1.0
297
+
298
+ rows: list[dict] = []
299
+ for _, price_row in df_prices.iterrows():
300
+ row_date = price_row["_date"]
301
+ if not (start <= row_date <= end):
302
+ continue
303
+
304
+ close = _safe(price_row["close"])
305
+ if close == 0:
306
+ continue
307
+
308
+ # Returns — look back within df_prices to avoid reloading
309
+ def _ret(days: int) -> float:
310
+ past = df_prices[df_prices["_date"] <= (row_date - timedelta(days=days))]
311
+ if past.empty:
312
+ return 0.0
313
+ pc = _safe(past.iloc[-1]["close"])
314
+ return round((close - pc) / pc * 100, 4) if pc else 0.0
315
+
316
+ ret_1d = _ret(1)
317
+ ret_7d = _ret(7)
318
+ ret_14d = _ret(14)
319
+ ret_30d = _ret(30)
320
+ ret_60d = _ret(60)
321
+
322
+ rsi = _safe(price_row.get("RSI_14"))
323
+ macd = _safe(price_row.get("MACD_12_26_9"))
324
+ macd_sig = _safe(price_row.get("MACDs_12_26_9"))
325
+ bb_lower = _safe(price_row.get("BBL_20_2.0"))
326
+ bb_upper = _safe(price_row.get("BBU_20_2.0"))
327
+ bb_range = bb_upper - bb_lower
328
+ atr = _safe(price_row.get("ATRr_14"))
329
+ sma20 = _safe(price_row.get("SMA_20"))
330
+ sma50 = _safe(price_row.get("SMA_50"))
331
+
332
+ month = row_date.month
333
+ rows.append({
334
+ "date": row_date,
335
+ "rsi_14": round(rsi, 4),
336
+ "macd_signal": 1 if macd > macd_sig else (-1 if macd < macd_sig else 0),
337
+ "bb_position": round((close - bb_lower) / bb_range, 4) if bb_range > 0 else 0.5,
338
+ "atr_14": round(atr, 4),
339
+ "atr_pct": round(atr / close * 100, 4) if close > 0 else 0.0,
340
+ "sma_20_50_cross": 1 if sma20 > sma50 else -1,
341
+ "return_1d": ret_1d,
342
+ "return_7d": ret_7d,
343
+ "return_14d": ret_14d,
344
+ "return_30d": ret_30d,
345
+ "return_60d": ret_60d,
346
+ "momentum_score": float(np.sign(ret_7d) + np.sign(ret_30d)),
347
+ "month_sin": round(float(np.sin(2 * np.pi * month / 12)), 4),
348
+ "month_cos": round(float(np.cos(2 * np.pi * month / 12)), 4),
349
+ "day_of_week": row_date.weekday(),
350
+ "harvest_season_flag": _harvest_season_flag(symbol, month),
351
+ "days_to_opec_meeting": _days_to_next_opec(row_date) if symbol in ("CL=F", "NG=F") else 0,
352
+ "oil_gold_ratio": round(cl_price / gc_price, 6) if gc_price > 0 else 0.0,
353
+ "dxy_proxy": round(1 - gc_price / gc_hist_mean, 4) if gc_hist_mean > 0 else 0.5,
354
+ })
355
+
356
+ return pd.DataFrame(rows)
357
+
358
+
359
+ if __name__ == "__main__":
360
+ parser = argparse.ArgumentParser()
361
+ parser.add_argument("--symbol", default="GC=F")
362
+ parser.add_argument("--date", default=None)
363
+ args = parser.parse_args()
364
+ features = get_price_features(args.symbol, args.date)
365
+ print(json.dumps(features, indent=2))
signals/weather_features.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Weather Features — thin wrapper that surfaces weather_features table data
3
+ as commodity-specific signals for the feature builder.
4
+
5
+ The heavy lifting (fetching + engineering drought_index etc.) is done in
6
+ data/collector_weather.py. This module just shapes the data for ML consumption.
7
+ """
8
+
9
+ import sys
10
+ from datetime import date, timedelta
11
+ from pathlib import Path
12
+
13
+ import pandas as pd
14
+
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+ from data.db import get_conn
17
+
18
+ # Which regions matter most per commodity
19
+ COMMODITY_REGIONS: dict[str, list[str]] = {
20
+ "CL=F": ["middle_east_oil", "texas_energy"],
21
+ "NG=F": ["texas_energy", "middle_east_oil"],
22
+ "GC=F": ["south_africa_gold"],
23
+ "ZW=F": ["black_sea_ukraine"],
24
+ "ZC=F": ["us_corn_belt", "black_sea_ukraine"],
25
+ "ZS=F": ["us_corn_belt", "brazil_soy"],
26
+ "CT=F": ["india_monsoon", "brazil_soy"],
27
+ "SB=F": ["india_monsoon", "brazil_soy"],
28
+ "USDINR=X":["india_monsoon"],
29
+ "HG=F": ["chile_copper"],
30
+ }
31
+
32
+
33
+ def get_weather_features(commodity: str, days: int = 90) -> dict:
34
+ """
35
+ Return the latest aggregated weather signals for a commodity.
36
+
37
+ Averages drought_index, heat_stress_days, and precip_anomaly_pct across
38
+ the commodity's primary regions over the last 30 days.
39
+
40
+ Args:
41
+ commodity: Ticker symbol, e.g. "ZW=F"
42
+ days: Look-back window in calendar days (used for region filter)
43
+
44
+ Returns:
45
+ Dict with keys: drought_index, heat_stress_days, precip_anomaly_pct.
46
+ Returns zeros if no data found.
47
+ """
48
+ regions = COMMODITY_REGIONS.get(commodity, [])
49
+ if not regions:
50
+ return {"drought_index": 0.0, "heat_stress_days": 0, "precip_anomaly_pct": 0.0}
51
+
52
+ cutoff = date.today() - timedelta(days=30) # use last 30 days for signal
53
+ placeholders = ",".join(["?"] * len(regions))
54
+
55
+ conn = get_conn()
56
+ df = conn.execute(
57
+ f"""
58
+ SELECT drought_index, heat_stress_days, precip_anomaly_pct
59
+ FROM weather_features
60
+ WHERE commodity = ?
61
+ AND region IN ({placeholders})
62
+ AND date >= ?
63
+ """,
64
+ [commodity] + regions + [cutoff],
65
+ ).df()
66
+ conn.close()
67
+
68
+ if df.empty:
69
+ return {"drought_index": 0.0, "heat_stress_days": 0, "precip_anomaly_pct": 0.0}
70
+
71
+ return {
72
+ "drought_index": round(float(df["drought_index"].mean()), 4),
73
+ "heat_stress_days": int(df["heat_stress_days"].mean()),
74
+ "precip_anomaly_pct": round(float(df["precip_anomaly_pct"].mean()), 2),
75
+ }
76
+
77
+
78
+ def get_weather_dataframe(commodity: str, days: int = 90) -> pd.DataFrame:
79
+ """
80
+ Return time-series weather data for a commodity (all relevant regions).
81
+ Used by the feature builder to join weather signals into the training matrix.
82
+ """
83
+ regions = COMMODITY_REGIONS.get(commodity, [])
84
+ if not regions:
85
+ return pd.DataFrame()
86
+
87
+ cutoff = date.today() - timedelta(days=days)
88
+ placeholders = ",".join(["?"] * len(regions))
89
+
90
+ conn = get_conn()
91
+ df = conn.execute(
92
+ f"""
93
+ SELECT date, region,
94
+ drought_index, heat_stress_days, precip_anomaly_pct
95
+ FROM weather_features
96
+ WHERE commodity = ?
97
+ AND region IN ({placeholders})
98
+ AND date >= ?
99
+ ORDER BY date
100
+ """,
101
+ [commodity] + regions + [cutoff],
102
+ ).df()
103
+ conn.close()
104
+
105
+ if df.empty:
106
+ return df
107
+
108
+ # Average across regions per date
109
+ return (
110
+ df.groupby("date")
111
+ .agg(
112
+ drought_index=("drought_index", "mean"),
113
+ heat_stress_days=("heat_stress_days", "mean"),
114
+ precip_anomaly_pct=("precip_anomaly_pct", "mean"),
115
+ )
116
+ .reset_index()
117
+ .sort_values("date")
118
+ )