Spaces:
Sleeping
Sleeping
File size: 7,272 Bytes
22328de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | import os
import re
import json
import sys
import httpx
from dotenv import load_dotenv
try:
_here = os.path.dirname(os.path.abspath(__file__))
_root = os.path.dirname(_here)
except NameError:
_root = os.getcwd()
if _root not in sys.path:
sys.path.insert(0, _root)
from baseline.prompts import SYSTEM_PROMPT
import os
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# Supports both OpenAI and Google AI Studio (Gemini) as drop-in
# If OPENAI_BASE_URL is set, use it (Google AI Studio or other compatible API)
# Otherwise default to OpenAI
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GOOGLE_AI_KEY")
base_url = os.getenv("OPENAI_BASE_URL", None) # None = use OpenAI default
model = os.getenv("BASELINE_MODEL", "gemini-2.0-flash")
env_base_url = os.getenv("ENV_BASE_URL", "http://localhost:7860")
if not api_key:
raise ValueError(
"No API key found. Set OPENAI_API_KEY (for OpenAI) or "
"GOOGLE_AI_KEY + OPENAI_BASE_URL (for Google AI Studio / other providers)"
)
# Build client — works for OpenAI, Google AI Studio, Groq, OpenRouter
client_kwargs = {"api_key": api_key}
if base_url:
client_kwargs["base_url"] = base_url
client = OpenAI(**client_kwargs)
print(f"Baseline agent initialised:")
print(f" Provider: {'Google AI Studio' if 'google' in (base_url or '') else 'OpenAI-compatible'}")
print(f" Model: {model}")
print(f" Environment: {env_base_url}")
BASE_URL = env_base_url
BASELINE_SEEDS = {1: 42, 2: 99, 3: 777}
def format_score_line(task_id: int, score: float) -> str:
return f"SCORE task_{task_id}: {score:.4f}"
def call_llm(messages: list) -> str:
try:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0.0
)
return response.choices[0].message.content
except Exception as e:
print(f"Fatal OpenAI API crash: {e}")
sys.exit(1)
def parse_action(raw_text: str) -> dict:
"""Extract and parse action JSON from LLM output, handling all common failure modes."""
text = raw_text.strip()
# Mode 1: strip markdown code fences (```json ... ``` or ``` ... ```)
fence_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', text)
if fence_match:
text = fence_match.group(1).strip()
# Mode 2: find first { ... } JSON object if there's surrounding prose
brace_match = re.search(r'\{[\s\S]*\}', text)
if brace_match:
text = brace_match.group(0)
# Mode 3: fix trailing commas (common LLM mistake)
text = re.sub(r',\s*([}\]])', r'\1', text)
# Mode 4: fix single quotes used instead of double quotes
# Only do this if JSON parse fails first
try:
return json.loads(text)
except json.JSONDecodeError:
try:
# Replace single-quoted keys/values carefully
text_fixed = re.sub(r"'([^']*)'", r'"\1"', text)
return json.loads(text_fixed)
except json.JSONDecodeError:
return None # caller handles None
def safe_action(parsed: dict | None, step_num: int) -> dict:
"""Convert parsed dict to valid action, with safe fallbacks."""
if parsed is None:
# After 3 failed parses in a row, submit to end episode gracefully
return {"action_type": "submit"}
action_type = parsed.get("action_type", "").lower()
if action_type == "query" and "sql" in parsed:
return parsed
elif action_type == "ddl" and "sql" in parsed:
return parsed
elif action_type == "test" and "target_table" in parsed:
return parsed
elif action_type == "submit":
return parsed
elif "sql" in parsed:
# LLM gave SQL but wrong action_type — infer it
sql = parsed["sql"].strip().upper()
inferred_type = "query" if sql.startswith(("SELECT","WITH","EXPLAIN")) else "ddl"
return {"action_type": inferred_type, "sql": parsed["sql"]}
else:
# Completely unparseable — explore schema as safe default
if step_num <= 3:
return {"action_type": "query", "sql": "SELECT name, sql FROM sqlite_master WHERE type IN ('table','view')"}
return {"action_type": "submit"}
def run_task(task_id: int) -> float:
print(f"Starting task {task_id}")
try:
seed = BASELINE_SEEDS.get(task_id)
resp = httpx.post(f"{BASE_URL}/reset", json={"task_id": task_id, "seed": seed}, timeout=30.0)
resp.raise_for_status()
resp_data = resp.json()
obs = resp_data.get("observation", resp_data)
session_id = resp_data.get("session_id", "")
except Exception as e:
print(f"Failed to reset environment for task {task_id}: {e}")
return 0.0
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
max_steps = obs.get("max_steps", 25)
consecutive_parse_failures = 0
for step in range(max_steps):
messages.append({"role": "user", "content": json.dumps(obs)})
try:
llm_response = call_llm(messages)
parsed = parse_action(llm_response)
if parsed is None:
consecutive_parse_failures += 1
if consecutive_parse_failures >= 3:
print(f"Warning: 3 consecutive parse failures at step {step}. Handing episode submit.")
action = {"action_type": "submit"}
else:
action = safe_action(parsed, step)
else:
consecutive_parse_failures = 0
action = safe_action(parsed, step)
except Exception as e:
print(f"LLM error at step {step}: {e}")
action = {"action_type": "submit"}
messages.append({"role": "assistant", "content": json.dumps(action)})
try:
headers = {"X-Session-ID": session_id} if session_id else {}
step_resp = httpx.post(f"{BASE_URL}/step", json=action, headers=headers, timeout=30.0)
step_resp.raise_for_status()
step_data = step_resp.json()
obs = step_data.get("observation", step_data)
if step_data.get("done") or step_data.get("truncated"):
break
except Exception as e:
print(f"Failed to step environment: {e}")
break
try:
headers = {"X-Session-ID": session_id} if session_id else {}
grader_resp = httpx.get(f"{BASE_URL}/grader", headers=headers, timeout=10.0)
grader_resp.raise_for_status()
final_score = grader_resp.json().get("score", 0.0)
except Exception as e:
print(f"Failed to get grader score: {e}")
final_score = 0.0
print(format_score_line(task_id, final_score))
return final_score
def run_baseline():
scores = {}
for task_id in [1, 2, 3]:
score = run_task(task_id)
scores[f"task_{task_id}"] = score
print("\n--- Summary ---")
for task, score in scores.items():
print(f"{task}: {score:.4f}")
if __name__ == "__main__":
try:
run_baseline()
except Exception as e:
print(f"Top-level execution crash: {e}")
sys.exit(1)
|