Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import re | |
| import sqlite3 | |
| import requests | |
| import logging | |
| from langchain_community.utilities.sql_database import SQLDatabase | |
| from sqlalchemy import create_engine, text | |
| from sqlalchemy.pool import StaticPool | |
| logger = logging.getLogger(__name__) | |
| _engine = None | |
| _db = None | |
| CHINOOK_SQL_URL = ( | |
| "https://raw.githubusercontent.com/lerocha/chinook-database/" | |
| "master/ChinookDatabase/DataSources/Chinook_Sqlite.sql" | |
| ) | |
| # Use project root for the cached SQL file | |
| LOCAL_SQL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "Chinook_Sqlite.sql") | |
| def _load_sql_script() -> str: | |
| if os.path.isfile(LOCAL_SQL_PATH): | |
| logger.info(f"Loading Chinook SQL from local file: {LOCAL_SQL_PATH}") | |
| with open(LOCAL_SQL_PATH, "r", encoding="utf-8") as f: | |
| return f.read() | |
| logger.info("Local SQL file not found. Downloading from GitHub...") | |
| response = requests.get(CHINOOK_SQL_URL, timeout=60) | |
| response.raise_for_status() | |
| sql_script = response.text | |
| try: | |
| with open(LOCAL_SQL_PATH, "w", encoding="utf-8") as f: | |
| f.write(sql_script) | |
| logger.info(f"Cached SQL script to {LOCAL_SQL_PATH}") | |
| except Exception as e: | |
| logger.warning(f"Could not cache SQL script locally: {e}") | |
| return sql_script | |
| def _create_engine(): | |
| sql_script = _load_sql_script() | |
| connection = sqlite3.connect(":memory:", check_same_thread=False) | |
| connection.executescript(sql_script) | |
| logger.info("Chinook database loaded successfully into memory.") | |
| return create_engine( | |
| "sqlite://", | |
| creator=lambda: connection, | |
| poolclass=StaticPool, | |
| connect_args={"check_same_thread": False}, | |
| ) | |
| def get_engine(): | |
| global _engine | |
| if _engine is None: | |
| _engine = _create_engine() | |
| return _engine | |
| def get_db() -> SQLDatabase: | |
| global _db | |
| if _db is None: | |
| _db = SQLDatabase(get_engine()) | |
| return _db | |
| def run_query_safe(query: str, params: dict = None) -> str: | |
| engine = get_engine() | |
| try: | |
| with engine.connect() as conn: | |
| if params: | |
| result = conn.execute(text(query), params) | |
| else: | |
| result = conn.execute(text(query)) | |
| rows = result.fetchall() | |
| columns = result.keys() | |
| if not rows: | |
| return "[]" | |
| results_list = [dict(zip(columns, row)) for row in rows] | |
| return json.dumps(results_list, default=str) | |
| except Exception as e: | |
| logger.error(f"Query error: {e} | query={query} | params={params}") | |
| raise | |
| def normalize_phone(phone: str) -> str: | |
| if not phone: | |
| return "" | |
| phone = phone.strip() | |
| if phone.startswith("+"): | |
| return "+" + re.sub(r"[^\d]", "", phone[1:]) | |
| return re.sub(r"[^\d]", "", phone) | |
| def verify_database() -> dict: | |
| try: | |
| db = get_db() | |
| tables = db.get_usable_table_names() | |
| result = db.run("SELECT COUNT(*) FROM Customer;") | |
| logger.info(f"Database verification OK. Customer count query returned: {result}") | |
| return {"status": "healthy", "tables": len(tables)} | |
| except Exception as e: | |
| logger.error(f"Database verification failed: {e}") | |
| return {"status": "unhealthy", "error": str(e)} | |