jtlevine's picture
Complete Claude agents, Chronos-2 forecasting, and frontend polish
e837c6a
"""
Claude extraction agent for the Market Intelligence pipeline.
Normalizes heterogeneous price data from Agmarknet and eNAM into
canonical commodity IDs, standardized units, and validated dates.
Flags stale entries and anomalies. Falls back to rule-based regex
extraction when the Anthropic API is unavailable.
"""
from __future__ import annotations
import json
import logging
import math
import os
from dataclasses import dataclass, field
from typing import Any
from config import COMMODITY_MAP, COMMODITIES, MANDI_MAP, MANDIS
log = logging.getLogger(__name__)
# ── Output dataclass ────────────────────────────────────────────────────
@dataclass
class ExtractionResult:
"""Structured extraction output for a single mandi."""
mandi_id: str
normalized_prices: list[dict] = field(default_factory=list)
stale_entries: list[dict] = field(default_factory=list)
anomalies: list[dict] = field(default_factory=list)
commodity_mappings: dict = field(default_factory=dict)
extraction_method: str = "rule_based" # "claude" | "rule_based"
confidence: float = 0.0
tokens_used: int = 0
# ── Commodity name mapping (canonical aliases) ──────────────────────────
COMMODITY_ALIASES: dict[str, str] = {
# Rice variants
"paddy(samba)": "RICE-SAMBA",
"paddy samba": "RICE-SAMBA",
"samba paddy": "RICE-SAMBA",
"rice(paddy)": "RICE-SAMBA",
"rice paddy": "RICE-SAMBA",
"paddy": "RICE-SAMBA",
"rice": "RICE-SAMBA",
"samba": "RICE-SAMBA",
# Groundnut variants
"groundnut": "GNUT-POD",
"groundnut pods": "GNUT-POD",
"groundnut pods (raw)": "GNUT-POD",
"moongphali": "GNUT-POD",
"peanut": "GNUT-POD",
"groundnut(pods)": "GNUT-POD",
# Turmeric
"turmeric": "TUR-FIN",
"turmeric(finger)": "TUR-FIN",
"haldi": "TUR-FIN",
"turmeric finger": "TUR-FIN",
# Cotton
"cotton": "COT-MCU",
"cotton(kapas)": "COT-MCU",
"kapas": "COT-MCU",
"cotton kapas": "COT-MCU",
# Onion
"onion": "ONI-RED",
"onion red": "ONI-RED",
"vengayam": "ONI-RED",
# Coconut/Copra
"copra": "COP-DRY",
"coconut": "COP-DRY",
"coconut(copra)": "COP-DRY",
"copra(dry)": "COP-DRY",
# Maize
"maize": "MZE-YEL",
"maize(yellow)": "MZE-YEL",
"corn": "MZE-YEL",
"makka": "MZE-YEL",
# Black gram
"urad": "URD-BLK",
"urad dal": "URD-BLK",
"urad (black gram)": "URD-BLK",
"black gram": "URD-BLK",
"blackgram": "URD-BLK",
# Green gram
"moong": "MNG-GRN",
"moong(green gram)": "MNG-GRN",
"green gram": "MNG-GRN",
"greengram": "MNG-GRN",
"moong dal": "MNG-GRN",
# Banana
"banana": "BAN-ROB",
"banana(robusta)": "BAN-ROB",
"vazhai": "BAN-ROB",
}
# ── Claude tool definitions ─────────────────────────────────────────────
TOOLS = [
{
"name": "parse_agmarknet_entry",
"description": (
"Normalize an Agmarknet price record: map commodity name variants to "
"canonical ID, standardize units to per-quintal, validate date format. "
"Returns normalized record with canonical commodity_id."
),
"input_schema": {
"type": "object",
"properties": {
"raw_commodity_name": {"type": "string"},
"price_rs": {"type": "number"},
"unit": {"type": "string"},
"date_str": {"type": "string"},
"mandi_name": {"type": "string"},
},
"required": ["raw_commodity_name", "price_rs"],
},
},
{
"name": "parse_enam_listing",
"description": (
"Parse eNAM scraped data, handling the difference between "
"last traded price (eNAM) vs modal price (Agmarknet). "
"Returns normalized record."
),
"input_schema": {
"type": "object",
"properties": {
"raw_commodity_name": {"type": "string"},
"last_traded_price_rs": {"type": "number"},
"lot_size_quintals": {"type": "number"},
"trade_date": {"type": "string"},
"mandi_name": {"type": "string"},
},
"required": ["raw_commodity_name", "last_traded_price_rs"],
},
},
{
"name": "detect_stale_data",
"description": (
"Flag entries where price hasn't changed in 3+ consecutive days. "
"This is a common copy-paste artifact in mandi price reporting."
),
"input_schema": {
"type": "object",
"properties": {
"price_series": {
"type": "array",
"items": {"type": "object"},
"description": "Array of {date, price} objects sorted by date.",
},
"commodity_id": {"type": "string"},
"mandi_id": {"type": "string"},
},
"required": ["price_series"],
},
},
{
"name": "normalize_commodity",
"description": (
"Map a variant commodity name to the canonical taxonomy. "
"Handle: 'Groundnut' vs 'Groundnut Pods' vs 'Moongphali', "
"'Paddy' vs 'Paddy(Samba)' vs 'Rice(Paddy)', etc."
),
"input_schema": {
"type": "object",
"properties": {
"raw_name": {"type": "string"},
},
"required": ["raw_name"],
},
},
{
"name": "flag_anomalies",
"description": (
"Identify prices >3 standard deviations from 30-day rolling mean "
"for same mandi/commodity pair. Returns flagged anomalies."
),
"input_schema": {
"type": "object",
"properties": {
"price_series": {
"type": "array",
"items": {"type": "object"},
},
"commodity_id": {"type": "string"},
"mandi_id": {"type": "string"},
},
"required": ["price_series"],
},
},
]
SYSTEM_PROMPT = (
"You are a market data extraction agent for Tamil Nadu agricultural prices. "
"Normalize heterogeneous price records from Agmarknet (government mandi data) "
"and eNAM (electronic trading platform) into canonical commodity IDs and "
"standardized units. Flag stale or anomalous entries.\n\n"
"Commodity taxonomy:\n"
+ "\n".join(f" {c['id']}: {c['name']} ({c['agmarknet_name']})" for c in COMMODITIES)
+ "\n\nWhen normalizing commodity names, always map to the closest canonical ID. "
"Common variants: 'Paddy(Samba)' -> RICE-SAMBA, 'Groundnut pods' -> GNUT-POD, "
"'Cotton(Kapas)' -> COT-MCU, 'Urad' -> URD-BLK."
)
# ── Tool execution (local logic) ────────────────────────────────────────
def _execute_tool(tool_name: str, tool_input: dict) -> dict:
"""Execute a tool call locally, returning structured results."""
if tool_name == "parse_agmarknet_entry":
return _tool_parse_agmarknet(tool_input)
elif tool_name == "parse_enam_listing":
return _tool_parse_enam(tool_input)
elif tool_name == "detect_stale_data":
return _tool_detect_stale(tool_input)
elif tool_name == "normalize_commodity":
return _tool_normalize_commodity(tool_input)
elif tool_name == "flag_anomalies":
return _tool_flag_anomalies(tool_input)
else:
return {"error": f"Unknown tool: {tool_name}"}
def _tool_parse_agmarknet(inp: dict) -> dict:
"""Normalize an Agmarknet price entry."""
raw_name = inp.get("raw_commodity_name", "")
commodity_id = _match_commodity(raw_name)
price = inp.get("price_rs", 0)
unit = inp.get("unit", "quintal").lower()
# Unit conversion
if unit == "tonne" or unit == "mt":
price = price / 10 # convert tonne to quintal
return {
"commodity_id": commodity_id,
"original_name": raw_name,
"price_rs_per_quintal": price,
"unit_standardized": "quintal",
"valid": commodity_id is not None,
}
def _tool_parse_enam(inp: dict) -> dict:
"""Normalize an eNAM listing."""
raw_name = inp.get("raw_commodity_name", "")
commodity_id = _match_commodity(raw_name)
last_traded = inp.get("last_traded_price_rs", 0)
return {
"commodity_id": commodity_id,
"original_name": raw_name,
"price_rs_per_quintal": last_traded,
"price_type": "last_traded",
"note": "eNAM reports last traded price, not modal. May be higher than Agmarknet modal.",
"valid": commodity_id is not None,
}
def _tool_detect_stale(inp: dict) -> dict:
"""Detect stale (unchanged) prices in a series."""
series = inp.get("price_series", [])
stale_runs = []
if len(series) < 3:
return {"stale_entries": [], "note": "Series too short to detect staleness."}
# Sort by date
sorted_series = sorted(series, key=lambda x: x.get("date", ""))
current_run = [sorted_series[0]]
for i in range(1, len(sorted_series)):
if sorted_series[i].get("price") == sorted_series[i - 1].get("price"):
current_run.append(sorted_series[i])
else:
if len(current_run) >= 3:
stale_runs.append({
"start_date": current_run[0].get("date"),
"end_date": current_run[-1].get("date"),
"consecutive_days": len(current_run),
"price": current_run[0].get("price"),
})
current_run = [sorted_series[i]]
if len(current_run) >= 3:
stale_runs.append({
"start_date": current_run[0].get("date"),
"end_date": current_run[-1].get("date"),
"consecutive_days": len(current_run),
"price": current_run[0].get("price"),
})
return {"stale_entries": stale_runs, "total_stale_runs": len(stale_runs)}
def _tool_normalize_commodity(inp: dict) -> dict:
"""Map a variant name to canonical commodity ID."""
raw_name = inp.get("raw_name", "")
commodity_id = _match_commodity(raw_name)
commodity = COMMODITY_MAP.get(commodity_id) if commodity_id else None
return {
"raw_name": raw_name,
"commodity_id": commodity_id,
"canonical_name": commodity["name"] if commodity else None,
"category": commodity["category"] if commodity else None,
"match_confidence": 0.95 if commodity_id else 0.0,
}
def _tool_flag_anomalies(inp: dict) -> dict:
"""Flag prices >3 sigma from rolling mean."""
series = inp.get("price_series", [])
if len(series) < 10:
return {"anomalies": [], "note": "Series too short for anomaly detection."}
prices = [s.get("price", 0) for s in sorted(series, key=lambda x: x.get("date", ""))]
anomalies = []
window = 30
for i in range(window, len(prices)):
window_prices = prices[max(0, i - window):i]
mean = sum(window_prices) / len(window_prices)
variance = sum((p - mean) ** 2 for p in window_prices) / len(window_prices)
std = math.sqrt(variance) if variance > 0 else 1
if abs(prices[i] - mean) > 3 * std:
anomalies.append({
"index": i,
"date": series[i].get("date") if i < len(series) else None,
"price": prices[i],
"rolling_mean": round(mean, 0),
"rolling_std": round(std, 0),
"z_score": round((prices[i] - mean) / std, 2),
})
return {"anomalies": anomalies, "total_anomalies": len(anomalies)}
# ── Commodity matching ───────────────────────────────────────────────────
# Pre-sorted alias keys (longest first) — avoids re-sorting on every call
_SORTED_ALIASES = sorted(COMMODITY_ALIASES.keys(), key=len, reverse=True)
def _match_commodity(raw_name: str) -> str | None:
"""Match a raw commodity name to canonical ID."""
if not raw_name:
return None
name_lower = raw_name.lower().strip()
# Direct alias match
if name_lower in COMMODITY_ALIASES:
return COMMODITY_ALIASES[name_lower]
# Substring match (longest alias first)
for alias in _SORTED_ALIASES:
if alias in name_lower:
return COMMODITY_ALIASES[alias]
# Try matching against canonical names
for c in COMMODITIES:
if c["name"].lower() in name_lower or name_lower in c["name"].lower():
return c["id"]
if c["agmarknet_name"].lower() in name_lower or name_lower in c["agmarknet_name"].lower():
return c["id"]
return None
# ── Rule-based fallback ─────────────────────────────────────────────────
class RuleBasedExtractor:
"""Regex-based extraction when Claude is unavailable."""
@classmethod
def extract_prices(cls, price_records: list[dict], mandi_id: str) -> ExtractionResult:
"""Normalize and validate a list of price records for a mandi."""
result = ExtractionResult(mandi_id=mandi_id, extraction_method="rule_based")
mappings: dict[str, str] = {}
for rec in price_records:
# Normalize commodity name
raw_name = rec.get("commodity_name", rec.get("commodity_id", ""))
commodity_id = rec.get("commodity_id")
if commodity_id not in COMMODITY_MAP:
commodity_id = _match_commodity(raw_name)
if commodity_id:
mappings[raw_name] = commodity_id
if commodity_id is None:
continue
normalized = {
"mandi_id": mandi_id,
"commodity_id": commodity_id,
"date": rec.get("date"),
"min_price_rs": rec.get("min_price_rs", 0),
"max_price_rs": rec.get("max_price_rs", 0),
"modal_price_rs": rec.get("modal_price_rs", 0),
"arrivals_tonnes": rec.get("arrivals_tonnes", 0),
"source": rec.get("source", "unknown"),
"quality_flag": rec.get("quality_flag", "good"),
}
result.normalized_prices.append(normalized)
# Detect stale entries
cls._detect_stale_entries(result)
# Detect anomalies
cls._detect_anomalies(result)
result.commodity_mappings = mappings
result.confidence = 0.75 if result.normalized_prices else 0.3
return result
@classmethod
def _detect_stale_entries(cls, result: ExtractionResult):
"""Flag entries where price hasn't changed for 3+ days."""
from collections import defaultdict
by_commodity: dict[str, list[dict]] = defaultdict(list)
for p in result.normalized_prices:
by_commodity[p["commodity_id"]].append(p)
for commodity_id, prices in by_commodity.items():
sorted_prices = sorted(prices, key=lambda x: x.get("date", ""))
current_run = [sorted_prices[0]] if sorted_prices else []
for i in range(1, len(sorted_prices)):
if sorted_prices[i]["modal_price_rs"] == sorted_prices[i - 1]["modal_price_rs"]:
current_run.append(sorted_prices[i])
else:
if len(current_run) >= 3:
for entry in current_run:
entry["quality_flag"] = "stale"
result.stale_entries.append(entry)
current_run = [sorted_prices[i]]
if len(current_run) >= 3:
for entry in current_run:
entry["quality_flag"] = "stale"
result.stale_entries.append(entry)
@classmethod
def _detect_anomalies(cls, result: ExtractionResult):
"""Flag prices >3 sigma from 30-day rolling mean."""
from collections import defaultdict
by_commodity: dict[str, list[dict]] = defaultdict(list)
for p in result.normalized_prices:
by_commodity[p["commodity_id"]].append(p)
window = 30
for commodity_id, prices in by_commodity.items():
sorted_prices = sorted(prices, key=lambda x: x.get("date", ""))
modal_prices = [p["modal_price_rs"] for p in sorted_prices]
# Need at least 10 data points before we can start detecting
if len(modal_prices) < 10:
continue
for i in range(len(sorted_prices)):
# Build the rolling window from preceding entries
window_start = max(0, i - window)
window_vals = modal_prices[window_start:i]
# Need sufficient history to compute meaningful stats
if len(window_vals) < 10:
continue
mean = sum(window_vals) / len(window_vals)
variance = sum((v - mean) ** 2 for v in window_vals) / len(window_vals)
std = math.sqrt(variance) if variance > 0 else 1
z_score = (modal_prices[i] - mean) / std
if abs(z_score) > 3:
sorted_prices[i]["quality_flag"] = "anomalous"
result.anomalies.append({
**sorted_prices[i],
"z_score": round(z_score, 2),
"rolling_mean": round(mean, 0),
})
# ── Claude agent ────────────────────────────────────────────────────────
class ExtractionAgent:
"""Multi-round Claude tool-use agent for market data extraction.
Falls back to RuleBasedExtractor when the Anthropic API is unavailable
or ANTHROPIC_API_KEY is not set.
"""
MAX_ROUNDS = 6
def __init__(self, model: str = "claude-sonnet-4-20250514"):
self.model = model
self._client = None
self._fallback = RuleBasedExtractor()
def _get_client(self):
"""Lazy-init the Anthropic client."""
if self._client is not None:
return self._client
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
log.warning("ANTHROPIC_API_KEY not set -- using rule-based fallback")
return None
try:
import anthropic
self._client = anthropic.Anthropic(api_key=api_key)
return self._client
except ImportError:
log.warning("anthropic package not installed -- using rule-based fallback")
return None
def extract(
self,
mandi_id: str,
agmarknet_records: list[dict] | None = None,
enam_records: list[dict] | None = None,
) -> ExtractionResult:
"""Run extraction for a single mandi.
Attempts Claude agent loop first; falls back to regex if unavailable.
"""
client = self._get_client()
if client is not None:
return self._claude_extract(client, mandi_id, agmarknet_records, enam_records)
return self._rule_based_extract(mandi_id, agmarknet_records, enam_records)
def _claude_extract(
self,
client: Any,
mandi_id: str,
agmarknet_records: list[dict] | None,
enam_records: list[dict] | None,
) -> ExtractionResult:
"""Multi-round tool-use loop with Claude."""
result = ExtractionResult(mandi_id=mandi_id, extraction_method="claude")
tools_used: list[str] = []
total_tokens = 0
# Collect tool results across rounds for post-loop aggregation
collected_normalized: list[dict] = []
collected_stale: list[dict] = []
collected_anomalies: list[dict] = []
collected_mappings: dict[str, str] = {}
mandi = MANDI_MAP.get(mandi_id)
parts = [f"Extract and normalize price data for mandi {mandi_id}"]
if mandi:
parts.append(f"({mandi.name}, {mandi.district}, type={mandi.market_type})")
if agmarknet_records:
parts.append(f"\n--- AGMARKNET RECORDS ({len(agmarknet_records)} entries) ---")
for rec in agmarknet_records[:20]:
parts.append(json.dumps(rec, default=str))
if enam_records:
parts.append(f"\n--- eNAM RECORDS ({len(enam_records)} entries) ---")
for rec in enam_records[:20]:
parts.append(json.dumps(rec, default=str))
parts.append(
"\nUse the available tools to:\n"
"1. Call normalize_commodity for each unique commodity name\n"
"2. Call parse_agmarknet_entry / parse_enam_listing for each record\n"
"3. Call detect_stale_data for each mandi/commodity price series\n"
"4. Call flag_anomalies for each mandi/commodity price series\n"
"Return normalized prices in JSON."
)
messages: list[dict] = [{"role": "user", "content": "\n".join(parts)}]
for round_num in range(self.MAX_ROUNDS):
try:
response = client.messages.create(
model=self.model,
max_tokens=4096,
system=SYSTEM_PROMPT,
tools=TOOLS,
messages=messages,
)
except Exception as e:
log.error("Claude API error on round %d: %s", round_num, e)
return self._rule_based_extract(mandi_id, agmarknet_records, enam_records)
if hasattr(response, "usage"):
total_tokens += getattr(response.usage, "input_tokens", 0)
total_tokens += getattr(response.usage, "output_tokens", 0)
tool_calls = []
for block in response.content:
if block.type == "tool_use":
tool_calls.append(block)
tools_used.append(block.name)
if response.stop_reason == "end_turn" or not tool_calls:
break
messages.append({"role": "assistant", "content": response.content})
tool_results = []
for tc in tool_calls:
tool_result = _execute_tool(tc.name, tc.input)
tool_results.append({
"type": "tool_result",
"tool_use_id": tc.id,
"content": json.dumps(tool_result),
})
# Accumulate structured data from each tool's output
self._accumulate_tool_result(
tc.name, tc.input, tool_result, mandi_id,
collected_normalized, collected_stale,
collected_anomalies, collected_mappings,
)
messages.append({"role": "user", "content": tool_results})
# If Claude called parse tools, use accumulated results directly.
# If Claude only called normalize/detect/flag tools but not parse,
# build normalized prices from raw records + commodity mappings.
if collected_normalized:
result.normalized_prices = collected_normalized
else:
# Claude may have only used normalize_commodity + detect/flag.
# Build normalized prices from raw input using the mappings Claude found.
all_records = (agmarknet_records or []) + (enam_records or [])
for rec in all_records:
raw_name = rec.get("commodity_name", rec.get("commodity_id", ""))
commodity_id = rec.get("commodity_id")
# Try Claude's mappings first, then our local matcher
if commodity_id not in COMMODITY_MAP:
commodity_id = collected_mappings.get(raw_name) or _match_commodity(raw_name)
if commodity_id is None:
continue
result.normalized_prices.append({
"mandi_id": mandi_id,
"commodity_id": commodity_id,
"date": rec.get("date"),
"min_price_rs": rec.get("min_price_rs", 0),
"max_price_rs": rec.get("max_price_rs", 0),
"modal_price_rs": rec.get("modal_price_rs", 0),
"arrivals_tonnes": rec.get("arrivals_tonnes", 0),
"source": rec.get("source", "unknown"),
"quality_flag": rec.get("quality_flag", "good"),
})
result.stale_entries = collected_stale
result.anomalies = collected_anomalies
result.commodity_mappings = collected_mappings
result.tokens_used = total_tokens
result.confidence = 0.85 if result.normalized_prices else 0.5
return result
@staticmethod
def _accumulate_tool_result(
tool_name: str,
tool_input: dict,
tool_result: dict,
mandi_id: str,
collected_normalized: list[dict],
collected_stale: list[dict],
collected_anomalies: list[dict],
collected_mappings: dict[str, str],
):
"""Parse a single tool call result into the appropriate collection."""
if tool_name == "parse_agmarknet_entry":
cid = tool_result.get("commodity_id")
if cid and tool_result.get("valid"):
collected_normalized.append({
"mandi_id": mandi_id,
"commodity_id": cid,
"date": tool_input.get("date_str"),
"min_price_rs": 0,
"max_price_rs": 0,
"modal_price_rs": tool_result.get("price_rs_per_quintal", 0),
"arrivals_tonnes": 0,
"source": "agmarknet",
"quality_flag": "good",
})
raw = tool_input.get("raw_commodity_name", "")
if raw:
collected_mappings[raw] = cid
elif tool_name == "parse_enam_listing":
cid = tool_result.get("commodity_id")
if cid and tool_result.get("valid"):
collected_normalized.append({
"mandi_id": mandi_id,
"commodity_id": cid,
"date": tool_input.get("trade_date"),
"min_price_rs": 0,
"max_price_rs": 0,
"modal_price_rs": tool_result.get("price_rs_per_quintal", 0),
"arrivals_tonnes": 0,
"source": "enam",
"quality_flag": "good",
})
raw = tool_input.get("raw_commodity_name", "")
if raw:
collected_mappings[raw] = cid
elif tool_name == "normalize_commodity":
cid = tool_result.get("commodity_id")
raw = tool_result.get("raw_name", "")
if cid and raw:
collected_mappings[raw] = cid
elif tool_name == "detect_stale_data":
stale_entries = tool_result.get("stale_entries", [])
cid = tool_input.get("commodity_id", "unknown")
mid = tool_input.get("mandi_id", mandi_id)
for entry in stale_entries:
collected_stale.append({
"mandi_id": mid,
"commodity_id": cid,
"start_date": entry.get("start_date"),
"end_date": entry.get("end_date"),
"consecutive_days": entry.get("consecutive_days"),
"price": entry.get("price"),
"quality_flag": "stale",
})
elif tool_name == "flag_anomalies":
anomalies = tool_result.get("anomalies", [])
cid = tool_input.get("commodity_id", "unknown")
mid = tool_input.get("mandi_id", mandi_id)
for entry in anomalies:
collected_anomalies.append({
"mandi_id": mid,
"commodity_id": cid,
"date": entry.get("date"),
"price": entry.get("price"),
"z_score": entry.get("z_score"),
"rolling_mean": entry.get("rolling_mean"),
"rolling_std": entry.get("rolling_std"),
"quality_flag": "anomalous",
})
def _rule_based_extract(
self,
mandi_id: str,
agmarknet_records: list[dict] | None,
enam_records: list[dict] | None,
) -> ExtractionResult:
"""Fallback extraction using regex-based approach."""
all_records = []
if agmarknet_records:
all_records.extend(agmarknet_records)
if enam_records:
all_records.extend(enam_records)
return self._fallback.extract_prices(all_records, mandi_id)