Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from env.environment import SQLDebuggerEnvironment | |
| from env.models import ( | |
| Action, ActionType, DifficultyLevel, | |
| BaselineResponse, BaselineResult | |
| ) | |
| # BASELINE AGENT | |
| # Uses rule-based heuristics β no GPT-4 | |
| # Must complete within 60 seconds | |
| # OPENAI_API_KEY must come from environment | |
| def _check_api_key(): | |
| """Edge case: OPENAI_API_KEY not set β raise clear error.""" | |
| key = os.environ.get("OPENAI_API_KEY") | |
| if not key: | |
| raise ValueError( | |
| "OPENAI_API_KEY environment variable is not set. " | |
| "Please set it before running baseline: " | |
| "set OPENAI_API_KEY=your-key-here" | |
| ) | |
| return key | |
| def _rule_based_agent(env: SQLDebuggerEnvironment, task: dict) -> tuple[float, int, str]: | |
| """ | |
| Rule-based baseline agent that analyzes the buggy query | |
| and attempts a fix using heuristics. | |
| Fast β no API calls needed for baseline scoring. | |
| """ | |
| context = task.get("current_context", {}) | |
| buggy_query = context.get("buggy_query", "") | |
| error_msg = context.get("error_message", "") | |
| error_type = context.get("error_type_hint", "syntax") | |
| category = context.get("category", "syntax") | |
| steps_taken = 0 | |
| total_reward = 0.0 | |
| # ββ Step 1: Identify the error ββββββββββββββββββββββββββββββββ | |
| identify_payload = { | |
| "error_location": _guess_error_location(buggy_query, error_msg, category), | |
| "error_type": error_type, | |
| "explanation": f"Detected {category} issue in query: {error_msg[:100]}" | |
| } | |
| action1 = Action( | |
| action_type=ActionType.IDENTIFY_ERROR, | |
| payload=identify_payload | |
| ) | |
| resp1 = env.step(action1) | |
| total_reward += resp1.reward.score | |
| steps_taken += 1 | |
| if resp1.done: | |
| return total_reward, steps_taken, resp1.reward.feedback | |
| # ββ Step 2: Submit answer based on heuristic fix ββββββββββββββ | |
| fixed_query = _apply_heuristic_fix(buggy_query, category, error_msg) | |
| explanation = _generate_explanation(buggy_query, fixed_query, category) | |
| if category == "performance": | |
| action2 = Action( | |
| action_type=ActionType.OPTIMIZE_QUERY, | |
| payload={ | |
| "optimized_query": fixed_query, | |
| "optimization_type": f"Fix {category} issue: {error_type}", | |
| "explanation": explanation, | |
| "root_cause": f"Performance issue detected: {error_msg[:100]}", | |
| "expected_improvement":"Significant reduction in query execution time", | |
| "confidence": 0.6 | |
| } | |
| ) | |
| else: | |
| action2 = Action( | |
| action_type=ActionType.SUBMIT_ANSWER, | |
| payload={ | |
| "fixed_query": fixed_query, | |
| "explanation": explanation, | |
| "error_type": error_type, | |
| "error_location": identify_payload["error_location"], | |
| "confidence": 0.6 | |
| } | |
| ) | |
| resp2 = env.step(action2) | |
| total_reward += resp2.reward.score | |
| steps_taken += 1 | |
| return total_reward, steps_taken, resp2.reward.feedback | |
| def _guess_error_location(query: str, error_msg: str, category: str) -> str: | |
| """Heuristic: guess where the error is based on keywords.""" | |
| q = query.upper() | |
| e = error_msg.upper() | |
| if "SELECT" in e or "COLUMN" in e: | |
| return "SELECT clause" | |
| if "WHERE" in e or "FILTER" in e: | |
| return "WHERE clause" | |
| if "JOIN" in e or "ON" in e: | |
| return "JOIN condition" | |
| if "GROUP" in e or "HAVING" in e: | |
| return "GROUP BY / HAVING clause" | |
| if "ORDER" in e: | |
| return "ORDER BY clause" | |
| if category == "performance": | |
| return "Query structure β performance bottleneck" | |
| return "Unknown location" | |
| def _apply_heuristic_fix(query: str, category: str, error_msg: str) -> str: | |
| """ | |
| Apply simple heuristic fixes based on category. | |
| Not perfect β baseline is meant to score low-medium, | |
| showing the environment has room for agent improvement. | |
| """ | |
| q = query.strip() | |
| if category == "syntax": | |
| # Fix missing commas in SELECT | |
| if "syntax error" in error_msg.lower() and "name" in error_msg.lower(): | |
| import re | |
| q = re.sub(r"SELECT\s+(\w+)\s+(\w+)", r"SELECT \1, \2", q, flags=re.IGNORECASE) | |
| # Fix missing WHERE | |
| if "WHERE" not in q.upper() and "=" in q: | |
| q = q.replace(" id =", " WHERE id =") | |
| q = q.replace(" name =", " WHERE name =") | |
| # Fix unclosed string | |
| if q.count("'") % 2 != 0: | |
| q = q + "'" | |
| # Fix ORDER β ORDER BY | |
| import re | |
| q = re.sub(r"\bORDER\s+(?!BY)(\w)", r"ORDER BY \1", q, flags=re.IGNORECASE) | |
| # Fix GROUP β GROUP BY | |
| q = re.sub(r"\bGROUP\s+(?!BY)(\w)", r"GROUP BY \1", q, flags=re.IGNORECASE) | |
| elif category == "logic": | |
| # Fix INNER JOIN β LEFT JOIN for inclusion | |
| if "INNER JOIN" in q.upper(): | |
| q = q.replace("INNER JOIN", "LEFT JOIN").replace("inner join", "LEFT JOIN") | |
| # Fix WHERE aggregate β HAVING | |
| import re | |
| having_pattern = re.compile( | |
| r"WHERE\s+(AVG|SUM|COUNT|MAX|MIN)\s*\(", re.IGNORECASE | |
| ) | |
| if having_pattern.search(q): | |
| # Move aggregate condition to HAVING | |
| q = having_pattern.sub("HAVING \\1(", q) | |
| elif category == "performance": | |
| # For performance issues, suggest JOIN-based rewrite | |
| if "SELECT *" in q.upper(): | |
| q = q.replace("SELECT *", "SELECT id, name, status, created_at") | |
| return q | |
| def _generate_explanation(buggy: str, fixed: str, category: str) -> str: | |
| """Generate a human-readable explanation of the fix.""" | |
| if buggy.strip() == fixed.strip(): | |
| return f"Analyzed the {category} issue. The query may require deeper inspection." | |
| explanations = { | |
| "syntax": "Fixed syntax error in the SQL query by correcting the query structure.", | |
| "logic": "Fixed logic error by correcting the JOIN type and query conditions.", | |
| "performance": "Optimized query performance by restructuring to avoid expensive operations.", | |
| } | |
| base = explanations.get(category, "Applied heuristic fix to the SQL query.") | |
| return f"{base} Original: '{buggy[:60]}...' Fixed: '{fixed[:60]}...'" | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN BASELINE RUNNER | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_baseline() -> BaselineResponse: | |
| """ | |
| Runs baseline agent against one task of each difficulty. | |
| Returns BaselineResponse with scores for all 3 tasks. | |
| Must complete within 60 seconds. | |
| """ | |
| try: | |
| _check_api_key() | |
| except ValueError as e: | |
| print(f"Warning: {e}") | |
| results = [] | |
| difficulties = [ | |
| (DifficultyLevel.EASY, "easy_001"), | |
| (DifficultyLevel.MEDIUM, "medium_001"), | |
| (DifficultyLevel.HARD, "hard_001"), | |
| ] | |
| for difficulty, task_id in difficulties: | |
| env = SQLDebuggerEnvironment() | |
| try: | |
| obs = env.reset(difficulty=difficulty.value, task_id=task_id) | |
| task_context = {"current_context": obs.current_context} | |
| start = time.time() | |
| score, steps, feedback = _rule_based_agent(env, task_context) | |
| elapsed = time.time() - start | |
| # FIX 1: clamp score strictly between 0 and 1 exclusive | |
| safe_score = round(max(0.001, min(0.999, float(score))), 4) | |
| results.append(BaselineResult( | |
| task_id = task_id, | |
| difficulty = difficulty, | |
| score = safe_score, | |
| steps = steps, | |
| feedback = f"{feedback} (elapsed: {elapsed:.2f}s)" | |
| )) | |
| print(f"Baseline {difficulty.value}: score={safe_score}, steps={steps}") | |
| except Exception as e: | |
| results.append(BaselineResult( | |
| task_id = task_id, | |
| difficulty = difficulty, | |
| score = 0.001, # FIX 2: was 0.0, which is an invalid boundary value | |
| steps = 0, | |
| feedback = f"Error: {str(e)}" | |
| )) | |
| avg = round(sum(r.score for r in results) / len(results), 4) if results else 0.5 | |
| print(f"Baseline average score: {avg}") | |
| return BaselineResponse(results=results, average_score=avg) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # DIRECT RUN | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| print("Running baseline agent...") | |
| response = run_baseline() | |
| print(f"\nFinal Results:") | |
| for r in response.results: | |
| print(f" {r.difficulty.value:8} | {r.task_id:12} | score={r.score} | steps={r.steps}") | |
| print(f"\nAverage Score: {response.average_score}") |