EDA_Explorer / core /llm_planner.py
ProfessionalMario's picture
Fresh deployment with LFS tracking
9eecab5
import json
import os
import requests
from pathlib import Path
from utils.logger import logger
# OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434")
# OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "hf.co/bartowski/gemma-2-2b-it-GGUF:Q5_K_M")
class LLMPlanner:
"""
Schema-aware query planner using Ollama (default: gemma3).
Before calling the LLM, the planner reads the dataset's JSON metadata
(fast β€” no parquet touch) and injects column names and types into the
prompt so the model can resolve natural-language column references to
their exact names and pick the right preprocessing operation.
Falls back gracefully if Ollama is unavailable.
"""
VALID_AGENTS = {
"metadata_agent",
"dataframe_agent",
"visualization_agent",
"transformer_agent",
"analysis_agent",
}
VALID_OPERATIONS = {
# metadata
"columns", "numeric_columns", "categorical_columns",
"missing_values", "column_count",
# dataframe
"head", "row_count", "mean", "max", "min",
# visualization
"histogram", "bar_chart",
# transformer β€” cleaning
"drop_duplicates", "fill_nulls", "drop_column",
"drop_constant_columns", "strip_whitespace",
"drop_missing_rows", "drop_missing_cols",
# transformer β€” explicit fill strategies
"fill_mean", "fill_median", "fill_mode", "fill_zero",
# transformer β€” transforms
"normalize", "standardize", "encode", "onehot", "rename",
# analysis
"analyze","analyse"
}
SYSTEM_PROMPT = """\
You are a planner for a data analysis CLI system.
Convert the user query into a JSON execution plan.
Return ONLY valid JSON with exactly this structure:
{
"agent": "<agent_name>",
"operation": "<operation>",
"dataset": "<exact dataset name or null>",
"column": "<exact column name from schema or null>"
}
Valid agents:
metadata_agent β€” schema / structure queries
dataframe_agent β€” statistics, row previews
visualization_agent β€” charts
transformer_agent β€” cleaning, filling, encoding, scaling
metadata_agent operations:
columns, numeric_columns, categorical_columns, missing_values, column_count
dataframe_agent operations:
head, row_count, mean, max, min
visualization_agent operations:
histogram, bar_chart
transformer_agent operations:
Cleaning : drop_duplicates, drop_column, drop_constant_columns,
strip_whitespace, drop_missing_rows, drop_missing_cols
Filling : fill_nulls (smart β€” auto picks mean/median/mode),
fill_mean, fill_median, fill_mode, fill_zero
Scaling : normalize (min-max [0,1]), standardize (z-score)
Encoding : encode (label), onehot (one-hot / get_dummies)
Other : rename
Rules:
- Output ONLY JSON β€” no explanation, no markdown, no extra keys.
- Use the EXACT column name from the schema context provided.
- If the query covers all columns (e.g. "fill all nulls"), set column to null.
- For queries about listing/showing structure β†’ metadata_agent.
- For queries about previewing data or computing statistics β†’ dataframe_agent.
- For fill operations: choose fill_mean/fill_median/fill_mode/fill_zero when the
user explicitly names a strategy; use fill_nulls when they don't.
"""
def __init__(self):
self.enabled = True
# logger.info(f"LLMPlanner ready | model={OLLAMA_MODEL} | base={OLLAMA_BASE_URL}")
# ── schema context ─────────────────────────────────────────────────────
def _load_schema_context(self, query):
"""
Scan the query for a known dataset name and load its JSON metadata.
Returns a compact schema string for injection into the LLM prompt.
Reads only the tiny JSON file β€” parquet is never touched.
"""
meta_dir = Path("data/metadata")
if not meta_dir.exists():
return None, ""
q = query.lower()
for meta_file in sorted(meta_dir.glob("*.json")):
name = meta_file.stem
if name.endswith("_clean"):
continue
if name.lower() not in q:
continue
try:
with open(meta_file) as f:
schema = json.load(f)
cols = schema.get("columns", [])
numeric = schema.get("numeric_columns", [])
cats = schema.get("categorical_columns", [])
col_types = schema.get("column_types", {})
rows = schema.get("rows", "?")
lines = [
f"Dataset '{name}' ({rows} rows, {len(cols)} columns):",
f" All columns : {', '.join(cols)}",
f" Numeric : {', '.join(numeric) if numeric else 'none'}",
f" Categorical : {', '.join(cats) if cats else 'none'}",
f" Column types : {', '.join(f'{k}:{v}' for k, v in col_types.items())}",
]
return name, "\n".join(lines)
except Exception as e:
logger.warning(f"Schema load failed for '{name}' | {e}")
return None, ""
# ── ollama call ────────────────────────────────────────────────────────
# def _call_ollama(self, user_query, schema_context=""):
# """POST to local Ollama API and return the raw response string."""
# try:
# schema_block = (
# f"\n\nSchema context (use exact column names from here):\n{schema_context}"
# if schema_context else ""
# )
# prompt = f"{self.SYSTEM_PROMPT}{schema_block}\n\nUser Query: {user_query}\n\nJSON:"
# response = requests.post(
# f"{OLLAMA_BASE_URL}/api/generate",
# json={
# "model": OLLAMA_MODEL,
# "prompt": prompt,
# "format": "json",
# "stream": False,
# "options": {
# "temperature": 0,
# "top_p": 0.9,
# "num_predict": 100,
# "stop": ["\n\n"],
# },
# },
# timeout=30,
# )
# response.raise_for_status()
# return response.json().get("response", "").strip()
# except Exception as e:
# logger.error(f"Ollama call failed | {e}")
# return None
# ── ollama call with huggingface fallback ────────────────────────────────
def _call_ollama(self, user_query, schema_context=""):
"""
POST to local Ollama API. Falls back to Hugging Face Serverless API
if local service is unavailable, missing, or times out.
"""
# Add this environment check at the top of your function
if os.environ.get("ENVIRONMENT") == "production":
logger.info("Production mode: Skipping local Ollama check. Routing directly to Hugging Face...")
# Jump straight to your Hugging Face API request logic here!
schema_block = (
f"\n\nSchema context (use exact column names from here):\n{schema_context}"
if schema_context else ""
)
prompt = f"{self.SYSTEM_PROMPT}{schema_block}\n\nUser Query: {user_query}\n\nJSON:"
# 1. Try Local Ollama First
try:
logger.info("Attempting local Ollama generation...")
# response = requests.post(
# f"{OLLAMA_BASE_URL}/api/generate",
# json={
# "model": OLLAMA_MODEL,
# "prompt": prompt,
# "format": "json",
# "stream": False,
# "options": {
# "temperature": 0,
# "top_p": 0.9,
# "num_predict": 100,
# "stop": ["\n\n"],
# },
# },
# timeout=5, # Reduced timeout so fallback triggers rapidly if offline
# )
# response.raise_for_status()
# return response.json().get("response", "").strip()
except Exception as local_err:
logger.warning(f"Local Ollama unavailable ({local_err}). Routing fallback to Hugging Face...")
# 2. Hugging Face Serverless Fallback
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
logger.error("Hugging Face fallback skipped: HF_TOKEN environment variable not set.")
return None
try:
# We use HF's serverless OpenAI-compatible Router endpoint
hf_url = "https://router.huggingface.co/v1/chat/completions"
headers = {
"Authorization": f"Bearer {hf_token}",
"Content-Type": "application/json"
}
# Format to structure for standard OpenAI/HF chat endpoint specs
hf_payload = {
"model": "google/gemma-3-12b-it", # Fallback leverages beefier cloud model
"messages": [
{"role": "user", "content": prompt}
],
"temperature": 0,
"max_tokens": 100,
"response_format": {"type": "json_object"} # Forces strict JSON format out of HF
}
hf_response = requests.post(hf_url, headers=headers, json=hf_payload, timeout=15)
hf_response.raise_for_status()
# Extract text out of OpenAI spec completion block
result_json = hf_response.json()
content = result_json['choices'][0]['message']['content'].strip()
logger.info("Successfully fetched plan from Hugging Face.")
return content
except Exception as hf_err:
logger.error(f"Hugging Face fallback also failed | {hf_err}")
return None
# ── public API ─────────────────────────────────────────────────────────
def plan(self, query):
"""
Return a validated execution plan dict, or None if unavailable.
The plan always contains: agent, operation, dataset, column.
"""
_, schema_ctx = self._load_schema_context(query)
content = self._call_ollama(query, schema_ctx)
if not content:
return None
try:
if content.startswith("```"):
content = content.split("```")[1]
if content.startswith("json"):
content = content[4:]
plan = json.loads(content)
agent = plan.get("agent")
operation = plan.get("operation")
if agent not in self.VALID_AGENTS:
logger.error(f"LLM returned invalid agent: {agent!r}")
return None
if operation not in self.VALID_OPERATIONS:
logger.error(f"LLM returned invalid operation: {operation!r}")
return None
plan.setdefault("dataset", None)
plan.setdefault("column", None)
logger.info(f"LLMPlanner plan β†’ {plan}")
return plan
except json.JSONDecodeError as e:
logger.error(f"LLM response not valid JSON | {e} | raw: {content!r}")
return None
except Exception as e:
logger.error(f"LLMPlanner error | {e}")
return None