Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import sys | |
| import httpx | |
| from dotenv import load_dotenv | |
| try: | |
| _here = os.path.dirname(os.path.abspath(__file__)) | |
| _root = os.path.dirname(_here) | |
| except NameError: | |
| _root = os.getcwd() | |
| if _root not in sys.path: | |
| sys.path.insert(0, _root) | |
| from baseline.prompts import SYSTEM_PROMPT | |
| import os | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Supports both OpenAI and Google AI Studio (Gemini) as drop-in | |
| # If OPENAI_BASE_URL is set, use it (Google AI Studio or other compatible API) | |
| # Otherwise default to OpenAI | |
| api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GOOGLE_AI_KEY") | |
| base_url = os.getenv("OPENAI_BASE_URL", None) # None = use OpenAI default | |
| model = os.getenv("BASELINE_MODEL", "gemini-2.0-flash") | |
| env_base_url = os.getenv("ENV_BASE_URL", "http://localhost:7860") | |
| if not api_key: | |
| raise ValueError( | |
| "No API key found. Set OPENAI_API_KEY (for OpenAI) or " | |
| "GOOGLE_AI_KEY + OPENAI_BASE_URL (for Google AI Studio / other providers)" | |
| ) | |
| # Build client — works for OpenAI, Google AI Studio, Groq, OpenRouter | |
| client_kwargs = {"api_key": api_key} | |
| if base_url: | |
| client_kwargs["base_url"] = base_url | |
| client = OpenAI(**client_kwargs) | |
| print(f"Baseline agent initialised:") | |
| print(f" Provider: {'Google AI Studio' if 'google' in (base_url or '') else 'OpenAI-compatible'}") | |
| print(f" Model: {model}") | |
| print(f" Environment: {env_base_url}") | |
| BASE_URL = env_base_url | |
| BASELINE_SEEDS = {1: 42, 2: 99, 3: 777} | |
| def format_score_line(task_id: int, score: float) -> str: | |
| return f"SCORE task_{task_id}: {score:.4f}" | |
| def call_llm(messages: list) -> str: | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=0.0 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| print(f"Fatal OpenAI API crash: {e}") | |
| sys.exit(1) | |
| def parse_action(raw_text: str) -> dict: | |
| """Extract and parse action JSON from LLM output, handling all common failure modes.""" | |
| text = raw_text.strip() | |
| # Mode 1: strip markdown code fences (```json ... ``` or ``` ... ```) | |
| fence_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', text) | |
| if fence_match: | |
| text = fence_match.group(1).strip() | |
| # Mode 2: find first { ... } JSON object if there's surrounding prose | |
| brace_match = re.search(r'\{[\s\S]*\}', text) | |
| if brace_match: | |
| text = brace_match.group(0) | |
| # Mode 3: fix trailing commas (common LLM mistake) | |
| text = re.sub(r',\s*([}\]])', r'\1', text) | |
| # Mode 4: fix single quotes used instead of double quotes | |
| # Only do this if JSON parse fails first | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| try: | |
| # Replace single-quoted keys/values carefully | |
| text_fixed = re.sub(r"'([^']*)'", r'"\1"', text) | |
| return json.loads(text_fixed) | |
| except json.JSONDecodeError: | |
| return None # caller handles None | |
| def safe_action(parsed: dict | None, step_num: int) -> dict: | |
| """Convert parsed dict to valid action, with safe fallbacks.""" | |
| if parsed is None: | |
| # After 3 failed parses in a row, submit to end episode gracefully | |
| return {"action_type": "submit"} | |
| action_type = parsed.get("action_type", "").lower() | |
| if action_type == "query" and "sql" in parsed: | |
| return parsed | |
| elif action_type == "ddl" and "sql" in parsed: | |
| return parsed | |
| elif action_type == "test" and "target_table" in parsed: | |
| return parsed | |
| elif action_type == "submit": | |
| return parsed | |
| elif "sql" in parsed: | |
| # LLM gave SQL but wrong action_type — infer it | |
| sql = parsed["sql"].strip().upper() | |
| inferred_type = "query" if sql.startswith(("SELECT","WITH","EXPLAIN")) else "ddl" | |
| return {"action_type": inferred_type, "sql": parsed["sql"]} | |
| else: | |
| # Completely unparseable — explore schema as safe default | |
| if step_num <= 3: | |
| return {"action_type": "query", "sql": "SELECT name, sql FROM sqlite_master WHERE type IN ('table','view')"} | |
| return {"action_type": "submit"} | |
| def run_task(task_id: int) -> float: | |
| print(f"Starting task {task_id}") | |
| try: | |
| seed = BASELINE_SEEDS.get(task_id) | |
| resp = httpx.post(f"{BASE_URL}/reset", json={"task_id": task_id, "seed": seed}, timeout=30.0) | |
| resp.raise_for_status() | |
| resp_data = resp.json() | |
| obs = resp_data.get("observation", resp_data) | |
| session_id = resp_data.get("session_id", "") | |
| except Exception as e: | |
| print(f"Failed to reset environment for task {task_id}: {e}") | |
| return 0.0 | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| max_steps = obs.get("max_steps", 25) | |
| consecutive_parse_failures = 0 | |
| for step in range(max_steps): | |
| messages.append({"role": "user", "content": json.dumps(obs)}) | |
| try: | |
| llm_response = call_llm(messages) | |
| parsed = parse_action(llm_response) | |
| if parsed is None: | |
| consecutive_parse_failures += 1 | |
| if consecutive_parse_failures >= 3: | |
| print(f"Warning: 3 consecutive parse failures at step {step}. Handing episode submit.") | |
| action = {"action_type": "submit"} | |
| else: | |
| action = safe_action(parsed, step) | |
| else: | |
| consecutive_parse_failures = 0 | |
| action = safe_action(parsed, step) | |
| except Exception as e: | |
| print(f"LLM error at step {step}: {e}") | |
| action = {"action_type": "submit"} | |
| messages.append({"role": "assistant", "content": json.dumps(action)}) | |
| try: | |
| headers = {"X-Session-ID": session_id} if session_id else {} | |
| step_resp = httpx.post(f"{BASE_URL}/step", json=action, headers=headers, timeout=30.0) | |
| step_resp.raise_for_status() | |
| step_data = step_resp.json() | |
| obs = step_data.get("observation", step_data) | |
| if step_data.get("done") or step_data.get("truncated"): | |
| break | |
| except Exception as e: | |
| print(f"Failed to step environment: {e}") | |
| break | |
| try: | |
| headers = {"X-Session-ID": session_id} if session_id else {} | |
| grader_resp = httpx.get(f"{BASE_URL}/grader", headers=headers, timeout=10.0) | |
| grader_resp.raise_for_status() | |
| final_score = grader_resp.json().get("score", 0.0) | |
| except Exception as e: | |
| print(f"Failed to get grader score: {e}") | |
| final_score = 0.0 | |
| print(format_score_line(task_id, final_score)) | |
| return final_score | |
| def run_baseline(): | |
| scores = {} | |
| for task_id in [1, 2, 3]: | |
| score = run_task(task_id) | |
| scores[f"task_{task_id}"] = score | |
| print("\n--- Summary ---") | |
| for task, score in scores.items(): | |
| print(f"{task}: {score:.4f}") | |
| if __name__ == "__main__": | |
| try: | |
| run_baseline() | |
| except Exception as e: | |
| print(f"Top-level execution crash: {e}") | |
| sys.exit(1) | |