import os import re import json import sys import httpx from dotenv import load_dotenv try: _here = os.path.dirname(os.path.abspath(__file__)) _root = os.path.dirname(_here) except NameError: _root = os.getcwd() if _root not in sys.path: sys.path.insert(0, _root) from baseline.prompts import SYSTEM_PROMPT import os from openai import OpenAI from dotenv import load_dotenv load_dotenv() # Supports both OpenAI and Google AI Studio (Gemini) as drop-in # If OPENAI_BASE_URL is set, use it (Google AI Studio or other compatible API) # Otherwise default to OpenAI api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GOOGLE_AI_KEY") base_url = os.getenv("OPENAI_BASE_URL", None) # None = use OpenAI default model = os.getenv("BASELINE_MODEL", "gemini-2.0-flash") env_base_url = os.getenv("ENV_BASE_URL", "http://localhost:7860") if not api_key: raise ValueError( "No API key found. Set OPENAI_API_KEY (for OpenAI) or " "GOOGLE_AI_KEY + OPENAI_BASE_URL (for Google AI Studio / other providers)" ) # Build client — works for OpenAI, Google AI Studio, Groq, OpenRouter client_kwargs = {"api_key": api_key} if base_url: client_kwargs["base_url"] = base_url client = OpenAI(**client_kwargs) print(f"Baseline agent initialised:") print(f" Provider: {'Google AI Studio' if 'google' in (base_url or '') else 'OpenAI-compatible'}") print(f" Model: {model}") print(f" Environment: {env_base_url}") BASE_URL = env_base_url BASELINE_SEEDS = {1: 42, 2: 99, 3: 777} def format_score_line(task_id: int, score: float) -> str: return f"SCORE task_{task_id}: {score:.4f}" def call_llm(messages: list) -> str: try: response = client.chat.completions.create( model=model, messages=messages, temperature=0.0 ) return response.choices[0].message.content except Exception as e: print(f"Fatal OpenAI API crash: {e}") sys.exit(1) def parse_action(raw_text: str) -> dict: """Extract and parse action JSON from LLM output, handling all common failure modes.""" text = raw_text.strip() # Mode 1: strip markdown code fences (```json ... ``` or ``` ... ```) fence_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', text) if fence_match: text = fence_match.group(1).strip() # Mode 2: find first { ... } JSON object if there's surrounding prose brace_match = re.search(r'\{[\s\S]*\}', text) if brace_match: text = brace_match.group(0) # Mode 3: fix trailing commas (common LLM mistake) text = re.sub(r',\s*([}\]])', r'\1', text) # Mode 4: fix single quotes used instead of double quotes # Only do this if JSON parse fails first try: return json.loads(text) except json.JSONDecodeError: try: # Replace single-quoted keys/values carefully text_fixed = re.sub(r"'([^']*)'", r'"\1"', text) return json.loads(text_fixed) except json.JSONDecodeError: return None # caller handles None def safe_action(parsed: dict | None, step_num: int) -> dict: """Convert parsed dict to valid action, with safe fallbacks.""" if parsed is None: # After 3 failed parses in a row, submit to end episode gracefully return {"action_type": "submit"} action_type = parsed.get("action_type", "").lower() if action_type == "query" and "sql" in parsed: return parsed elif action_type == "ddl" and "sql" in parsed: return parsed elif action_type == "test" and "target_table" in parsed: return parsed elif action_type == "submit": return parsed elif "sql" in parsed: # LLM gave SQL but wrong action_type — infer it sql = parsed["sql"].strip().upper() inferred_type = "query" if sql.startswith(("SELECT","WITH","EXPLAIN")) else "ddl" return {"action_type": inferred_type, "sql": parsed["sql"]} else: # Completely unparseable — explore schema as safe default if step_num <= 3: return {"action_type": "query", "sql": "SELECT name, sql FROM sqlite_master WHERE type IN ('table','view')"} return {"action_type": "submit"} def run_task(task_id: int) -> float: print(f"Starting task {task_id}") try: seed = BASELINE_SEEDS.get(task_id) resp = httpx.post(f"{BASE_URL}/reset", json={"task_id": task_id, "seed": seed}, timeout=30.0) resp.raise_for_status() resp_data = resp.json() obs = resp_data.get("observation", resp_data) session_id = resp_data.get("session_id", "") except Exception as e: print(f"Failed to reset environment for task {task_id}: {e}") return 0.0 messages = [{"role": "system", "content": SYSTEM_PROMPT}] max_steps = obs.get("max_steps", 25) consecutive_parse_failures = 0 for step in range(max_steps): messages.append({"role": "user", "content": json.dumps(obs)}) try: llm_response = call_llm(messages) parsed = parse_action(llm_response) if parsed is None: consecutive_parse_failures += 1 if consecutive_parse_failures >= 3: print(f"Warning: 3 consecutive parse failures at step {step}. Handing episode submit.") action = {"action_type": "submit"} else: action = safe_action(parsed, step) else: consecutive_parse_failures = 0 action = safe_action(parsed, step) except Exception as e: print(f"LLM error at step {step}: {e}") action = {"action_type": "submit"} messages.append({"role": "assistant", "content": json.dumps(action)}) try: headers = {"X-Session-ID": session_id} if session_id else {} step_resp = httpx.post(f"{BASE_URL}/step", json=action, headers=headers, timeout=30.0) step_resp.raise_for_status() step_data = step_resp.json() obs = step_data.get("observation", step_data) if step_data.get("done") or step_data.get("truncated"): break except Exception as e: print(f"Failed to step environment: {e}") break try: headers = {"X-Session-ID": session_id} if session_id else {} grader_resp = httpx.get(f"{BASE_URL}/grader", headers=headers, timeout=10.0) grader_resp.raise_for_status() final_score = grader_resp.json().get("score", 0.0) except Exception as e: print(f"Failed to get grader score: {e}") final_score = 0.0 print(format_score_line(task_id, final_score)) return final_score def run_baseline(): scores = {} for task_id in [1, 2, 3]: score = run_task(task_id) scores[f"task_{task_id}"] = score print("\n--- Summary ---") for task, score in scores.items(): print(f"{task}: {score:.4f}") if __name__ == "__main__": try: run_baseline() except Exception as e: print(f"Top-level execution crash: {e}") sys.exit(1)