jtlevine's picture
Phase 1.4: region-aware recommendation agent + Swahili translation
7afee5d
"""
Neon PostgreSQL persistence layer for Market Intelligence.
Stores pipeline runs, market prices, price forecasts, sell recommendations,
and agent traces. Falls back gracefully when DATABASE_URL is not set --
the app runs in demo mode with in-memory data.
Tables:
pipeline_runs -- run metadata (status, duration, cost, step results)
market_prices -- reconciled mandi prices by commodity
price_forecasts -- 7/14/30d price predictions with confidence intervals
sell_recommendations -- optimal sell options per farmer
agent_traces -- Claude agent tool call traces
model_metrics -- ML model evaluation metrics per run
delivery_logs -- SMS delivery logs per pipeline run
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import (
Column,
DateTime,
Float,
Integer,
String,
Text,
Boolean,
create_engine,
text,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
log = logging.getLogger(__name__)
DATABASE_URL = os.getenv("DATABASE_URL", "")
_engine = None
_SessionLocal = None
class Base(DeclarativeBase):
pass
class PipelineRun(Base):
__tablename__ = "pipeline_runs"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(64), unique=True, nullable=False, index=True)
started_at = Column(DateTime(timezone=True), nullable=False)
finished_at = Column(DateTime(timezone=True))
status = Column(String(20), nullable=False)
duration_sec = Column(Float)
total_cost_usd = Column(Float, default=0)
mandis_count = Column(Integer, default=0)
commodities_count = Column(Integer, default=0)
step_results = Column(Text)
errors = Column(Text)
price_conflicts = Column(JSONB)
class MarketPrice(Base):
__tablename__ = "market_prices"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(64), nullable=False, index=True)
mandi_id = Column(String(20), nullable=False, index=True)
commodity_id = Column(String(20), nullable=False, index=True)
date = Column(String(10))
source = Column(String(100))
price_rs = Column(Float)
arrivals_tonnes = Column(Float)
quality_flag = Column(String(20))
full_data = Column(JSONB)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
class PriceForecast(Base):
__tablename__ = "price_forecasts"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(64), nullable=False, index=True)
mandi_id = Column(String(20), nullable=False, index=True)
commodity_id = Column(String(20), nullable=False, index=True)
forecast_date = Column(String(10))
horizon_days = Column(Integer)
predicted_price = Column(Float)
ci_lower = Column(Float)
ci_upper = Column(Float)
model_type = Column(String(30))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
class SellRecommendation(Base):
__tablename__ = "sell_recommendations"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(64), nullable=False, index=True)
farmer_id = Column(String(20), nullable=False, index=True)
commodity_id = Column(String(20), nullable=False)
best_mandi_id = Column(String(20))
best_timing = Column(String(10))
net_price_rs = Column(Float)
potential_gain_rs = Column(Float)
recommendation_text = Column(Text)
full_data = Column(JSONB)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
class AgentTrace(Base):
__tablename__ = "agent_traces"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(64), nullable=False, index=True)
agent_type = Column(String(30), nullable=False) # extraction, reconciliation, recommendation
mandi_id = Column(String(20), index=True)
tool_calls = Column(Text)
reasoning = Column(Text)
tokens_used = Column(Integer, default=0)
cost_usd = Column(Float, default=0)
duration_sec = Column(Float)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
class ModelMetric(Base):
__tablename__ = "model_metrics"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(64), nullable=False, index=True)
model_name = Column(String(50), nullable=False)
metric_name = Column(String(50), nullable=False)
metric_value = Column(Float, nullable=False)
extra_data = Column(Text)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
class DeliveryLog(Base):
__tablename__ = "delivery_logs"
id = Column(Integer, primary_key=True, autoincrement=True)
run_id = Column(String(64), nullable=False, index=True)
farmer_id = Column(Text, nullable=False)
farmer_name = Column(Text)
phone = Column(Text)
channel = Column(String(20), default="console")
sms_text = Column(Text)
sms_text_local = Column(Text)
status = Column(String(20), default="dry_run")
error = Column(Text)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), index=True)
def get_engine():
"""Get or create the SQLAlchemy engine."""
global _engine
if _engine is None:
if not DATABASE_URL:
return None
kwargs: dict[str, Any] = {"pool_pre_ping": True, "pool_timeout": 10}
if DATABASE_URL.startswith("sqlite"):
pass
else:
kwargs["pool_size"] = 2
kwargs["max_overflow"] = 3
kwargs["pool_recycle"] = 300
connect_args: dict[str, Any] = {"connect_timeout": 10}
if "sslmode" not in DATABASE_URL:
connect_args["sslmode"] = "require"
kwargs["connect_args"] = connect_args
_engine = create_engine(DATABASE_URL, **kwargs)
return _engine
def get_session() -> Session | None:
"""Get a database session. Returns None if DB not configured."""
global _SessionLocal
engine = get_engine()
if engine is None:
return None
if _SessionLocal is None:
_SessionLocal = sessionmaker(bind=engine)
return _SessionLocal()
_db_initialized = False
def init_db() -> bool:
"""Create all tables if they don't exist. Idempotent."""
global _db_initialized
if _db_initialized:
return True
engine = get_engine()
if engine is None:
log.info("DATABASE_URL not set -- running without persistence")
return False
try:
Base.metadata.create_all(engine)
_db_initialized = True
log.info("Database tables initialized")
return True
except Exception:
log.exception("Failed to initialize database")
return False
def save_pipeline_run(run_result: dict) -> bool:
"""Persist a pipeline run result to the database."""
session = get_session()
if session is None:
return False
try:
run_info = run_result.get("run_info", {})
run = PipelineRun(
run_id=run_info.get("run_id", f"run-{datetime.now(timezone.utc).isoformat()}"),
started_at=datetime.fromisoformat(run_info["started_at"])
if "started_at" in run_info else datetime.now(timezone.utc),
finished_at=datetime.fromisoformat(run_info["finished_at"])
if "finished_at" in run_info else datetime.now(timezone.utc),
status=run_info.get("status", "ok"),
duration_sec=run_info.get("duration_s", 0),
total_cost_usd=run_info.get("total_cost_usd", 0),
mandis_count=len(run_result.get("mandis", [])),
commodities_count=len(set(
p.get("commodity_id") for p in run_result.get("market_prices", [])
)),
step_results=json.dumps(run_info.get("steps", {})),
errors=json.dumps(run_info.get("errors", [])),
price_conflicts=run_result.get("price_conflicts", []),
)
session.add(run)
# Save market prices (with full data blob)
for mp in run_result.get("market_prices", []):
session.add(MarketPrice(
run_id=run.run_id,
mandi_id=mp.get("mandi_id", ""),
commodity_id=mp.get("commodity_id", ""),
date=mp.get("date", ""),
source=mp.get("source_used", ""),
price_rs=mp.get("price_rs"),
arrivals_tonnes=mp.get("arrivals_tonnes"),
quality_flag=mp.get("quality_flag", ""),
full_data=mp,
))
# Save price forecasts
for fc in run_result.get("price_forecasts", []):
for horizon, key in [(7, "price_7d"), (14, "price_14d"), (30, "price_30d")]:
predicted = fc.get(key)
if predicted:
session.add(PriceForecast(
run_id=run.run_id,
mandi_id=fc.get("mandi_id", ""),
commodity_id=fc.get("commodity_id", ""),
forecast_date=datetime.now(timezone.utc).strftime("%Y-%m-%d"),
horizon_days=horizon,
predicted_price=predicted,
ci_lower=fc.get(f"ci_lower_{horizon}d"),
ci_upper=fc.get(f"ci_upper_{horizon}d"),
model_type=run_result.get("model_metrics", {}).get("model_type", ""),
))
# Save sell recommendations (with full data blob)
#
# Phase 1.4 DB-schema gap: the Python dict carries new fields
# `recommendation_local` and `local_language_code` (Kenya migration,
# Option A rename). Those live inside the `full_data` JSONB blob
# here — the SellRecommendation ORM model intentionally has NO
# dedicated columns for them so this rename required zero SQL
# migration. If a future phase wants indexable columns for
# language-aware queries, add them here and backfill from
# full_data->>'local_language_code'. Intentional scope gap for
# Phase 1.4; follow-up tracked in the LastMileBench Kenya pivot
# notes.
for rec in run_result.get("sell_recommendations", []):
best = rec.get("best_option", {})
session.add(SellRecommendation(
run_id=run.run_id,
farmer_id=rec.get("farmer_id", ""),
commodity_id=rec.get("commodity_id", ""),
best_mandi_id=best.get("mandi_id", ""),
best_timing=best.get("sell_timing", ""),
net_price_rs=best.get("net_price_rs"),
potential_gain_rs=rec.get("potential_gain_rs"),
recommendation_text=rec.get("recommendation_text", ""),
full_data=rec,
))
# Save agent traces
for trace in run_result.get("recommendation_reasoning", []):
session.add(AgentTrace(
run_id=run.run_id,
agent_type="recommendation",
tool_calls=json.dumps(trace.get("reasoning_trace", [])),
reasoning=trace.get("recommendation_en", ""),
tokens_used=trace.get("tokens_used", 0),
))
session.commit()
log.info("Pipeline run %s persisted to database", run.run_id)
return True
except Exception:
session.rollback()
log.exception("Failed to persist pipeline run")
return False
finally:
session.close()
def get_recent_runs(limit: int = 20) -> list[dict]:
"""Fetch recent pipeline runs from the database."""
session = get_session()
if session is None:
return []
try:
runs = (
session.query(PipelineRun)
.order_by(PipelineRun.started_at.desc())
.limit(limit)
.all()
)
return [
{
"run_id": r.run_id,
"started_at": r.started_at.isoformat() if r.started_at else None,
"status": r.status,
"duration_sec": r.duration_sec,
"total_cost_usd": r.total_cost_usd,
"mandis_count": r.mandis_count,
"commodities_count": r.commodities_count,
}
for r in runs
]
except Exception:
log.exception("Failed to fetch pipeline runs")
return []
finally:
session.close()
def save_delivery_logs(run_id: str, logs: list[dict]) -> bool:
"""Bulk insert delivery log entries for a pipeline run."""
session = get_session()
if session is None:
return False
try:
for entry in logs:
session.add(DeliveryLog(
run_id=run_id,
farmer_id=entry.get("farmer_id", ""),
farmer_name=entry.get("farmer_name"),
phone=entry.get("phone"),
channel=entry.get("channel", "console"),
sms_text=entry.get("sms_text"),
sms_text_local=entry.get("sms_text_local"),
status=entry.get("status", "dry_run"),
error=entry.get("error"),
))
session.commit()
log.info("Saved %d delivery logs for run %s", len(logs), run_id)
return True
except Exception:
session.rollback()
log.exception("Failed to save delivery logs")
return False
finally:
session.close()
def get_delivery_logs(limit: int = 50) -> list[dict]:
"""Fetch recent delivery logs from the database."""
session = get_session()
if session is None:
return []
try:
rows = (
session.query(DeliveryLog)
.order_by(DeliveryLog.created_at.desc())
.limit(limit)
.all()
)
return [
{
"id": r.id,
"run_id": r.run_id,
"farmer_id": r.farmer_id,
"farmer_name": r.farmer_name,
"phone": r.phone,
"channel": r.channel,
"sms_text": r.sms_text,
"sms_text_local": r.sms_text_local,
"status": r.status,
"error": r.error,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
for r in rows
]
except Exception:
log.exception("Failed to fetch delivery logs")
return []
finally:
session.close()
def health_check() -> dict:
"""Check database connectivity."""
engine = get_engine()
if engine is None:
return {"status": "not_configured", "message": "DATABASE_URL not set"}
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return {"status": "ok", "message": "Connected to database"}
except Exception as e:
return {"status": "error", "message": str(e)}