Spaces:
Sleeping
Sleeping
fixed errors v2
Browse files- .gitignore +5 -0
- Dockerfile +3 -2
- inference.py +61 -25
- models.py +11 -1
- server/app.py +6 -0
- server/environment.py +56 -14
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
uv.lock
|
| 4 |
+
.env
|
| 5 |
+
.venv/
|
Dockerfile
CHANGED
|
@@ -9,9 +9,10 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|
| 9 |
# Copy all project files
|
| 10 |
COPY . .
|
| 11 |
|
| 12 |
-
# Set
|
| 13 |
ENV PYTHONPATH="/app:$PYTHONPATH"
|
| 14 |
-
|
|
|
|
| 15 |
# Health check
|
| 16 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
|
| 17 |
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
|
|
|
|
| 9 |
# Copy all project files
|
| 10 |
COPY . .
|
| 11 |
|
| 12 |
+
# Set environment variables for docker and huggingface
|
| 13 |
ENV PYTHONPATH="/app:$PYTHONPATH"
|
| 14 |
+
ENV PYTHONUNBUFFERED=1
|
| 15 |
+
ENV PORT=7860
|
| 16 |
# Health check
|
| 17 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
|
| 18 |
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
|
inference.py
CHANGED
|
@@ -50,17 +50,18 @@ CRITICAL SQLite-specific rules (violations cause immediate errors):
|
|
| 50 |
2. To change column types, add NOT NULL, or add FKs: CREATE new table, INSERT INTO new SELECT FROM old, DROP old, RENAME new.
|
| 51 |
3. Apostrophes in data (O'Brien, O'Neill) are present — escape with '' in string literals.
|
| 52 |
4. Execute exactly ONE SQL statement per step.
|
| 53 |
-
5.
|
| 54 |
-
6.
|
| 55 |
-
7. For
|
| 56 |
-
8.
|
| 57 |
-
9.
|
| 58 |
-
10.
|
|
|
|
| 59 |
|
| 60 |
TARGET SCHEMA (achieve this exactly):
|
| 61 |
{target_ddl}
|
| 62 |
|
| 63 |
-
Respond ONLY with valid JSON
|
| 64 |
{{"sql_command": "your SQL here", "reasoning": "why", "submit_final": false}}"""
|
| 65 |
|
| 66 |
ALL_TASKS = [
|
|
@@ -74,6 +75,26 @@ ALL_TASKS = [
|
|
| 74 |
]
|
| 75 |
MAX_PARSE_ERRORS = 5 # Consecutive parse errors before giving up
|
| 76 |
AUTO_SUBMIT_THRESHOLD = 0.95
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
def call_llm(messages: list, timeout: int = 90) -> str:
|
|
@@ -186,13 +207,16 @@ def run_task_local(task_name: str) -> dict:
|
|
| 186 |
history = [{"role": "system", "content": task_system_prompt}]
|
| 187 |
|
| 188 |
# Initial observation
|
| 189 |
-
initial_msg =
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
rewards_list = []
|
| 198 |
consecutive_parse_errors = 0 # D6: Track consecutive only
|
|
@@ -204,8 +228,8 @@ def run_task_local(task_name: str) -> dict:
|
|
| 204 |
if done:
|
| 205 |
break
|
| 206 |
|
| 207 |
-
# --- D5: Context window fix
|
| 208 |
-
messages =
|
| 209 |
|
| 210 |
try:
|
| 211 |
raw_response = call_llm(messages)
|
|
@@ -226,11 +250,21 @@ def run_task_local(task_name: str) -> dict:
|
|
| 226 |
print(f"[STEP] step={step+1} action=MAX_PARSE_ERRORS reward=0.00 done=true error=too_many_consecutive_parse_errors", flush=True)
|
| 227 |
done = True
|
| 228 |
break
|
| 229 |
-
|
| 230 |
-
history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
"role": "user",
|
| 232 |
-
"content": 'ERROR: Your response was not valid JSON. Respond
|
| 233 |
-
}
|
| 234 |
continue
|
| 235 |
|
| 236 |
# Build the MigrationAction
|
|
@@ -278,24 +312,26 @@ def run_task_local(task_name: str) -> dict:
|
|
| 278 |
)
|
| 279 |
|
| 280 |
# Add to conversation history
|
|
|
|
| 281 |
history.append({"role": "assistant", "content": json.dumps(action_dict)})
|
| 282 |
|
| 283 |
# --- D5: Lean feedback — NO schema repetition ---
|
| 284 |
-
|
| 285 |
f"EXECUTION RESULT: {obs.last_execution_result}\n"
|
| 286 |
f"Progress: {obs.migration_progress:.2f}"
|
|
|
|
| 287 |
)
|
| 288 |
if done:
|
| 289 |
-
|
| 290 |
elif obs.migration_progress >= 0.9:
|
| 291 |
-
|
| 292 |
"\n\nMigration is nearly complete! Run SELECT COUNT(*) on each table "
|
| 293 |
"and compare to your expectations. If everything matches, set submit_final to true."
|
| 294 |
)
|
| 295 |
else:
|
| 296 |
-
|
| 297 |
|
| 298 |
-
|
| 299 |
|
| 300 |
# Print END
|
| 301 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards_list) if rewards_list else "0.00"
|
|
|
|
| 50 |
2. To change column types, add NOT NULL, or add FKs: CREATE new table, INSERT INTO new SELECT FROM old, DROP old, RENAME new.
|
| 51 |
3. Apostrophes in data (O'Brien, O'Neill) are present — escape with '' in string literals.
|
| 52 |
4. Execute exactly ONE SQL statement per step.
|
| 53 |
+
5. If a table already exists, you MUST drop it before recreating it (e.g., DROP TABLE IF EXISTS users_new).
|
| 54 |
+
6. SQLite strictly expects `INSERT INTO tbl VALUES (...)`, not `VALUE (...)`. Ensure column counts match exactly.
|
| 55 |
+
7. For table normalization: create new tables first, INSERT INTO ... SELECT, then drop old tables.
|
| 56 |
+
8. For orphaned FK rows: check the TARGET SCHEMA for the anomaly/issues table name. Log invalid records there before dropping.
|
| 57 |
+
9. For text currency (e.g. '$90,000'): strip '$' and ',' then cast to the target type (INTEGER/REAL).
|
| 58 |
+
10. IMPORTANT: Before writing any DDL, execute SELECT * FROM tablename LIMIT 5 to inspect the data format.
|
| 59 |
+
11. Do NOT set submit_final to true until you run SELECT COUNT(*) and verify data matches the task.
|
| 60 |
|
| 61 |
TARGET SCHEMA (achieve this exactly):
|
| 62 |
{target_ddl}
|
| 63 |
|
| 64 |
+
Respond ONLY with a valid JSON object. Do not use markdown backticks (```json). No conversational text.
|
| 65 |
{{"sql_command": "your SQL here", "reasoning": "why", "submit_final": false}}"""
|
| 66 |
|
| 67 |
ALL_TASKS = [
|
|
|
|
| 75 |
]
|
| 76 |
MAX_PARSE_ERRORS = 5 # Consecutive parse errors before giving up
|
| 77 |
AUTO_SUBMIT_THRESHOLD = 0.95
|
| 78 |
+
MAX_HISTORY_PAIRS = 4 # Keep maximum of 4 user/assistant turn pairs
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def build_messages(system_prompt: str, history: list, current_obs_msg: dict) -> list:
|
| 82 |
+
"""
|
| 83 |
+
Build messages explicitly pruning history to avoid context bloat.
|
| 84 |
+
"""
|
| 85 |
+
system_msg = [{"role": "system", "content": system_prompt}]
|
| 86 |
+
|
| 87 |
+
# We only want assistant/user pairs. Filter out system msgs if any exist in history
|
| 88 |
+
filtered_history = [m for m in history if m["role"] != "system"]
|
| 89 |
+
|
| 90 |
+
# Keep only the last MAX_HISTORY_PAIRS * 2 messages
|
| 91 |
+
max_msgs = MAX_HISTORY_PAIRS * 2
|
| 92 |
+
if len(filtered_history) > max_msgs:
|
| 93 |
+
pruned_history = filtered_history[-max_msgs:]
|
| 94 |
+
else:
|
| 95 |
+
pruned_history = filtered_history
|
| 96 |
+
|
| 97 |
+
return system_msg + pruned_history + [current_obs_msg]
|
| 98 |
|
| 99 |
|
| 100 |
def call_llm(messages: list, timeout: int = 90) -> str:
|
|
|
|
| 207 |
history = [{"role": "system", "content": task_system_prompt}]
|
| 208 |
|
| 209 |
# Initial observation
|
| 210 |
+
initial_msg = {
|
| 211 |
+
"role": "user",
|
| 212 |
+
"content": (
|
| 213 |
+
f"CURRENT DATABASE SCHEMA:\n{obs.current_schema_sql}\n\n"
|
| 214 |
+
f"Status: {obs.last_execution_result}\n"
|
| 215 |
+
f"Migration progress: {obs.migration_progress:.2f}\n\n"
|
| 216 |
+
f"Start by inspecting the source data with SELECT queries, then begin the migration."
|
| 217 |
+
)
|
| 218 |
+
}
|
| 219 |
+
history = []
|
| 220 |
|
| 221 |
rewards_list = []
|
| 222 |
consecutive_parse_errors = 0 # D6: Track consecutive only
|
|
|
|
| 228 |
if done:
|
| 229 |
break
|
| 230 |
|
| 231 |
+
# --- D5: Context window fix: Aggressively prune history via build_messages ---
|
| 232 |
+
messages = build_messages(task_system_prompt, history, initial_msg)
|
| 233 |
|
| 234 |
try:
|
| 235 |
raw_response = call_llm(messages)
|
|
|
|
| 250 |
print(f"[STEP] step={step+1} action=MAX_PARSE_ERRORS reward=0.00 done=true error=too_many_consecutive_parse_errors", flush=True)
|
| 251 |
done = True
|
| 252 |
break
|
| 253 |
+
|
| 254 |
+
# CRITICAL: Strip <think> tags before appending to history to prevent 413 Context OOM
|
| 255 |
+
stripped_response = re.sub(r"<think>.*?</think>", "", raw_response, flags=re.DOTALL).strip()
|
| 256 |
+
stripped_response = re.sub(r"<think>.*$", "", stripped_response, flags=re.DOTALL).strip()
|
| 257 |
+
# If it's still huge, truncate it to 500 chars to save context
|
| 258 |
+
if len(stripped_response) > 500:
|
| 259 |
+
stripped_response = stripped_response[:500] + "... [TRUNCATED DUE TO PARSE ERROR]"
|
| 260 |
+
|
| 261 |
+
history.append(initial_msg) # The prompt we sent
|
| 262 |
+
history.append({"role": "assistant", "content": stripped_response}) # The stripped response
|
| 263 |
+
|
| 264 |
+
initial_msg = {
|
| 265 |
"role": "user",
|
| 266 |
+
"content": 'ERROR: Your response was not a valid JSON object. Do not use markdown blocks. Respond strictly with: {"sql_command": "...", "reasoning": "...", "submit_final": false}'
|
| 267 |
+
}
|
| 268 |
continue
|
| 269 |
|
| 270 |
# Build the MigrationAction
|
|
|
|
| 312 |
)
|
| 313 |
|
| 314 |
# Add to conversation history
|
| 315 |
+
history.append(initial_msg)
|
| 316 |
history.append({"role": "assistant", "content": json.dumps(action_dict)})
|
| 317 |
|
| 318 |
# --- D5: Lean feedback — NO schema repetition ---
|
| 319 |
+
feedback_text = (
|
| 320 |
f"EXECUTION RESULT: {obs.last_execution_result}\n"
|
| 321 |
f"Progress: {obs.migration_progress:.2f}"
|
| 322 |
+
f"\nSchema Diff (Missing/Extra constraints vs Target):\n{obs.schema_diff}"
|
| 323 |
)
|
| 324 |
if done:
|
| 325 |
+
feedback_text += "\n\nEpisode complete."
|
| 326 |
elif obs.migration_progress >= 0.9:
|
| 327 |
+
feedback_text += (
|
| 328 |
"\n\nMigration is nearly complete! Run SELECT COUNT(*) on each table "
|
| 329 |
"and compare to your expectations. If everything matches, set submit_final to true."
|
| 330 |
)
|
| 331 |
else:
|
| 332 |
+
feedback_text += "\n\nContinue the migration. Write your next SQL command."
|
| 333 |
|
| 334 |
+
initial_msg = {"role": "user", "content": feedback_text}
|
| 335 |
|
| 336 |
# Print END
|
| 337 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards_list) if rewards_list else "0.00"
|
models.py
CHANGED
|
@@ -10,7 +10,7 @@ from __future__ import annotations
|
|
| 10 |
from typing import Any, Dict, Optional
|
| 11 |
|
| 12 |
from openenv.core.env_server.types import Action, Observation, State
|
| 13 |
-
from pydantic import Field
|
| 14 |
|
| 15 |
|
| 16 |
class MigrationAction(Action):
|
|
@@ -40,6 +40,11 @@ class MigrationAction(Action):
|
|
| 40 |
description="Set to true when you believe the migration is complete"
|
| 41 |
)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
class MigrationObservation(Observation):
|
| 45 |
"""
|
|
@@ -60,6 +65,7 @@ class MigrationObservation(Observation):
|
|
| 60 |
step_number: Current step count (0 after reset, increments each step).
|
| 61 |
migration_progress: Current grader score from 0.0 to 1.0.
|
| 62 |
task_name: Name of the current task being attempted.
|
|
|
|
| 63 |
"""
|
| 64 |
|
| 65 |
current_schema_sql: str = Field(
|
|
@@ -88,6 +94,10 @@ class MigrationObservation(Observation):
|
|
| 88 |
default="",
|
| 89 |
description="Name of the current task"
|
| 90 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
class MigrationState(State):
|
|
|
|
| 10 |
from typing import Any, Dict, Optional
|
| 11 |
|
| 12 |
from openenv.core.env_server.types import Action, Observation, State
|
| 13 |
+
from pydantic import Field, field_validator
|
| 14 |
|
| 15 |
|
| 16 |
class MigrationAction(Action):
|
|
|
|
| 40 |
description="Set to true when you believe the migration is complete"
|
| 41 |
)
|
| 42 |
|
| 43 |
+
@field_validator("sql_command")
|
| 44 |
+
@classmethod
|
| 45 |
+
def strip_whitespace(cls, v: str) -> str:
|
| 46 |
+
return v.strip()
|
| 47 |
+
|
| 48 |
|
| 49 |
class MigrationObservation(Observation):
|
| 50 |
"""
|
|
|
|
| 65 |
step_number: Current step count (0 after reset, increments each step).
|
| 66 |
migration_progress: Current grader score from 0.0 to 1.0.
|
| 67 |
task_name: Name of the current task being attempted.
|
| 68 |
+
schema_diff: Human-readable diff between current and target schemas.
|
| 69 |
"""
|
| 70 |
|
| 71 |
current_schema_sql: str = Field(
|
|
|
|
| 94 |
default="",
|
| 95 |
description="Name of the current task"
|
| 96 |
)
|
| 97 |
+
schema_diff: Optional[str] = Field(
|
| 98 |
+
default=None,
|
| 99 |
+
description="Human-readable diff between current and expected target schemas"
|
| 100 |
+
)
|
| 101 |
|
| 102 |
|
| 103 |
class MigrationState(State):
|
server/app.py
CHANGED
|
@@ -128,6 +128,11 @@ async def list_tasks() -> Dict[str, Any]:
|
|
| 128 |
"reasoning": "string -- Explanation of the action (optional)",
|
| 129 |
"submit_final": "boolean -- Set true when migration is complete (default: false)",
|
| 130 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
}
|
| 132 |
|
| 133 |
|
|
@@ -170,6 +175,7 @@ async def grade_task(
|
|
| 170 |
}
|
| 171 |
|
| 172 |
return {
|
|
|
|
| 173 |
"tasks": results,
|
| 174 |
"status": "graded",
|
| 175 |
}
|
|
|
|
| 128 |
"reasoning": "string -- Explanation of the action (optional)",
|
| 129 |
"submit_final": "boolean -- Set true when migration is complete (default: false)",
|
| 130 |
},
|
| 131 |
+
"example_action": {
|
| 132 |
+
"sql_command": "CREATE TABLE ...",
|
| 133 |
+
"reasoning": "Creating the new destination table before copying data.",
|
| 134 |
+
"submit_final": False
|
| 135 |
+
}
|
| 136 |
}
|
| 137 |
|
| 138 |
|
|
|
|
| 175 |
}
|
| 176 |
|
| 177 |
return {
|
| 178 |
+
"grader_version": "1.0",
|
| 179 |
"tasks": results,
|
| 180 |
"status": "graded",
|
| 181 |
}
|
server/environment.py
CHANGED
|
@@ -18,6 +18,7 @@ import re
|
|
| 18 |
import sqlite3
|
| 19 |
import threading
|
| 20 |
import uuid
|
|
|
|
| 21 |
from typing import Any, Dict, List, Optional
|
| 22 |
|
| 23 |
# Support both in-repo and standalone imports
|
|
@@ -145,11 +146,23 @@ class DbMigrationEnvironment(Environment):
|
|
| 145 |
return None, "Error: Query exceeded execution time limit (possible infinite loop). Simplify your query."
|
| 146 |
return None, str(e)
|
| 147 |
except sqlite3.Warning as e:
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
| 153 |
return None, str(e)
|
| 154 |
finally:
|
| 155 |
self._conn.set_progress_handler(None, 0)
|
|
@@ -214,8 +227,11 @@ class DbMigrationEnvironment(Environment):
|
|
| 214 |
self._conn = None
|
| 215 |
|
| 216 |
# Create fresh in-memory database
|
| 217 |
-
self._conn = sqlite3.connect(":memory:")
|
| 218 |
|
|
|
|
|
|
|
|
|
|
| 219 |
# CRITICAL: Enable foreign key enforcement
|
| 220 |
self._conn.execute("PRAGMA foreign_keys = ON")
|
| 221 |
|
|
@@ -239,16 +255,28 @@ class DbMigrationEnvironment(Environment):
|
|
| 239 |
|
| 240 |
# Compute initial score
|
| 241 |
initial_score = self._reconciler.score(self._conn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
return MigrationObservation(
|
| 244 |
done=False,
|
| 245 |
reward=0.0,
|
| 246 |
-
current_schema_sql=
|
| 247 |
-
target_schema_sql=
|
| 248 |
last_execution_result="Environment initialized. Ready for migration.",
|
| 249 |
step_number=0,
|
| 250 |
migration_progress=initial_score,
|
| 251 |
task_name=self.task_name,
|
|
|
|
| 252 |
metadata={"status": "ready"},
|
| 253 |
)
|
| 254 |
|
|
@@ -289,7 +317,11 @@ class DbMigrationEnvironment(Environment):
|
|
| 289 |
sql_command = action.sql_command.strip()
|
| 290 |
|
| 291 |
# --- A3: Dangerous SQL Blacklist ---
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
execution_result = (
|
| 294 |
"Error: This SQL command is not allowed for security reasons. "
|
| 295 |
"ATTACH DATABASE, DETACH DATABASE, LOAD_EXTENSION, and "
|
|
@@ -342,9 +374,9 @@ class DbMigrationEnvironment(Environment):
|
|
| 342 |
if self._is_read_query(sql_command):
|
| 343 |
execution_result = self._format_query_results(cursor)
|
| 344 |
else:
|
| 345 |
-
rows_affected = cursor
|
| 346 |
-
execution_result = f"Success:
|
| 347 |
-
#
|
| 348 |
if not self._in_explicit_tx:
|
| 349 |
try:
|
| 350 |
self._conn.commit()
|
|
@@ -384,19 +416,29 @@ class DbMigrationEnvironment(Environment):
|
|
| 384 |
if done:
|
| 385 |
meta["trajectory"] = self._trajectory
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
return MigrationObservation(
|
| 388 |
done=done,
|
| 389 |
reward=step_reward,
|
| 390 |
-
current_schema_sql=
|
| 391 |
-
target_schema_sql=
|
| 392 |
last_execution_result=execution_result,
|
| 393 |
step_number=self._step_count,
|
| 394 |
migration_progress=current_score,
|
| 395 |
task_name=self.task_name,
|
|
|
|
| 396 |
metadata=meta,
|
| 397 |
)
|
| 398 |
|
| 399 |
-
@property
|
| 400 |
def state(self) -> MigrationState:
|
| 401 |
"""Get current environment state."""
|
| 402 |
return self._state
|
|
|
|
| 18 |
import sqlite3
|
| 19 |
import threading
|
| 20 |
import uuid
|
| 21 |
+
import difflib
|
| 22 |
from typing import Any, Dict, List, Optional
|
| 23 |
|
| 24 |
# Support both in-repo and standalone imports
|
|
|
|
| 146 |
return None, "Error: Query exceeded execution time limit (possible infinite loop). Simplify your query."
|
| 147 |
return None, str(e)
|
| 148 |
except sqlite3.Warning as e:
|
| 149 |
+
# Multi-statement fallback
|
| 150 |
+
try:
|
| 151 |
+
self._conn.executescript(sql)
|
| 152 |
+
return None, None
|
| 153 |
+
except Exception as script_e:
|
| 154 |
+
return None, f"Error (Multi-Statement Fallback Failed): {script_e}. Original error: {e}"
|
| 155 |
+
except sqlite3.OperationalError as e:
|
| 156 |
+
err_str = str(e).lower()
|
| 157 |
+
if "table" in err_str and "already exists" in err_str:
|
| 158 |
+
return None, f"Schema Error: {e}. You must DROP the old table first if replacing it."
|
| 159 |
+
if "has no column" in err_str:
|
| 160 |
+
return None, f"Schema Error: {e}. Check table columns."
|
| 161 |
+
return None, str(e)
|
| 162 |
except Exception as e:
|
| 163 |
+
err_str = str(e).lower()
|
| 164 |
+
if "values for" in err_str and "columns" in err_str:
|
| 165 |
+
return None, f"Data Error: {e}. Ensure you are inserting the correct number of columns."
|
| 166 |
return None, str(e)
|
| 167 |
finally:
|
| 168 |
self._conn.set_progress_handler(None, 0)
|
|
|
|
| 227 |
self._conn = None
|
| 228 |
|
| 229 |
# Create fresh in-memory database
|
| 230 |
+
self._conn = sqlite3.connect(":memory:", isolation_level=None)
|
| 231 |
|
| 232 |
+
# Performance PRAGMAs for Docker I/O
|
| 233 |
+
self._conn.execute("PRAGMA journal_mode = MEMORY")
|
| 234 |
+
|
| 235 |
# CRITICAL: Enable foreign key enforcement
|
| 236 |
self._conn.execute("PRAGMA foreign_keys = ON")
|
| 237 |
|
|
|
|
| 255 |
|
| 256 |
# Compute initial score
|
| 257 |
initial_score = self._reconciler.score(self._conn)
|
| 258 |
+
self._state.migration_progress = initial_score
|
| 259 |
+
|
| 260 |
+
current_ddl = self._get_current_schema()
|
| 261 |
+
target_ddl = self._task_config["target_ddl"]
|
| 262 |
+
diff = "\n".join(difflib.unified_diff(
|
| 263 |
+
current_ddl.splitlines(),
|
| 264 |
+
target_ddl.splitlines(),
|
| 265 |
+
fromfile="current_schema",
|
| 266 |
+
tofile="target_schema",
|
| 267 |
+
lineterm=""
|
| 268 |
+
))
|
| 269 |
|
| 270 |
return MigrationObservation(
|
| 271 |
done=False,
|
| 272 |
reward=0.0,
|
| 273 |
+
current_schema_sql=current_ddl,
|
| 274 |
+
target_schema_sql=target_ddl,
|
| 275 |
last_execution_result="Environment initialized. Ready for migration.",
|
| 276 |
step_number=0,
|
| 277 |
migration_progress=initial_score,
|
| 278 |
task_name=self.task_name,
|
| 279 |
+
schema_diff=diff if diff else "Schemas match exactly.",
|
| 280 |
metadata={"status": "ready"},
|
| 281 |
)
|
| 282 |
|
|
|
|
| 317 |
sql_command = action.sql_command.strip()
|
| 318 |
|
| 319 |
# --- A3: Dangerous SQL Blacklist ---
|
| 320 |
+
sql_lower = sql_command.lower()
|
| 321 |
+
if "pragma" in sql_lower and "foreign_keys" in sql_lower and "off" in sql_lower:
|
| 322 |
+
execution_result = "Security Error: Disabling PRAGMA foreign_keys is strictly explicitly forbidden."
|
| 323 |
+
action_error = "pragma_off_blocked"
|
| 324 |
+
elif _DANGEROUS_PATTERNS.search(sql_command):
|
| 325 |
execution_result = (
|
| 326 |
"Error: This SQL command is not allowed for security reasons. "
|
| 327 |
"ATTACH DATABASE, DETACH DATABASE, LOAD_EXTENSION, and "
|
|
|
|
| 374 |
if self._is_read_query(sql_command):
|
| 375 |
execution_result = self._format_query_results(cursor)
|
| 376 |
else:
|
| 377 |
+
rows_affected = getattr(cursor, "rowcount", -1) if cursor else -1
|
| 378 |
+
execution_result = f"Success: Action executed. Rows affected: {rows_affected}"
|
| 379 |
+
# Try to auto-commit
|
| 380 |
if not self._in_explicit_tx:
|
| 381 |
try:
|
| 382 |
self._conn.commit()
|
|
|
|
| 416 |
if done:
|
| 417 |
meta["trajectory"] = self._trajectory
|
| 418 |
|
| 419 |
+
current_ddl = self._get_current_schema()
|
| 420 |
+
target_ddl = self._task_config["target_ddl"]
|
| 421 |
+
diff = "\n".join(difflib.unified_diff(
|
| 422 |
+
current_ddl.splitlines(),
|
| 423 |
+
target_ddl.splitlines(),
|
| 424 |
+
fromfile="current_schema",
|
| 425 |
+
tofile="target_schema",
|
| 426 |
+
lineterm=""
|
| 427 |
+
))
|
| 428 |
+
|
| 429 |
return MigrationObservation(
|
| 430 |
done=done,
|
| 431 |
reward=step_reward,
|
| 432 |
+
current_schema_sql=current_ddl,
|
| 433 |
+
target_schema_sql=target_ddl,
|
| 434 |
last_execution_result=execution_result,
|
| 435 |
step_number=self._step_count,
|
| 436 |
migration_progress=current_score,
|
| 437 |
task_name=self.task_name,
|
| 438 |
+
schema_diff=diff if diff else "Schemas match exactly.",
|
| 439 |
metadata=meta,
|
| 440 |
)
|
| 441 |
|
|
|
|
| 442 |
def state(self) -> MigrationState:
|
| 443 |
"""Get current environment state."""
|
| 444 |
return self._state
|