|
|
import json |
|
|
import time |
|
|
import logging |
|
|
import re |
|
|
from datetime import datetime |
|
|
from typing import Dict, Any, List, Optional |
|
|
|
|
|
|
|
|
try: |
|
|
from app.core.config import get_groq_client, get_config |
|
|
except ImportError: |
|
|
|
|
|
import logging |
|
|
logging.getLogger(__name__).warning("Could not import config. Mocking for syntax check.") |
|
|
get_groq_client = lambda: None |
|
|
get_config = lambda: type('Config', (), {'MODEL_NAME': 'openai/gpt-oss-120b'})() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger("rivergen_agents") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_and_parse_json(raw_content: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Production-grade JSON parser that handles common LLM formatting issues. |
|
|
""" |
|
|
try: |
|
|
return json.loads(raw_content) |
|
|
except json.JSONDecodeError: |
|
|
clean_text = re.sub(r"```json\s*|\s*```", "", raw_content, flags=re.IGNORECASE).strip() |
|
|
try: |
|
|
return json.loads(clean_text) |
|
|
except json.JSONDecodeError as e: |
|
|
logger.error(f"JSON Parsing Failed. Raw content sample: {raw_content[:200]}...") |
|
|
raise ValueError(f"LLM returned invalid JSON format: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def router_agent(full_payload: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Analyzes input to route requests. |
|
|
Includes token usage tracking for cost observability. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
start_time = time.time() |
|
|
request_id = full_payload.get("request_id", "unknown_id") |
|
|
logger.info(f"π§ [Router] Analyzing Request ID: {request_id}") |
|
|
|
|
|
|
|
|
data_sources = full_payload.get('data_sources', []) |
|
|
source_summary = [] |
|
|
for ds in data_sources: |
|
|
source_summary.append({ |
|
|
"name": ds.get("name"), |
|
|
"type": ds.get("type", "unknown").lower() |
|
|
}) |
|
|
|
|
|
routing_context = { |
|
|
"user_prompt": full_payload.get("user_prompt"), |
|
|
"data_source_count": len(data_sources), |
|
|
"data_sources": source_summary, |
|
|
"context_roles": full_payload.get("user_context", {}).get("roles", []) |
|
|
} |
|
|
|
|
|
system_prompt = """ |
|
|
You are the **Master Router** for RiverGen AI. |
|
|
Route the request based on Data Source Counts and Types. |
|
|
|
|
|
**ROUTING RULES (STRICT):** |
|
|
1. **Multi-Source**: If `data_source_count` > 1 -> SELECT `multi_source_agent` (IMMEDIATELY). |
|
|
2. **Streaming**: If prompt mentions 'consume', 'topic', 'kafka', or 'stream' -> SELECT `stream_agent`. |
|
|
3. **Single Source Logic**: |
|
|
- Type 'postgresql', 'oracle', 'mysql', 'sqlserver' -> `sql_agent` |
|
|
- Type 'mongodb', 'dynamodb', 'redis', 'cassandra' -> `nosql_agent` |
|
|
- Type 'snowflake', 'bigquery', 'redshift', 's3' -> `big_data_agent` |
|
|
- Type 'pinecone', 'weaviate', 'vector' -> `vector_store_agent` |
|
|
4. **Machine Learning**: If prompt mentions 'train', 'model', 'predict' -> SELECT `ml_agent`. |
|
|
|
|
|
**OUTPUT FORMAT:** |
|
|
Return ONLY valid JSON: |
|
|
{ |
|
|
"selected_agent": "agent_name", |
|
|
"confidence": 1.0, |
|
|
"reasoning": "Brief explanation" |
|
|
} |
|
|
""" |
|
|
|
|
|
try: |
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": json.dumps(routing_context, indent=2)} |
|
|
], |
|
|
temperature=0, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
raw_response = completion.choices[0].message.content |
|
|
result = clean_and_parse_json(raw_response) |
|
|
|
|
|
|
|
|
|
|
|
usage_stats = { |
|
|
"input_tokens": completion.usage.prompt_tokens, |
|
|
"output_tokens": completion.usage.completion_tokens, |
|
|
"total_tokens": completion.usage.total_tokens |
|
|
} |
|
|
|
|
|
|
|
|
result["usage"] = usage_stats |
|
|
|
|
|
duration = (time.time() - start_time) * 1000 |
|
|
logger.info(f"π [Router] Selected: {result.get('selected_agent')} - {duration:.2f}ms") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Router Agent Failed: {str(e)}", exc_info=True) |
|
|
|
|
|
|
|
|
empty_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} |
|
|
|
|
|
|
|
|
if len(data_sources) > 1: |
|
|
return { |
|
|
"selected_agent": "multi_source_agent", |
|
|
"confidence": 0.5, |
|
|
"reasoning": "Fallback: Multiple sources.", |
|
|
"usage": empty_usage |
|
|
} |
|
|
|
|
|
return { |
|
|
"error": "Routing Failed", |
|
|
"selected_agent": "error_handler", |
|
|
"usage": empty_usage |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream_agent(payload: Dict[str, Any], feedback: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Step 3/4 (Branch D): Generates an Analytical Streaming Execution Plan. |
|
|
Hardened for Windowing, Aggregations, and Anomaly Detection. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"π‘ [Stream Agent] Generating plan... Feedback: {bool(feedback)}") |
|
|
|
|
|
try: |
|
|
|
|
|
data_sources = payload.get('data_sources', []) |
|
|
schema_summary = [] |
|
|
known_fields = [] |
|
|
|
|
|
|
|
|
primary_ds_id = data_sources[0].get("data_source_id", 1) if data_sources else 1 |
|
|
|
|
|
for ds in data_sources: |
|
|
ds_name = ds.get('name', 'Unknown Stream') |
|
|
|
|
|
|
|
|
|
|
|
schemas = ds.get('schemas') or [] |
|
|
topics = ds.get('topics') or [] |
|
|
|
|
|
|
|
|
for schema in schemas: |
|
|
for table in schema.get('tables', []): |
|
|
t_name = table.get('table_name') |
|
|
cols = [c['column_name'] for c in table.get('columns', [])] |
|
|
known_fields.extend(cols) |
|
|
schema_summary.append(f"Source: {ds_name} | Topic: {t_name} | Fields: {', '.join(cols)}") |
|
|
|
|
|
|
|
|
for topic in topics: |
|
|
t_name = topic.get('topic_name') |
|
|
cols = [f['field_name'] for f in topic.get('fields', [])] |
|
|
known_fields.extend(cols) |
|
|
schema_summary.append(f"Source: {ds_name} | Topic: {t_name} | Fields: {', '.join(cols)}") |
|
|
|
|
|
|
|
|
response_template = { |
|
|
"request_id": payload.get("request_id"), |
|
|
"status": "success", |
|
|
"intent_type": "stream_analytics", |
|
|
"execution_plan": { |
|
|
"strategy": "stream_processor", |
|
|
"type": "kafka_streams_config", |
|
|
"operations": [ |
|
|
{ |
|
|
"step": 1, |
|
|
"type": "consume_and_process", |
|
|
"operation_type": "read_process", |
|
|
"data_source_id": primary_ds_id, |
|
|
"query_payload": { |
|
|
"topic": "<<TOPIC_NAME>>", |
|
|
"offset_strategy": "latest", |
|
|
"windowing": { |
|
|
"enabled": False, |
|
|
"window_type": "tumbling", |
|
|
"size_seconds": 60, |
|
|
"aggregation_functions": [] |
|
|
}, |
|
|
"analytics": { |
|
|
"calculate_moving_average": False, |
|
|
"anomaly_detection": False, |
|
|
"metrics": [] |
|
|
}, |
|
|
"filter_expression": {}, |
|
|
"limit": 1000 |
|
|
}, |
|
|
"governance_applied": {"note": "Stream encryption and PII masking applied"} |
|
|
} |
|
|
] |
|
|
}, |
|
|
"ai_metadata": { |
|
|
"confidence_score": 0.0, |
|
|
"reasoning_steps": [] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are the **Stream Agent** for RiverGen AI. |
|
|
Generate high-fidelity Kafka Streams or KSQL configurations. |
|
|
|
|
|
**INPUT CONTEXT:** |
|
|
- User Prompt: "{payload.get('user_prompt')}" |
|
|
- Available Streams: {chr(10).join(schema_summary)} |
|
|
- Current Date: {datetime.now().strftime("%Y-%m-%d")} |
|
|
|
|
|
**STRICT EXECUTION RULES:** |
|
|
|
|
|
1. **Temporal Windowing**: |
|
|
- If "windowing", "time windows", or specific durations (e.g., "per minute") are mentioned, set `windowing.enabled: true`. |
|
|
- Default `size_seconds` is 60. |
|
|
|
|
|
2. **Analytical Logic**: |
|
|
- "Moving average" -> `analytics.calculate_moving_average: true`. |
|
|
- "Anomalies" / "Outliers" -> `analytics.anomaly_detection: true`. |
|
|
|
|
|
3. **Payload Filtering**: |
|
|
- Distill filters (e.g., "only event_type login") into `filter_expression`. |
|
|
- **HALLUCINATION CHECK**: ONLY use fields from: {', '.join(known_fields)}. |
|
|
|
|
|
4. **Consumer Mapping**: |
|
|
- Map the schema "Topic" to the `query_payload.topic` field. |
|
|
- If prompt implies historical analysis (e.g., "replay", "from start"), set `offset_strategy` to 'earliest'. |
|
|
|
|
|
**OUTPUT FORMAT:** |
|
|
Return ONLY a valid JSON object matching the template exactly. |
|
|
{json.dumps(response_template, indent=2)} |
|
|
""" |
|
|
|
|
|
if feedback: |
|
|
system_prompt += f"\n\nπ¨ **FIX PREVIOUS ERROR**: {feedback}" |
|
|
|
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": f"ID: {payload.get('request_id')}"} |
|
|
], |
|
|
temperature=0, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
|
|
|
lean_response = clean_and_parse_json(completion.choices[0].message.content) |
|
|
|
|
|
|
|
|
generation_time_ms = int((time.time() - start_time) * 1000) |
|
|
|
|
|
|
|
|
if "ai_metadata" not in lean_response: |
|
|
lean_response["ai_metadata"] = {} |
|
|
|
|
|
lean_response["ai_metadata"]["generation_time_ms"] = generation_time_ms |
|
|
lean_response["ai_metadata"]["model"] = config.MODEL_NAME |
|
|
|
|
|
return lean_response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Stream Agent Failed: {e}", exc_info=True) |
|
|
return {"error": f"Stream Agent Failed: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sql_agent(payload: Dict[str, Any], feedback: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Step 3/4: Generates a Dialect-Aware Execution Plan. |
|
|
Enforces Transaction Safety and Literal RLS Injection. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"πΎ [SQL Agent] Generating plan... Feedback: {bool(feedback)}") |
|
|
|
|
|
try: |
|
|
|
|
|
data_sources = payload.get('data_sources', []) |
|
|
|
|
|
primary_ds = data_sources[0] if data_sources else {} |
|
|
db_type = primary_ds.get('type', 'postgresql').lower() |
|
|
ds_id = primary_ds.get('data_source_id', 1) |
|
|
|
|
|
|
|
|
user_context = payload.get('user_context', {}) |
|
|
user_id = user_context.get("user_id", 0) |
|
|
|
|
|
|
|
|
context_vars = { |
|
|
"user_id": user_id, |
|
|
"org_id": user_context.get("organization_id"), |
|
|
"attributes": user_context.get("attributes", {}) |
|
|
} |
|
|
|
|
|
schema_summary = [] |
|
|
governance_instructions = [] |
|
|
|
|
|
for ds in data_sources: |
|
|
ds_name = ds.get('name', 'Unknown Source') |
|
|
|
|
|
schemas = ds.get('schemas') or [] |
|
|
|
|
|
for schema in schemas: |
|
|
|
|
|
tables = schema.get('tables') or [] |
|
|
for table in tables: |
|
|
t_name = table.get('table_name') |
|
|
|
|
|
cols_data = table.get('columns') or [] |
|
|
cols = [c.get('column_name') for c in cols_data if c.get('column_name')] |
|
|
|
|
|
if cols: |
|
|
schema_summary.append(f"Table: {t_name} | Columns: {', '.join(cols)}") |
|
|
|
|
|
|
|
|
policies = ds.get('governance_policies', {}) |
|
|
if policies: |
|
|
rls = policies.get("row_level_security", {}) |
|
|
if rls.get("enabled"): |
|
|
|
|
|
governance_instructions.append( |
|
|
f"β οΈ MANDATORY RLS FOR '{ds_name}': You MUST add the following filter to the 'customers' table: " |
|
|
f"`region IN (SELECT region FROM user_access WHERE user_id = {user_id})`. " |
|
|
f"Inject the literal value {user_id}." |
|
|
) |
|
|
|
|
|
|
|
|
lean_template = { |
|
|
"intent_summary": "<<BRIEF_SUMMARY>>", |
|
|
"sql_statement": f"<<VALID_{db_type.upper()}_SQL>>", |
|
|
"governance_explanation": "<<CONFIRM_RLS>>", |
|
|
"confidence_score": 0.0, |
|
|
"reasoning_steps": ["<<STEP_1>>", "<<STEP_2>>"], |
|
|
"visualization_config": [], |
|
|
"suggestions": [] |
|
|
} |
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are the **SQL Agent**. |
|
|
|
|
|
Generate a secure JSON plan for **{db_type.upper()}**. |
|
|
|
|
|
**SQL BEST PRACTICES ({db_type.upper()}):** |
|
|
- Use {db_type} specific syntax (e.g., {'SYSDATE' if db_type == 'oracle' else 'CURRENT_DATE'}). |
|
|
- For WRITE/DELETE, wrap in `BEGIN;` and `COMMIT;`. |
|
|
- RLS: {chr(10).join(governance_instructions) if governance_instructions else "None."} |
|
|
|
|
|
**SCHEMA:** |
|
|
{chr(10).join(schema_summary)} |
|
|
|
|
|
**OUTPUT FORMAT:** |
|
|
Return ONLY a valid JSON object matching the template exactly. |
|
|
{json.dumps(lean_template, indent=2)} |
|
|
""" |
|
|
|
|
|
if feedback: |
|
|
system_prompt += f"\nπ¨ **FIX PREVIOUS ERROR**: {feedback}" |
|
|
|
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": f"Request ID: {payload.get('request_id')}"} |
|
|
], |
|
|
temperature=0, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
|
|
|
lean_response = clean_and_parse_json(completion.choices[0].message.content) |
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
generation_time_ms = int((end_time - start_time) * 1000) |
|
|
|
|
|
|
|
|
sql_stmt = lean_response.get("sql_statement", "") |
|
|
op_type = "read" if "SELECT" in sql_stmt.upper() and "INSERT" not in sql_stmt.upper() else "write" |
|
|
|
|
|
|
|
|
final_plan = { |
|
|
"request_id": payload.get("request_id"), |
|
|
"execution_id": payload.get("execution_id"), |
|
|
"plan_id": f"plan-{payload.get('request_id')}", |
|
|
"status": "success", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"intent_type": "analytical_query", |
|
|
"intent_summary": lean_response.get("intent_summary", ""), |
|
|
"execution_plan": { |
|
|
"strategy": "pushdown", |
|
|
"type": "sql_query", |
|
|
"operations": [{ |
|
|
"step": 1, |
|
|
"operation_type": op_type, |
|
|
"compute_engine": db_type, |
|
|
"query": sql_stmt, |
|
|
"query_payload": { |
|
|
"language": "sql", |
|
|
"dialect": db_type, |
|
|
"statement": sql_stmt |
|
|
}, |
|
|
"governance_applied": {"rls_rules": governance_instructions} |
|
|
}] |
|
|
}, |
|
|
"visualization": lean_response.get("visualization_config", []), |
|
|
"ai_metadata": { |
|
|
"generation_time_ms": generation_time_ms, |
|
|
"confidence_score": lean_response.get("confidence_score", 0.0), |
|
|
"explanation": lean_response.get("governance_explanation", ""), |
|
|
"reasoning_steps": lean_response.get("reasoning_steps", []) |
|
|
}, |
|
|
"suggestions": lean_response.get("suggestions", []) |
|
|
} |
|
|
return final_plan |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"SQL Agent Failed: {e}", exc_info=True) |
|
|
return {"error": f"SQL Agent Failed: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vector_store_agent(payload: Dict[str, Any], feedback: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Step 3/4: Generates a RiverGen Execution Plan for Vector Databases. |
|
|
Hardened for strict Judge compliance and correct query payload structure. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"π― [Vector Agent] Generating plan... Feedback: {bool(feedback)}") |
|
|
|
|
|
try: |
|
|
|
|
|
data_sources = payload.get("data_sources", []) |
|
|
primary_ds = data_sources[0] if data_sources else {} |
|
|
ds_id = primary_ds.get("data_source_id") |
|
|
ds_name = primary_ds.get("name") |
|
|
db_type = primary_ds.get("type", "vector") |
|
|
|
|
|
|
|
|
exec_ctx = payload.get("execution_context", {}) |
|
|
default_top_k = exec_ctx.get("max_rows", 10) |
|
|
|
|
|
|
|
|
schema_summary = [] |
|
|
valid_metadata_fields = [] |
|
|
|
|
|
|
|
|
schemas = primary_ds.get("schemas") or [] |
|
|
|
|
|
for schema in schemas: |
|
|
for table in schema.get("tables", []) or []: |
|
|
t_name = table.get('table_name') |
|
|
cols_data = table.get('columns') or [] |
|
|
cols = [] |
|
|
|
|
|
for c in cols_data: |
|
|
col_name = c.get('column_name') |
|
|
col_type = c.get('column_type', 'unknown') |
|
|
cols.append(f"{col_name} ({col_type})") |
|
|
|
|
|
|
|
|
|
|
|
if "vector" not in col_type.lower() and col_name != "id": |
|
|
valid_metadata_fields.append(col_name) |
|
|
|
|
|
schema_summary.append(f"Index: {t_name} | Fields: {', '.join(cols)}") |
|
|
|
|
|
|
|
|
lean_template = { |
|
|
"intent_summary": "<<BRIEF_SUMMARY>>", |
|
|
"vector_search_config": { |
|
|
"index_name": "<<INDEX_NAME_FROM_SCHEMA>>", |
|
|
"vector_column": "<<VECTOR_COLUMN_FROM_SCHEMA>>", |
|
|
"query_text": "<<SEMANTIC_SEARCH_TEXT>>", |
|
|
"top_k": 10, |
|
|
"filters": {} |
|
|
}, |
|
|
"reasoning_steps": ["<<STEP_1>>", "<<STEP_2>>"], |
|
|
"suggestions": ["<<SUGGESTION>>"] |
|
|
} |
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are the **Vector Store Agent**. |
|
|
|
|
|
**OBJECTIVE:** |
|
|
Generate a valid vector search configuration for {db_type.upper()}. |
|
|
|
|
|
**INPUT CONTEXT:** |
|
|
- User Prompt: "{payload.get('user_prompt')}" |
|
|
- Default Top-K: {default_top_k} |
|
|
|
|
|
**AVAILABLE SCHEMA:** |
|
|
{chr(10).join(schema_summary)} |
|
|
|
|
|
**VALID FILTERS:** |
|
|
{json.dumps(valid_metadata_fields)} |
|
|
|
|
|
**STRICT RULES:** |
|
|
1. **Target Index**: You MUST use the exact 'Index' name from the Available Schema. |
|
|
2. **Vector Column**: You MUST identify the column with type 'vector(...)'. |
|
|
3. **Query Text**: |
|
|
- If the user provides a search query (e.g., "find shoes"), use it. |
|
|
- If the prompt is generic (e.g., "query vector"), use the **entire user prompt** as the query text. |
|
|
- NEVER leave this empty. |
|
|
4. **Filtering**: Only filter on 'Valid Filters'. If a requested filter is missing, ignore it and note in reasoning. |
|
|
|
|
|
**OUTPUT FORMAT:** |
|
|
Return ONLY a valid JSON object matching this structure: |
|
|
{json.dumps(lean_template, indent=2)} |
|
|
""" |
|
|
|
|
|
if feedback: |
|
|
system_prompt += f"\nπ¨ FIX PREVIOUS ERROR: {feedback}" |
|
|
|
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": f"Request ID: {payload.get('request_id')}"} |
|
|
], |
|
|
temperature=0, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
generation_time_ms = int((end_time - start_time) * 1000) |
|
|
input_tokens = completion.usage.prompt_tokens |
|
|
output_tokens = completion.usage.completion_tokens |
|
|
|
|
|
|
|
|
lean_response = clean_and_parse_json(completion.choices[0].message.content) |
|
|
vs_config = lean_response.get("vector_search_config", {}) |
|
|
|
|
|
|
|
|
query_text = vs_config.get("query_text", payload.get('user_prompt')) |
|
|
|
|
|
final_plan = { |
|
|
"request_id": payload.get("request_id"), |
|
|
"execution_id": payload.get("execution_id", f"exec-{payload.get('request_id')}"), |
|
|
"plan_id": f"plan-{int(time.time())}", |
|
|
"status": "success", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"intent_type": "query", |
|
|
"intent_summary": lean_response.get("intent_summary", "Vector Search"), |
|
|
"execution_plan": { |
|
|
"strategy": "pushdown", |
|
|
"type": "vector_search", |
|
|
"operations": [ |
|
|
{ |
|
|
"step": 1, |
|
|
"step_id": "op-1", |
|
|
"operation_type": "read", |
|
|
"type": "source_query", |
|
|
"description": lean_response.get("intent_summary"), |
|
|
"data_source_id": ds_id, |
|
|
"compute_type": "source_native", |
|
|
"compute_engine": db_type, |
|
|
"dependencies": [], |
|
|
"query": f"search('{query_text}', k={vs_config.get('top_k', 10)})", |
|
|
"query_payload": { |
|
|
"language": "vector", |
|
|
"dialect": None, |
|
|
"statement": f"search('{query_text}')", |
|
|
|
|
|
"parameters": { |
|
|
"index_name": vs_config.get("index_name"), |
|
|
"vector_column": vs_config.get("vector_column"), |
|
|
"query_vector_text": query_text, |
|
|
"top_k": vs_config.get("top_k", 10), |
|
|
"filters": vs_config.get("filters", {}), |
|
|
"search_params": { |
|
|
"metric": "cosine", |
|
|
"queries": [query_text] |
|
|
} |
|
|
} |
|
|
}, |
|
|
"governance_applied": { |
|
|
"rls_rules": [], |
|
|
"masking_rules": [] |
|
|
}, |
|
|
"output_artifact": "similar_vectors" |
|
|
} |
|
|
] |
|
|
}, |
|
|
"visualization": None, |
|
|
"ai_metadata": { |
|
|
"model": config.MODEL_NAME, |
|
|
"input_tokens": input_tokens, |
|
|
"output_tokens": output_tokens, |
|
|
"generation_time_ms": generation_time_ms, |
|
|
"confidence": 0.95, |
|
|
"confidence_score": 0.95, |
|
|
"explanation": "Performed vector similarity search using the provided schema.", |
|
|
"reasoning_steps": lean_response.get("reasoning_steps", []) |
|
|
}, |
|
|
"suggestions": lean_response.get("suggestions", []) |
|
|
} |
|
|
|
|
|
return final_plan |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Vector Agent Failed: {e}", exc_info=True) |
|
|
return {"error": f"Vector Agent Failed: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def multi_source_agent(payload: Dict[str, Any], feedback: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Step 3/4 (Branch B): Generates a Hybrid/Federated Execution Plan. |
|
|
Hardened for System Table Injection and Multi-Hop Joins. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"π [Multi-Source Agent] Generating hybrid plan... Feedback: {bool(feedback)}") |
|
|
|
|
|
try: |
|
|
|
|
|
data_sources = payload.get('data_sources', []) |
|
|
user_context = payload.get('user_context', {}) |
|
|
user_id = user_context.get("user_id", 0) |
|
|
|
|
|
context_vars = { |
|
|
"user_id": user_id, |
|
|
"org_id": user_context.get("organization_id"), |
|
|
"attributes": user_context.get("attributes", {}) |
|
|
} |
|
|
|
|
|
schema_summary = [] |
|
|
governance_instructions = [] |
|
|
source_map = {} |
|
|
|
|
|
for ds in data_sources: |
|
|
ds_id = ds.get('data_source_id') |
|
|
ds_name = ds.get('name') |
|
|
ds_type = ds.get('type') |
|
|
source_map[ds_name] = ds_id |
|
|
|
|
|
|
|
|
schemas_list = ds.get('schemas') or [] |
|
|
file_meta = ds.get('file_metadata') or {} |
|
|
file_schemas = file_meta.get('schemas') or [] |
|
|
curr_schemas = schemas_list + file_schemas |
|
|
|
|
|
for schema in curr_schemas: |
|
|
for table in schema.get('tables', []) or []: |
|
|
t_name = table.get('table_name') |
|
|
cols_data = table.get('columns') or [] |
|
|
cols = [c.get('column_name') for c in cols_data if c.get('column_name')] |
|
|
|
|
|
if cols: |
|
|
schema_summary.append(f"SOURCE (ID {ds_id}) [{ds_type}] '{ds_name}' -> Table: {t_name} | Columns: {', '.join(cols)}") |
|
|
|
|
|
|
|
|
policies = ds.get('governance_policies', {}) |
|
|
if policies: |
|
|
rls = policies.get("row_level_security", {}) |
|
|
if rls.get("enabled"): |
|
|
|
|
|
governance_instructions.append( |
|
|
f"β οΈ RLS FOR '{ds_name}': You must filter by region. " |
|
|
f"DO NOT query 'user_access' table directly. " |
|
|
f"Instead, INJECT the literal value: `region IN (SELECT region FROM (VALUES ('US-East'), ('EU-West')) AS user_access(region))` " |
|
|
f"OR simply `region = 'US-East'` based on context." |
|
|
) |
|
|
|
|
|
|
|
|
lean_template = { |
|
|
"intent_summary": "<<BRIEF_SUMMARY>>", |
|
|
"intent_type": "join", |
|
|
"confidence_score": 0.0, |
|
|
"execution_plan": { |
|
|
"strategy": "hybrid", |
|
|
"type": "trino_sql", |
|
|
"operations": [ |
|
|
{ |
|
|
"step": 1, |
|
|
"step_id": "<<UNIQUE_ID>>", |
|
|
"operation_type": "read", |
|
|
"type": "source_query", |
|
|
"description": "<<DESC>>", |
|
|
"data_source_id": 1, |
|
|
"compute_type": "source_native", |
|
|
"compute_engine": "<<ENGINE>>", |
|
|
"dependencies": [], |
|
|
"query": "<<SQL_QUERY>>", |
|
|
"query_payload": { |
|
|
"language": "sql", |
|
|
"dialect": "<<DIALECT>>", |
|
|
"statement": "<<SQL_QUERY>>", |
|
|
"parameters": [] |
|
|
}, |
|
|
"governance_applied": { |
|
|
"rls_rules": ["<<RULE_NAME>>"], |
|
|
"masking_rules": [] |
|
|
}, |
|
|
"output_artifact": "<<ARTIFACT_NAME>>" |
|
|
} |
|
|
] |
|
|
}, |
|
|
"reasoning_steps": ["<<STEP_1>>", "<<STEP_2>>"], |
|
|
"suggestions": ["<<SUGGESTION>>"], |
|
|
"visualization_config": [] |
|
|
} |
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are the **Multi-Source Agent** for RiverGen AI. |
|
|
|
|
|
**OBJECTIVE:** |
|
|
Generate a **Hybrid Execution Plan** to federate data. |
|
|
|
|
|
**INPUT CONTEXT:** |
|
|
- Schema: {chr(10).join(schema_summary)} |
|
|
- Governance: {chr(10).join(governance_instructions) if governance_instructions else "None."} |
|
|
- Literals: {json.dumps(context_vars)} |
|
|
|
|
|
**CRITICAL RULES:** |
|
|
1. **Topology Check**: |
|
|
- If `Orders` table lacks `product_id`, DO NOT join it to `Products`. |
|
|
- Instead, calculate "Customer Metrics" (Orders+Customers) and "Product Metrics" (Sales+Products) as **separate operations**. |
|
|
|
|
|
2. **System Tables**: |
|
|
- Replace `user_access` with the literal values provided in context (e.g., `WHERE region = '...'`). |
|
|
|
|
|
3. **Addressing**: |
|
|
- Use Fully Qualified Names: `datasource_name.schema_name.table_name` (e.g. `postgresql_production.public.customers`). |
|
|
|
|
|
**OUTPUT FORMAT:** |
|
|
Return ONLY a valid JSON object matching the Lean Template exactly. |
|
|
{json.dumps(lean_template, indent=2)} |
|
|
""" |
|
|
|
|
|
if feedback: |
|
|
system_prompt += f"\nπ¨ FIX PREVIOUS ERROR: {feedback}" |
|
|
|
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": f"Request ID: {payload.get('request_id')}"} |
|
|
], |
|
|
temperature=0, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
generation_time_ms = int((end_time - start_time) * 1000) |
|
|
input_tokens = completion.usage.prompt_tokens |
|
|
output_tokens = completion.usage.completion_tokens |
|
|
|
|
|
|
|
|
lean_response = clean_and_parse_json(completion.choices[0].message.content) |
|
|
|
|
|
|
|
|
ai_confidence = lean_response.get("confidence_score", 0.0) |
|
|
viz_config = lean_response.get("visualization_config") |
|
|
if not isinstance(viz_config, list): |
|
|
viz_config = [] |
|
|
|
|
|
final_plan = { |
|
|
"request_id": payload.get("request_id"), |
|
|
"execution_id": payload.get("execution_id", f"exec-{payload.get('request_id')}"), |
|
|
"plan_id": f"plan-{int(time.time())}", |
|
|
"status": "success", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"intent_type": lean_response.get("intent_type", "join"), |
|
|
"intent_summary": lean_response.get("intent_summary", ""), |
|
|
"execution_plan": lean_response.get("execution_plan", {}), |
|
|
"visualization": viz_config, |
|
|
"ai_metadata": { |
|
|
"model": config.MODEL_NAME, |
|
|
"input_tokens": input_tokens, |
|
|
"output_tokens": output_tokens, |
|
|
"generation_time_ms": generation_time_ms, |
|
|
"confidence": ai_confidence, |
|
|
"confidence_score": ai_confidence, |
|
|
"explanation": lean_response.get("intent_summary"), |
|
|
"reasoning_steps": lean_response.get("reasoning_steps", []) |
|
|
}, |
|
|
"suggestions": lean_response.get("suggestions", []) |
|
|
} |
|
|
|
|
|
return final_plan |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Multi-Source Agent Failed: {e}", exc_info=True) |
|
|
return {"error": f"Multi-Source Agent Failed: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def llm_judge(original_payload: Dict[str, Any], generated_plan: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Step 5: Universal Quality Gate. |
|
|
Dynamically applies specialized validation rules for SQL, NoSQL, Vector, Stream, ML, or Generic plans. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
try: |
|
|
|
|
|
execution_plan = generated_plan.get("execution_plan", {}) |
|
|
plan_type = execution_plan.get("type", "unknown").lower() |
|
|
|
|
|
|
|
|
data_sources = original_payload.get("data_sources", []) |
|
|
valid_schema_context = [] |
|
|
|
|
|
for ds in data_sources: |
|
|
ds_name = ds.get("name") |
|
|
ds_id = ds.get("data_source_id") |
|
|
|
|
|
|
|
|
|
|
|
schemas = ds.get("schemas") or [] |
|
|
|
|
|
|
|
|
if not schemas: |
|
|
file_meta = ds.get("file_metadata") or {} |
|
|
schemas = file_meta.get("schemas") or [] |
|
|
|
|
|
for schema in schemas: |
|
|
tables = schema.get("tables") or [] |
|
|
for table in tables: |
|
|
valid_schema_context.append({ |
|
|
"data_source_id": ds_id, |
|
|
"source": ds_name, |
|
|
"object": table.get("table_name"), |
|
|
"columns": [c['column_name'].lower() for c in (table.get('columns') or [])] |
|
|
}) |
|
|
|
|
|
|
|
|
topics = ds.get("topics") or [] |
|
|
for topic in topics: |
|
|
valid_schema_context.append({ |
|
|
"data_source_id": ds_id, |
|
|
"source": ds_name, |
|
|
"object": topic.get("topic_name"), |
|
|
"columns": [f['field_name'].lower() for f in (topic.get('fields') or [])] |
|
|
}) |
|
|
|
|
|
|
|
|
valid_schema_context.append({ |
|
|
"source": "SYSTEM_SECURITY", |
|
|
"object": "user_access", |
|
|
"columns": ["user_id", "region", "role", "permissions", "organization_id"] |
|
|
}) |
|
|
|
|
|
|
|
|
multi_source_judge_prompt = f""" |
|
|
You are the **Multi-Source Federation Judge** for RiverGen AI. |
|
|
|
|
|
|
|
|
You validate federated execution plans that combine data across SQL databases, NoSQL databases, and cloud storage (S3, Parquet, Snowflake, etc.). |
|
|
|
|
|
INPUT: |
|
|
1. User Prompt: |
|
|
"{original_payload.get("user_prompt")}" |
|
|
2. Valid Schema (Queryable Sources): |
|
|
{json.dumps(valid_schema_context)} |
|
|
3. Proposed Execution Plan: |
|
|
{json.dumps(generated_plan, indent=2)} |
|
|
|
|
|
RULES: |
|
|
|
|
|
βββββββββββββββββββββββββββββ |
|
|
1) SCHEMA AUTHORITY & HALLUCINATION |
|
|
βββββββββββββββββββββββββββββ |
|
|
- All table references MUST exist in Valid Schema. |
|
|
- SQL or query references to unknown tables/columns β REJECT. |
|
|
- Fully Qualified Names (FQN) required for SQL: `source.schema.table` or aliased equivalent. |
|
|
- S3/NoSQL object references must match provided schema/path exactly. |
|
|
- If a source is claimed as dropped, it MUST NOT appear in any query. |
|
|
|
|
|
βββββββββββββββββββββββββββββ |
|
|
2) DIALECT & SYNTAX COMPLIANCE |
|
|
βββββββββββββββββββββββββββββ |
|
|
- SQL queries must be valid for their declared dialect (PostgreSQL, MySQL, Trino, etc.). |
|
|
- No database-specific proprietary constructs (PL/SQL, T-SQL) unless wrapped in pass-through. |
|
|
- No unsafe operations (e.g., unqualified cross joins, unsupported NoSQL filters). |
|
|
|
|
|
βββββββββββββββββββββββββββββ |
|
|
3) GOVERNANCE & RLS (CRITICAL UPDATE) |
|
|
βββββββββββββββββββββββββββββ |
|
|
- RLS, masking, or row-level filters must be applied where required. |
|
|
- **VALIDATION EXCEPTION**: If the plan replaces a system table reference (e.g., `user_access`) with a **Literal Filter** (e.g., `WHERE region = 'US-East'`) or a **CTE/VALUES clause**, this IS VALID. Do NOT reject it for missing the system table. |
|
|
- Enforcement should be pushed down into the query if supported. |
|
|
- If RLS is missing for a source that requires it β REJECT. |
|
|
|
|
|
βββββββββββββββββββββββββββββ |
|
|
4) FEDERATION & JOIN LOGIC |
|
|
βββββββββββββββββββββββββββββ |
|
|
- **Topology Check**: Do NOT allow joins if the schema does not support them (e.g., joining `Orders` to `Products` without a `product_id` key). |
|
|
- **No Cross Joins**: Unqualified joins (Cartesian products) are strictly FORBIDDEN. |
|
|
- If no join key exists, the plan MUST generate separate operations or use `"SAFE_PARTIAL": true` and document in `limitations`. |
|
|
- Metrics requested by the user must be computed when possible; otherwise, explain in `limitations`. |
|
|
|
|
|
βββββββββββββββββββββββββββββ |
|
|
5) DROPPED & PARTIAL SOURCES |
|
|
βββββββββββββββββββββββββββββ |
|
|
- If a source cannot be queried (schema missing, unsupported type), it must be listed in `dropped_sources`. |
|
|
- Limitations or partial results must be documented in `validation.notes` or `limitations`. |
|
|
|
|
|
βββββββββββββββββββββββββββββ |
|
|
REQUIRED OUTPUT |
|
|
βββββββββββββββββββββββββββββ |
|
|
Return ONLY JSON matching this structure exactly: |
|
|
{{ |
|
|
"approved": boolean, |
|
|
"feedback": "string", |
|
|
"score": float, |
|
|
"governance_enforcement": {{ }}, |
|
|
"validation": {{ |
|
|
"missing_fields": [], |
|
|
"dropped_sources": [], |
|
|
"notes": [], |
|
|
"performance_warnings": [] |
|
|
}} |
|
|
}} |
|
|
Do NOT include any extra text. |
|
|
""" |
|
|
|
|
|
vector_judge_prompt = f""" |
|
|
You are the **Vector Store Judge** for RiverGen AI. You validate vector similarity search plans (Pinecone, Weaviate, etc.). |
|
|
|
|
|
INPUT: |
|
|
1. User Prompt: |
|
|
"{original_payload.get("user_prompt")}" |
|
|
2. Valid Schema (indexes and vector columns): |
|
|
{json.dumps(valid_schema_context)} |
|
|
3. Proposed Execution Plan: |
|
|
{json.dumps(generated_plan, indent=2)} |
|
|
|
|
|
RULES: |
|
|
1) REQUIRED VECTOR PARAMETERS: |
|
|
- `index_name` and `vector_column` must exist in Valid Schema. |
|
|
- `search_params` must include: |
|
|
* `metric` (cosine, euclidean, etc.) |
|
|
* `queries` (non-empty array) OR `embedding_required = true` |
|
|
* `top_k` (positive integer) |
|
|
- `query_vector` may be empty only if `embedding_required = true`. |
|
|
|
|
|
2) METADATA FILTERS: |
|
|
- Only allowed fields from Valid Schema. |
|
|
- Document any omitted filters in `validation.notes`. |
|
|
|
|
|
3) GOVERNANCE: |
|
|
- RLS/masking must be applied if defined in schema. |
|
|
|
|
|
4) SAFE_PARTIAL: |
|
|
- Approve if query returns safe fields and missing fields are documented. |
|
|
|
|
|
OUTPUT: |
|
|
Return ONLY JSON: |
|
|
{{ |
|
|
"approved": boolean, |
|
|
"feedback": "string", |
|
|
"score": float, |
|
|
"governance_enforcement": {{ }}, |
|
|
"validation": {{ |
|
|
"missing_fields": [], |
|
|
"dropped_sources": [], |
|
|
"notes": [], |
|
|
"performance_warnings": [] |
|
|
}} |
|
|
}} |
|
|
No extra text. |
|
|
""" |
|
|
|
|
|
nosql_judge_prompt = f""" |
|
|
You are the **NoSQL Quality Assurance Judge** for RiverGen AI. You validate NoSQL execution plans (MongoDB, DynamoDB, Redis, Elasticsearch). |
|
|
|
|
|
INPUT: |
|
|
1. User Prompt: |
|
|
"{original_payload.get("user_prompt")}" |
|
|
2. Valid Schema (collections/tables & fields): |
|
|
{json.dumps(valid_schema_context)} |
|
|
3. Proposed Execution Plan: |
|
|
{json.dumps(generated_plan, indent=2)} |
|
|
|
|
|
RULES: |
|
|
1) HALLUCINATION CHECK: |
|
|
- Any collection/table/field not in Valid Schema β REJECT. |
|
|
- Include step_id in feedback. |
|
|
|
|
|
2) DIALECT-SPECIFIC VALIDATION: |
|
|
- MongoDB: `find`/`aggregate` must be valid JSON-like docs. |
|
|
- DynamoDB: Check KeyConditionExpression, FilterExpression. |
|
|
- Redis/FT.SEARCH: Index names and field filters must exist. |
|
|
- Elasticsearch: JSON DSL must be valid. |
|
|
|
|
|
3) GOVERNANCE: |
|
|
- RLS/masking enforcement must be documented if applicable. |
|
|
|
|
|
4) SAFE_PARTIAL: |
|
|
- Approve if only safe fields are returned and missing fields documented. |
|
|
|
|
|
OUTPUT: |
|
|
Return ONLY JSON: |
|
|
{{ |
|
|
"approved": boolean, |
|
|
"feedback": "string", |
|
|
"score": float, |
|
|
"governance_enforcement": {{ }}, |
|
|
"validation": {{ |
|
|
"missing_fields": [], |
|
|
"dropped_sources": [], |
|
|
"notes": [], |
|
|
"performance_warnings": [] |
|
|
}} |
|
|
}} |
|
|
No extra text. |
|
|
""" |
|
|
|
|
|
sql_judge_prompt = f""" |
|
|
You are the **SQL Quality Assurance Judge** for RiverGen AI. You validate SQL execution plans for correctness, safety, and schema alignment. |
|
|
|
|
|
INPUT: |
|
|
1. User Prompt: |
|
|
"{original_payload.get("user_prompt")}" |
|
|
2. Valid Schema (tables & columns): |
|
|
{json.dumps(valid_schema_context)} |
|
|
3. Proposed Execution Plan: |
|
|
{json.dumps(generated_plan, indent=2)} |
|
|
4. Target Data Source Engine: |
|
|
"{generated_plan.get('compute_engine')}" # e.g., postgres, mysql, oracle, sqlserver, cassandra |
|
|
|
|
|
RULES: |
|
|
1) HALLUCINATION CHECK: |
|
|
- Any table/column not in Valid Schema β REJECT. |
|
|
- Include step_id in feedback. |
|
|
|
|
|
2) SYNTAX & DIALECT CHECK: |
|
|
- SQL must be valid for the declared engine/dialect. |
|
|
- PostgreSQL: standard SQL, interval/date syntax. |
|
|
- MySQL: use `LIMIT`, backticks if needed. |
|
|
- Oracle: use `SYSDATE`, `INTERVAL`, JSON_ARRAYAGG/JSON_OBJECT for nested data. |
|
|
- SQL Server: use `GETDATE()`, `DATEADD`, JSON functions for nesting. |
|
|
- Cassandra CQL: `ALLOW FILTERING` flagged as performance risk. |
|
|
|
|
|
- If the SQL uses syntax from a different engine than the data source β REJECT. |
|
|
- Provide specific feedback on syntax errors or dialect mismatches. |
|
|
|
|
|
3) GOVERNANCE: |
|
|
- Confirm RLS or masking is applied if defined. |
|
|
- If policy references missing objects, accept only if documented. |
|
|
|
|
|
4) PARTIAL DATA: |
|
|
- Approve if safe and explain missing fields in `validation.missing_fields`. |
|
|
- Include notes for performance issues or risky operations. |
|
|
|
|
|
OUTPUT: |
|
|
Return ONLY a JSON object: |
|
|
{{ |
|
|
"approved": boolean, |
|
|
"feedback": "string", |
|
|
"score": float, |
|
|
"governance_enforcement": {{ }}, |
|
|
"validation": {{ |
|
|
"missing_fields": [], |
|
|
"dropped_sources": [], |
|
|
"notes": [], |
|
|
"performance_warnings": [] |
|
|
}} |
|
|
}} |
|
|
Do NOT include any extra text. |
|
|
""" |
|
|
|
|
|
|
|
|
judge_output_schema = { |
|
|
"approved": "boolean", |
|
|
"score": "float", |
|
|
"feedback": "string", |
|
|
"validation": { |
|
|
"feature_issues": [], |
|
|
"execution_issues": [], |
|
|
"ml_best_practice_violations": [], |
|
|
"notes": [] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ml_judge_prompt = f""" |
|
|
You are the **RiverGen ML Quality Assurance Judge**. |
|
|
|
|
|
You validate ML execution plans for: |
|
|
- correctness |
|
|
- ML best practices |
|
|
- execution safety |
|
|
- schema alignment |
|
|
|
|
|
Your decision is FINAL. |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
INPUTS |
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
1. User Prompt: |
|
|
"{original_payload.get("user_prompt")}" |
|
|
|
|
|
2. Valid Data Schema: |
|
|
{json.dumps(valid_schema_context)} |
|
|
|
|
|
3. Proposed ML Execution Plan: |
|
|
{json.dumps(generated_plan, indent=2)} |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
VALIDATION RULES (HARD FAILS) |
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
### 1οΈβ£ Feature / Label Validation |
|
|
REJECT if: |
|
|
- Target column appears in features |
|
|
- ID / primary key is used as a feature without justification |
|
|
- Features or labels do not exist in schema |
|
|
|
|
|
### 2οΈβ£ Strategy Validation |
|
|
REJECT if: |
|
|
- CSV/file-based workflows use anything other than `sequential_dag` |
|
|
- Distributed strategy used without dataset size justification |
|
|
|
|
|
### 3οΈβ£ Execution Correctness |
|
|
REJECT if: |
|
|
- DuckDB queries reference CSVs as tables |
|
|
- `read_csv_auto()` (or equivalent) is NOT used for CSV ingestion |
|
|
- SQL syntax is invalid for the declared engine |
|
|
|
|
|
### 4οΈβ£ Compute Engine Validation |
|
|
REJECT if: |
|
|
- Pandas is used as a model training engine |
|
|
- ML training lacks a defined ML framework (e.g., sklearn) |
|
|
|
|
|
### 5οΈβ£ Preprocessing Completeness |
|
|
REJECT if: |
|
|
- Missing value handling is absent |
|
|
- Scaling/normalization is missing for numeric features |
|
|
- Train/test split is missing or ambiguous |
|
|
|
|
|
### 6οΈβ£ Metrics Enforcement |
|
|
REJECT if: |
|
|
- Regression tasks do not include BOTH RMSE and RΒ² |
|
|
- Classification tasks do not include Precision, Recall, F1, AUC-ROC |
|
|
|
|
|
### 7οΈβ£ Artifact & Reproducibility |
|
|
REJECT if: |
|
|
- Model output path is missing |
|
|
- Evaluation report path is missing |
|
|
- random_state is missing for splits |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
SCORING GUIDELINES |
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
- 1.0 β Production-ready, fully correct |
|
|
- 0.8β0.9 β Minor issues, safe to auto-fix |
|
|
- <0.8 β Must be regenerated |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
OUTPUT FORMAT (JSON ONLY) |
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
Return ONLY: |
|
|
{json.dumps(judge_output_schema, indent=2)} |
|
|
|
|
|
NO extra text. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
general_qa_judge_prompt = f""" |
|
|
You are the **Quality Assurance Judge** for RiverGen AI. Evaluate any execution plan (SQL, NoSQL, vector) for: |
|
|
- Schema compliance |
|
|
- Hallucinations |
|
|
- Governance & RLS enforcement |
|
|
- Dialect-specific syntax |
|
|
- Performance & safety |
|
|
- Partial safe fulfillment |
|
|
|
|
|
INPUT: |
|
|
1. User Prompt: |
|
|
"{original_payload.get("user_prompt")}" |
|
|
2. Valid Schema: |
|
|
{json.dumps(valid_schema_context)} |
|
|
3. Proposed Execution Plan: |
|
|
{json.dumps(generated_plan, indent=2)} |
|
|
|
|
|
RULES: |
|
|
1) Any reference to non-existent table/collection/column β reject. |
|
|
2) Vector operations must include index_name, vector_column, top_k, and queries or embedding_required. |
|
|
3) SQL/NoSQL syntax must match the target engine. |
|
|
4) Governance policies must be enforced or documented if omitted. |
|
|
5) Safe partial plans are approvable with missing fields documented. |
|
|
6) Risky operations (full scans, ALLOW FILTERING, large top_k) must include performance warnings. |
|
|
|
|
|
OUTPUT (STRICT JSON): |
|
|
{{ |
|
|
"approved": boolean, |
|
|
"feedback": "string", |
|
|
"score": float, |
|
|
"governance_enforcement": {{ }}, |
|
|
"validation": {{ |
|
|
"missing_fields": [], |
|
|
"dropped_sources": [], |
|
|
"notes": [], |
|
|
"performance_warnings": [] |
|
|
}} |
|
|
}} |
|
|
Do NOT include any text outside the JSON. |
|
|
""" |
|
|
|
|
|
|
|
|
if plan_type == "vector_search": |
|
|
logger.info("π§ Using Vector Store Judge Prompt") |
|
|
system_prompt = vector_judge_prompt |
|
|
elif plan_type == "nosql_query": |
|
|
logger.info("π§ Using NoSQL Judge Prompt") |
|
|
system_prompt = nosql_judge_prompt |
|
|
elif plan_type == "trino_sql": |
|
|
logger.info("π§ Using Multi-Source Judge Prompt") |
|
|
system_prompt = multi_source_judge_prompt |
|
|
elif plan_type == "sql_query": |
|
|
logger.info("π§ Using SQL Judge Prompt") |
|
|
system_prompt = sql_judge_prompt |
|
|
elif plan_type == "ml_workflow": |
|
|
logger.info("π§ Using ML Judge Prompt") |
|
|
system_prompt = ml_judge_prompt |
|
|
else: |
|
|
logger.info("π§ Using General QA Judge Prompt") |
|
|
system_prompt = general_qa_judge_prompt |
|
|
|
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[{"role": "system", "content": system_prompt}], |
|
|
temperature=0, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
|
|
|
result = clean_and_parse_json(completion.choices[0].message.content) |
|
|
|
|
|
|
|
|
result["usage"] = { |
|
|
"input_tokens": completion.usage.prompt_tokens, |
|
|
"output_tokens": completion.usage.completion_tokens, |
|
|
"total_tokens": completion.usage.total_tokens |
|
|
} |
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Judge Logic Error: {e}", exc_info=True) |
|
|
|
|
|
return { |
|
|
"approved": False, |
|
|
"feedback": f"Judge Logic Error: {str(e)}", |
|
|
"usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def nosql_agent(payload: Dict[str, Any], feedback: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Step 3/4: Generates a RiverGen Execution Plan for NoSQL Databases. |
|
|
Supported: MongoDB, Redis, Cassandra, DynamoDB. |
|
|
Hardened for Strict Schema Enforcement and Token Optimization. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"π¦ [NoSQL Agent] Generating optimized plan... Feedback: {bool(feedback)}") |
|
|
|
|
|
try: |
|
|
|
|
|
data_sources = payload.get("data_sources", []) |
|
|
primary_ds = data_sources[0] if data_sources else {} |
|
|
ds_id = primary_ds.get("data_source_id") |
|
|
ds_name = primary_ds.get("name") |
|
|
db_type = primary_ds.get("type", "generic_nosql").lower() |
|
|
|
|
|
|
|
|
exec_ctx = payload.get("execution_context", {}) |
|
|
max_rows = exec_ctx.get("max_rows", 1000) |
|
|
|
|
|
|
|
|
schema_summary = [] |
|
|
known_fields = set() |
|
|
|
|
|
|
|
|
schemas = primary_ds.get("schemas") or [] |
|
|
|
|
|
for schema in schemas: |
|
|
for table in schema.get("tables", []) or []: |
|
|
fields = [] |
|
|
cols_data = table.get("columns") or [] |
|
|
|
|
|
for col in cols_data: |
|
|
c_name = col.get('column_name') |
|
|
c_type = col.get('column_type', 'unknown') |
|
|
if c_name: |
|
|
fields.append(f"{c_name} ({c_type})") |
|
|
known_fields.add(c_name.lower()) |
|
|
|
|
|
schema_summary.append( |
|
|
f"Collection/Key: {table.get('table_name')} | Fields: {', '.join(fields)}" |
|
|
) |
|
|
|
|
|
|
|
|
governance_instructions = [] |
|
|
policies = primary_ds.get("governance_policies", {}) |
|
|
if policies: |
|
|
|
|
|
masking = policies.get("column_masking", {}) |
|
|
if masking.get("enabled"): |
|
|
governance_instructions.append( |
|
|
f"β οΈ MASKING REQUIRED: You must exclude or mask these fields if present: {masking.get('rules', 'See Schema')}" |
|
|
) |
|
|
|
|
|
|
|
|
lean_template = { |
|
|
"intent_summary": "<<BRIEF_SUMMARY>>", |
|
|
"nosql_statement": "<<VALID_QUERY_STRING>>", |
|
|
"validation": { |
|
|
"schema_matches": True, |
|
|
"missing_fields": ["<<FIELD_NOT_IN_SCHEMA>>"], |
|
|
"notes": ["<<EXPLAIN_OMISSIONS>>"] |
|
|
}, |
|
|
"governance_applied": { |
|
|
"rls_rules": [], |
|
|
"masking_rules": ["<<APPLIED_MASKING>>"] |
|
|
}, |
|
|
"confidence_score": 0.0, |
|
|
"reasoning_steps": ["<<STEP_1>>", "<<STEP_2>>"], |
|
|
"suggestions": ["<<Q1>>"] |
|
|
} |
|
|
|
|
|
system_prompt = f""" |
|
|
You are the **NoSQL Agent** for RiverGen AI. |
|
|
|
|
|
OBJECTIVE: |
|
|
Generate a valid, safe, and auditable query for a **{db_type.upper()}** NoSQL database (Cassandra, MongoDB, DynamoDB, Redis, Elasticsearch, etc.) based on the user prompt and the available schema. |
|
|
|
|
|
INPUT CONTEXT: |
|
|
- User Prompt: "{payload.get('user_prompt')}" |
|
|
- Max Rows: {max_rows} |
|
|
- AVAILABLE SCHEMA: |
|
|
{chr(10).join(schema_summary) if schema_summary else "No schema provided."} |
|
|
- GOVERNANCE: |
|
|
{chr(10).join(governance_instructions) if governance_instructions else "No active policies."} |
|
|
|
|
|
STRICT RULES (MANDATORY) |
|
|
1. SCHEMA AUTHORITY (ABSOLUTE): |
|
|
- You MUST NOT reference any collection/table/field that does not appear in AVAILABLE SCHEMA. |
|
|
- If the user asks for an object not present, add it to `validation.missing_fields`. |
|
|
- Do NOT invent nested structures or relationships. |
|
|
|
|
|
2. QUERYABILITY & DROPPED SOURCES: |
|
|
- If a source or collection exists in payload but is NOT present in AVAILABLE SCHEMA, treat it as NON-QUERYABLE. |
|
|
- Do NOT generate queries against non-queryable sources; instead, list them under `validation.dropped_sources` and explain why. |
|
|
|
|
|
3. DIALECT-SPECIFIC SYNTAX (EXAMPLES β obey exact dialect): |
|
|
- **MongoDB**: Use `db.collection.find({...})` or aggregation pipeline `db.collection.aggregate([...])`. |
|
|
- **Cassandra**: Use CQL `SELECT ... FROM keyspace.table WHERE ...;` and **avoid** `ALLOW FILTERING` where possible; if used, add a `performance_warnings` note. |
|
|
- **DynamoDB**: Use the expression-style syntax appropriate for DynamoDB (e.g., KeyConditionExpression, FilterExpression). |
|
|
- **Redis (Search)**: Use `FT.SEARCH index "query" FILTER ...` or appropriate native commands. |
|
|
- **Elasticsearch**: Use a JSON DSL query body with `match`, `bool`, `range`, etc. |
|
|
|
|
|
4. DEGRADATION & PARTIAL FULFILLMENT: |
|
|
- If the full user intent is impossible (missing fields/tables), produce: |
|
|
a) A best-effort query that returns whatever is available. |
|
|
b) `validation.missing_fields`: list of requested objects not present. |
|
|
c) `validation.notes`: human-readable explanation of what was omitted and why. |
|
|
d) `suggestions`: concrete next steps (e.g., "provide orders schema", "create secondary index on customer_id"). |
|
|
|
|
|
5. GOVERNANCE & RLS: |
|
|
- If governance_instructions reference tables/objects not in AVAILABLE SCHEMA: |
|
|
- Attempt literal substitution using Context Literals if present. |
|
|
- Otherwise, document omission under `validation.notes` and `governance_enforcement` with status `omitted`. |
|
|
- If RLS can be applied, show exact filter to be injected. |
|
|
|
|
|
6. TEMPORAL & METADATA MAPPING: |
|
|
- Map natural language time windows (e.g., "last 90 days") to explicit range filters using the available date/time fields. |
|
|
- If no date field exists, include a `validation.notes` entry explaining inability to apply time filter. |
|
|
|
|
|
7. PERFORMANCE & SAFETY: |
|
|
- Flag expensive patterns (Cassandra `ALLOW FILTERING`, unbounded scans, missing indexes) in `performance_warnings`. |
|
|
- Prefer query patterns that respect partition/primary keys for the given NoSQL engine. |
|
|
|
|
|
8. OUTPUT STRUCTURE (MANDATORY): |
|
|
- Return ONLY a JSON object that matches the provided lean template exactly. |
|
|
- The JSON MUST include a `validation` block with: |
|
|
- `missing_fields`: [], |
|
|
- `dropped_sources`: [], |
|
|
- `notes`: [], |
|
|
- `performance_warnings`: [] |
|
|
- Also provide `governance_enforcement` and `suggestions`. |
|
|
|
|
|
9. TRANSPARENCY: |
|
|
- If you cannot compute an aggregate (e.g., Lifetime Value) due to missing data, do NOT attempt to compute it; instead add a clear explanation and a suggested data requirement. |
|
|
|
|
|
10. Do not use any placeholders like date use actual date functions or fixed dates. |
|
|
OUTPUT FORMAT: |
|
|
Return ONLY a valid JSON object matching this LEAN structure: |
|
|
{json.dumps(lean_template, indent=2)} |
|
|
""" |
|
|
|
|
|
|
|
|
if feedback: |
|
|
system_prompt += f"\nπ¨ FIX PREVIOUS ERROR: {feedback}" |
|
|
|
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": f"Request ID: {payload.get('request_id')}"} |
|
|
], |
|
|
temperature=0, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
end_time = time.time() |
|
|
generation_time_ms = int((end_time - start_time) * 1000) |
|
|
|
|
|
|
|
|
input_tokens = completion.usage.prompt_tokens |
|
|
output_tokens = completion.usage.completion_tokens |
|
|
|
|
|
|
|
|
lean_response = clean_and_parse_json(completion.choices[0].message.content) |
|
|
|
|
|
|
|
|
final_plan = { |
|
|
"request_id": payload.get("request_id"), |
|
|
"execution_id": payload.get("execution_id", f"exec-{payload.get('request_id')}"), |
|
|
"plan_id": f"plan-{int(time.time())}", |
|
|
"status": "success", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"intent_type": "query" if not lean_response.get("validation", {}).get("missing_fields") else "partial_query", |
|
|
"intent_summary": lean_response.get("intent_summary", "NoSQL Query Execution"), |
|
|
"execution_plan": { |
|
|
"strategy": "pushdown", |
|
|
"type": "nosql_query", |
|
|
"operations": [ |
|
|
{ |
|
|
"step": 1, |
|
|
"step_id": "op-1", |
|
|
"operation_type": "read", |
|
|
"type": "source_query", |
|
|
"description": lean_response.get("intent_summary"), |
|
|
"data_source_id": ds_id, |
|
|
"compute_type": "source_native", |
|
|
"compute_engine": db_type, |
|
|
"dependencies": [], |
|
|
"query": lean_response.get("nosql_statement"), |
|
|
"query_payload": { |
|
|
"language": db_type, |
|
|
"dialect": None, |
|
|
"statement": lean_response.get("nosql_statement"), |
|
|
"parameters": [] |
|
|
}, |
|
|
"governance_applied": lean_response.get("governance_applied", {}), |
|
|
"output_artifact": "result_cursor" |
|
|
} |
|
|
] |
|
|
}, |
|
|
"visualization": None, |
|
|
"ai_metadata": { |
|
|
"model": config.MODEL_NAME, |
|
|
"input_tokens": input_tokens, |
|
|
"output_tokens": output_tokens, |
|
|
"generation_time_ms": generation_time_ms, |
|
|
"confidence": lean_response.get("confidence_score", 0.0), |
|
|
"confidence_score": lean_response.get("confidence_score", 0.0), |
|
|
"explanation": lean_response.get("validation", {}).get("notes", ["Execution successful"])[0], |
|
|
"reasoning_steps": lean_response.get("reasoning_steps", []) |
|
|
}, |
|
|
"suggestions": lean_response.get("suggestions", []) |
|
|
} |
|
|
|
|
|
|
|
|
if lean_response.get("validation", {}).get("missing_fields"): |
|
|
final_plan["warnings"] = [ |
|
|
f"Missing fields: {', '.join(lean_response['validation']['missing_fields'])}" |
|
|
] |
|
|
|
|
|
return final_plan |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"NoSQL Agent Failed: {e}", exc_info=True) |
|
|
return {"error": f"NoSQL Agent Failed: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def big_data_agent(payload: Dict[str, Any], feedback: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Step 3/4: Generates a RiverGen Execution Plan for Big Data workloads. |
|
|
Handles Cloud Warehouses (Snowflake, BigQuery) and Data Lakes (S3, Parquet). |
|
|
Supports Self-Correction Loop via 'feedback'. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"π [Big Data Agent] Generating plan... Feedback: {bool(feedback)}") |
|
|
|
|
|
try: |
|
|
|
|
|
data_sources = payload.get('data_sources', []) |
|
|
governance_context = [] |
|
|
source_type_hint = "unknown" |
|
|
|
|
|
|
|
|
primary_ds_id = data_sources[0].get("data_source_id") if data_sources else None |
|
|
|
|
|
for ds in data_sources: |
|
|
|
|
|
ds_type = ds.get('type', 'unknown') |
|
|
ds_name = ds.get('name', 'Unknown Source') |
|
|
|
|
|
|
|
|
if ds_type in ['snowflake', 'bigquery', 'redshift', 's3', 'databricks']: |
|
|
source_type_hint = ds_type |
|
|
|
|
|
policies = ds.get('governance_policies') or {} |
|
|
if policies: |
|
|
governance_context.append(f"Source '{ds_name}': {json.dumps(policies)}") |
|
|
|
|
|
|
|
|
response_template = { |
|
|
"request_id": payload.get("request_id"), |
|
|
"status": "success", |
|
|
"intent_type": "query", |
|
|
"execution_plan": { |
|
|
"strategy": "pushdown", |
|
|
"type": "sql_query", |
|
|
"operations": [ |
|
|
{ |
|
|
"step": 1, |
|
|
"type": "source_query", |
|
|
"operation_type": "read", |
|
|
"data_source_id": primary_ds_id, |
|
|
"query": "SELECT ...", |
|
|
"query_payload": { |
|
|
"language": "sql", |
|
|
"dialect": "snowflake", |
|
|
"statement": "SELECT ..." |
|
|
}, |
|
|
"governance_applied": { |
|
|
"rls_rules": [], |
|
|
"masking_rules": [] |
|
|
} |
|
|
} |
|
|
] |
|
|
}, |
|
|
"ai_metadata": { |
|
|
"confidence_score": 0.0, |
|
|
"reasoning_steps": [] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are the **Big Data Agent** for RiverGen AI. |
|
|
|
|
|
[Image of cloud data warehouse architecture] |
|
|
|
|
|
|
|
|
**YOUR TASK:** |
|
|
Generate an optimized Execution Plan for a Big Data workload (Cloud Warehouse or Data Lake). |
|
|
|
|
|
**INPUT CONTEXT:** |
|
|
- User Prompt: "{payload.get('user_prompt')}" |
|
|
- Data Source Schema: {json.dumps(data_sources)} |
|
|
- Primary Source Type: "{source_type_hint}" |
|
|
|
|
|
**GOVERNANCE POLICIES (MUST ENFORCE):** |
|
|
{chr(10).join(governance_context) if governance_context else "No specific policies."} |
|
|
|
|
|
**DIALECT & OPTIMIZATION RULES:** |
|
|
1. **Snowflake**: Use `Snowflake` dialect. Support `QUALIFY`, `FLATTEN`, and strictly use defined database/schema names (e.g. `DB.SCHEMA.TABLE`). |
|
|
2. **BigQuery**: Use `BigQuery` standard SQL. Handle nested fields (`record.field`) if present. Use backticks for project.dataset.table. |
|
|
3. **Data Lakes (S3/ADLS/File)**: |
|
|
- Assume compute engine is **DuckDB** or **Trino**. |
|
|
- **Partition Pruning**: If the schema mentions `partition_columns`, YOU MUST filter by them in the `WHERE` clause if the prompt allows (e.g. "last 30 days" -> `date >= ...`). |
|
|
- Use file functions like `read_parquet('s3://...')` if applicable, or standard SQL if the view is abstracted. |
|
|
|
|
|
**OUTPUT FORMAT:** |
|
|
Return ONLY valid JSON matching the exact template below. Adjust `dialect` field based on the source type (e.g. 'snowflake', 'bigquery', 'duckdb'). |
|
|
|
|
|
**OUTPUT TEMPLATE:** |
|
|
{json.dumps(response_template, indent=2)} |
|
|
""" |
|
|
|
|
|
|
|
|
if feedback: |
|
|
system_prompt += f""" |
|
|
|
|
|
π¨ **CRITICAL: FIX PREVIOUS ERROR** π¨ |
|
|
Your previous plan was rejected by the QA Judge. |
|
|
**FEEDBACK:** "{feedback}" |
|
|
|
|
|
**INSTRUCTIONS FOR FIX:** |
|
|
- If you used the wrong dialect (e.g. BigQuery syntax on Snowflake), fix it. |
|
|
- If you missed a partition filter on a large table, ADD IT. |
|
|
- If you hallucinated a path or table, check the schema string again. |
|
|
""" |
|
|
|
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": f"Request ID: {payload.get('request_id')}"} |
|
|
], |
|
|
temperature=0, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
|
|
|
lean_response = clean_and_parse_json(completion.choices[0].message.content) |
|
|
|
|
|
|
|
|
generation_time_ms = int((time.time() - start_time) * 1000) |
|
|
|
|
|
|
|
|
if "ai_metadata" not in lean_response: |
|
|
lean_response["ai_metadata"] = {} |
|
|
|
|
|
lean_response["ai_metadata"]["generation_time_ms"] = generation_time_ms |
|
|
lean_response["ai_metadata"]["model"] = config.MODEL_NAME |
|
|
|
|
|
return lean_response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Big Data Agent Failed: {e}", exc_info=True) |
|
|
return {"error": f"Big Data Agent Failed: {str(e)}"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ml_agent(payload: Dict[str, Any], feedback: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Step 3/4: Generates a specialized ML Execution Plan. |
|
|
Orchestrates Feature Engineering, Pre-processing, Model Training, and Evaluation. |
|
|
""" |
|
|
|
|
|
client = get_groq_client() |
|
|
config = get_config() |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"π§ [ML Agent] Building ML Pipeline... Feedback: {bool(feedback)}") |
|
|
|
|
|
try: |
|
|
|
|
|
user_prompt = payload.get('user_prompt') |
|
|
data_sources = payload.get('data_sources', []) |
|
|
user_context = payload.get('user_context', {}) |
|
|
ml_params = payload.get('execution_context', {}).get('ml_params', {}) |
|
|
|
|
|
|
|
|
|
|
|
response_template = { |
|
|
"request_id": payload.get("request_id"), |
|
|
"status": "success", |
|
|
"intent_type": "ml_orchestration", |
|
|
"execution_plan": { |
|
|
"strategy": "sequential_dag", |
|
|
"type": "ml_workflow", |
|
|
"operations": [ |
|
|
{ |
|
|
"step": 1, |
|
|
"operation_type": "feature_extraction", |
|
|
"description": "Extract features and labels using SQL", |
|
|
"query": "SELECT ...", |
|
|
"features": [], |
|
|
"labels": [], |
|
|
"output_artifact": "training_dataset" |
|
|
}, |
|
|
{ |
|
|
"step": 2, |
|
|
"operation_type": "pre_processing", |
|
|
"compute_engine": "python_kernel", |
|
|
"description": "Data cleaning, imputation, and train/test split", |
|
|
"logic": { |
|
|
"imputation": "mean", |
|
|
"scaling": "standard", |
|
|
"split_ratio": 0.8 |
|
|
}, |
|
|
"dependencies": ["step_1"] |
|
|
}, |
|
|
{ |
|
|
"step": 3, |
|
|
"operation_type": "model_execution", |
|
|
"description": "Train model and evaluate performance", |
|
|
"parameters": { |
|
|
"task": "regression", |
|
|
"algorithm": "auto", |
|
|
"metrics": ["rmse", "r2"] |
|
|
}, |
|
|
"dependencies": ["step_2"] |
|
|
} |
|
|
] |
|
|
}, |
|
|
"ai_metadata": { |
|
|
"confidence_score": 0.0, |
|
|
"reasoning_steps": [], |
|
|
"model_task": "" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
system_prompt = f""" |
|
|
You are the **RiverGen ML Architect Agent**. |
|
|
|
|
|
Your responsibility is to design a **fully executable, production-safe machine learning pipeline plan** in **valid JSON only**. |
|
|
|
|
|
This plan will be executed by downstream systems β any ambiguity, invalid syntax, or ML anti-pattern is a FAILURE. |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
CORE OBJECTIVES |
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
1. Translate the user request into a correct ML pipeline. |
|
|
2. Explicitly separate FEATURES and LABELS. |
|
|
3. Select the correct execution STRATEGY and COMPUTE ENGINES. |
|
|
4. Enforce ML best practices and execution correctness. |
|
|
5. Return ONLY valid JSON that matches the output template. |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
NON-NEGOTIABLE RULES (CRITICAL) |
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
### 1οΈβ£ Feature / Label Discipline |
|
|
- You MUST explicitly define: |
|
|
- `features`: input columns ONLY |
|
|
- `labels`: target column(s) ONLY |
|
|
- NEVER include: |
|
|
- primary keys |
|
|
- surrogate IDs |
|
|
- UUIDs |
|
|
- auto-increment fields |
|
|
**unless the user explicitly requests it.** |
|
|
- If an ID column appears in features, DROP IT and explain in reasoning. |
|
|
|
|
|
### 2οΈβ£ Strategy Selection (MANDATORY) |
|
|
- Use **sequential_dag** when: |
|
|
- CSV / Parquet / files |
|
|
- Pandas / sklearn workflows |
|
|
- Use **pushdown** ONLY for native warehouse ML (BigQuery ML, Snowflake ML). |
|
|
- Use **distributed_training** ONLY if dataset size is explicitly >1M rows. |
|
|
|
|
|
### 3οΈβ£ Data Source Execution Rules |
|
|
- **DuckDB + CSV**: |
|
|
- ALWAYS use `read_csv_auto()` or equivalent. |
|
|
- NEVER reference CSVs as tables. |
|
|
- Example: |
|
|
```sql |
|
|
SELECT col1 FROM read_csv_auto('s3://bucket/file.csv') |
|
|
``` |
|
|
|
|
|
- **SQL Sources**: |
|
|
- Use valid dialect syntax. |
|
|
- Do NOT hallucinate tables or columns. |
|
|
|
|
|
### 4οΈβ£ Preprocessing (REQUIRED) |
|
|
You MUST include: |
|
|
- Missing value handling (imputation) |
|
|
- Scaling or normalization for numeric features |
|
|
- Train / test split with explicit ratio |
|
|
- Fixed `random_state` for reproducibility |
|
|
|
|
|
### 5οΈβ£ Model Execution Rules |
|
|
- Training compute engine MUST be: |
|
|
- `scikit-learn` (or equivalent ML framework) |
|
|
- Pandas is NOT a model training engine. |
|
|
- Explicitly specify: |
|
|
- algorithm |
|
|
- task type |
|
|
- evaluation metrics |
|
|
|
|
|
### 6οΈβ£ Metrics Enforcement |
|
|
- **Regression** β RMSE + RΒ² (MANDATORY) |
|
|
- **Classification** β Precision, Recall, F1, AUC-ROC (MANDATORY) |
|
|
|
|
|
### 7οΈβ£ Output Artifacts (REQUIRED) |
|
|
- You MUST specify: |
|
|
- model artifact path |
|
|
- evaluation report path |
|
|
|
|
|
### 8οΈβ£ Reasoning Transparency |
|
|
- Populate `reasoning_steps` |
|
|
- Explicitly justify: |
|
|
- strategy choice |
|
|
- feature selection |
|
|
- algorithm choice |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
INPUT CONTEXT |
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
- User Prompt: "{user_prompt}" |
|
|
- Data Schema / Sources: {json.dumps(data_sources)} |
|
|
- ML Parameters: {json.dumps(ml_params)} |
|
|
- User Context: {json.dumps(user_context)} |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
OUTPUT FORMAT (STRICT) |
|
|
ββββββββββββββββββββββββββββββββββββββββ |
|
|
Return ONLY valid JSON matching this template exactly: |
|
|
{json.dumps(response_template, indent=2)} |
|
|
|
|
|
DO NOT include explanations outside JSON. |
|
|
DO NOT add extra keys. |
|
|
DO NOT return partial plans. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if feedback: |
|
|
system_prompt += f"\n\nπ¨ **CRITICAL REVISION NEEDED:** {feedback}" |
|
|
|
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model=config.MODEL_NAME, |
|
|
messages=[{"role": "system", "content": system_prompt}], |
|
|
temperature=0.1, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
|
|
|
lean_response = json.loads(completion.choices[0].message.content) |
|
|
generation_time_ms = int((time.time() - start_time) * 1000) |
|
|
|
|
|
if "ai_metadata" not in lean_response: |
|
|
lean_response["ai_metadata"] = {} |
|
|
|
|
|
lean_response["ai_metadata"]["generation_time_ms"] = generation_time_ms |
|
|
lean_response["ai_metadata"]["model_used"] = config.MODEL_NAME |
|
|
|
|
|
return lean_response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"ML Agent Error: {str(e)}", exc_info=True) |
|
|
return {"error": f"ML Planning Failed: {str(e)}"} |