""" inference.py — SQL Query Debugger OpenEnv Follows the mandatory [START]/[STEP]/[END] stdout format. Uses OpenAI client with API_BASE_URL, MODEL_NAME, HF_TOKEN. """ import os import json import textwrap from typing import List, Optional from openai import OpenAI from env.environment import SQLDebuggerEnvironment from env.models import Action, ActionType, DifficultyLevel # ───────────────────────────────────────────── # ENVIRONMENT VARIABLES # ───────────────────────────────────────────── API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or "dummy-key" API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") BENCHMARK = "sql-query-debugger" MAX_STEPS = 10 SUCCESS_SCORE_THRESHOLD = 0.5 # ───────────────────────────────────────────── # LOGGING FUNCTIONS — exact format required # ───────────────────────────────────────────── def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: error_val = error if error else "null" done_val = str(done).lower() print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) # ───────────────────────────────────────────── # SYSTEM PROMPT # ───────────────────────────────────────────── SYSTEM_PROMPT = textwrap.dedent(""" You are an expert SQL debugger. You will be given a buggy SQL query and must fix it. You must respond with a JSON object only — no explanation outside the JSON. For syntax/logic errors, respond with: { "action_type": "submit_answer", "fixed_query": "", "explanation": "", "error_type": "", "error_location": "", "confidence": 0.9 } For performance issues, respond with: { "action_type": "optimize_query", "optimized_query": "", "optimization_type": "", "explanation": "", "root_cause": "", "expected_improvement": "", "confidence": 0.85 } Always provide valid JSON. Never include markdown code blocks. """).strip() def build_user_prompt(obs) -> str: ctx = obs.current_context return textwrap.dedent(f""" Task: {obs.task_description} Difficulty: {obs.difficulty} Buggy Query: {ctx.get('buggy_query', 'N/A')} Error Message: {ctx.get('error_message', 'N/A')} Database Schema: {json.dumps(ctx.get('database_schema', {}), indent=2)} Error Type Hint: {ctx.get('error_type_hint', 'unknown')} Category: {ctx.get('category', 'unknown')} Steps Remaining: {ctx.get('steps_remaining', 20)} Analyze the buggy query and provide your fix as a JSON object. """).strip() # ───────────────────────────────────────────── # LLM CALL # ───────────────────────────────────────────── def get_llm_action(client: OpenAI, obs, step: int) -> Action: """Call the LLM and parse its response into an Action.""" user_prompt = build_user_prompt(obs) try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0.3, max_tokens=512, stream=False, ) text = (completion.choices[0].message.content or "").strip() # Parse JSON response # Remove markdown code blocks if present if "```" in text: text = text.split("```")[1] if text.startswith("json"): text = text[4:] text = text.strip() data = json.loads(text) action_type = data.get("action_type", "submit_answer") if action_type == "optimize_query": return Action( action_type=ActionType.OPTIMIZE_QUERY, payload={ "optimized_query": data.get("optimized_query", "SELECT 1"), "optimization_type": data.get("optimization_type", "Performance fix"), "explanation": data.get("explanation", ""), "root_cause": data.get("root_cause", ""), "expected_improvement": data.get("expected_improvement", ""), "confidence": float(data.get("confidence", 0.7)), } ) else: return Action( action_type=ActionType.SUBMIT_ANSWER, payload={ "fixed_query": data.get("fixed_query", "SELECT 1"), "explanation": data.get("explanation", ""), "error_type": data.get("error_type", "syntax"), "error_location": data.get("error_location", "unknown"), "confidence": float(data.get("confidence", 0.7)), } ) except Exception as exc: print(f"[DEBUG] LLM call failed: {exc}", flush=True) # Fallback to identify_error action return Action( action_type=ActionType.IDENTIFY_ERROR, payload={ "error_location": "unknown", "error_type": "syntax", "explanation": "LLM call failed, using fallback" } ) # ───────────────────────────────────────────── # MAIN INFERENCE LOOP # ───────────────────────────────────────────── def run_episode(client: OpenAI, difficulty: str, task_id: str) -> dict: """Run one full episode and return results.""" env = SQLDebuggerEnvironment() obs = env.reset(difficulty=difficulty, task_id=task_id) rewards = [] steps = 0 success = False score = 0.0 log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) try: for step in range(1, MAX_STEPS + 1): if env.state().done: break # Get action from LLM action = get_llm_action(client, obs, step) action_str = f"{action.action_type.value}" error_str = None try: resp = env.step(action) reward = resp.reward.score done = resp.done obs = resp.observation except Exception as e: reward = -0.1 done = False error_str = str(e)[:100] rewards.append(reward) steps = step log_step( step = step, action = action_str, reward = reward, done = done, error = error_str ) if done: break # Calculate score total_reward = sum(rewards) score = min(max(total_reward / MAX_STEPS, 0.0), 1.0) success = score >= SUCCESS_SCORE_THRESHOLD except Exception as e: print(f"[DEBUG] Episode error: {e}", flush=True) error_str = str(e)[:100] finally: log_end( success = success, steps = steps, score = score, rewards = rewards ) return { "task_id": task_id, "difficulty": difficulty, "score": score, "steps": steps, "success": success, } def main(): """Main entry point — runs inference on all 3 difficulty levels.""" print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True) print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True) client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) tasks = [ ("easy", "easy_001"), ("medium", "medium_001"), ("hard", "hard_001"), ] results = [] for difficulty, task_id in tasks: result = run_episode(client, difficulty, task_id) results.append(result) # Final summary avg_score = sum(r["score"] for r in results) / len(results) print(f"\n[DEBUG] Average Score: {avg_score:.3f}", flush=True) for r in results: print(f"[DEBUG] {r['difficulty']:8} | {r['task_id']:12} | score={r['score']:.3f} | steps={r['steps']}", flush=True) if __name__ == "__main__": main()