Spaces:
Sleeping
Sleeping
File size: 14,144 Bytes
4c1a85d c3002ad 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba b99e42b 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d c3002ad 4c1a85d c3002ad 4c1a85d c3002ad 4c1a85d c3002ad 4c1a85d c3002ad 4c1a85d c3002ad 4c1a85d c3002ad 4c1a85d c3002ad cd11aba 4c1a85d cd11aba 4c1a85d cd11aba c3002ad cd11aba c3002ad cd11aba c3002ad cd11aba 4c1a85d cd11aba 4c1a85d c3002ad 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d cd11aba 4c1a85d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 | #!/usr/bin/env python3
"""
DataQA Inference Script β Two-Phase Agent
------------------------------------------
LLM agent that plays the DataQA environment in two phases:
Phase 1: Identify all data quality issues
Phase 2: Propose fixes for identified issues
Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
Required environment variables:
API_BASE_URL - LLM API endpoint (e.g., https://router.huggingface.co/v1)
MODEL_NAME - Model identifier (e.g., Qwen/Qwen2.5-72B-Instruct)
HF_TOKEN - HuggingFace token / API key
STDOUT FORMAT (mandatory for evaluation):
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""
from __future__ import annotations
import os
import re
import sys
import time
from typing import List, Optional
import requests
from openai import OpenAI
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
BENCHMARK = "dataqa_env"
TASKS = ["easy", "medium", "hard", "alignment", "moderation"]
MAX_STEPS_PER_TASK = 3
# ---------------------------------------------------------------------------
# Logging helpers (structured stdout β exact format required by evaluation)
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# Environment HTTP client
# ---------------------------------------------------------------------------
class EnvHTTPClient:
"""Minimal HTTP client for the DataQA environment."""
def __init__(self, base_url: str):
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
def health(self) -> bool:
try:
r = self.session.get(f"{self.base_url}/health", timeout=10)
return r.status_code == 200
except Exception:
return False
def reset(self, task_id: str = "easy") -> dict:
r = self.session.post(
f"{self.base_url}/reset",
json={"task_id": task_id},
timeout=30,
)
r.raise_for_status()
return r.json()
def step(self, issues: list[str], fixes: list[str], task_id: str = "easy") -> dict:
r = self.session.post(
f"{self.base_url}/step",
json={"action": {"issues": issues, "fixes": fixes, "task_id": task_id}},
timeout=30,
)
r.raise_for_status()
return r.json()
# ---------------------------------------------------------------------------
# LLM Prompts
# ---------------------------------------------------------------------------
IDENTIFY_SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
You will be given:
1. A dataset in CSV format
2. A schema describing expected column types and constraints
3. Validation rules that the data should satisfy
You must identify ALL data quality issues and report each one in EXACTLY this format:
row:<row_number>,col:<column_name>,issue:<issue_type>
Supported issue types:
- missing_value (null, empty, or whitespace-only)
- wrong_type (value doesn't match expected type)
- duplicate_row (exact duplicate or duplicate key)
- out_of_range (value outside valid range)
- format_violation (wrong format, invalid enum value)
- inconsistent_value (computed field doesn't match, logical inconsistency)
- statistical_outlier (value is unreasonable given context)
- referential_integrity (foreign key violation)
CRITICAL INSTRUCTIONS FOR ROW NUMBERING:
- Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
- Row 1 = the FIRST data row after the header
- Row 2 = the SECOND data row after the header
- DO NOT use the employee_id, order_id, or experiment_id as the row number
- Column names must match exactly (use the CSV header names, lowercase)
- Check EVERY row and EVERY column systematically
- Consider cross-column consistency (e.g., total = quantity * price)
- Look for subtle issues like whitespace-only values, near-duplicates
- Report ALL issues you find, even if uncertain
Respond with ONLY the list of issues, one per line. No other text.
Example: row:3,col:salary,issue:missing_value"""
FIX_SYSTEM_PROMPT = """You are a data repair specialist. You have already identified data quality issues in a dataset. Now you must propose the correct values to fix each issue.
For each issue you identified, propose a fix in EXACTLY this format:
row:<row_number>,col:<column_name>,fix:<corrected_value>
Guidelines for proposing fixes:
- For missing_value: infer the correct value from context, schema, and other rows
- For wrong_type: convert to the correct type (e.g., "seventy-five thousand" β "75000")
- For out_of_range: propose a value within the valid range that makes sense in context
- For format_violation: correct the format (e.g., "26/01/2024" β "2024-01-26")
- For inconsistent_value: compute the correct value from related fields
- For duplicate_row: propose a corrected unique key or indicate removal
- For statistical_outlier: propose a reasonable value given the model/context
Use the schema, validation rules, and surrounding data to determine the correct fix.
Respond with ONLY the list of fixes, one per line. No other text.
Example: row:3,col:salary,fix:75000"""
def build_user_prompt(observation: dict, include_fixes: bool = False) -> str:
obs = observation if isinstance(observation, dict) else observation
parts = []
if obs.get("task_description"):
parts.append(f"TASK: {obs['task_description']}")
parts.append(f"SCHEMA:\n{obs.get('schema_description', '')}")
parts.append(f"VALIDATION RULES:\n{obs.get('validation_rules', '')}")
parts.append(f"DATASET:\n{obs.get('dataset_csv', '')}")
hint = obs.get("num_issues_hint", 0)
if hint:
parts.append(f"HINT: There are exactly {hint} issues to find.")
feedback = obs.get("feedback", "")
if feedback and "reset" not in feedback.lower():
parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
if include_fixes:
parts.append(
"Now propose fixes for ALL issues. "
"Use format: row:<N>,col:<name>,fix:<corrected_value>"
)
return "\n\n".join(parts)
def parse_llm_response(response: str) -> list[str]:
"""Extract issue lines from LLM response."""
issues = []
for line in response.strip().split("\n"):
line = line.strip()
if not line:
continue
line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
line = re.sub(r"^\s*[-*]\s*", "", line)
line = line.strip()
if "row" in line.lower() and "col" in line.lower():
match = re.search(
r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
line,
re.IGNORECASE,
)
if match:
normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
issues.append(normalized)
return issues
def parse_fix_response(response: str) -> list[str]:
"""Extract fix lines from LLM response."""
fixes = []
for line in response.strip().split("\n"):
line = line.strip()
if not line:
continue
line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
line = re.sub(r"^\s*[-*]\s*", "", line)
line = line.strip()
if "row" in line.lower() and "fix" in line.lower():
match = re.search(
r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+fix\s*[:=]\s*(.+?)$",
line,
re.IGNORECASE,
)
if match:
normalized = f"row:{match.group(1)},col:{match.group(2).lower()},fix:{match.group(3).strip()}"
fixes.append(normalized)
return fixes
def call_llm(client: OpenAI, system_prompt: str, user_prompt: str) -> str:
"""Call the LLM with retry on rate limit."""
for attempt in range(3):
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.1,
max_tokens=2048,
)
return response.choices[0].message.content or ""
except Exception as e:
if "rate_limit" in str(e).lower() or "429" in str(e):
wait = 10 * (attempt + 1)
print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
time.sleep(wait)
else:
print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
return ""
return ""
def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
"""
Run a single task with two-phase strategy:
Step 1: Identify issues only
Step 2: Identify + Fix (using feedback from step 1)
Step 3: Refined identify + fix (if needed)
"""
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
rewards: List[float] = []
steps_taken = 0
best_score = 0.0
success = False
try:
reset_response = env.reset(task_id=task_id)
observation = reset_response.get("observation", reset_response)
last_issues: list[str] = []
last_llm_output = ""
for step_num in range(1, MAX_STEPS_PER_TASK + 1):
error_msg = None
# ββ Phase 1: Identify issues ββ
user_prompt = build_user_prompt(observation)
identify_output = call_llm(client, IDENTIFY_SYSTEM_PROMPT, user_prompt)
issues = parse_llm_response(identify_output)
if not issues and not error_msg:
error_msg = "no issues parsed from LLM response"
# ββ Phase 2: Propose fixes (from step 2 onward, or always if we have issues) ββ
fixes: list[str] = []
if issues and step_num >= 2:
# Build a fix prompt that includes the identified issues
fix_prompt = build_user_prompt(observation, include_fixes=True)
fix_prompt += f"\n\nISSUES FOUND:\n" + "\n".join(issues)
fix_output = call_llm(client, FIX_SYSTEM_PROMPT, fix_prompt)
fixes = parse_fix_response(fix_output)
# ββ Submit to environment ββ
action_str = ";".join(issues[:5]) if issues else "none"
if fixes:
action_str += "|fixes:" + ";".join(fixes[:3])
step_response = env.step(issues, fixes, task_id=task_id)
observation = step_response.get("observation", step_response)
reward = float(step_response.get("reward", 0.0) or 0.0)
done = bool(step_response.get("done", False))
best_score = max(best_score, reward)
rewards.append(reward)
steps_taken = step_num
log_step(
step=step_num,
action=action_str,
reward=reward,
done=done,
error=error_msg,
)
if done:
break
last_issues = issues
last_llm_output = identify_output
success = best_score >= 0.5
finally:
log_end(success=success, steps=steps_taken, score=best_score, rewards=rewards)
return best_score
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
print(f"[DEBUG] DataQA Inference starting", file=sys.stderr, flush=True)
print(f"[DEBUG] ENV_URL={ENV_URL}", file=sys.stderr, flush=True)
print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
env = EnvHTTPClient(ENV_URL)
llm_client = OpenAI(
base_url=API_BASE_URL,
api_key=API_KEY or "no-key",
)
if not env.health():
print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
sys.exit(1)
print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
scores = {}
for task_id in TASKS:
try:
score = run_task(llm_client, env, task_id)
scores[task_id] = score
except Exception as e:
print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
scores[task_id] = 0.0
avg_score = sum(scores.values()) / len(scores) if scores else 0.0
print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
if __name__ == "__main__":
main()
|