Eishaan commited on
Commit
05c4751
·
1 Parent(s): 62d9b39

fixed errors v2

Browse files
Files changed (6) hide show
  1. .gitignore +5 -0
  2. Dockerfile +3 -2
  3. inference.py +61 -25
  4. models.py +11 -1
  5. server/app.py +6 -0
  6. server/environment.py +56 -14
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ uv.lock
4
+ .env
5
+ .venv/
Dockerfile CHANGED
@@ -9,9 +9,10 @@ RUN pip install --no-cache-dir -r requirements.txt
9
  # Copy all project files
10
  COPY . .
11
 
12
- # Set Python path for imports
13
  ENV PYTHONPATH="/app:$PYTHONPATH"
14
-
 
15
  # Health check
16
  HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
17
  CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
 
9
  # Copy all project files
10
  COPY . .
11
 
12
+ # Set environment variables for docker and huggingface
13
  ENV PYTHONPATH="/app:$PYTHONPATH"
14
+ ENV PYTHONUNBUFFERED=1
15
+ ENV PORT=7860
16
  # Health check
17
  HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
18
  CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
inference.py CHANGED
@@ -50,17 +50,18 @@ CRITICAL SQLite-specific rules (violations cause immediate errors):
50
  2. To change column types, add NOT NULL, or add FKs: CREATE new table, INSERT INTO new SELECT FROM old, DROP old, RENAME new.
51
  3. Apostrophes in data (O'Brien, O'Neill) are present — escape with '' in string literals.
52
  4. Execute exactly ONE SQL statement per step.
53
- 5. For table normalization: create new tables first, INSERT INTO ... SELECT, then drop old tables.
54
- 6. For orphaned FK rows: check the TARGET SCHEMA for the correct anomaly/issues table name (it varies per task). Log invalid records there before dropping.
55
- 7. For text currency columns like '$90,000' or '$1,234.56': strip '$' and ',' then cast to the type in the target schema (INTEGER for whole numbers, REAL for decimals).
56
- 8. IMPORTANT: Before writing any DDL, execute SELECT * FROM tablename LIMIT 5 for each source table to inspect the actual data format and identify edge cases like empty strings, leading whitespace, NULL values, and special characters.
57
- 9. Do NOT set submit_final to true until you have run SELECT COUNT(*) on your target tables and verified the counts and data match what the task requires.
58
- 10. When migration is complete and verified, set submit_final to true.
 
59
 
60
  TARGET SCHEMA (achieve this exactly):
61
  {target_ddl}
62
 
63
- Respond ONLY with valid JSON no markdown, no code blocks, no text outside the object:
64
  {{"sql_command": "your SQL here", "reasoning": "why", "submit_final": false}}"""
65
 
66
  ALL_TASKS = [
@@ -74,6 +75,26 @@ ALL_TASKS = [
74
  ]
75
  MAX_PARSE_ERRORS = 5 # Consecutive parse errors before giving up
76
  AUTO_SUBMIT_THRESHOLD = 0.95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  def call_llm(messages: list, timeout: int = 90) -> str:
@@ -186,13 +207,16 @@ def run_task_local(task_name: str) -> dict:
186
  history = [{"role": "system", "content": task_system_prompt}]
187
 
188
  # Initial observation
189
- initial_msg = (
190
- f"CURRENT DATABASE SCHEMA:\n{obs.current_schema_sql}\n\n"
191
- f"Status: {obs.last_execution_result}\n"
192
- f"Migration progress: {obs.migration_progress:.2f}\n\n"
193
- f"Start by inspecting the source data with SELECT queries, then begin the migration."
194
- )
195
- history.append({"role": "user", "content": initial_msg})
 
 
 
196
 
197
  rewards_list = []
198
  consecutive_parse_errors = 0 # D6: Track consecutive only
@@ -204,8 +228,8 @@ def run_task_local(task_name: str) -> dict:
204
  if done:
205
  break
206
 
207
- # --- D5: Context window fix only keep last 10 messages + system ---
208
- messages = [history[0]] + history[-10:]
209
 
210
  try:
211
  raw_response = call_llm(messages)
@@ -226,11 +250,21 @@ def run_task_local(task_name: str) -> dict:
226
  print(f"[STEP] step={step+1} action=MAX_PARSE_ERRORS reward=0.00 done=true error=too_many_consecutive_parse_errors", flush=True)
227
  done = True
228
  break
229
- history.append({"role": "assistant", "content": raw_response})
230
- history.append({
 
 
 
 
 
 
 
 
 
 
231
  "role": "user",
232
- "content": 'ERROR: Your response was not valid JSON. Respond ONLY with: {"sql_command": "...", "reasoning": "...", "submit_final": false}',
233
- })
234
  continue
235
 
236
  # Build the MigrationAction
@@ -278,24 +312,26 @@ def run_task_local(task_name: str) -> dict:
278
  )
279
 
280
  # Add to conversation history
 
281
  history.append({"role": "assistant", "content": json.dumps(action_dict)})
282
 
283
  # --- D5: Lean feedback — NO schema repetition ---
284
- feedback_msg = (
285
  f"EXECUTION RESULT: {obs.last_execution_result}\n"
286
  f"Progress: {obs.migration_progress:.2f}"
 
287
  )
288
  if done:
289
- feedback_msg += "\n\nEpisode complete."
290
  elif obs.migration_progress >= 0.9:
291
- feedback_msg += (
292
  "\n\nMigration is nearly complete! Run SELECT COUNT(*) on each table "
293
  "and compare to your expectations. If everything matches, set submit_final to true."
294
  )
295
  else:
296
- feedback_msg += "\n\nContinue the migration. Write your next SQL command."
297
 
298
- history.append({"role": "user", "content": feedback_msg})
299
 
300
  # Print END
301
  rewards_str = ",".join(f"{r:.2f}" for r in rewards_list) if rewards_list else "0.00"
 
50
  2. To change column types, add NOT NULL, or add FKs: CREATE new table, INSERT INTO new SELECT FROM old, DROP old, RENAME new.
51
  3. Apostrophes in data (O'Brien, O'Neill) are present — escape with '' in string literals.
52
  4. Execute exactly ONE SQL statement per step.
53
+ 5. If a table already exists, you MUST drop it before recreating it (e.g., DROP TABLE IF EXISTS users_new).
54
+ 6. SQLite strictly expects `INSERT INTO tbl VALUES (...)`, not `VALUE (...)`. Ensure column counts match exactly.
55
+ 7. For table normalization: create new tables first, INSERT INTO ... SELECT, then drop old tables.
56
+ 8. For orphaned FK rows: check the TARGET SCHEMA for the anomaly/issues table name. Log invalid records there before dropping.
57
+ 9. For text currency (e.g. '$90,000'): strip '$' and ',' then cast to the target type (INTEGER/REAL).
58
+ 10. IMPORTANT: Before writing any DDL, execute SELECT * FROM tablename LIMIT 5 to inspect the data format.
59
+ 11. Do NOT set submit_final to true until you run SELECT COUNT(*) and verify data matches the task.
60
 
61
  TARGET SCHEMA (achieve this exactly):
62
  {target_ddl}
63
 
64
+ Respond ONLY with a valid JSON object. Do not use markdown backticks (```json). No conversational text.
65
  {{"sql_command": "your SQL here", "reasoning": "why", "submit_final": false}}"""
66
 
67
  ALL_TASKS = [
 
75
  ]
76
  MAX_PARSE_ERRORS = 5 # Consecutive parse errors before giving up
77
  AUTO_SUBMIT_THRESHOLD = 0.95
78
+ MAX_HISTORY_PAIRS = 4 # Keep maximum of 4 user/assistant turn pairs
79
+
80
+
81
+ def build_messages(system_prompt: str, history: list, current_obs_msg: dict) -> list:
82
+ """
83
+ Build messages explicitly pruning history to avoid context bloat.
84
+ """
85
+ system_msg = [{"role": "system", "content": system_prompt}]
86
+
87
+ # We only want assistant/user pairs. Filter out system msgs if any exist in history
88
+ filtered_history = [m for m in history if m["role"] != "system"]
89
+
90
+ # Keep only the last MAX_HISTORY_PAIRS * 2 messages
91
+ max_msgs = MAX_HISTORY_PAIRS * 2
92
+ if len(filtered_history) > max_msgs:
93
+ pruned_history = filtered_history[-max_msgs:]
94
+ else:
95
+ pruned_history = filtered_history
96
+
97
+ return system_msg + pruned_history + [current_obs_msg]
98
 
99
 
100
  def call_llm(messages: list, timeout: int = 90) -> str:
 
207
  history = [{"role": "system", "content": task_system_prompt}]
208
 
209
  # Initial observation
210
+ initial_msg = {
211
+ "role": "user",
212
+ "content": (
213
+ f"CURRENT DATABASE SCHEMA:\n{obs.current_schema_sql}\n\n"
214
+ f"Status: {obs.last_execution_result}\n"
215
+ f"Migration progress: {obs.migration_progress:.2f}\n\n"
216
+ f"Start by inspecting the source data with SELECT queries, then begin the migration."
217
+ )
218
+ }
219
+ history = []
220
 
221
  rewards_list = []
222
  consecutive_parse_errors = 0 # D6: Track consecutive only
 
228
  if done:
229
  break
230
 
231
+ # --- D5: Context window fix: Aggressively prune history via build_messages ---
232
+ messages = build_messages(task_system_prompt, history, initial_msg)
233
 
234
  try:
235
  raw_response = call_llm(messages)
 
250
  print(f"[STEP] step={step+1} action=MAX_PARSE_ERRORS reward=0.00 done=true error=too_many_consecutive_parse_errors", flush=True)
251
  done = True
252
  break
253
+
254
+ # CRITICAL: Strip <think> tags before appending to history to prevent 413 Context OOM
255
+ stripped_response = re.sub(r"<think>.*?</think>", "", raw_response, flags=re.DOTALL).strip()
256
+ stripped_response = re.sub(r"<think>.*$", "", stripped_response, flags=re.DOTALL).strip()
257
+ # If it's still huge, truncate it to 500 chars to save context
258
+ if len(stripped_response) > 500:
259
+ stripped_response = stripped_response[:500] + "... [TRUNCATED DUE TO PARSE ERROR]"
260
+
261
+ history.append(initial_msg) # The prompt we sent
262
+ history.append({"role": "assistant", "content": stripped_response}) # The stripped response
263
+
264
+ initial_msg = {
265
  "role": "user",
266
+ "content": 'ERROR: Your response was not a valid JSON object. Do not use markdown blocks. Respond strictly with: {"sql_command": "...", "reasoning": "...", "submit_final": false}'
267
+ }
268
  continue
269
 
270
  # Build the MigrationAction
 
312
  )
313
 
314
  # Add to conversation history
315
+ history.append(initial_msg)
316
  history.append({"role": "assistant", "content": json.dumps(action_dict)})
317
 
318
  # --- D5: Lean feedback — NO schema repetition ---
319
+ feedback_text = (
320
  f"EXECUTION RESULT: {obs.last_execution_result}\n"
321
  f"Progress: {obs.migration_progress:.2f}"
322
+ f"\nSchema Diff (Missing/Extra constraints vs Target):\n{obs.schema_diff}"
323
  )
324
  if done:
325
+ feedback_text += "\n\nEpisode complete."
326
  elif obs.migration_progress >= 0.9:
327
+ feedback_text += (
328
  "\n\nMigration is nearly complete! Run SELECT COUNT(*) on each table "
329
  "and compare to your expectations. If everything matches, set submit_final to true."
330
  )
331
  else:
332
+ feedback_text += "\n\nContinue the migration. Write your next SQL command."
333
 
334
+ initial_msg = {"role": "user", "content": feedback_text}
335
 
336
  # Print END
337
  rewards_str = ",".join(f"{r:.2f}" for r in rewards_list) if rewards_list else "0.00"
models.py CHANGED
@@ -10,7 +10,7 @@ from __future__ import annotations
10
  from typing import Any, Dict, Optional
11
 
12
  from openenv.core.env_server.types import Action, Observation, State
13
- from pydantic import Field
14
 
15
 
16
  class MigrationAction(Action):
@@ -40,6 +40,11 @@ class MigrationAction(Action):
40
  description="Set to true when you believe the migration is complete"
41
  )
42
 
 
 
 
 
 
43
 
44
  class MigrationObservation(Observation):
45
  """
@@ -60,6 +65,7 @@ class MigrationObservation(Observation):
60
  step_number: Current step count (0 after reset, increments each step).
61
  migration_progress: Current grader score from 0.0 to 1.0.
62
  task_name: Name of the current task being attempted.
 
63
  """
64
 
65
  current_schema_sql: str = Field(
@@ -88,6 +94,10 @@ class MigrationObservation(Observation):
88
  default="",
89
  description="Name of the current task"
90
  )
 
 
 
 
91
 
92
 
93
  class MigrationState(State):
 
10
  from typing import Any, Dict, Optional
11
 
12
  from openenv.core.env_server.types import Action, Observation, State
13
+ from pydantic import Field, field_validator
14
 
15
 
16
  class MigrationAction(Action):
 
40
  description="Set to true when you believe the migration is complete"
41
  )
42
 
43
+ @field_validator("sql_command")
44
+ @classmethod
45
+ def strip_whitespace(cls, v: str) -> str:
46
+ return v.strip()
47
+
48
 
49
  class MigrationObservation(Observation):
50
  """
 
65
  step_number: Current step count (0 after reset, increments each step).
66
  migration_progress: Current grader score from 0.0 to 1.0.
67
  task_name: Name of the current task being attempted.
68
+ schema_diff: Human-readable diff between current and target schemas.
69
  """
70
 
71
  current_schema_sql: str = Field(
 
94
  default="",
95
  description="Name of the current task"
96
  )
97
+ schema_diff: Optional[str] = Field(
98
+ default=None,
99
+ description="Human-readable diff between current and expected target schemas"
100
+ )
101
 
102
 
103
  class MigrationState(State):
server/app.py CHANGED
@@ -128,6 +128,11 @@ async def list_tasks() -> Dict[str, Any]:
128
  "reasoning": "string -- Explanation of the action (optional)",
129
  "submit_final": "boolean -- Set true when migration is complete (default: false)",
130
  },
 
 
 
 
 
131
  }
132
 
133
 
@@ -170,6 +175,7 @@ async def grade_task(
170
  }
171
 
172
  return {
 
173
  "tasks": results,
174
  "status": "graded",
175
  }
 
128
  "reasoning": "string -- Explanation of the action (optional)",
129
  "submit_final": "boolean -- Set true when migration is complete (default: false)",
130
  },
131
+ "example_action": {
132
+ "sql_command": "CREATE TABLE ...",
133
+ "reasoning": "Creating the new destination table before copying data.",
134
+ "submit_final": False
135
+ }
136
  }
137
 
138
 
 
175
  }
176
 
177
  return {
178
+ "grader_version": "1.0",
179
  "tasks": results,
180
  "status": "graded",
181
  }
server/environment.py CHANGED
@@ -18,6 +18,7 @@ import re
18
  import sqlite3
19
  import threading
20
  import uuid
 
21
  from typing import Any, Dict, List, Optional
22
 
23
  # Support both in-repo and standalone imports
@@ -145,11 +146,23 @@ class DbMigrationEnvironment(Environment):
145
  return None, "Error: Query exceeded execution time limit (possible infinite loop). Simplify your query."
146
  return None, str(e)
147
  except sqlite3.Warning as e:
148
- return None, (
149
- f"Error: SQLite requires one statement per step. "
150
- f"Split your commands into separate steps. Original error: {e}"
151
- )
 
 
 
 
 
 
 
 
 
152
  except Exception as e:
 
 
 
153
  return None, str(e)
154
  finally:
155
  self._conn.set_progress_handler(None, 0)
@@ -214,8 +227,11 @@ class DbMigrationEnvironment(Environment):
214
  self._conn = None
215
 
216
  # Create fresh in-memory database
217
- self._conn = sqlite3.connect(":memory:")
218
 
 
 
 
219
  # CRITICAL: Enable foreign key enforcement
220
  self._conn.execute("PRAGMA foreign_keys = ON")
221
 
@@ -239,16 +255,28 @@ class DbMigrationEnvironment(Environment):
239
 
240
  # Compute initial score
241
  initial_score = self._reconciler.score(self._conn)
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  return MigrationObservation(
244
  done=False,
245
  reward=0.0,
246
- current_schema_sql=self._get_current_schema(),
247
- target_schema_sql=self._task_config["target_ddl"],
248
  last_execution_result="Environment initialized. Ready for migration.",
249
  step_number=0,
250
  migration_progress=initial_score,
251
  task_name=self.task_name,
 
252
  metadata={"status": "ready"},
253
  )
254
 
@@ -289,7 +317,11 @@ class DbMigrationEnvironment(Environment):
289
  sql_command = action.sql_command.strip()
290
 
291
  # --- A3: Dangerous SQL Blacklist ---
292
- if _DANGEROUS_PATTERNS.search(sql_command):
 
 
 
 
293
  execution_result = (
294
  "Error: This SQL command is not allowed for security reasons. "
295
  "ATTACH DATABASE, DETACH DATABASE, LOAD_EXTENSION, and "
@@ -342,9 +374,9 @@ class DbMigrationEnvironment(Environment):
342
  if self._is_read_query(sql_command):
343
  execution_result = self._format_query_results(cursor)
344
  else:
345
- rows_affected = cursor.rowcount
346
- execution_result = f"Success: {rows_affected} rows affected"
347
- # Only auto-commit if not in explicit transaction (A4)
348
  if not self._in_explicit_tx:
349
  try:
350
  self._conn.commit()
@@ -384,19 +416,29 @@ class DbMigrationEnvironment(Environment):
384
  if done:
385
  meta["trajectory"] = self._trajectory
386
 
 
 
 
 
 
 
 
 
 
 
387
  return MigrationObservation(
388
  done=done,
389
  reward=step_reward,
390
- current_schema_sql=self._get_current_schema(),
391
- target_schema_sql=self._task_config["target_ddl"],
392
  last_execution_result=execution_result,
393
  step_number=self._step_count,
394
  migration_progress=current_score,
395
  task_name=self.task_name,
 
396
  metadata=meta,
397
  )
398
 
399
- @property
400
  def state(self) -> MigrationState:
401
  """Get current environment state."""
402
  return self._state
 
18
  import sqlite3
19
  import threading
20
  import uuid
21
+ import difflib
22
  from typing import Any, Dict, List, Optional
23
 
24
  # Support both in-repo and standalone imports
 
146
  return None, "Error: Query exceeded execution time limit (possible infinite loop). Simplify your query."
147
  return None, str(e)
148
  except sqlite3.Warning as e:
149
+ # Multi-statement fallback
150
+ try:
151
+ self._conn.executescript(sql)
152
+ return None, None
153
+ except Exception as script_e:
154
+ return None, f"Error (Multi-Statement Fallback Failed): {script_e}. Original error: {e}"
155
+ except sqlite3.OperationalError as e:
156
+ err_str = str(e).lower()
157
+ if "table" in err_str and "already exists" in err_str:
158
+ return None, f"Schema Error: {e}. You must DROP the old table first if replacing it."
159
+ if "has no column" in err_str:
160
+ return None, f"Schema Error: {e}. Check table columns."
161
+ return None, str(e)
162
  except Exception as e:
163
+ err_str = str(e).lower()
164
+ if "values for" in err_str and "columns" in err_str:
165
+ return None, f"Data Error: {e}. Ensure you are inserting the correct number of columns."
166
  return None, str(e)
167
  finally:
168
  self._conn.set_progress_handler(None, 0)
 
227
  self._conn = None
228
 
229
  # Create fresh in-memory database
230
+ self._conn = sqlite3.connect(":memory:", isolation_level=None)
231
 
232
+ # Performance PRAGMAs for Docker I/O
233
+ self._conn.execute("PRAGMA journal_mode = MEMORY")
234
+
235
  # CRITICAL: Enable foreign key enforcement
236
  self._conn.execute("PRAGMA foreign_keys = ON")
237
 
 
255
 
256
  # Compute initial score
257
  initial_score = self._reconciler.score(self._conn)
258
+ self._state.migration_progress = initial_score
259
+
260
+ current_ddl = self._get_current_schema()
261
+ target_ddl = self._task_config["target_ddl"]
262
+ diff = "\n".join(difflib.unified_diff(
263
+ current_ddl.splitlines(),
264
+ target_ddl.splitlines(),
265
+ fromfile="current_schema",
266
+ tofile="target_schema",
267
+ lineterm=""
268
+ ))
269
 
270
  return MigrationObservation(
271
  done=False,
272
  reward=0.0,
273
+ current_schema_sql=current_ddl,
274
+ target_schema_sql=target_ddl,
275
  last_execution_result="Environment initialized. Ready for migration.",
276
  step_number=0,
277
  migration_progress=initial_score,
278
  task_name=self.task_name,
279
+ schema_diff=diff if diff else "Schemas match exactly.",
280
  metadata={"status": "ready"},
281
  )
282
 
 
317
  sql_command = action.sql_command.strip()
318
 
319
  # --- A3: Dangerous SQL Blacklist ---
320
+ sql_lower = sql_command.lower()
321
+ if "pragma" in sql_lower and "foreign_keys" in sql_lower and "off" in sql_lower:
322
+ execution_result = "Security Error: Disabling PRAGMA foreign_keys is strictly explicitly forbidden."
323
+ action_error = "pragma_off_blocked"
324
+ elif _DANGEROUS_PATTERNS.search(sql_command):
325
  execution_result = (
326
  "Error: This SQL command is not allowed for security reasons. "
327
  "ATTACH DATABASE, DETACH DATABASE, LOAD_EXTENSION, and "
 
374
  if self._is_read_query(sql_command):
375
  execution_result = self._format_query_results(cursor)
376
  else:
377
+ rows_affected = getattr(cursor, "rowcount", -1) if cursor else -1
378
+ execution_result = f"Success: Action executed. Rows affected: {rows_affected}"
379
+ # Try to auto-commit
380
  if not self._in_explicit_tx:
381
  try:
382
  self._conn.commit()
 
416
  if done:
417
  meta["trajectory"] = self._trajectory
418
 
419
+ current_ddl = self._get_current_schema()
420
+ target_ddl = self._task_config["target_ddl"]
421
+ diff = "\n".join(difflib.unified_diff(
422
+ current_ddl.splitlines(),
423
+ target_ddl.splitlines(),
424
+ fromfile="current_schema",
425
+ tofile="target_schema",
426
+ lineterm=""
427
+ ))
428
+
429
  return MigrationObservation(
430
  done=done,
431
  reward=step_reward,
432
+ current_schema_sql=current_ddl,
433
+ target_schema_sql=target_ddl,
434
  last_execution_result=execution_result,
435
  step_number=self._step_count,
436
  migration_progress=current_score,
437
  task_name=self.task_name,
438
+ schema_diff=diff if diff else "Schemas match exactly.",
439
  metadata=meta,
440
  )
441
 
 
442
  def state(self) -> MigrationState:
443
  """Get current environment state."""
444
  return self._state