Spaces:
Configuration error
Configuration error
| """ | |
| inference.py — DB Schema Migration baseline agent | |
| Reads env vars: | |
| API_BASE_URL (default: https://api-inference.huggingface.co/v1) | |
| MODEL_NAME (default: meta-llama/Llama-3.1-8B-Instruct) | |
| HF_TOKEN or API_KEY | |
| IMAGE_NAME or LOCAL_IMAGE_NAME (for from_docker_image) | |
| STDOUT format: | |
| [START] task=<task_name> env=db-schema-migration model=<model_name> | |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...> | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import requests | |
| from openai import OpenAI | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "hf_placeholder") | |
| IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME", "db-schema-migration:latest") | |
| ENV_URL = os.getenv("ENV_URL", "http://localhost:7860") | |
| TASK = os.getenv("TASK", "easy") | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| # --------------------------------------------------------------------------- | |
| # Env helpers | |
| # --------------------------------------------------------------------------- | |
| def env_reset(task: str) -> dict: | |
| r = requests.post(f"{ENV_URL}/reset", json={"task": task}, timeout=30) | |
| r.raise_for_status() | |
| return r.json() | |
| def env_step(action: dict) -> dict: | |
| r = requests.post(f"{ENV_URL}/step", json=action, timeout=30) | |
| r.raise_for_status() | |
| return r.json() | |
| def env_state() -> dict: | |
| r = requests.get(f"{ENV_URL}/state", timeout=30) | |
| r.raise_for_status() | |
| return r.json() | |
| # --------------------------------------------------------------------------- | |
| # Schema pretty-printer | |
| # --------------------------------------------------------------------------- | |
| def format_schema(tables: list) -> str: | |
| lines = [] | |
| for t in tables: | |
| lines.append(f"TABLE: {t['name']}") | |
| for c in t["columns"]: | |
| pk = " [PK]" if c.get("primary_key") else "" | |
| fk = f" [FK -> {c['foreign_key']}]" if c.get("foreign_key") else "" | |
| lines.append(f" - {c['name']} {c['data_type']}{pk}{fk}") | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # LLM agent | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """You are a database migration expert agent operating in an RL environment. | |
| You will be given the current database schema and a list of requirements. | |
| Your job is to decide the NEXT single migration action to take. | |
| Available operations: | |
| - rename_table: rename an existing table | |
| - rename_column: rename a column in a table | |
| - add_column: add a new column to a table | |
| - drop_column: remove a column from a table | |
| - change_type: change a column's data type | |
| - add_foreign_key: add a foreign key constraint | |
| - normalize_table: extract a new table from a denormalized table (hard task) | |
| - done: signal you are finished | |
| Respond with ONLY valid JSON matching this schema, nothing else: | |
| { | |
| "operation": "<operation_name>", | |
| "table": "<current_table_name>", | |
| "column": "<column_name_or_null>", | |
| "new_name": "<new_name_or_null>", | |
| "data_type": "<type_or_null>", | |
| "reference_table": "<ref_table_or_null>", | |
| "reference_column": "<ref_col_or_null>", | |
| "reason": "<one sentence why>" | |
| } | |
| Rules: | |
| - Do ONE action per response | |
| - If all requirements are met, use {"operation": "done", "table": "", "reason": "all done"} | |
| - Never repeat a successful action | |
| - Think step by step: rename tables first, then columns, then types, then add/FK | |
| """ | |
| def build_user_prompt(obs: dict, task_desc: str) -> str: | |
| schema_str = format_schema(obs["current_schema"]) | |
| reqs = "\n".join(f" {i+1}. {r}" for i, r in enumerate(obs["target_requirements"])) | |
| violations = obs.get("violations", []) | |
| steps_left = obs["max_steps"] - obs["step_count"] | |
| parts = [ | |
| f"TASK: {task_desc}", | |
| f"\nCURRENT SCHEMA:\n{schema_str}", | |
| f"\nREQUIREMENTS:\n{reqs}", | |
| ] | |
| if violations: | |
| parts.append(f"\nVIOLATIONS (fix these!):\n" + "\n".join(f" - {v}" for v in violations)) | |
| if obs.get("steps_taken"): | |
| last = obs["steps_taken"][-3:] | |
| hist = "\n".join(f" - {s['operation']} on {s['table']}.{s.get('column','')} reward={s['reward']:.2f}" for s in last) | |
| parts.append(f"\nLAST 3 ACTIONS:\n{hist}") | |
| parts.append(f"\nSteps remaining: {steps_left}") | |
| parts.append("\nWhat is your NEXT single action? Respond with JSON only.") | |
| return "\n".join(parts) | |
| def call_llm(messages: list) -> str: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| max_tokens=300, | |
| temperature=0.1, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| def parse_action(text: str) -> dict: | |
| # Strip markdown fences if present | |
| text = text.strip() | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) | |
| return json.loads(text) | |
| # --------------------------------------------------------------------------- | |
| # Main episode loop | |
| # --------------------------------------------------------------------------- | |
| def run_episode(task: str = TASK): | |
| # Reset | |
| reset_result = env_reset(task) | |
| obs = reset_result["observation"] | |
| task_desc = reset_result["task_description"] | |
| print(f"[START] task={task} env=db-schema-migration model={MODEL_NAME}", flush=True) | |
| rewards = [] | |
| step = 0 | |
| done = False | |
| final_score = 0.0 | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| while not done: | |
| step += 1 | |
| user_msg = build_user_prompt(obs, task_desc) | |
| messages.append({"role": "user", "content": user_msg}) | |
| # Get action from LLM | |
| try: | |
| raw = call_llm(messages) | |
| action = parse_action(raw) | |
| messages.append({"role": "assistant", "content": raw}) | |
| except Exception as e: | |
| action = {"operation": "done", "table": "", "reason": f"parse error: {e}"} | |
| messages.append({"role": "assistant", "content": json.dumps(action)}) | |
| action_str = f"{action.get('operation')}({action.get('table','')}.{action.get('column','') or action.get('new_name','')})" | |
| # Step env | |
| try: | |
| result = env_step(action) | |
| reward = result["reward"] | |
| done = result["done"] | |
| obs = result["observation"] | |
| error = result.get("error") or "null" | |
| info = result.get("info", {}) | |
| final_score = info.get("final_score", info.get("partial_score", 0.0)) | |
| except Exception as e: | |
| reward = -0.1 | |
| done = True | |
| error = str(e) | |
| final_score = 0.0 | |
| rewards.append(reward) | |
| done_str = "true" if done else "false" | |
| print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={done_str} error={error}", flush=True) | |
| # Get final score from state | |
| try: | |
| s = env_state() | |
| final_score = s.get("score", final_score) | |
| except Exception: | |
| pass | |
| success = final_score >= 0.8 | |
| success_str = "true" if success else "false" | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print(f"[END] success={success_str} steps={step} score={final_score:.4f} rewards={rewards_str}", flush=True) | |
| return final_score | |
| if __name__ == "__main__": | |
| task = sys.argv[1] if len(sys.argv) > 1 else TASK | |
| score = run_episode(task) | |
| sys.exit(0 if score >= 0.8 else 1) | |