Spaces:
Sleeping
Sleeping
Commit ·
7b9dfc1
1
Parent(s): aca1396
fixed issues in SQL schema not available to agent and code parsing issues
Browse files- helpers/__init__.py +0 -0
- helpers/response_parser.py +211 -0
- inference.py +9 -106
- server/data_analysis_env.py +27 -2
helpers/__init__.py
ADDED
|
File without changes
|
helpers/response_parser.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
FALLBACK_ACTION = json.dumps({"action": "submit_answer", "answer": "unknown"})
|
| 6 |
+
# ── Layer 1: Sanitize special characters inside string values ──────────────────
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _sanitize_string_value(match: re.Match) -> str:
|
| 10 |
+
"""
|
| 11 |
+
Receives a regex match of ("key": "value") and cleans only the value part.
|
| 12 |
+
Escapes unescaped newlines, tabs, carriage returns, and inner double quotes.
|
| 13 |
+
This is the core trick LangChain uses in _replace_new_line / _custom_parser.
|
| 14 |
+
"""
|
| 15 |
+
opening = match.group(1) # e.g. "code": "
|
| 16 |
+
value = match.group(2) # raw value content (may span multiple lines)
|
| 17 |
+
closing = match.group(3) # closing "
|
| 18 |
+
|
| 19 |
+
value = re.sub(r"\n", r"\\n", value)
|
| 20 |
+
value = re.sub(r"\r", r"\\r", value)
|
| 21 |
+
value = re.sub(r"\t", r"\\t", value)
|
| 22 |
+
value = re.sub(r'(?<!\\)"', r'\\"', value) # escape unescaped inner quotes
|
| 23 |
+
|
| 24 |
+
return opening + value + closing
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _sanitize_all_string_values(text: str) -> str:
|
| 28 |
+
"""
|
| 29 |
+
Apply _sanitize_string_value to every JSON string value in the text.
|
| 30 |
+
Uses re.DOTALL so values that span multiple lines are handled correctly.
|
| 31 |
+
Generalised version of LangChain's _custom_parser (which only targeted action_input).
|
| 32 |
+
"""
|
| 33 |
+
return re.sub(
|
| 34 |
+
r'("[\w]+"\s*:\s*")(.*?)(")', # ("key": ") VALUE (")
|
| 35 |
+
_sanitize_string_value,
|
| 36 |
+
text,
|
| 37 |
+
flags=re.DOTALL,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ── Layer 2: Pre-parse text fixes ─────────────────────────────────────────────
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _preprocess(text: str) -> str:
|
| 45 |
+
"""Fix common LLM response quirks before attempting JSON parsing."""
|
| 46 |
+
|
| 47 |
+
# Strip markdown code fences (```json ... ``` or ``` ... ```)
|
| 48 |
+
# LangChain uses a regex for this: _json_markdown_re
|
| 49 |
+
match = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
|
| 50 |
+
if match:
|
| 51 |
+
text = match.group(1).strip()
|
| 52 |
+
|
| 53 |
+
# Double curly braces {{"k": "v"}} → {"k": "v"}
|
| 54 |
+
text = text.replace("{{", "{").replace("}}", "}")
|
| 55 |
+
|
| 56 |
+
# Python literals → JSON literals
|
| 57 |
+
text = re.sub(r"\bTrue\b", "true", text)
|
| 58 |
+
text = re.sub(r"\bFalse\b", "false", text)
|
| 59 |
+
text = re.sub(r"\bNone\b", "null", text)
|
| 60 |
+
|
| 61 |
+
# Trailing commas before } or ]
|
| 62 |
+
text = re.sub(r",\s*([}\]])", r"\1", text)
|
| 63 |
+
|
| 64 |
+
# Outer single-quote wrap '{"k": "v"}' → {"k": "v"}
|
| 65 |
+
if text.startswith("'") and text.endswith("'"):
|
| 66 |
+
text = text[1:-1].replace("\\'", "'")
|
| 67 |
+
|
| 68 |
+
return text.strip()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ── Layer 3: Extract first JSON blob from surrounding prose ───────────────────
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _extract_json_blob(text: str) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Pull out the first {...} or [...] blob from text that has prose around it.
|
| 77 |
+
Inspired by LangChain's _json_markdown_re fallback in parse_json_markdown.
|
| 78 |
+
"""
|
| 79 |
+
match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
|
| 80 |
+
return match.group(1) if match else text
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ── Layer 4: parse_partial_json — LangChain's stack-based closer ──────────────
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _parse_partial_json(s: str) -> Any:
|
| 87 |
+
"""
|
| 88 |
+
Parse JSON that may be truncated / missing closing brackets.
|
| 89 |
+
Adapted from LangChain's parse_partial_json (originally from open-interpreter).
|
| 90 |
+
Uses a stack to track open containers and closes them before parsing.
|
| 91 |
+
"""
|
| 92 |
+
s = s.strip()
|
| 93 |
+
|
| 94 |
+
# Try the string as-is first
|
| 95 |
+
try:
|
| 96 |
+
return json.loads(s)
|
| 97 |
+
except json.JSONDecodeError:
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
# Walk through and auto-close any unclosed {, [, or "
|
| 101 |
+
stack = []
|
| 102 |
+
is_inside = False # inside a string?
|
| 103 |
+
position = 0
|
| 104 |
+
|
| 105 |
+
for i, char in enumerate(s):
|
| 106 |
+
if is_inside:
|
| 107 |
+
if char == '"' and s[i - 1] != "\\":
|
| 108 |
+
is_inside = False
|
| 109 |
+
else:
|
| 110 |
+
if char == '"':
|
| 111 |
+
is_inside = True
|
| 112 |
+
stack.append('"')
|
| 113 |
+
elif char in "{[":
|
| 114 |
+
stack.append(char)
|
| 115 |
+
elif char in "}]":
|
| 116 |
+
if stack and stack[-1] in "{[":
|
| 117 |
+
stack.pop()
|
| 118 |
+
position = i
|
| 119 |
+
|
| 120 |
+
# Close open containers in reverse order
|
| 121 |
+
completed = s[: position + 1]
|
| 122 |
+
for bracket in reversed(stack):
|
| 123 |
+
if bracket == '"':
|
| 124 |
+
completed += '"'
|
| 125 |
+
elif bracket == "{":
|
| 126 |
+
completed += "}"
|
| 127 |
+
elif bracket == "[":
|
| 128 |
+
completed += "]"
|
| 129 |
+
|
| 130 |
+
return json.loads(completed)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ── Layer 5: Direct greedy extraction — last resort for unescaped inner quotes ──
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _extract_fields_direct(text: str) -> dict:
|
| 137 |
+
"""Extract action fields using greedy regex anchored to the last closing quote.
|
| 138 |
+
|
| 139 |
+
Handles the case where the model emits unescaped double-quote characters inside
|
| 140 |
+
a "code" or "answer" value (e.g. df["col"]). The non-greedy `(.*?)` in
|
| 141 |
+
_sanitize_all_string_values stops at the *first* inner quote and corrupts the
|
| 142 |
+
output. By using a greedy `(.*)` anchored with a lookahead for the last `"}`
|
| 143 |
+
boundary we capture the full value regardless of inner quotes.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
text: Pre-processed JSON-like string.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Dict with 'action' and 'code'/'answer' keys.
|
| 150 |
+
|
| 151 |
+
Raises:
|
| 152 |
+
ValueError: If the action field cannot be found or the value cannot be
|
| 153 |
+
extracted for the detected action type.
|
| 154 |
+
"""
|
| 155 |
+
action_match = re.search(r'"action"\s*:\s*"(\w+)"', text)
|
| 156 |
+
if not action_match:
|
| 157 |
+
raise ValueError("No 'action' field found")
|
| 158 |
+
action_type = action_match.group(1)
|
| 159 |
+
|
| 160 |
+
if action_type == "execute_code":
|
| 161 |
+
m = re.search(r'"code"\s*:\s*"(.*)"(?=\s*})', text, re.DOTALL)
|
| 162 |
+
if m:
|
| 163 |
+
return {"action": "execute_code", "code": m.group(1)}
|
| 164 |
+
elif action_type == "submit_answer":
|
| 165 |
+
m = re.search(r'"answer"\s*:\s*"(.*)"(?=\s*})', text, re.DOTALL)
|
| 166 |
+
if m:
|
| 167 |
+
return {"action": "submit_answer", "answer": m.group(1)}
|
| 168 |
+
|
| 169 |
+
raise ValueError(f"Could not extract value for action_type={action_type!r}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ── Public API ─────────────────────────────────────────────────────────────────
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def parse_model_action(response_text: str) -> dict:
|
| 176 |
+
"""
|
| 177 |
+
Parse a raw LLM response into an action dict.
|
| 178 |
+
|
| 179 |
+
Pipeline (mirrors LangChain's JsonOutputParser internals):
|
| 180 |
+
1. _preprocess – fix markdown fences, double braces, Python literals …
|
| 181 |
+
2. _sanitize_all_string_values – escape unescaped quotes/newlines inside values
|
| 182 |
+
3. _extract_json_blob – strip surrounding prose
|
| 183 |
+
4. _parse_partial_json – close truncated JSON with a stack algorithm
|
| 184 |
+
|
| 185 |
+
Each strategy is tried independently so a failure in one doesn't block others.
|
| 186 |
+
"""
|
| 187 |
+
text = response_text.strip()
|
| 188 |
+
|
| 189 |
+
strategies = [
|
| 190 |
+
lambda t: _parse_partial_json(t),
|
| 191 |
+
# (preprocessed, sanitized, as-is)
|
| 192 |
+
lambda t: _parse_partial_json(_sanitize_all_string_values(_preprocess(t))),
|
| 193 |
+
# (extract blob first, then preprocess + sanitize)
|
| 194 |
+
lambda t: _parse_partial_json(_sanitize_all_string_values(_preprocess(_extract_json_blob(t)))),
|
| 195 |
+
# (preprocess + extract blob, then sanitize)
|
| 196 |
+
lambda t: _parse_partial_json(_sanitize_all_string_values(_extract_json_blob(_preprocess(t)))),
|
| 197 |
+
# (sanitize raw text, skip preprocess — rare fallback)
|
| 198 |
+
lambda t: _parse_partial_json(_sanitize_all_string_values(t)),
|
| 199 |
+
# greedy extraction — handles unescaped inner quotes in code/answer values
|
| 200 |
+
lambda t: _extract_fields_direct(_preprocess(_extract_json_blob(t))),
|
| 201 |
+
lambda t: _extract_fields_direct(_extract_json_blob(t)),
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
for strategy in strategies:
|
| 205 |
+
try:
|
| 206 |
+
return strategy(text)
|
| 207 |
+
except (json.JSONDecodeError, ValueError):
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
print(f"JSON Decoding Error while parsing action in response text: {response_text}")
|
| 211 |
+
return json.loads(FALLBACK_ACTION)
|
inference.py
CHANGED
|
@@ -7,6 +7,7 @@ from dotenv import load_dotenv
|
|
| 7 |
from openai import OpenAI
|
| 8 |
|
| 9 |
from client import DataAnalysisClient
|
|
|
|
| 10 |
from models import DataAction
|
| 11 |
|
| 12 |
load_dotenv()
|
|
@@ -20,13 +21,17 @@ ENV_SERVER_URL = os.getenv("ENV_SERVER_URL") or "https://mohammed-altaf-dataanal
|
|
| 20 |
|
| 21 |
SYSTEM_PROMPT = """
|
| 22 |
<ROLE>
|
| 23 |
-
You are a data analyst. You
|
| 24 |
-
|
|
|
|
| 25 |
</ROLE>
|
| 26 |
|
| 27 |
<RULES>
|
| 28 |
-
- Use `print()` to
|
| 29 |
-
-
|
|
|
|
|
|
|
|
|
|
| 30 |
- When you have the answer, submit it in the exact format requested
|
| 31 |
- Be precise with numbers and formatting
|
| 32 |
</RULES>
|
|
@@ -42,8 +47,6 @@ Respond with ONLY the JSON, no other text.
|
|
| 42 |
</NOTE>
|
| 43 |
"""
|
| 44 |
|
| 45 |
-
FALLBACK_ACTION = json.dumps({"action": "submit_answer", "answer": "unknown"})
|
| 46 |
-
|
| 47 |
|
| 48 |
def log_start(task: str, env: str, model: str) -> None:
|
| 49 |
"""Log the start of a task episode.
|
|
@@ -87,105 +90,6 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
|
|
| 87 |
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 88 |
|
| 89 |
|
| 90 |
-
def parse_model_action(response_text: str) -> dict:
|
| 91 |
-
"""Parse the model's raw text response into an action dict.
|
| 92 |
-
|
| 93 |
-
Handles multiple LLM response edge cases:
|
| 94 |
-
- Markdown code blocks (```json ... ``` or ``` ... ```)
|
| 95 |
-
- Double curly braces e.g. {{"key": "value"}}
|
| 96 |
-
- Single quotes instead of double quotes e.g. {'key': 'value'}
|
| 97 |
-
- Python literals: True/False/None → true/false/null
|
| 98 |
-
- Trailing commas in objects/arrays e.g. {"key": "value",}
|
| 99 |
-
- Extra text/prose before or after the JSON blob
|
| 100 |
-
- Escaped single quotes inside single-quoted strings
|
| 101 |
-
- Whitespace and newline noise
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
response_text: Raw string returned by the model.
|
| 105 |
-
|
| 106 |
-
Returns:
|
| 107 |
-
Parsed action dict, or a fallback submit_answer on failure.
|
| 108 |
-
"""
|
| 109 |
-
|
| 110 |
-
def attempt_parse(text: str) -> dict:
|
| 111 |
-
return json.loads(text)
|
| 112 |
-
|
| 113 |
-
def apply_fixes(text: str) -> str:
|
| 114 |
-
# Strip markdown code blocks
|
| 115 |
-
if text.startswith("```"):
|
| 116 |
-
parts = text.split("```")
|
| 117 |
-
if len(parts) >= 2:
|
| 118 |
-
text = parts[1]
|
| 119 |
-
if text.startswith("json"):
|
| 120 |
-
text = text[4:]
|
| 121 |
-
text = text.strip()
|
| 122 |
-
|
| 123 |
-
# Double curly braces → single
|
| 124 |
-
text = text.replace("{{", "{").replace("}}", "}")
|
| 125 |
-
|
| 126 |
-
# Python literals → JSON literals
|
| 127 |
-
text = re.sub(r"\bTrue\b", "true", text)
|
| 128 |
-
text = re.sub(r"\bFalse\b", "false", text)
|
| 129 |
-
text = re.sub(r"\bNone\b", "null", text)
|
| 130 |
-
|
| 131 |
-
# Trailing commas before } or ]
|
| 132 |
-
text = re.sub(r",\s*([}\]])", r"\1", text)
|
| 133 |
-
|
| 134 |
-
# Single quote handling — two distinct cases:
|
| 135 |
-
#
|
| 136 |
-
# Case 1: Entire JSON is wrapped in outer single quotes
|
| 137 |
-
# e.g. '{"action": "x", "code": "df[\'col\']"}'
|
| 138 |
-
# → strip the outer quotes and unescape internal \'
|
| 139 |
-
if text.startswith("'") and text.endswith("'"):
|
| 140 |
-
text = text[1:-1].replace("\\'", "'")
|
| 141 |
-
|
| 142 |
-
# Case 2: JSON itself uses single quotes as delimiters
|
| 143 |
-
# e.g. {'action': 'execute_code', 'code': 'print()'}
|
| 144 |
-
# → only apply when structure looks single-quote delimited
|
| 145 |
-
# → avoids corrupting double-quoted values that contain bracket notation
|
| 146 |
-
elif text.startswith("{'") or ("': " in text and '": ' not in text):
|
| 147 |
-
text = re.sub(
|
| 148 |
-
r"'((?:\\'|[^'])*)'", lambda m: '"' + m.group(1).replace("\\'", "'").replace('"', '\\"') + '"', text
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
return text
|
| 152 |
-
|
| 153 |
-
def extract_json_blob(text: str) -> str:
|
| 154 |
-
"""Extract the first {...} or [...] blob from text with surrounding prose."""
|
| 155 |
-
match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
|
| 156 |
-
if match:
|
| 157 |
-
return match.group(1)
|
| 158 |
-
return text
|
| 159 |
-
|
| 160 |
-
text = response_text.strip()
|
| 161 |
-
|
| 162 |
-
try:
|
| 163 |
-
return attempt_parse(text)
|
| 164 |
-
except json.JSONDecodeError:
|
| 165 |
-
pass
|
| 166 |
-
|
| 167 |
-
try:
|
| 168 |
-
return attempt_parse(apply_fixes(text))
|
| 169 |
-
except json.JSONDecodeError:
|
| 170 |
-
pass
|
| 171 |
-
|
| 172 |
-
try:
|
| 173 |
-
blob = extract_json_blob(text)
|
| 174 |
-
return attempt_parse(apply_fixes(blob))
|
| 175 |
-
except json.JSONDecodeError:
|
| 176 |
-
pass
|
| 177 |
-
|
| 178 |
-
try:
|
| 179 |
-
fixed = apply_fixes(text)
|
| 180 |
-
blob = extract_json_blob(fixed)
|
| 181 |
-
return attempt_parse(blob)
|
| 182 |
-
except json.JSONDecodeError:
|
| 183 |
-
pass
|
| 184 |
-
|
| 185 |
-
print(f"JSON Decoding Error while parsing action in response text: {response_text}")
|
| 186 |
-
return json.loads(FALLBACK_ACTION)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
def run_task(openai_client: OpenAI, env_client: Any, task_id: int) -> float:
|
| 190 |
"""Run a single task episode using the language model as the agent.
|
| 191 |
|
|
@@ -234,7 +138,6 @@ def run_task(openai_client: OpenAI, env_client: Any, task_id: int) -> float:
|
|
| 234 |
except Exception as exc:
|
| 235 |
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 236 |
response_text = FALLBACK_ACTION
|
| 237 |
-
|
| 238 |
action = parse_model_action(response_text)
|
| 239 |
action_type = action.get("action", "")
|
| 240 |
|
|
|
|
| 7 |
from openai import OpenAI
|
| 8 |
|
| 9 |
from client import DataAnalysisClient
|
| 10 |
+
from helpers.response_parser import FALLBACK_ACTION, parse_model_action
|
| 11 |
from models import DataAction
|
| 12 |
|
| 13 |
load_dotenv()
|
|
|
|
| 21 |
|
| 22 |
SYSTEM_PROMPT = """
|
| 23 |
<ROLE>
|
| 24 |
+
You are a data analyst. You have two data sources available:
|
| 25 |
+
1. `df` — a pandas DataFrame (sales CSV, pre-loaded)
|
| 26 |
+
2. A SQLite database at `db_path` — contains additional tables (e.g. customer_profiles, product_catalog)
|
| 27 |
</ROLE>
|
| 28 |
|
| 29 |
<RULES>
|
| 30 |
+
- Use `print()` to output results
|
| 31 |
+
- `pd`, `np`, `sqlite3`, and `db_path` are already in scope — NEVER use import statements (they will fail)
|
| 32 |
+
- `df` is a pandas DataFrame — use pandas operations on it, NEVER SQL
|
| 33 |
+
- To query the SQLite database use: `conn = sqlite3.connect(db_path)` then `pd.read_sql(query, conn)`
|
| 34 |
+
- For cross-source tasks: query SQLite for the extra data, then merge with `df` using pandas
|
| 35 |
- When you have the answer, submit it in the exact format requested
|
| 36 |
- Be precise with numbers and formatting
|
| 37 |
</RULES>
|
|
|
|
| 47 |
</NOTE>
|
| 48 |
"""
|
| 49 |
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def log_start(task: str, env: str, model: str) -> None:
|
| 52 |
"""Log the start of a task episode.
|
|
|
|
| 90 |
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 91 |
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def run_task(openai_client: OpenAI, env_client: Any, task_id: int) -> float:
|
| 94 |
"""Run a single task episode using the language model as the agent.
|
| 95 |
|
|
|
|
| 138 |
except Exception as exc:
|
| 139 |
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 140 |
response_text = FALLBACK_ACTION
|
|
|
|
| 141 |
action = parse_model_action(response_text)
|
| 142 |
action_type = action.get("action", "")
|
| 143 |
|
server/data_analysis_env.py
CHANGED
|
@@ -35,6 +35,7 @@ class DataAnalysisEnv(Environment):
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
MAX_STEPS = 20
|
|
|
|
| 38 |
|
| 39 |
def __init__(self):
|
| 40 |
"""Initialize the environment with default state."""
|
|
@@ -79,14 +80,38 @@ class DataAnalysisEnv(Environment):
|
|
| 79 |
def _dataset_info(self) -> str:
|
| 80 |
"""Generate a summary of the dataset schema for the agent.
|
| 81 |
|
|
|
|
|
|
|
|
|
|
| 82 |
Returns:
|
| 83 |
-
A string describing column names, dtypes, row count,
|
|
|
|
| 84 |
"""
|
| 85 |
buf = io.StringIO()
|
| 86 |
self._df.info(buf=buf)
|
| 87 |
info_str = buf.getvalue()
|
| 88 |
sample = self._df.head(3).to_string()
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
def reset(
|
| 92 |
self,
|
|
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
MAX_STEPS = 20
|
| 38 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 39 |
|
| 40 |
def __init__(self):
|
| 41 |
"""Initialize the environment with default state."""
|
|
|
|
| 80 |
def _dataset_info(self) -> str:
|
| 81 |
"""Generate a summary of the dataset schema for the agent.
|
| 82 |
|
| 83 |
+
Includes the sales DataFrame schema plus the SQLite database table schemas
|
| 84 |
+
so the agent knows what data is available and where to find it.
|
| 85 |
+
|
| 86 |
Returns:
|
| 87 |
+
A string describing column names, dtypes, row count, a sample for df,
|
| 88 |
+
and table schemas for the SQLite database.
|
| 89 |
"""
|
| 90 |
buf = io.StringIO()
|
| 91 |
self._df.info(buf=buf)
|
| 92 |
info_str = buf.getvalue()
|
| 93 |
sample = self._df.head(3).to_string()
|
| 94 |
+
df_section = f"=== df (pandas DataFrame, pre-loaded from sales CSV) ===\nShape: {self._df.shape}\n{info_str}\nSample rows:\n{sample}"
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
conn = sqlite3.connect(DB_PATH)
|
| 98 |
+
cursor = conn.cursor()
|
| 99 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
| 100 |
+
tables = [row[0] for row in cursor.fetchall()]
|
| 101 |
+
db_lines = ["\n=== SQLite database (accessible via sqlite3.connect(db_path)) ==="]
|
| 102 |
+
for table in tables:
|
| 103 |
+
cursor.execute(f"PRAGMA table_info({table})")
|
| 104 |
+
cols = [(row[1], row[2]) for row in cursor.fetchall()]
|
| 105 |
+
cursor.execute(f"SELECT COUNT(*) FROM {table}")
|
| 106 |
+
count = cursor.fetchone()[0]
|
| 107 |
+
col_str = ", ".join(f"{c} ({t})" for c, t in cols)
|
| 108 |
+
db_lines.append(f" Table '{table}' ({count} rows): {col_str}")
|
| 109 |
+
conn.close()
|
| 110 |
+
db_section = "\n".join(db_lines)
|
| 111 |
+
except Exception:
|
| 112 |
+
db_section = "\n=== SQLite database: schema unavailable ==="
|
| 113 |
+
|
| 114 |
+
return f"{df_section}\n{db_section}"
|
| 115 |
|
| 116 |
def reset(
|
| 117 |
self,
|