""" inference.py — DB Schema Migration baseline agent Reads env vars: API_BASE_URL (default: https://api-inference.huggingface.co/v1) MODEL_NAME (default: meta-llama/Llama-3.1-8B-Instruct) HF_TOKEN or API_KEY IMAGE_NAME or LOCAL_IMAGE_NAME (for from_docker_image) STDOUT format: [START] task= env=db-schema-migration model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= score= rewards= """ import os import sys import json import requests from openai import OpenAI # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "hf_placeholder") IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME", "db-schema-migration:latest") ENV_URL = os.getenv("ENV_URL", "http://localhost:7860") TASK = os.getenv("TASK", "easy") client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) # --------------------------------------------------------------------------- # Env helpers # --------------------------------------------------------------------------- def env_reset(task: str) -> dict: r = requests.post(f"{ENV_URL}/reset", json={"task": task}, timeout=30) r.raise_for_status() return r.json() def env_step(action: dict) -> dict: r = requests.post(f"{ENV_URL}/step", json=action, timeout=30) r.raise_for_status() return r.json() def env_state() -> dict: r = requests.get(f"{ENV_URL}/state", timeout=30) r.raise_for_status() return r.json() # --------------------------------------------------------------------------- # Schema pretty-printer # --------------------------------------------------------------------------- def format_schema(tables: list) -> str: lines = [] for t in tables: lines.append(f"TABLE: {t['name']}") for c in t["columns"]: pk = " [PK]" if c.get("primary_key") else "" fk = f" [FK -> {c['foreign_key']}]" if c.get("foreign_key") else "" lines.append(f" - {c['name']} {c['data_type']}{pk}{fk}") return "\n".join(lines) # --------------------------------------------------------------------------- # LLM agent # --------------------------------------------------------------------------- SYSTEM_PROMPT = """You are a database migration expert agent operating in an RL environment. You will be given the current database schema and a list of requirements. Your job is to decide the NEXT single migration action to take. Available operations: - rename_table: rename an existing table - rename_column: rename a column in a table - add_column: add a new column to a table - drop_column: remove a column from a table - change_type: change a column's data type - add_foreign_key: add a foreign key constraint - normalize_table: extract a new table from a denormalized table (hard task) - done: signal you are finished Respond with ONLY valid JSON matching this schema, nothing else: { "operation": "", "table": "", "column": "", "new_name": "", "data_type": "", "reference_table": "", "reference_column": "", "reason": "" } Rules: - Do ONE action per response - If all requirements are met, use {"operation": "done", "table": "", "reason": "all done"} - Never repeat a successful action - Think step by step: rename tables first, then columns, then types, then add/FK """ def build_user_prompt(obs: dict, task_desc: str) -> str: schema_str = format_schema(obs["current_schema"]) reqs = "\n".join(f" {i+1}. {r}" for i, r in enumerate(obs["target_requirements"])) violations = obs.get("violations", []) steps_left = obs["max_steps"] - obs["step_count"] parts = [ f"TASK: {task_desc}", f"\nCURRENT SCHEMA:\n{schema_str}", f"\nREQUIREMENTS:\n{reqs}", ] if violations: parts.append(f"\nVIOLATIONS (fix these!):\n" + "\n".join(f" - {v}" for v in violations)) if obs.get("steps_taken"): last = obs["steps_taken"][-3:] hist = "\n".join(f" - {s['operation']} on {s['table']}.{s.get('column','')} reward={s['reward']:.2f}" for s in last) parts.append(f"\nLAST 3 ACTIONS:\n{hist}") parts.append(f"\nSteps remaining: {steps_left}") parts.append("\nWhat is your NEXT single action? Respond with JSON only.") return "\n".join(parts) def call_llm(messages: list) -> str: response = client.chat.completions.create( model=MODEL_NAME, messages=messages, max_tokens=300, temperature=0.1, ) return response.choices[0].message.content.strip() def parse_action(text: str) -> dict: # Strip markdown fences if present text = text.strip() if text.startswith("```"): lines = text.split("\n") text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) return json.loads(text) # --------------------------------------------------------------------------- # Main episode loop # --------------------------------------------------------------------------- def run_episode(task: str = TASK): # Reset reset_result = env_reset(task) obs = reset_result["observation"] task_desc = reset_result["task_description"] print(f"[START] task={task} env=db-schema-migration model={MODEL_NAME}", flush=True) rewards = [] step = 0 done = False final_score = 0.0 messages = [{"role": "system", "content": SYSTEM_PROMPT}] while not done: step += 1 user_msg = build_user_prompt(obs, task_desc) messages.append({"role": "user", "content": user_msg}) # Get action from LLM try: raw = call_llm(messages) action = parse_action(raw) messages.append({"role": "assistant", "content": raw}) except Exception as e: action = {"operation": "done", "table": "", "reason": f"parse error: {e}"} messages.append({"role": "assistant", "content": json.dumps(action)}) action_str = f"{action.get('operation')}({action.get('table','')}.{action.get('column','') or action.get('new_name','')})" # Step env try: result = env_step(action) reward = result["reward"] done = result["done"] obs = result["observation"] error = result.get("error") or "null" info = result.get("info", {}) final_score = info.get("final_score", info.get("partial_score", 0.0)) except Exception as e: reward = -0.1 done = True error = str(e) final_score = 0.0 rewards.append(reward) done_str = "true" if done else "false" print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={done_str} error={error}", flush=True) # Get final score from state try: s = env_state() final_score = s.get("score", final_score) except Exception: pass success = final_score >= 0.8 success_str = "true" if success else "false" rewards_str = ",".join(f"{r:.2f}" for r in rewards) print(f"[END] success={success_str} steps={step} score={final_score:.4f} rewards={rewards_str}", flush=True) return final_score if __name__ == "__main__": task = sys.argv[1] if len(sys.argv) > 1 else TASK score = run_episode(task) sys.exit(0 if score >= 0.8 else 1)