File size: 14,393 Bytes
1b42f19
 
 
 
6a32325
1b42f19
 
6a32325
 
 
 
 
 
 
 
 
1b42f19
 
 
 
 
 
 
 
 
 
 
6a32325
1b42f19
 
 
 
 
 
 
6a32325
53f659d
 
6a32325
53f659d
1b42f19
6a32325
71fa486
1b42f19
6a32325
 
 
 
 
 
 
 
05c4751
 
 
 
 
 
 
6a32325
 
71fa486
1b42f19
05c4751
71fa486
 
 
 
 
 
 
 
 
 
 
6a32325
1b42f19
05c4751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b42f19
 
 
6a32325
1b42f19
 
 
 
 
 
 
 
6a32325
1b42f19
 
 
 
6a32325
 
 
 
 
 
 
 
 
 
 
 
 
 
1b42f19
 
 
 
 
 
 
 
 
 
 
6a32325
 
1b42f19
 
 
6a32325
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a32325
1b42f19
 
 
 
 
 
 
 
6a32325
 
1b42f19
6a32325
 
 
1b42f19
6a32325
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
71fa486
1b42f19
 
6a32325
 
71fa486
1b42f19
 
 
 
6a32325
 
 
 
 
71fa486
 
6a32325
05c4751
 
 
 
 
 
 
 
 
 
1b42f19
 
6a32325
1b42f19
 
 
 
71fa486
1b42f19
 
 
05c4751
 
1b42f19
 
 
 
 
 
 
 
 
 
 
 
6a32325
 
 
1b42f19
6a32325
 
1b42f19
 
05c4751
 
 
 
 
 
 
 
 
 
 
 
1b42f19
05c4751
 
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a32325
1b42f19
 
 
 
 
 
 
 
 
 
6a32325
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
05c4751
1b42f19
 
6a32325
05c4751
6a32325
71fa486
05c4751
1b42f19
 
05c4751
1b42f19
05c4751
6a32325
 
1b42f19
 
05c4751
1b42f19
05c4751
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a32325
1b42f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71fa486
1b42f19
71fa486
1b42f19
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
#!/usr/bin/env python3
"""
Baseline Inference Script for SQL Migration Environment.

Runs all 7 migration tasks sequentially using an LLM via OpenAI-compatible API.
Outputs structured [START]/[STEP]/[END] format for automated evaluation.

Fixes Applied:
- D1: Task description injected into system prompt
- D2: Hardcoded system prompt traps removed (no more audit_log/INTEGER traps)
- D3: Data discovery rule added (agent runs SELECT before DDL)
- D4: Submit guard added (agent must verify before submitting)
- D5: Context window bloat fixed (schema not repeated every step)
- D6: Parse error counter tracks consecutive errors only
- D7: response_format JSON mode with fallback

Usage:
    python inference.py

Environment Variables:
    API_BASE_URL: LLM inference endpoint (default: HF router)
    MODEL_NAME: Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
    HF_TOKEN or API_KEY: Authentication token
"""

import json
import os
import re
import sys
import time
import traceback

# Server URL for the environment
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")

# LLM Configuration
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN or os.getenv("API_KEY")

# --- D2: Cleaned system prompt — no hardcoded table names or type traps ---
SYSTEM_PROMPT_TEMPLATE = """You are an autonomous SQLite database migration engine. You receive the current schema and a target schema. Write SQL to transform the current state to the target state without losing row data.

TASK OBJECTIVE:
{task_description}

CRITICAL SQLite-specific rules (violations cause immediate errors):
1. SQLite does NOT support ALTER TABLE ADD CONSTRAINT, ALTER COLUMN, or ADD PRIMARY KEY.
2. To change column types, add NOT NULL, or add FKs: CREATE new table, INSERT INTO new SELECT FROM old, DROP old, RENAME new.
3. Apostrophes in data (O'Brien, O'Neill) are present — escape with '' in string literals.
4. Execute exactly ONE SQL statement per step.
5. If a table already exists, you MUST drop it before recreating it (e.g., DROP TABLE IF EXISTS users_new).
6. SQLite strictly expects `INSERT INTO tbl VALUES (...)`, not `VALUE (...)`. Ensure column counts match exactly.
7. For table normalization: create new tables first, INSERT INTO ... SELECT, then drop old tables.
8. For orphaned FK rows: check the TARGET SCHEMA for the anomaly/issues table name. Log invalid records there before dropping.
9. For text currency (e.g. '$90,000'): strip '$' and ',' then cast to the target type (INTEGER/REAL).
10. IMPORTANT: Before writing any DDL, execute SELECT * FROM tablename LIMIT 5 to inspect the data format.
11. Do NOT set submit_final to true until you run SELECT COUNT(*) and verify data matches the task.

TARGET SCHEMA (achieve this exactly):
{target_ddl}

Respond ONLY with a valid JSON object. Do not use markdown backticks (```json). No conversational text.
{{"sql_command": "your SQL here", "reasoning": "why", "submit_final": false}}"""

ALL_TASKS = [
    "column-restructure",
    "soft-delete-restoration",
    "table-normalization",
    "schema-version-merge",
    "multi-entity-extraction",
    "cascade-migration",
    "dual-source-consolidation",
]
MAX_PARSE_ERRORS = 5  # Consecutive parse errors before giving up
AUTO_SUBMIT_THRESHOLD = 0.95
MAX_HISTORY_PAIRS = 4  # Keep maximum of 4 user/assistant turn pairs


def build_messages(system_prompt: str, history: list, current_obs_msg: dict) -> list:
    """
    Build messages explicitly pruning history to avoid context bloat.
    """
    system_msg = [{"role": "system", "content": system_prompt}]
    
    # We only want assistant/user pairs. Filter out system msgs if any exist in history
    filtered_history = [m for m in history if m["role"] != "system"]
    
    # Keep only the last MAX_HISTORY_PAIRS * 2 messages
    max_msgs = MAX_HISTORY_PAIRS * 2
    if len(filtered_history) > max_msgs:
        pruned_history = filtered_history[-max_msgs:]
    else:
        pruned_history = filtered_history
        
    return system_msg + pruned_history + [current_obs_msg]


def call_llm(messages: list, timeout: int = 90) -> str:
    """Call the LLM API with JSON mode fallback."""
    from openai import OpenAI

    client = OpenAI(
        base_url=API_BASE_URL,
        api_key=API_KEY,
        timeout=timeout,
    )

    # --- D7: Try JSON mode first, fallback to plain ---
    try:
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=messages,
            temperature=0.0,
            max_tokens=1024,
            response_format={"type": "json_object"},
        )
        return response.choices[0].message.content.strip()
    except Exception:
        pass

    # Fallback: plain text mode
    try:
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=messages,
            temperature=0.0,
            max_tokens=1024,
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        raise TimeoutError(f"LLM API error: {e}")


def parse_action(raw_text: str) -> dict:
    """
    Parse LLM output into an action dict.

    Handles: raw JSON, markdown-wrapped JSON, <think>...</think> blocks,
    escaped quotes in SQL, and truncated output recovery.
    """
    text = raw_text.strip()

    # Strip <think>...</think> blocks (Qwen3, DeepSeek-R1)
    text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
    text = re.sub(r"<think>.*$", "", text, flags=re.DOTALL).strip()

    # Strip markdown code block fences
    if text.startswith("```"):
        lines = text.split("\n")
        lines = [l for l in lines if not l.strip().startswith("```")]
        text = "\n".join(lines).strip()

    # Try direct JSON parse
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    # Try to find JSON object in the text
    start = text.find("{")
    end = text.rfind("}") + 1
    if start >= 0 and end > start:
        try:
            return json.loads(text[start:end])
        except json.JSONDecodeError:
            pass

    # --- D6: Improved regex that handles escaped quotes ---
    sql_match = re.search(r'"sql_command"\s*:\s*"((?:[^"\\]|\\.)*)"', text)
    if sql_match:
        sql = sql_match.group(1)
        # Unescape JSON string escapes
        sql = sql.replace('\\"', '"').replace("\\n", "\n").replace("\\\\", "\\")
        return {
            "sql_command": sql,
            "reasoning": "auto-extracted from malformed response",
            "submit_final": False,
        }

    raise ValueError(f"Could not parse JSON from LLM response: {text[:200]}")


def run_task_local(task_name: str) -> dict:
    """
    Run a single task using a local environment instance (no server needed).
    """
    sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
    from server.environment import DbMigrationEnvironment
    from models import MigrationAction
    import seeds

    env = DbMigrationEnvironment(task_name=task_name)
    task_config = seeds.TASKS[task_name]
    task_max_steps = task_config.get("max_steps", 20)

    print(f"[START] task={task_name} env=sql-migration-agent model={MODEL_NAME}", flush=True)

    obs = env.reset()

    # --- D1: Inject task description into system prompt ---
    task_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
        task_description=task_config["description"],
        target_ddl=obs.target_schema_sql,
    )
    history = [{"role": "system", "content": task_system_prompt}]

    # Initial observation
    initial_msg = {
        "role": "user",
        "content": (
            f"CURRENT DATABASE SCHEMA:\n{obs.current_schema_sql}\n\n"
            f"Status: {obs.last_execution_result}\n"
            f"Migration progress: {obs.migration_progress:.2f}\n\n"
            f"Start by inspecting the source data with SELECT queries, then begin the migration."
        )
    }
    history = []

    rewards_list = []
    consecutive_parse_errors = 0  # D6: Track consecutive only
    final_score = 0.0
    steps_taken = 0
    done = False

    for step in range(task_max_steps):
        if done:
            break

        # --- D5: Context window fix: Aggressively prune history via build_messages ---
        messages = build_messages(task_system_prompt, history, initial_msg)

        try:
            raw_response = call_llm(messages)
        except TimeoutError as e:
            error_msg = str(e)[:100]
            print(f"[STEP] step={step+1} action=API_TIMEOUT reward=0.00 done=true error={error_msg}", flush=True)
            done = True
            break

        # Parse the action
        try:
            action_dict = parse_action(raw_response)
            consecutive_parse_errors = 0  # D6: Reset on success
        except ValueError:
            consecutive_parse_errors += 1
            print(f"[STEP] step={step+1} action=PARSE_ERROR reward=0.00 done=false error=parse_error", flush=True)
            if consecutive_parse_errors >= MAX_PARSE_ERRORS:
                print(f"[STEP] step={step+1} action=MAX_PARSE_ERRORS reward=0.00 done=true error=too_many_consecutive_parse_errors", flush=True)
                done = True
                break
            
            # CRITICAL: Strip <think> tags before appending to history to prevent 413 Context OOM
            stripped_response = re.sub(r"<think>.*?</think>", "", raw_response, flags=re.DOTALL).strip()
            stripped_response = re.sub(r"<think>.*$", "", stripped_response, flags=re.DOTALL).strip()
            # If it's still huge, truncate it to 500 chars to save context
            if len(stripped_response) > 500:
                stripped_response = stripped_response[:500] + "... [TRUNCATED DUE TO PARSE ERROR]"
                
            history.append(initial_msg)  # The prompt we sent
            history.append({"role": "assistant", "content": stripped_response}) # The stripped response
            
            initial_msg = {
                "role": "user",
                "content": 'ERROR: Your response was not a valid JSON object. Do not use markdown blocks. Respond strictly with: {"sql_command": "...", "reasoning": "...", "submit_final": false}'
            }
            continue

        # Build the MigrationAction
        try:
            action = MigrationAction(
                sql_command=action_dict.get("sql_command", ""),
                reasoning=action_dict.get("reasoning", ""),
                submit_final=action_dict.get("submit_final", False),
            )
        except Exception as e:
            print(f"[STEP] step={step+1} action=INVALID_ACTION reward=0.00 done=false error={str(e)[:50]}", flush=True)
            continue

        # Execute the action
        obs = env.step(action)
        steps_taken = step + 1
        step_reward = obs.reward if obs.reward is not None else 0.0
        rewards_list.append(step_reward)
        final_score = obs.migration_progress
        done = obs.done

        # AUTO-SUBMIT: If we reached near-perfect score, force submit
        if final_score >= AUTO_SUBMIT_THRESHOLD and not done:
            done = True
            submit_action = MigrationAction(
                sql_command="SELECT 1",
                reasoning="Migration complete — auto-submitting",
                submit_final=True,
            )
            obs = env.step(submit_action)
            final_score = obs.migration_progress

        # Log
        sql_abbrev = action.sql_command[:50].replace("\n", " ")
        if len(action.sql_command) > 50:
            sql_abbrev += "..."
        error_str = obs.metadata.get("error", "null") if obs.metadata else "null"
        if error_str != "null":
            error_str = error_str[:80]
        print(
            f"[STEP] step={steps_taken} action={sql_abbrev} "
            f"reward={step_reward:.2f} done={'true' if done else 'false'} "
            f"error={error_str}",
            flush=True,
        )

        # Add to conversation history
        history.append(initial_msg)
        history.append({"role": "assistant", "content": json.dumps(action_dict)})

        # --- D5: Lean feedback — NO schema repetition ---
        feedback_text = (
            f"EXECUTION RESULT: {obs.last_execution_result}\n"
            f"Progress: {obs.migration_progress:.2f}"
            f"\nSchema Diff (Missing/Extra constraints vs Target):\n{obs.schema_diff}"
        )
        if done:
            feedback_text += "\n\nEpisode complete."
        elif obs.migration_progress >= 0.9:
            feedback_text += (
                "\n\nMigration is nearly complete! Run SELECT COUNT(*) on each table "
                "and compare to your expectations. If everything matches, set submit_final to true."
            )
        else:
            feedback_text += "\n\nContinue the migration. Write your next SQL command."

        initial_msg = {"role": "user", "content": feedback_text}

    # Print END
    rewards_str = ",".join(f"{r:.2f}" for r in rewards_list) if rewards_list else "0.00"
    success = "true" if final_score >= 0.8 else "false"
    print(
        f"[END] success={success} steps={steps_taken} "
        f"score={final_score:.2f} rewards={rewards_str}",
        flush=True,
    )

    env.close()

    return {
        "task_name": task_name,
        "score": final_score,
        "steps": steps_taken,
        "rewards": rewards_list,
    }


def main():
    """Run all 7 tasks sequentially."""
    if not API_KEY:
        print("WARNING: No API key found. Set HF_TOKEN or API_KEY.", file=sys.stderr)
        sys.exit(1)

    results = {}
    for task_name in ALL_TASKS:
        try:
            result = run_task_local(task_name)
            results[task_name] = result["score"]
        except Exception as e:
            print(f"[ERROR] task={task_name} error={str(e)[:200]}", file=sys.stderr)
            traceback.print_exc(file=sys.stderr)
            results[task_name] = 0.0

    # Summary
    scores = list(results.values())
    avg = sum(scores) / len(scores) if scores else 0.0
    scores_str = " ".join(f"{t}={s:.2f}" for t, s in results.items())
    print(
        f"[SUMMARY] {scores_str} avg={avg:.2f}",
        flush=True,
    )


if __name__ == "__main__":
    main()