db-schema-migration / inference.py
hissterical's picture
Upload 9 files
a5c89a3 verified
"""
inference.py — DB Schema Migration baseline agent
Reads env vars:
API_BASE_URL (default: https://api-inference.huggingface.co/v1)
MODEL_NAME (default: meta-llama/Llama-3.1-8B-Instruct)
HF_TOKEN or API_KEY
IMAGE_NAME or LOCAL_IMAGE_NAME (for from_docker_image)
STDOUT format:
[START] task=<task_name> env=db-schema-migration model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...>
"""
import os
import sys
import json
import requests
from openai import OpenAI
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "hf_placeholder")
IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME", "db-schema-migration:latest")
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
TASK = os.getenv("TASK", "easy")
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
# ---------------------------------------------------------------------------
# Env helpers
# ---------------------------------------------------------------------------
def env_reset(task: str) -> dict:
r = requests.post(f"{ENV_URL}/reset", json={"task": task}, timeout=30)
r.raise_for_status()
return r.json()
def env_step(action: dict) -> dict:
r = requests.post(f"{ENV_URL}/step", json=action, timeout=30)
r.raise_for_status()
return r.json()
def env_state() -> dict:
r = requests.get(f"{ENV_URL}/state", timeout=30)
r.raise_for_status()
return r.json()
# ---------------------------------------------------------------------------
# Schema pretty-printer
# ---------------------------------------------------------------------------
def format_schema(tables: list) -> str:
lines = []
for t in tables:
lines.append(f"TABLE: {t['name']}")
for c in t["columns"]:
pk = " [PK]" if c.get("primary_key") else ""
fk = f" [FK -> {c['foreign_key']}]" if c.get("foreign_key") else ""
lines.append(f" - {c['name']} {c['data_type']}{pk}{fk}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# LLM agent
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are a database migration expert agent operating in an RL environment.
You will be given the current database schema and a list of requirements.
Your job is to decide the NEXT single migration action to take.
Available operations:
- rename_table: rename an existing table
- rename_column: rename a column in a table
- add_column: add a new column to a table
- drop_column: remove a column from a table
- change_type: change a column's data type
- add_foreign_key: add a foreign key constraint
- normalize_table: extract a new table from a denormalized table (hard task)
- done: signal you are finished
Respond with ONLY valid JSON matching this schema, nothing else:
{
"operation": "<operation_name>",
"table": "<current_table_name>",
"column": "<column_name_or_null>",
"new_name": "<new_name_or_null>",
"data_type": "<type_or_null>",
"reference_table": "<ref_table_or_null>",
"reference_column": "<ref_col_or_null>",
"reason": "<one sentence why>"
}
Rules:
- Do ONE action per response
- If all requirements are met, use {"operation": "done", "table": "", "reason": "all done"}
- Never repeat a successful action
- Think step by step: rename tables first, then columns, then types, then add/FK
"""
def build_user_prompt(obs: dict, task_desc: str) -> str:
schema_str = format_schema(obs["current_schema"])
reqs = "\n".join(f" {i+1}. {r}" for i, r in enumerate(obs["target_requirements"]))
violations = obs.get("violations", [])
steps_left = obs["max_steps"] - obs["step_count"]
parts = [
f"TASK: {task_desc}",
f"\nCURRENT SCHEMA:\n{schema_str}",
f"\nREQUIREMENTS:\n{reqs}",
]
if violations:
parts.append(f"\nVIOLATIONS (fix these!):\n" + "\n".join(f" - {v}" for v in violations))
if obs.get("steps_taken"):
last = obs["steps_taken"][-3:]
hist = "\n".join(f" - {s['operation']} on {s['table']}.{s.get('column','')} reward={s['reward']:.2f}" for s in last)
parts.append(f"\nLAST 3 ACTIONS:\n{hist}")
parts.append(f"\nSteps remaining: {steps_left}")
parts.append("\nWhat is your NEXT single action? Respond with JSON only.")
return "\n".join(parts)
def call_llm(messages: list) -> str:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=300,
temperature=0.1,
)
return response.choices[0].message.content.strip()
def parse_action(text: str) -> dict:
# Strip markdown fences if present
text = text.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
return json.loads(text)
# ---------------------------------------------------------------------------
# Main episode loop
# ---------------------------------------------------------------------------
def run_episode(task: str = TASK):
# Reset
reset_result = env_reset(task)
obs = reset_result["observation"]
task_desc = reset_result["task_description"]
print(f"[START] task={task} env=db-schema-migration model={MODEL_NAME}", flush=True)
rewards = []
step = 0
done = False
final_score = 0.0
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
while not done:
step += 1
user_msg = build_user_prompt(obs, task_desc)
messages.append({"role": "user", "content": user_msg})
# Get action from LLM
try:
raw = call_llm(messages)
action = parse_action(raw)
messages.append({"role": "assistant", "content": raw})
except Exception as e:
action = {"operation": "done", "table": "", "reason": f"parse error: {e}"}
messages.append({"role": "assistant", "content": json.dumps(action)})
action_str = f"{action.get('operation')}({action.get('table','')}.{action.get('column','') or action.get('new_name','')})"
# Step env
try:
result = env_step(action)
reward = result["reward"]
done = result["done"]
obs = result["observation"]
error = result.get("error") or "null"
info = result.get("info", {})
final_score = info.get("final_score", info.get("partial_score", 0.0))
except Exception as e:
reward = -0.1
done = True
error = str(e)
final_score = 0.0
rewards.append(reward)
done_str = "true" if done else "false"
print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={done_str} error={error}", flush=True)
# Get final score from state
try:
s = env_state()
final_score = s.get("score", final_score)
except Exception:
pass
success = final_score >= 0.8
success_str = "true" if success else "false"
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={success_str} steps={step} score={final_score:.4f} rewards={rewards_str}", flush=True)
return final_score
if __name__ == "__main__":
task = sys.argv[1] if len(sys.argv) > 1 else TASK
score = run_episode(task)
sys.exit(0 if score >= 0.8 else 1)