Spaces:
Running
Running
| """ | |
| inference.py β NL2SQL-Bench Baseline Inference Script | |
| ======================================================== | |
| MANDATORY COMPLIANCE | |
| -------------------- | |
| - Named `inference.py`, placed in project root. | |
| - Uses OpenAI client for all LLM calls. | |
| - Reads: API_BASE_URL, MODEL_NAME, HF_TOKEN from environment. | |
| - Emits [START] / [STEP] / [END] lines to stdout in the exact format below. | |
| - Runs all 3 tasks; total runtime < 20 min on 2 vCPU / 8 GB. | |
| STDOUT FORMAT (exact β any deviation breaks scoring) | |
| ---------------------------------------------------- | |
| [START] task=<task_name> env=nl2sql-bench model=<model_name> | |
| [STEP] step=<n> action=<sql_one_line> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...> | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| import sys | |
| import textwrap | |
| from typing import List, Optional | |
| from openai import OpenAI | |
| # # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| # MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct") | |
| # API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "") | |
| # IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest") | |
| # SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000") | |
| # BENCHMARK = "nl2sql-bench" | |
| # MAX_STEPS = 5 | |
| # TEMPERATURE = 0.2 # Low temp for SQL generation | |
| # MAX_TOKENS = 512 | |
| # SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β success | |
| # TASKS = ["simple-filter", "join-aggregation", "analytics-window"] | |
| # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| # Points to your newly uploaded fine-tuned weights! | |
| MODEL_NAME = os.getenv("MODEL_NAME", "ritvik360/qwen-7b-nl2sql-merged_1") | |
| # CRITICAL FIX: Looks for 'API_KEY' first to satisfy the evaluator's LiteLLM proxy | |
| API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "") or os.getenv("OPENAI_API_KEY") | |
| IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest") | |
| # CRITICAL FIX: Point the default directly to your live HF Space! | |
| SPACE_URL = os.getenv("SPACE_URL", "https://ritvik360-nl2sql-bench.hf.space") | |
| BENCHMARK = "nl2sql-bench" | |
| MAX_STEPS = 5 | |
| TEMPERATURE = 0.2 # Low temp for SQL generation | |
| MAX_TOKENS = 512 | |
| SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β success | |
| TASKS = ["simple-filter", "join-aggregation", "analytics-window"] | |
| # ββ System prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = textwrap.dedent(""" | |
| You are an expert SQL analyst working with a SQLite e-commerce database. | |
| DATABASE SCHEMA | |
| --------------- | |
| categories(id, name) | |
| products(id, name, category_id, price, stock_quantity) | |
| customers(id, name, email, country, tierβ{bronze|silver|gold}, created_at) | |
| orders(id, customer_id, statusβ{pending|processing|shipped|delivered|cancelled}, | |
| created_at, total_amount) | |
| order_items(id, order_id, product_id, quantity, unit_price) | |
| reviews(id, product_id, customer_id, ratingβ1-5, created_at) | |
| RULES | |
| ----- | |
| 1. Write a single SELECT query β no INSERT/UPDATE/DELETE. | |
| 2. Output ONLY the SQL query, nothing else. No markdown, no explanation. | |
| 3. Use SQLite syntax: strftime('%Y-%m', date_col) for month, ROUND(x, 2) for decimals. | |
| 4. Window functions (RANK, DENSE_RANK, ROW_NUMBER, running SUM) are supported. | |
| 5. CTEs (WITH ... AS (...)) are supported. | |
| 6. If you receive an error, fix it carefully in your next attempt. | |
| 7. If you receive partial results, refine your query to match the expected output. | |
| """).strip() | |
| # ββ Stdout logging (mandatory format) βββββββββββββββββββββββββββββββββββββ | |
| def log_start(task: str, model: str) -> None: | |
| print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True) | |
| def log_step( | |
| step: int, action: str, reward: float, done: bool, error: Optional[str] | |
| ) -> None: | |
| # Collapse multi-line SQL to single line for log compliance | |
| action_single = " ".join(action.split()) | |
| error_val = error.replace("\n", " ") if error else "null" | |
| print( | |
| f"[STEP] step={step} action={action_single!r} " | |
| f"reward={reward:.2f} done={str(done).lower()} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end( | |
| success: bool, steps: int, score: float, rewards: List[float] | |
| ) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} " | |
| f"score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| # ββ LLM interaction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_user_prompt( | |
| question: str, | |
| schema_context: str, | |
| step: int, | |
| last_query: str, | |
| last_error: Optional[str], | |
| last_result: list, | |
| result_columns: list, | |
| ) -> str: | |
| parts = [f"QUESTION: {question}", ""] | |
| if step > 1: | |
| parts.append(f"Your previous SQL (step {step - 1}):") | |
| parts.append(f" {' '.join(last_query.split())}") | |
| parts.append("") | |
| if last_error: | |
| parts.append(f"ERROR: {last_error}") | |
| elif last_result: | |
| preview = str(last_result[:3]).replace("\n", " ") | |
| parts.append(f"RESULT PREVIEW (first 3 rows): {preview}") | |
| parts.append(f"COLUMNS: {result_columns}") | |
| parts.append("") | |
| parts.append("Please correct or refine your query.") | |
| else: | |
| parts.append("Write a SQL query to answer the question.") | |
| return "\n".join(parts) | |
| def call_llm(client: OpenAI, user_prompt: str) -> str: | |
| try: | |
| resp = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| stream=False, | |
| ) | |
| text = (resp.choices[0].message.content or "").strip() | |
| # Strip markdown code fences if model wraps in ```sql ... ``` | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| text = "\n".join( | |
| l for l in lines | |
| if not l.strip().startswith("```") | |
| ).strip() | |
| return text if text else "SELECT 1" | |
| except Exception as exc: | |
| print(f"[DEBUG] LLM call failed: {exc}", file=sys.stderr, flush=True) | |
| return "SELECT 1" | |
| # ββ Single-task episode ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_task(client: OpenAI, env, task_name: str) -> dict: | |
| """Run one full episode for the given task. Returns result dict.""" | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| score = 0.0 | |
| success = False | |
| log_start(task_name, MODEL_NAME) | |
| try: | |
| # Reset β pass task_name via action payload or query param | |
| # OpenEnv reset() may not accept task args via HTTP; we rely on | |
| # NL2SQL_DEFAULT_TASK env-var being set before calling, OR we | |
| # pass it as a reset parameter if the server supports it. | |
| result = await env.reset() # changed | |
| obs = result.observation | |
| for step in range(1, MAX_STEPS + 1): | |
| if result.done: | |
| break | |
| user_prompt = build_user_prompt( | |
| question=obs.question, | |
| schema_context=obs.schema_context, | |
| step=step, | |
| last_query=obs.last_query, | |
| last_error=obs.last_error, | |
| last_result=obs.last_result, | |
| result_columns=obs.result_columns, | |
| ) | |
| sql = call_llm(client, user_prompt) | |
| from models import NL2SQLAction # local to avoid circular at module level | |
| action = NL2SQLAction(query=sql) | |
| result = await env.step(action) | |
| obs = result.observation | |
| reward = obs.reward or 0.0 | |
| done = obs.done | |
| error = obs.last_error | |
| rewards.append(reward) | |
| steps_taken = step | |
| log_step(step=step, action=sql, reward=reward, done=done, error=error) | |
| if done: | |
| break | |
| # Compute final score | |
| # CRITICAL: Evaluator requires score strictly in (0, 1) β not 0.0, not 1.0. | |
| # A perfect solve gives 1.0 β clamp to 0.999. All-fail gives 0.0 β clamp to 0.001. | |
| raw_score = sum(rewards) / max(len(rewards), 1) | |
| score = round(min(max(raw_score, 0.001), 0.999), 4) | |
| success = raw_score >= SUCCESS_THRESHOLD | |
| except Exception as exc: | |
| print(f"[DEBUG] Episode error for {task_name}: {exc}", file=sys.stderr, flush=True) | |
| finally: | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| return {"task": task_name, "success": success, "score": score, "rewards": rewards} | |
| # ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def main() -> None: | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| # Import here to avoid import errors if openenv not installed during lint | |
| from client import NL2SQLEnv | |
| all_results = [] | |
| for task_name in TASKS: | |
| # Set the default task for the server session via env-var approach. | |
| # For the hosted Space, we rely on the task cycling implemented in | |
| # the task registry's round-robin iterator. | |
| os.environ["NL2SQL_DEFAULT_TASK"] = task_name | |
| try: | |
| async with NL2SQLEnv(base_url=SPACE_URL) as env: | |
| result = await run_task(client, env, task_name) | |
| all_results.append(result) | |
| except Exception as exc: | |
| print( | |
| f"[DEBUG] Failed to connect for task {task_name}: {exc}", | |
| file=sys.stderr, | |
| flush=True, | |
| ) | |
| # Emit a zero-score END to keep log format valid | |
| log_end(success=False, steps=0, score=0.0, rewards=[]) | |
| all_results.append({"task": task_name, "success": False, "score": 0.0}) | |
| # Summary to stderr (not scored, for human readability) | |
| print("\n=== Baseline Summary ===", file=sys.stderr) | |
| for r in all_results: | |
| print( | |
| f" {r['task']:20s} score={r['score']:.3f} " | |
| f"success={r['success']}", | |
| file=sys.stderr, | |
| ) | |
| avg = sum(r["score"] for r in all_results) / max(len(all_results), 1) | |
| print(f" {'AVERAGE':20s} score={avg:.3f}", file=sys.stderr) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |