Mohammed-Altaf commited on
Commit
7b9dfc1
·
1 Parent(s): aca1396

fixed issues in SQL schema not available to agent and code parsing issues

Browse files
helpers/__init__.py ADDED
File without changes
helpers/response_parser.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Any
4
+
5
+ FALLBACK_ACTION = json.dumps({"action": "submit_answer", "answer": "unknown"})
6
+ # ── Layer 1: Sanitize special characters inside string values ──────────────────
7
+
8
+
9
+ def _sanitize_string_value(match: re.Match) -> str:
10
+ """
11
+ Receives a regex match of ("key": "value") and cleans only the value part.
12
+ Escapes unescaped newlines, tabs, carriage returns, and inner double quotes.
13
+ This is the core trick LangChain uses in _replace_new_line / _custom_parser.
14
+ """
15
+ opening = match.group(1) # e.g. "code": "
16
+ value = match.group(2) # raw value content (may span multiple lines)
17
+ closing = match.group(3) # closing "
18
+
19
+ value = re.sub(r"\n", r"\\n", value)
20
+ value = re.sub(r"\r", r"\\r", value)
21
+ value = re.sub(r"\t", r"\\t", value)
22
+ value = re.sub(r'(?<!\\)"', r'\\"', value) # escape unescaped inner quotes
23
+
24
+ return opening + value + closing
25
+
26
+
27
+ def _sanitize_all_string_values(text: str) -> str:
28
+ """
29
+ Apply _sanitize_string_value to every JSON string value in the text.
30
+ Uses re.DOTALL so values that span multiple lines are handled correctly.
31
+ Generalised version of LangChain's _custom_parser (which only targeted action_input).
32
+ """
33
+ return re.sub(
34
+ r'("[\w]+"\s*:\s*")(.*?)(")', # ("key": ") VALUE (")
35
+ _sanitize_string_value,
36
+ text,
37
+ flags=re.DOTALL,
38
+ )
39
+
40
+
41
+ # ── Layer 2: Pre-parse text fixes ─────────────────────────────────────────────
42
+
43
+
44
+ def _preprocess(text: str) -> str:
45
+ """Fix common LLM response quirks before attempting JSON parsing."""
46
+
47
+ # Strip markdown code fences (```json ... ``` or ``` ... ```)
48
+ # LangChain uses a regex for this: _json_markdown_re
49
+ match = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
50
+ if match:
51
+ text = match.group(1).strip()
52
+
53
+ # Double curly braces {{"k": "v"}} → {"k": "v"}
54
+ text = text.replace("{{", "{").replace("}}", "}")
55
+
56
+ # Python literals → JSON literals
57
+ text = re.sub(r"\bTrue\b", "true", text)
58
+ text = re.sub(r"\bFalse\b", "false", text)
59
+ text = re.sub(r"\bNone\b", "null", text)
60
+
61
+ # Trailing commas before } or ]
62
+ text = re.sub(r",\s*([}\]])", r"\1", text)
63
+
64
+ # Outer single-quote wrap '{"k": "v"}' → {"k": "v"}
65
+ if text.startswith("'") and text.endswith("'"):
66
+ text = text[1:-1].replace("\\'", "'")
67
+
68
+ return text.strip()
69
+
70
+
71
+ # ── Layer 3: Extract first JSON blob from surrounding prose ───────────────────
72
+
73
+
74
+ def _extract_json_blob(text: str) -> str:
75
+ """
76
+ Pull out the first {...} or [...] blob from text that has prose around it.
77
+ Inspired by LangChain's _json_markdown_re fallback in parse_json_markdown.
78
+ """
79
+ match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
80
+ return match.group(1) if match else text
81
+
82
+
83
+ # ── Layer 4: parse_partial_json — LangChain's stack-based closer ──────────────
84
+
85
+
86
+ def _parse_partial_json(s: str) -> Any:
87
+ """
88
+ Parse JSON that may be truncated / missing closing brackets.
89
+ Adapted from LangChain's parse_partial_json (originally from open-interpreter).
90
+ Uses a stack to track open containers and closes them before parsing.
91
+ """
92
+ s = s.strip()
93
+
94
+ # Try the string as-is first
95
+ try:
96
+ return json.loads(s)
97
+ except json.JSONDecodeError:
98
+ pass
99
+
100
+ # Walk through and auto-close any unclosed {, [, or "
101
+ stack = []
102
+ is_inside = False # inside a string?
103
+ position = 0
104
+
105
+ for i, char in enumerate(s):
106
+ if is_inside:
107
+ if char == '"' and s[i - 1] != "\\":
108
+ is_inside = False
109
+ else:
110
+ if char == '"':
111
+ is_inside = True
112
+ stack.append('"')
113
+ elif char in "{[":
114
+ stack.append(char)
115
+ elif char in "}]":
116
+ if stack and stack[-1] in "{[":
117
+ stack.pop()
118
+ position = i
119
+
120
+ # Close open containers in reverse order
121
+ completed = s[: position + 1]
122
+ for bracket in reversed(stack):
123
+ if bracket == '"':
124
+ completed += '"'
125
+ elif bracket == "{":
126
+ completed += "}"
127
+ elif bracket == "[":
128
+ completed += "]"
129
+
130
+ return json.loads(completed)
131
+
132
+
133
+ # ── Layer 5: Direct greedy extraction — last resort for unescaped inner quotes ──
134
+
135
+
136
+ def _extract_fields_direct(text: str) -> dict:
137
+ """Extract action fields using greedy regex anchored to the last closing quote.
138
+
139
+ Handles the case where the model emits unescaped double-quote characters inside
140
+ a "code" or "answer" value (e.g. df["col"]). The non-greedy `(.*?)` in
141
+ _sanitize_all_string_values stops at the *first* inner quote and corrupts the
142
+ output. By using a greedy `(.*)` anchored with a lookahead for the last `"}`
143
+ boundary we capture the full value regardless of inner quotes.
144
+
145
+ Args:
146
+ text: Pre-processed JSON-like string.
147
+
148
+ Returns:
149
+ Dict with 'action' and 'code'/'answer' keys.
150
+
151
+ Raises:
152
+ ValueError: If the action field cannot be found or the value cannot be
153
+ extracted for the detected action type.
154
+ """
155
+ action_match = re.search(r'"action"\s*:\s*"(\w+)"', text)
156
+ if not action_match:
157
+ raise ValueError("No 'action' field found")
158
+ action_type = action_match.group(1)
159
+
160
+ if action_type == "execute_code":
161
+ m = re.search(r'"code"\s*:\s*"(.*)"(?=\s*})', text, re.DOTALL)
162
+ if m:
163
+ return {"action": "execute_code", "code": m.group(1)}
164
+ elif action_type == "submit_answer":
165
+ m = re.search(r'"answer"\s*:\s*"(.*)"(?=\s*})', text, re.DOTALL)
166
+ if m:
167
+ return {"action": "submit_answer", "answer": m.group(1)}
168
+
169
+ raise ValueError(f"Could not extract value for action_type={action_type!r}")
170
+
171
+
172
+ # ── Public API ─────────────────────────────────────────────────────────────────
173
+
174
+
175
+ def parse_model_action(response_text: str) -> dict:
176
+ """
177
+ Parse a raw LLM response into an action dict.
178
+
179
+ Pipeline (mirrors LangChain's JsonOutputParser internals):
180
+ 1. _preprocess – fix markdown fences, double braces, Python literals …
181
+ 2. _sanitize_all_string_values – escape unescaped quotes/newlines inside values
182
+ 3. _extract_json_blob – strip surrounding prose
183
+ 4. _parse_partial_json – close truncated JSON with a stack algorithm
184
+
185
+ Each strategy is tried independently so a failure in one doesn't block others.
186
+ """
187
+ text = response_text.strip()
188
+
189
+ strategies = [
190
+ lambda t: _parse_partial_json(t),
191
+ # (preprocessed, sanitized, as-is)
192
+ lambda t: _parse_partial_json(_sanitize_all_string_values(_preprocess(t))),
193
+ # (extract blob first, then preprocess + sanitize)
194
+ lambda t: _parse_partial_json(_sanitize_all_string_values(_preprocess(_extract_json_blob(t)))),
195
+ # (preprocess + extract blob, then sanitize)
196
+ lambda t: _parse_partial_json(_sanitize_all_string_values(_extract_json_blob(_preprocess(t)))),
197
+ # (sanitize raw text, skip preprocess — rare fallback)
198
+ lambda t: _parse_partial_json(_sanitize_all_string_values(t)),
199
+ # greedy extraction — handles unescaped inner quotes in code/answer values
200
+ lambda t: _extract_fields_direct(_preprocess(_extract_json_blob(t))),
201
+ lambda t: _extract_fields_direct(_extract_json_blob(t)),
202
+ ]
203
+
204
+ for strategy in strategies:
205
+ try:
206
+ return strategy(text)
207
+ except (json.JSONDecodeError, ValueError):
208
+ continue
209
+
210
+ print(f"JSON Decoding Error while parsing action in response text: {response_text}")
211
+ return json.loads(FALLBACK_ACTION)
inference.py CHANGED
@@ -7,6 +7,7 @@ from dotenv import load_dotenv
7
  from openai import OpenAI
8
 
9
  from client import DataAnalysisClient
 
10
  from models import DataAction
11
 
12
  load_dotenv()
@@ -20,13 +21,17 @@ ENV_SERVER_URL = os.getenv("ENV_SERVER_URL") or "https://mohammed-altaf-dataanal
20
 
21
  SYSTEM_PROMPT = """
22
  <ROLE>
23
- You are a data analyst. You are given a dataset loaded as a pandas DataFrame called `df`.
24
- You can execute Python/pandas code to explore the dataset and answer the question.
 
25
  </ROLE>
26
 
27
  <RULES>
28
- - Use `print()` to see results of your code
29
- - The DataFrame `df` is pre-loaded with pandas as `pd` and numpy as `np`
 
 
 
30
  - When you have the answer, submit it in the exact format requested
31
  - Be precise with numbers and formatting
32
  </RULES>
@@ -42,8 +47,6 @@ Respond with ONLY the JSON, no other text.
42
  </NOTE>
43
  """
44
 
45
- FALLBACK_ACTION = json.dumps({"action": "submit_answer", "answer": "unknown"})
46
-
47
 
48
  def log_start(task: str, env: str, model: str) -> None:
49
  """Log the start of a task episode.
@@ -87,105 +90,6 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
87
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
88
 
89
 
90
- def parse_model_action(response_text: str) -> dict:
91
- """Parse the model's raw text response into an action dict.
92
-
93
- Handles multiple LLM response edge cases:
94
- - Markdown code blocks (```json ... ``` or ``` ... ```)
95
- - Double curly braces e.g. {{"key": "value"}}
96
- - Single quotes instead of double quotes e.g. {'key': 'value'}
97
- - Python literals: True/False/None → true/false/null
98
- - Trailing commas in objects/arrays e.g. {"key": "value",}
99
- - Extra text/prose before or after the JSON blob
100
- - Escaped single quotes inside single-quoted strings
101
- - Whitespace and newline noise
102
-
103
- Args:
104
- response_text: Raw string returned by the model.
105
-
106
- Returns:
107
- Parsed action dict, or a fallback submit_answer on failure.
108
- """
109
-
110
- def attempt_parse(text: str) -> dict:
111
- return json.loads(text)
112
-
113
- def apply_fixes(text: str) -> str:
114
- # Strip markdown code blocks
115
- if text.startswith("```"):
116
- parts = text.split("```")
117
- if len(parts) >= 2:
118
- text = parts[1]
119
- if text.startswith("json"):
120
- text = text[4:]
121
- text = text.strip()
122
-
123
- # Double curly braces → single
124
- text = text.replace("{{", "{").replace("}}", "}")
125
-
126
- # Python literals → JSON literals
127
- text = re.sub(r"\bTrue\b", "true", text)
128
- text = re.sub(r"\bFalse\b", "false", text)
129
- text = re.sub(r"\bNone\b", "null", text)
130
-
131
- # Trailing commas before } or ]
132
- text = re.sub(r",\s*([}\]])", r"\1", text)
133
-
134
- # Single quote handling — two distinct cases:
135
- #
136
- # Case 1: Entire JSON is wrapped in outer single quotes
137
- # e.g. '{"action": "x", "code": "df[\'col\']"}'
138
- # → strip the outer quotes and unescape internal \'
139
- if text.startswith("'") and text.endswith("'"):
140
- text = text[1:-1].replace("\\'", "'")
141
-
142
- # Case 2: JSON itself uses single quotes as delimiters
143
- # e.g. {'action': 'execute_code', 'code': 'print()'}
144
- # → only apply when structure looks single-quote delimited
145
- # → avoids corrupting double-quoted values that contain bracket notation
146
- elif text.startswith("{'") or ("': " in text and '": ' not in text):
147
- text = re.sub(
148
- r"'((?:\\'|[^'])*)'", lambda m: '"' + m.group(1).replace("\\'", "'").replace('"', '\\"') + '"', text
149
- )
150
-
151
- return text
152
-
153
- def extract_json_blob(text: str) -> str:
154
- """Extract the first {...} or [...] blob from text with surrounding prose."""
155
- match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
156
- if match:
157
- return match.group(1)
158
- return text
159
-
160
- text = response_text.strip()
161
-
162
- try:
163
- return attempt_parse(text)
164
- except json.JSONDecodeError:
165
- pass
166
-
167
- try:
168
- return attempt_parse(apply_fixes(text))
169
- except json.JSONDecodeError:
170
- pass
171
-
172
- try:
173
- blob = extract_json_blob(text)
174
- return attempt_parse(apply_fixes(blob))
175
- except json.JSONDecodeError:
176
- pass
177
-
178
- try:
179
- fixed = apply_fixes(text)
180
- blob = extract_json_blob(fixed)
181
- return attempt_parse(blob)
182
- except json.JSONDecodeError:
183
- pass
184
-
185
- print(f"JSON Decoding Error while parsing action in response text: {response_text}")
186
- return json.loads(FALLBACK_ACTION)
187
-
188
-
189
  def run_task(openai_client: OpenAI, env_client: Any, task_id: int) -> float:
190
  """Run a single task episode using the language model as the agent.
191
 
@@ -234,7 +138,6 @@ def run_task(openai_client: OpenAI, env_client: Any, task_id: int) -> float:
234
  except Exception as exc:
235
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
236
  response_text = FALLBACK_ACTION
237
-
238
  action = parse_model_action(response_text)
239
  action_type = action.get("action", "")
240
 
 
7
  from openai import OpenAI
8
 
9
  from client import DataAnalysisClient
10
+ from helpers.response_parser import FALLBACK_ACTION, parse_model_action
11
  from models import DataAction
12
 
13
  load_dotenv()
 
21
 
22
  SYSTEM_PROMPT = """
23
  <ROLE>
24
+ You are a data analyst. You have two data sources available:
25
+ 1. `df` a pandas DataFrame (sales CSV, pre-loaded)
26
+ 2. A SQLite database at `db_path` — contains additional tables (e.g. customer_profiles, product_catalog)
27
  </ROLE>
28
 
29
  <RULES>
30
+ - Use `print()` to output results
31
+ - `pd`, `np`, `sqlite3`, and `db_path` are already in scope NEVER use import statements (they will fail)
32
+ - `df` is a pandas DataFrame — use pandas operations on it, NEVER SQL
33
+ - To query the SQLite database use: `conn = sqlite3.connect(db_path)` then `pd.read_sql(query, conn)`
34
+ - For cross-source tasks: query SQLite for the extra data, then merge with `df` using pandas
35
  - When you have the answer, submit it in the exact format requested
36
  - Be precise with numbers and formatting
37
  </RULES>
 
47
  </NOTE>
48
  """
49
 
 
 
50
 
51
  def log_start(task: str, env: str, model: str) -> None:
52
  """Log the start of a task episode.
 
90
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def run_task(openai_client: OpenAI, env_client: Any, task_id: int) -> float:
94
  """Run a single task episode using the language model as the agent.
95
 
 
138
  except Exception as exc:
139
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
140
  response_text = FALLBACK_ACTION
 
141
  action = parse_model_action(response_text)
142
  action_type = action.get("action", "")
143
 
server/data_analysis_env.py CHANGED
@@ -35,6 +35,7 @@ class DataAnalysisEnv(Environment):
35
  """
36
 
37
  MAX_STEPS = 20
 
38
 
39
  def __init__(self):
40
  """Initialize the environment with default state."""
@@ -79,14 +80,38 @@ class DataAnalysisEnv(Environment):
79
  def _dataset_info(self) -> str:
80
  """Generate a summary of the dataset schema for the agent.
81
 
 
 
 
82
  Returns:
83
- A string describing column names, dtypes, row count, and a sample.
 
84
  """
85
  buf = io.StringIO()
86
  self._df.info(buf=buf)
87
  info_str = buf.getvalue()
88
  sample = self._df.head(3).to_string()
89
- return f"Dataset shape: {self._df.shape}\n\n{info_str}\nSample rows:\n{sample}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def reset(
92
  self,
 
35
  """
36
 
37
  MAX_STEPS = 20
38
+ SUPPORTS_CONCURRENT_SESSIONS = True
39
 
40
  def __init__(self):
41
  """Initialize the environment with default state."""
 
80
  def _dataset_info(self) -> str:
81
  """Generate a summary of the dataset schema for the agent.
82
 
83
+ Includes the sales DataFrame schema plus the SQLite database table schemas
84
+ so the agent knows what data is available and where to find it.
85
+
86
  Returns:
87
+ A string describing column names, dtypes, row count, a sample for df,
88
+ and table schemas for the SQLite database.
89
  """
90
  buf = io.StringIO()
91
  self._df.info(buf=buf)
92
  info_str = buf.getvalue()
93
  sample = self._df.head(3).to_string()
94
+ df_section = f"=== df (pandas DataFrame, pre-loaded from sales CSV) ===\nShape: {self._df.shape}\n{info_str}\nSample rows:\n{sample}"
95
+
96
+ try:
97
+ conn = sqlite3.connect(DB_PATH)
98
+ cursor = conn.cursor()
99
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
100
+ tables = [row[0] for row in cursor.fetchall()]
101
+ db_lines = ["\n=== SQLite database (accessible via sqlite3.connect(db_path)) ==="]
102
+ for table in tables:
103
+ cursor.execute(f"PRAGMA table_info({table})")
104
+ cols = [(row[1], row[2]) for row in cursor.fetchall()]
105
+ cursor.execute(f"SELECT COUNT(*) FROM {table}")
106
+ count = cursor.fetchone()[0]
107
+ col_str = ", ".join(f"{c} ({t})" for c, t in cols)
108
+ db_lines.append(f" Table '{table}' ({count} rows): {col_str}")
109
+ conn.close()
110
+ db_section = "\n".join(db_lines)
111
+ except Exception:
112
+ db_section = "\n=== SQLite database: schema unavailable ==="
113
+
114
+ return f"{df_section}\n{db_section}"
115
 
116
  def reset(
117
  self,