File size: 7,272 Bytes
22328de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
import re
import json
import sys
import httpx
from dotenv import load_dotenv

try:
    _here = os.path.dirname(os.path.abspath(__file__))
    _root = os.path.dirname(_here)
except NameError:
    _root = os.getcwd()

if _root not in sys.path:
    sys.path.insert(0, _root)

from baseline.prompts import SYSTEM_PROMPT

import os
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

# Supports both OpenAI and Google AI Studio (Gemini) as drop-in
# If OPENAI_BASE_URL is set, use it (Google AI Studio or other compatible API)
# Otherwise default to OpenAI
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GOOGLE_AI_KEY")
base_url = os.getenv("OPENAI_BASE_URL", None)  # None = use OpenAI default
model = os.getenv("BASELINE_MODEL", "gemini-2.0-flash")
env_base_url = os.getenv("ENV_BASE_URL", "http://localhost:7860")

if not api_key:
    raise ValueError(
        "No API key found. Set OPENAI_API_KEY (for OpenAI) or "
        "GOOGLE_AI_KEY + OPENAI_BASE_URL (for Google AI Studio / other providers)"
    )

# Build client — works for OpenAI, Google AI Studio, Groq, OpenRouter
client_kwargs = {"api_key": api_key}
if base_url:
    client_kwargs["base_url"] = base_url

client = OpenAI(**client_kwargs)

print(f"Baseline agent initialised:")
print(f"  Provider: {'Google AI Studio' if 'google' in (base_url or '') else 'OpenAI-compatible'}")
print(f"  Model: {model}")
print(f"  Environment: {env_base_url}")

BASE_URL = env_base_url
BASELINE_SEEDS = {1: 42, 2: 99, 3: 777}

def format_score_line(task_id: int, score: float) -> str:
    return f"SCORE task_{task_id}: {score:.4f}"

def call_llm(messages: list) -> str:
    try:
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=0.0
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"Fatal OpenAI API crash: {e}")
        sys.exit(1)

def parse_action(raw_text: str) -> dict:
    """Extract and parse action JSON from LLM output, handling all common failure modes."""
    text = raw_text.strip()
    
    # Mode 1: strip markdown code fences (```json ... ``` or ``` ... ```)
    fence_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', text)
    if fence_match:
        text = fence_match.group(1).strip()
    
    # Mode 2: find first { ... } JSON object if there's surrounding prose
    brace_match = re.search(r'\{[\s\S]*\}', text)
    if brace_match:
        text = brace_match.group(0)
    
    # Mode 3: fix trailing commas (common LLM mistake)
    text = re.sub(r',\s*([}\]])', r'\1', text)
    
    # Mode 4: fix single quotes used instead of double quotes
    # Only do this if JSON parse fails first
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        try:
            # Replace single-quoted keys/values carefully
            text_fixed = re.sub(r"'([^']*)'", r'"\1"', text)
            return json.loads(text_fixed)
        except json.JSONDecodeError:
            return None  # caller handles None

def safe_action(parsed: dict | None, step_num: int) -> dict:
    """Convert parsed dict to valid action, with safe fallbacks."""
    if parsed is None:
        # After 3 failed parses in a row, submit to end episode gracefully
        return {"action_type": "submit"}
    
    action_type = parsed.get("action_type", "").lower()
    
    if action_type == "query" and "sql" in parsed:
        return parsed
    elif action_type == "ddl" and "sql" in parsed:
        return parsed
    elif action_type == "test" and "target_table" in parsed:
        return parsed
    elif action_type == "submit":
        return parsed
    elif "sql" in parsed:
        # LLM gave SQL but wrong action_type — infer it
        sql = parsed["sql"].strip().upper()
        inferred_type = "query" if sql.startswith(("SELECT","WITH","EXPLAIN")) else "ddl"
        return {"action_type": inferred_type, "sql": parsed["sql"]}
    else:
        # Completely unparseable — explore schema as safe default
        if step_num <= 3:
            return {"action_type": "query", "sql": "SELECT name, sql FROM sqlite_master WHERE type IN ('table','view')"}
        return {"action_type": "submit"}

def run_task(task_id: int) -> float:
    print(f"Starting task {task_id}")
    try:
        seed = BASELINE_SEEDS.get(task_id)
        resp = httpx.post(f"{BASE_URL}/reset", json={"task_id": task_id, "seed": seed}, timeout=30.0)
        resp.raise_for_status()
        resp_data = resp.json()
        obs = resp_data.get("observation", resp_data)
        session_id = resp_data.get("session_id", "")
    except Exception as e:
        print(f"Failed to reset environment for task {task_id}: {e}")
        return 0.0

    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    max_steps = obs.get("max_steps", 25)
    
    consecutive_parse_failures = 0
    
    for step in range(max_steps):
        messages.append({"role": "user", "content": json.dumps(obs)})
        
        try:
            llm_response = call_llm(messages)
            parsed = parse_action(llm_response)
            
            if parsed is None:
                consecutive_parse_failures += 1
                if consecutive_parse_failures >= 3:
                    print(f"Warning: 3 consecutive parse failures at step {step}. Handing episode submit.")
                    action = {"action_type": "submit"}
                else:
                    action = safe_action(parsed, step)
            else:
                consecutive_parse_failures = 0
                action = safe_action(parsed, step)
                
        except Exception as e:
            print(f"LLM error at step {step}: {e}")
            action = {"action_type": "submit"}
            
        messages.append({"role": "assistant", "content": json.dumps(action)})
        
        try:
            headers = {"X-Session-ID": session_id} if session_id else {}
            step_resp = httpx.post(f"{BASE_URL}/step", json=action, headers=headers, timeout=30.0)
            step_resp.raise_for_status()
            step_data = step_resp.json()
            
            obs = step_data.get("observation", step_data)
            if step_data.get("done") or step_data.get("truncated"):
                break
        except Exception as e:
            print(f"Failed to step environment: {e}")
            break
            
    try:
        headers = {"X-Session-ID": session_id} if session_id else {}
        grader_resp = httpx.get(f"{BASE_URL}/grader", headers=headers, timeout=10.0)
        grader_resp.raise_for_status()
        final_score = grader_resp.json().get("score", 0.0)
    except Exception as e:
        print(f"Failed to get grader score: {e}")
        final_score = 0.0
        
    print(format_score_line(task_id, final_score))
    return final_score

def run_baseline():
    scores = {}
    for task_id in [1, 2, 3]:
        score = run_task(task_id)
        scores[f"task_{task_id}"] = score
        
    print("\n--- Summary ---")
    for task, score in scores.items():
        print(f"{task}: {score:.4f}")

if __name__ == "__main__":
    try:
        run_baseline()
    except Exception as e:
        print(f"Top-level execution crash: {e}")
        sys.exit(1)