Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import pandas as pd | |
| import json | |
| from pathlib import Path | |
| import psycopg2 | |
| from dotenv import load_dotenv | |
| import time | |
| from tqdm import tqdm | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from main import get_gemini_response, sql_prompt | |
| def validate_sql_query(query, conn): | |
| """Validate if the SQL query is syntactically correct""" | |
| try: | |
| with conn.cursor() as cursor: | |
| # Reset any aborted transaction | |
| conn.rollback() | |
| # Now try to validate the query | |
| cursor.execute("EXPLAIN " + query) | |
| conn.commit() # Commit the successful EXPLAIN | |
| return True, None | |
| except psycopg2.Error as e: | |
| # Rollback on error | |
| conn.rollback() | |
| return False, str(e) | |
| def handle_api_error(error): | |
| """Handle different types of API errors""" | |
| if "429" in str(error): | |
| return "API quota exceeded", 30 # Wait 30 seconds | |
| return str(error), 0 | |
| def run_query_tests(): | |
| load_dotenv() | |
| # Database connection | |
| conn = psycopg2.connect( | |
| dbname=os.getenv('DB_NAME'), | |
| user=os.getenv('DB_USER'), | |
| password=os.getenv('DB_PASSWORD'), | |
| host=os.getenv('DB_HOST'), | |
| port=os.getenv('DB_PORT', '5432') | |
| ) | |
| # Read the test dataset with encoding specification | |
| csv_path = Path(__file__).parent / 'Pagila Evals Dataset(Sheet1).csv' | |
| try: | |
| test_data = pd.read_csv(csv_path, encoding='cp1252') | |
| except UnicodeDecodeError: | |
| # Fallback to latin-1 if cp1252 fails | |
| test_data = pd.read_csv(csv_path, encoding='latin-1') | |
| # Clean up any special quotes in the queries | |
| test_data['Natural Language Query'] = test_data['Natural Language Query'].str.replace('"', '"').str.replace('"', '"') | |
| results_dir = Path(__file__).parent / 'results' | |
| results_dir.mkdir(exist_ok=True) | |
| # Load existing results if any | |
| output_file = results_dir / 'query_results.json' | |
| if output_file.exists(): | |
| with open(output_file, 'r', encoding='utf-8') as f: | |
| results = json.load(f) | |
| else: | |
| results = {} | |
| # Process queries with progress bar | |
| for _, row in tqdm(test_data.iterrows(), total=len(test_data), desc="Processing queries"): | |
| query_num = str(row['Query Number']) | |
| # Skip if already processed successfully | |
| if query_num in results and results[query_num]['sql_query'] and results[query_num]['is_valid']: | |
| continue | |
| nl_query = row['Natural Language Query'] | |
| difficulty = row['Difficulty'] | |
| max_retries = 3 | |
| retry_count = 0 | |
| while retry_count < max_retries: | |
| try: | |
| sql_query = get_gemini_response(nl_query, sql_prompt) | |
| sql_query = sql_query.replace('```sql', '').replace('```', '').strip() | |
| is_valid, error_msg = validate_sql_query(sql_query, conn) | |
| results[query_num] = { | |
| 'natural_language_query': nl_query, | |
| 'sql_query': sql_query, | |
| 'difficulty': difficulty, | |
| 'is_valid': is_valid, | |
| 'error': error_msg | |
| } | |
| # Save progress after each successful query | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, indent=2, ensure_ascii=False) | |
| break # Success, exit retry loop | |
| except Exception as e: | |
| error_msg, wait_time = handle_api_error(e) | |
| retry_count += 1 | |
| if wait_time > 0: | |
| print(f"\nAPI quota exceeded. Waiting {wait_time} seconds...") | |
| time.sleep(wait_time) | |
| if retry_count == max_retries: | |
| results[query_num] = { | |
| 'natural_language_query': nl_query, | |
| 'sql_query': None, | |
| 'difficulty': difficulty, | |
| 'is_valid': False, | |
| 'error': error_msg | |
| } | |
| # Save progress even for failed queries | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, indent=2, ensure_ascii=False) | |
| conn.close() | |
| print(f"\nResults saved to {output_file}") | |
| if __name__ == "__main__": | |
| run_query_tests() | |