open-dataops-env / baseline /inference.py
rohan9977's picture
Upload folder using huggingface_hub
22328de verified
import os
import re
import json
import sys
import httpx
from dotenv import load_dotenv
try:
_here = os.path.dirname(os.path.abspath(__file__))
_root = os.path.dirname(_here)
except NameError:
_root = os.getcwd()
if _root not in sys.path:
sys.path.insert(0, _root)
from baseline.prompts import SYSTEM_PROMPT
import os
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# Supports both OpenAI and Google AI Studio (Gemini) as drop-in
# If OPENAI_BASE_URL is set, use it (Google AI Studio or other compatible API)
# Otherwise default to OpenAI
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GOOGLE_AI_KEY")
base_url = os.getenv("OPENAI_BASE_URL", None) # None = use OpenAI default
model = os.getenv("BASELINE_MODEL", "gemini-2.0-flash")
env_base_url = os.getenv("ENV_BASE_URL", "http://localhost:7860")
if not api_key:
raise ValueError(
"No API key found. Set OPENAI_API_KEY (for OpenAI) or "
"GOOGLE_AI_KEY + OPENAI_BASE_URL (for Google AI Studio / other providers)"
)
# Build client — works for OpenAI, Google AI Studio, Groq, OpenRouter
client_kwargs = {"api_key": api_key}
if base_url:
client_kwargs["base_url"] = base_url
client = OpenAI(**client_kwargs)
print(f"Baseline agent initialised:")
print(f" Provider: {'Google AI Studio' if 'google' in (base_url or '') else 'OpenAI-compatible'}")
print(f" Model: {model}")
print(f" Environment: {env_base_url}")
BASE_URL = env_base_url
BASELINE_SEEDS = {1: 42, 2: 99, 3: 777}
def format_score_line(task_id: int, score: float) -> str:
return f"SCORE task_{task_id}: {score:.4f}"
def call_llm(messages: list) -> str:
try:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0.0
)
return response.choices[0].message.content
except Exception as e:
print(f"Fatal OpenAI API crash: {e}")
sys.exit(1)
def parse_action(raw_text: str) -> dict:
"""Extract and parse action JSON from LLM output, handling all common failure modes."""
text = raw_text.strip()
# Mode 1: strip markdown code fences (```json ... ``` or ``` ... ```)
fence_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', text)
if fence_match:
text = fence_match.group(1).strip()
# Mode 2: find first { ... } JSON object if there's surrounding prose
brace_match = re.search(r'\{[\s\S]*\}', text)
if brace_match:
text = brace_match.group(0)
# Mode 3: fix trailing commas (common LLM mistake)
text = re.sub(r',\s*([}\]])', r'\1', text)
# Mode 4: fix single quotes used instead of double quotes
# Only do this if JSON parse fails first
try:
return json.loads(text)
except json.JSONDecodeError:
try:
# Replace single-quoted keys/values carefully
text_fixed = re.sub(r"'([^']*)'", r'"\1"', text)
return json.loads(text_fixed)
except json.JSONDecodeError:
return None # caller handles None
def safe_action(parsed: dict | None, step_num: int) -> dict:
"""Convert parsed dict to valid action, with safe fallbacks."""
if parsed is None:
# After 3 failed parses in a row, submit to end episode gracefully
return {"action_type": "submit"}
action_type = parsed.get("action_type", "").lower()
if action_type == "query" and "sql" in parsed:
return parsed
elif action_type == "ddl" and "sql" in parsed:
return parsed
elif action_type == "test" and "target_table" in parsed:
return parsed
elif action_type == "submit":
return parsed
elif "sql" in parsed:
# LLM gave SQL but wrong action_type — infer it
sql = parsed["sql"].strip().upper()
inferred_type = "query" if sql.startswith(("SELECT","WITH","EXPLAIN")) else "ddl"
return {"action_type": inferred_type, "sql": parsed["sql"]}
else:
# Completely unparseable — explore schema as safe default
if step_num <= 3:
return {"action_type": "query", "sql": "SELECT name, sql FROM sqlite_master WHERE type IN ('table','view')"}
return {"action_type": "submit"}
def run_task(task_id: int) -> float:
print(f"Starting task {task_id}")
try:
seed = BASELINE_SEEDS.get(task_id)
resp = httpx.post(f"{BASE_URL}/reset", json={"task_id": task_id, "seed": seed}, timeout=30.0)
resp.raise_for_status()
resp_data = resp.json()
obs = resp_data.get("observation", resp_data)
session_id = resp_data.get("session_id", "")
except Exception as e:
print(f"Failed to reset environment for task {task_id}: {e}")
return 0.0
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
max_steps = obs.get("max_steps", 25)
consecutive_parse_failures = 0
for step in range(max_steps):
messages.append({"role": "user", "content": json.dumps(obs)})
try:
llm_response = call_llm(messages)
parsed = parse_action(llm_response)
if parsed is None:
consecutive_parse_failures += 1
if consecutive_parse_failures >= 3:
print(f"Warning: 3 consecutive parse failures at step {step}. Handing episode submit.")
action = {"action_type": "submit"}
else:
action = safe_action(parsed, step)
else:
consecutive_parse_failures = 0
action = safe_action(parsed, step)
except Exception as e:
print(f"LLM error at step {step}: {e}")
action = {"action_type": "submit"}
messages.append({"role": "assistant", "content": json.dumps(action)})
try:
headers = {"X-Session-ID": session_id} if session_id else {}
step_resp = httpx.post(f"{BASE_URL}/step", json=action, headers=headers, timeout=30.0)
step_resp.raise_for_status()
step_data = step_resp.json()
obs = step_data.get("observation", step_data)
if step_data.get("done") or step_data.get("truncated"):
break
except Exception as e:
print(f"Failed to step environment: {e}")
break
try:
headers = {"X-Session-ID": session_id} if session_id else {}
grader_resp = httpx.get(f"{BASE_URL}/grader", headers=headers, timeout=10.0)
grader_resp.raise_for_status()
final_score = grader_resp.json().get("score", 0.0)
except Exception as e:
print(f"Failed to get grader score: {e}")
final_score = 0.0
print(format_score_line(task_id, final_score))
return final_score
def run_baseline():
scores = {}
for task_id in [1, 2, 3]:
score = run_task(task_id)
scores[f"task_{task_id}"] = score
print("\n--- Summary ---")
for task, score in scores.items():
print(f"{task}: {score:.4f}")
if __name__ == "__main__":
try:
run_baseline()
except Exception as e:
print(f"Top-level execution crash: {e}")
sys.exit(1)