Spaces:
Runtime error
Runtime error
| import mysql.connector | |
| from mysql.connector import Error | |
| import requests | |
| import json | |
| import os | |
| def generate_sql_query(natural_language_query, schema_info, space_url): | |
| """Generate SQL query using Hugging Face Space API.""" | |
| # Construct a more structured prompt | |
| prompt = f"""Given this SQL table schema: | |
| {schema_info} | |
| Write a SQL query that will: | |
| {natural_language_query} | |
| The query should be valid MySQL syntax and include only the SELECT statement.""" | |
| # Make API request to the Hugging Face Space | |
| payload = { | |
| "inputs": prompt, | |
| "options": { | |
| "use_cache": False | |
| } | |
| } | |
| try: | |
| response = requests.post(space_url, json=payload) | |
| if response.status_code == 200: | |
| return response.json().get('generated_text', '').strip() | |
| else: | |
| raise Exception(f"API request failed: {response.text}") | |
| except Exception as e: | |
| print(f"API Error: {str(e)}") | |
| return None | |
| def main(): | |
| try: | |
| # Define the Hugging Face Space URL | |
| space_url = "https://huggingface.co/spaces/nileshhanotia/sql" | |
| # Define your schema information | |
| schema_info = """ | |
| CREATE TABLE sales ( | |
| pizza_id DECIMAL(8,2) PRIMARY KEY, | |
| order_id DECIMAL(8,2), | |
| pizza_name_id VARCHAR(14), | |
| quantity DECIMAL(4,2), | |
| order_date DATE, | |
| order_time VARCHAR(8), | |
| unit_price DECIMAL(5,2), | |
| total_price DECIMAL(5,2), | |
| pizza_size VARCHAR(3), | |
| pizza_category VARCHAR(7), | |
| pizza_ingredients VARCHAR(97), | |
| pizza_name VARCHAR(42) | |
| ); | |
| """ | |
| # Establish connection to the database | |
| connection = mysql.connector.connect( | |
| host="localhost", | |
| database="pizza", | |
| user="root", | |
| password="root", | |
| port=8889 | |
| ) | |
| if connection.is_connected(): | |
| cursor = connection.cursor() | |
| print("Database connected successfully!") | |
| while True: | |
| try: | |
| # Get user input | |
| print("\nEnter your question (or 'exit' to quit):") | |
| natural_language_query = input("> ").strip() | |
| if natural_language_query.lower() == 'exit': | |
| break | |
| # Generate and execute query | |
| sql_query = generate_sql_query(natural_language_query, schema_info, space_url) | |
| if sql_query: | |
| print(f"\nExecuting SQL Query:\n{sql_query}") | |
| cursor.execute(sql_query) | |
| records = cursor.fetchall() | |
| # Print results | |
| if records: | |
| print("\nResults:") | |
| # Get column names | |
| columns = [desc[0] for desc in cursor.description] | |
| print(" | ".join(columns)) | |
| print("-" * (len(" | ".join(columns)) + 10)) | |
| for row in records: | |
| print(" | ".join(str(val) for val in row)) | |
| else: | |
| print("\nNo results found.") | |
| except KeyboardInterrupt: | |
| print("\nOperation cancelled by user.") | |
| continue | |
| except Exception as e: | |
| print(f"\nError: {str(e)}") | |
| continue | |
| except Error as e: | |
| print(f"\nDatabase error: {str(e)}") | |
| except Exception as e: | |
| print(f"\nApplication error: {str(e)}") | |
| finally: | |
| if 'connection' in locals() and connection.is_connected(): | |
| cursor.close() | |
| connection.close() | |
| print("\nMySQL connection closed.") | |
| if __name__ == "__main__": | |
| main() | |