Spaces:
Sleeping
Sleeping
| """ | |
| Multi-component grading system for SQL query evaluation. | |
| Scores agent queries against ground truth with partial credit: | |
| - syntax_score (0.1): Query parses and executes without error | |
| - column_score (0.2): Fraction of expected columns present | |
| - row_score (0.3): Fraction of expected rows matching | |
| - exact_score (0.4): Full result set matches ground truth exactly | |
| Total reward per question is in (0.0, 1.0) — strictly between 0 and 1. | |
| """ | |
| from typing import Any, List, Optional, Tuple | |
| from .database import Database, QueryResult | |
| # Epsilon to ensure scores are strictly between 0 and 1 (never exactly 0.0 or 1.0) | |
| _EPS = 0.001 | |
| def _clamp_reward(reward: float) -> float: | |
| """Clamp reward to be strictly within (0, 1).""" | |
| return min(max(reward, _EPS), 1.0 - _EPS) | |
| def _normalize_value(val: Any) -> Any: | |
| """Normalize a value for comparison (handle float/int equivalence, None).""" | |
| if val is None: | |
| return None | |
| if isinstance(val, float): | |
| if val == int(val): | |
| return int(val) | |
| return round(val, 2) | |
| if isinstance(val, str): | |
| return val.strip() | |
| return val | |
| def _normalize_row(row: Tuple) -> Tuple: | |
| """Normalize all values in a row.""" | |
| return tuple(_normalize_value(v) for v in row) | |
| def _normalize_column_name(col: str) -> str: | |
| """Normalize column name for comparison (lowercase, strip).""" | |
| return col.strip().lower() | |
| def grade_query( | |
| db: Database, | |
| agent_sql: str, | |
| expected_columns: List[str], | |
| expected_rows: List[List], | |
| order_matters: bool = True, | |
| ) -> dict: | |
| """ | |
| Grade an agent's SQL query against expected results. | |
| Args: | |
| db: Active Database instance. | |
| agent_sql: The SQL query submitted by the agent. | |
| expected_columns: List of expected column names. | |
| expected_rows: List of expected row values (list of lists). | |
| order_matters: Whether row order affects scoring. | |
| Returns: | |
| Dictionary with: | |
| - reward: float in [0.0, 1.0] | |
| - syntax_score: float | |
| - column_score: float | |
| - row_score: float | |
| - exact_score: float | |
| - query_result: QueryResult object | |
| - feedback: str describing what was right/wrong | |
| """ | |
| result = db.execute_query(agent_sql) | |
| # Component weights | |
| W_SYNTAX = 0.1 | |
| W_COLUMN = 0.2 | |
| W_ROW = 0.3 | |
| W_EXACT = 0.4 | |
| # --- Syntax Score --- | |
| if not result.success: | |
| return { | |
| "reward": _clamp_reward(0.0), | |
| "syntax_score": 0.0, | |
| "column_score": 0.0, | |
| "row_score": 0.0, | |
| "exact_score": 0.0, | |
| "query_result": result, | |
| "feedback": f"SQL error: {result.error}", | |
| } | |
| syntax_score = 1.0 | |
| # --- Column Score --- | |
| expected_cols_normalized = [_normalize_column_name(c) for c in expected_columns] | |
| actual_cols_normalized = [_normalize_column_name(c) for c in result.columns] | |
| if not expected_cols_normalized: | |
| column_score = 1.0 if not actual_cols_normalized else 0.0 | |
| else: | |
| matched_cols = sum( | |
| 1 for c in expected_cols_normalized if c in actual_cols_normalized | |
| ) | |
| column_score = matched_cols / len(expected_cols_normalized) | |
| # --- Row Score --- | |
| expected_rows_normalized = [ | |
| _normalize_row(tuple(row)) for row in expected_rows | |
| ] | |
| actual_rows_normalized = [_normalize_row(row) for row in result.rows] | |
| if not expected_rows_normalized: | |
| row_score = 1.0 if not actual_rows_normalized else 0.0 | |
| else: | |
| if order_matters: | |
| # For ordered results, match position-by-position | |
| matched_rows = 0 | |
| for i, expected_row in enumerate(expected_rows_normalized): | |
| if i < len(actual_rows_normalized): | |
| if _rows_match(expected_row, actual_rows_normalized[i], expected_cols_normalized, actual_cols_normalized): | |
| matched_rows += 1 | |
| row_score = matched_rows / len(expected_rows_normalized) | |
| else: | |
| # For unordered results, check set membership | |
| matched_rows = 0 | |
| remaining_actual = list(actual_rows_normalized) | |
| for expected_row in expected_rows_normalized: | |
| for j, actual_row in enumerate(remaining_actual): | |
| if _rows_match(expected_row, actual_row, expected_cols_normalized, actual_cols_normalized): | |
| matched_rows += 1 | |
| remaining_actual.pop(j) | |
| break | |
| row_score = matched_rows / len(expected_rows_normalized) | |
| # --- Exact Score --- | |
| exact_score = 0.0 | |
| if column_score == 1.0 and row_score == 1.0: | |
| # Check exact match: same number of rows and all matched | |
| if len(actual_rows_normalized) == len(expected_rows_normalized): | |
| exact_score = 1.0 | |
| else: | |
| # Extra rows returned — partial exact credit | |
| exact_score = 0.5 | |
| # --- Total Reward --- | |
| reward = ( | |
| W_SYNTAX * syntax_score | |
| + W_COLUMN * column_score | |
| + W_ROW * row_score | |
| + W_EXACT * exact_score | |
| ) | |
| reward = round(_clamp_reward(reward), 4) | |
| # --- Feedback --- | |
| feedback_parts = [] | |
| if syntax_score == 1.0: | |
| feedback_parts.append("Query executed successfully.") | |
| if column_score < 1.0: | |
| missing = [c for c in expected_cols_normalized if c not in actual_cols_normalized] | |
| feedback_parts.append(f"Missing columns: {missing}. Expected: {expected_cols_normalized}, Got: {actual_cols_normalized}") | |
| if row_score < 1.0: | |
| feedback_parts.append( | |
| f"Row match: {row_score:.0%} ({int(row_score * len(expected_rows_normalized))}/{len(expected_rows_normalized)} rows correct). " | |
| f"Expected {len(expected_rows_normalized)} rows, got {len(actual_rows_normalized)}." | |
| ) | |
| if exact_score == 1.0: | |
| feedback_parts.append("Perfect match!") | |
| elif exact_score == 0.5: | |
| feedback_parts.append(f"All expected rows found but got {len(actual_rows_normalized)} rows instead of {len(expected_rows_normalized)} (extra rows).") | |
| return { | |
| "reward": reward, | |
| "syntax_score": syntax_score, | |
| "column_score": column_score, | |
| "row_score": row_score, | |
| "exact_score": exact_score, | |
| "query_result": result, | |
| "feedback": " ".join(feedback_parts), | |
| } | |
| def _rows_match( | |
| expected_row: Tuple, | |
| actual_row: Tuple, | |
| expected_cols: List[str], | |
| actual_cols: List[str], | |
| ) -> bool: | |
| """ | |
| Check if an actual row matches an expected row. | |
| Handles column reordering: maps expected columns to actual column positions. | |
| """ | |
| if len(expected_cols) != len(expected_row): | |
| return False | |
| # Build a mapping from expected column index to actual column index | |
| col_map = {} | |
| for i, ec in enumerate(expected_cols): | |
| if ec in actual_cols: | |
| col_map[i] = actual_cols.index(ec) | |
| else: | |
| return False # Missing column | |
| for exp_idx, act_idx in col_map.items(): | |
| if act_idx >= len(actual_row): | |
| return False | |
| exp_val = expected_row[exp_idx] | |
| act_val = _normalize_value(actual_row[act_idx]) | |
| if exp_val != act_val: | |
| # Try numeric comparison with tolerance | |
| try: | |
| if abs(float(exp_val) - float(act_val)) < 0.01: | |
| continue | |
| except (TypeError, ValueError): | |
| pass | |
| return False | |
| return True | |