Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import requests | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| load_dotenv() | |
| API_BASE_URL: str = os.environ["API_BASE_URL"] | |
| API_KEY: str = os.environ["API_KEY"] | |
| MODEL_NAME: str = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct:novita") | |
| ENV_BASE_URL: str = os.environ.get("ENV_BASE_URL", "http://localhost:7860") | |
| USE_LLM_AGENT: bool = os.environ.get("USE_LLM_AGENT", "1").strip() in {"1", "true", "True"} | |
| PROXY_PING_REQUIRED: bool = os.environ.get("PROXY_PING_REQUIRED", "1").strip() in {"1", "true", "True"} | |
| MAX_STEPS: int = 26 | |
| TASK_IDS: List[str] = [ | |
| "fix_broken_query", | |
| "inventory_restock_alerts", | |
| "find_data_anomalies", | |
| "detect_subscription_issues", | |
| "repair_data_pipeline", | |
| "multi_channel_attribution", | |
| ] | |
| BENCHMARK: str = "sql-data-analyst-env" | |
| SUCCESS_SCORE_THRESHOLD: float = 0.95 | |
| FEW_SHOT_EXAMPLE_COUNT: int = int(os.environ.get("FEW_SHOT_EXAMPLE_COUNT", "5")) | |
| HARD_EXAMPLE_RATIO: float = float(os.environ.get("HARD_EXAMPLE_RATIO", "0.6")) | |
| MEDIUM_EXAMPLE_RATIO: float = float(os.environ.get("MEDIUM_EXAMPLE_RATIO", "0.2")) | |
| DETERMINISTIC_HARD_QUERY_RATE: float = float(os.environ.get("DETERMINISTIC_HARD_QUERY_RATE", "0.35")) | |
| TRAINING_QUERY_PATH: str = os.environ.get("TRAINING_QUERY_PATH", "training_queries.json") | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| SYSTEM_PROMPT = """You are an expert SQL data analyst agent. | |
| Respond ONLY with compact JSON: {"action_type":"...","sql_query":"...","answer":{...}} | |
| Allowed action_type: list_tables, describe_table, execute_query, submit_answer, noop | |
| """ | |
| def load_training_queries() -> Dict[str, Any]: | |
| path = Path(TRAINING_QUERY_PATH) | |
| if not path.is_absolute(): | |
| path = Path(__file__).resolve().parent / path | |
| try: | |
| payload = json.loads(path.read_text(encoding="utf-8")) | |
| except FileNotFoundError as exc: | |
| raise RuntimeError(f"Training query file not found: {path}") from exc | |
| except json.JSONDecodeError as exc: | |
| raise RuntimeError(f"Invalid JSON in training query file: {path} ({exc})") from exc | |
| if not isinstance(payload, dict): | |
| raise RuntimeError(f"Training query file must contain a JSON object: {path}") | |
| if "few_shot" not in payload or "deterministic" not in payload: | |
| raise RuntimeError(f"Training query file missing required keys 'few_shot' and 'deterministic': {path}") | |
| return payload | |
| TRAINING_QUERIES: Dict[str, Any] = load_training_queries() | |
| def _clip(value: str, limit: int = 180) -> str: | |
| flat = " ".join(value.split()) | |
| if len(flat) <= limit: | |
| return flat | |
| return f"{flat[: limit - 3]}..." | |
| def _json_preview(value: Any, limit: int = 180) -> str: | |
| text = json.dumps(value, separators=(",", ":"), ensure_ascii=True, default=str) | |
| return _clip(text, limit=limit) | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: Dict[str, Any], reward: float, done: bool, error: Optional[str]) -> None: | |
| compact_action = json.dumps(action, separators=(",", ":"), ensure_ascii=True, default=str) | |
| err = "" if not error else _clip(str(error), limit=220) | |
| print( | |
| f"[STEP] step={step} action={compact_action} reward={reward:.4f} done={str(done).lower()} error={err}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={_json_preview([round(r, 4) for r in rewards], limit=260)}", | |
| flush=True, | |
| ) | |
| def env_reset(task_id: str) -> Dict[str, Any]: | |
| resp = requests.post(f"{ENV_BASE_URL}/reset", params={"task_id": task_id}, timeout=30) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def env_step(action: Dict[str, Any]) -> Dict[str, Any]: | |
| resp = requests.post(f"{ENV_BASE_URL}/step", json=action, timeout=30) | |
| if resp.status_code >= 400: | |
| raise RuntimeError(f"HTTP {resp.status_code} /step failed: {resp.text}") | |
| return resp.json() | |
| def validate_action(action: Dict[str, Any]) -> Dict[str, Any]: | |
| valid_types = {"execute_query", "describe_table", "submit_answer", "list_tables", "noop"} | |
| action_type = action.get("action_type") | |
| if action_type not in valid_types: | |
| return {"action_type": "noop"} | |
| cleaned: Dict[str, Any] = {"action_type": action_type} | |
| if action.get("sql_query") is not None: | |
| cleaned["sql_query"] = str(action["sql_query"]) | |
| if action.get("answer") is not None: | |
| cleaned["answer"] = action["answer"] | |
| return cleaned | |
| def build_sql_examples() -> List[Dict[str, str]]: | |
| few_shot = TRAINING_QUERIES.get("few_shot", {}) | |
| easy_pool = list(few_shot.get("easy", [])) | |
| medium_pool = list(few_shot.get("medium", [])) | |
| hard_pool = list(few_shot.get("hard", [])) | |
| random.shuffle(easy_pool) | |
| random.shuffle(medium_pool) | |
| random.shuffle(hard_pool) | |
| target_hard = max(1, int(round(FEW_SHOT_EXAMPLE_COUNT * HARD_EXAMPLE_RATIO))) | |
| target_medium = max(0, int(round(FEW_SHOT_EXAMPLE_COUNT * MEDIUM_EXAMPLE_RATIO))) | |
| target_hard = min(target_hard, len(hard_pool)) | |
| target_medium = min(target_medium, len(medium_pool)) | |
| target_easy = max(0, FEW_SHOT_EXAMPLE_COUNT - target_hard - target_medium) | |
| target_easy = min(target_easy, len(easy_pool)) | |
| examples = hard_pool[:target_hard] + medium_pool[:target_medium] + easy_pool[:target_easy] | |
| if len(examples) < FEW_SHOT_EXAMPLE_COUNT: | |
| shortfall = FEW_SHOT_EXAMPLE_COUNT - len(examples) | |
| leftovers = hard_pool[target_hard:] + medium_pool[target_medium:] + easy_pool[target_easy:] | |
| examples.extend(leftovers[:shortfall]) | |
| random.shuffle(examples) | |
| return examples | |
| def choose_query(primary_sql: str, harder_variants: List[str]) -> str: | |
| if harder_variants and random.random() < DETERMINISTIC_HARD_QUERY_RATE: | |
| return random.choice(harder_variants) | |
| return primary_sql | |
| def llm_action(observation: Dict[str, Any], history: List[str]) -> Dict[str, Any]: | |
| user_prompt = json.dumps( | |
| { | |
| "task_id": observation.get("task_id"), | |
| "goal": observation.get("goal"), | |
| "schema_info": observation.get("schema_info"), | |
| "last_query_result": observation.get("last_query_result"), | |
| "last_query_error": observation.get("last_query_error"), | |
| "last_action_error": observation.get("last_action_error"), | |
| "history": history[-6:], | |
| "sql_examples": build_sql_examples(), | |
| }, | |
| ensure_ascii=True, | |
| ) | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=256, | |
| stream=False, | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| if text.startswith("```"): | |
| text = text.strip("`") | |
| if text.lower().startswith("json"): | |
| text = text[4:].strip() | |
| return validate_action(json.loads(text)) | |
| except Exception: | |
| return {"action_type": "noop"} | |
| def ping_llm_proxy() -> None: | |
| """Ensure at least one request is sent through the configured LLM proxy.""" | |
| try: | |
| # More robust than a model-specific completion call. | |
| client.models.list() | |
| print("[INFO] proxy_ping=ok", flush=True) | |
| except Exception as exc: | |
| print(f"[ERROR] proxy_ping_failed={exc}", flush=True) | |
| if PROXY_PING_REQUIRED: | |
| raise | |
| def execute_and_log(step_no: int, action: Dict[str, Any], rewards: List[float]) -> Dict[str, Any]: | |
| action = validate_action(action) | |
| result = env_step(action) | |
| reward = float(result.get("reward", 0.0) or 0.0) | |
| done = bool(result.get("done", False)) | |
| obs = result["observation"] | |
| error = obs.get("last_action_error") or obs.get("last_query_error") | |
| log_step(step=step_no, action=action, reward=reward, done=done, error=error) | |
| rewards.append(reward) | |
| return result | |
| def run_fixed_query_task(task_id: str, sql: str, submit_key: str) -> float: | |
| rewards: List[float] = [] | |
| steps = 0 | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| env_reset(task_id) | |
| steps += 1 | |
| result = execute_and_log(steps, {"action_type": "execute_query", "sql_query": sql}, rewards) | |
| payload = result["observation"].get("last_query_result") or [] | |
| steps += 1 | |
| submit = execute_and_log(steps, {"action_type": "submit_answer", "answer": {submit_key: payload}}, rewards) | |
| score = float(submit.get("info", {}).get("final_score", 0.0) or 0.0) | |
| log_end(success=score >= SUCCESS_SCORE_THRESHOLD, steps=steps, score=score, rewards=rewards) | |
| return score | |
| def run_count_dict_task(task_id: str, queries: List[Tuple[str, str]]) -> float: | |
| rewards: List[float] = [] | |
| steps = 0 | |
| counts: Dict[str, int] = {} | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| env_reset(task_id) | |
| for key, sql in queries: | |
| steps += 1 | |
| result = execute_and_log(steps, {"action_type": "execute_query", "sql_query": sql}, rewards) | |
| rows = result["observation"].get("last_query_result") or [] | |
| value = 0 | |
| if rows and isinstance(rows[0], dict): | |
| try: | |
| value = int(rows[0].get("c", 0)) | |
| except (TypeError, ValueError): | |
| value = 0 | |
| counts[key] = value | |
| steps += 1 | |
| submit = execute_and_log(steps, {"action_type": "submit_answer", "answer": counts}, rewards) | |
| score = float(submit.get("info", {}).get("final_score", 0.0) or 0.0) | |
| log_end(success=score >= SUCCESS_SCORE_THRESHOLD, steps=steps, score=score, rewards=rewards) | |
| return score | |
| def run_deterministic(task_id: str) -> float: | |
| task_cfg = TRAINING_QUERIES.get("deterministic", {}).get(task_id) | |
| if not isinstance(task_cfg, dict): | |
| raise ValueError(f"Unsupported task id: {task_id}") | |
| task_type = task_cfg.get("type") | |
| if task_type == "fixed": | |
| primary_sql = str(task_cfg.get("primary_sql", "")) | |
| hard_variants = [str(x) for x in task_cfg.get("hard_variants", [])] | |
| submit_key = str(task_cfg.get("submit_key", "rows")) | |
| if not primary_sql: | |
| raise ValueError(f"Missing primary_sql for task id: {task_id}") | |
| sql = choose_query(primary_sql, hard_variants) | |
| return run_fixed_query_task(task_id, sql, submit_key) | |
| if task_type == "count_dict": | |
| raw_queries = task_cfg.get("queries", []) | |
| parsed_queries: List[Tuple[str, str]] = [] | |
| for item in raw_queries: | |
| if not isinstance(item, dict): | |
| continue | |
| key = str(item.get("key", "")).strip() | |
| sql = str(item.get("sql", "")).strip() | |
| if key and sql: | |
| parsed_queries.append((key, sql)) | |
| if not parsed_queries: | |
| raise ValueError(f"No valid count_dict queries for task id: {task_id}") | |
| return run_count_dict_task(task_id, parsed_queries) | |
| raise ValueError(f"Unsupported deterministic task type '{task_type}' for task id: {task_id}") | |
| def run_episode_with_llm(task_id: str) -> float: | |
| rewards: List[float] = [] | |
| history: List[str] = [] | |
| steps_taken = 0 | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| reset_result = env_reset(task_id) | |
| result = {"observation": reset_result["observation"], "done": False, "reward": 0.0, "info": {}} | |
| for step in range(1, MAX_STEPS + 1): | |
| if result.get("done"): | |
| break | |
| action = llm_action(result["observation"], history) | |
| steps_taken = step | |
| result = execute_and_log(step, action, rewards) | |
| history.append(json.dumps(action, ensure_ascii=True)) | |
| score = float(result.get("info", {}).get("final_score", 0.0) or 0.0) | |
| log_end(success=score >= SUCCESS_SCORE_THRESHOLD, steps=steps_taken, score=score, rewards=rewards) | |
| return score | |
| def main() -> None: | |
| global ENV_BASE_URL | |
| parser = argparse.ArgumentParser(description="SQL Data Analyst Env - Baseline Inference") | |
| parser.add_argument("--env-url", default=None) | |
| args = parser.parse_args() | |
| if args.env_url: | |
| ENV_BASE_URL = args.env_url.rstrip("/") | |
| print(f"Connecting to environment at {ENV_BASE_URL} ...", flush=True) | |
| try: | |
| health_resp = requests.get(f"{ENV_BASE_URL}/health", timeout=10) | |
| health_resp.raise_for_status() | |
| print(f"Health: {health_resp.json()}", flush=True) | |
| except Exception as exc: | |
| print(f"ERROR: Cannot reach environment server: {exc}", flush=True) | |
| sys.exit(1) | |
| print(f"Model: {MODEL_NAME}", flush=True) | |
| print(f"API: {API_BASE_URL}", flush=True) | |
| print(f"Mode: {'LLM' if USE_LLM_AGENT else 'deterministic'}", flush=True) | |
| print(f"Training Queries: {TRAINING_QUERY_PATH}", flush=True) | |
| try: | |
| ping_llm_proxy() | |
| except Exception as exc: | |
| print(f"ERROR: LLM proxy validation failed: {exc}", flush=True) | |
| sys.exit(1) | |
| total_start = time.time() | |
| scores: Dict[str, float] = {} | |
| try: | |
| for task_id in TASK_IDS: | |
| scores[task_id] = run_episode_with_llm(task_id) if USE_LLM_AGENT else run_deterministic(task_id) | |
| except Exception as exc: | |
| print(f"ERROR: Inference failed: {exc}", flush=True) | |
| sys.exit(1) | |
| avg = sum(scores.values()) / len(scores) | |
| elapsed = round(time.time() - total_start, 1) | |
| print("\n=== RUN SUMMARY ===", flush=True) | |
| for task_id in TASK_IDS: | |
| print(f"{task_id}: {scores.get(task_id, 0.0):.4f}", flush=True) | |
| print(f"average: {avg:.4f}", flush=True) | |
| print(f"elapsed_seconds: {elapsed:.1f}", flush=True) | |
| if __name__ == "__main__": | |
| main() | |