testing_multi_agent / src /db /database.py
adilpanwar's picture
Upload 16 files
927c050 verified
Raw
History Blame Contribute Delete
3.3 kB
import json
import os
import re
import sqlite3
import requests
import logging
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine, text
from sqlalchemy.pool import StaticPool
logger = logging.getLogger(__name__)
_engine = None
_db = None
CHINOOK_SQL_URL = (
"https://raw.githubusercontent.com/lerocha/chinook-database/"
"master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
)
# Use project root for the cached SQL file
LOCAL_SQL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "Chinook_Sqlite.sql")
def _load_sql_script() -> str:
if os.path.isfile(LOCAL_SQL_PATH):
logger.info(f"Loading Chinook SQL from local file: {LOCAL_SQL_PATH}")
with open(LOCAL_SQL_PATH, "r", encoding="utf-8") as f:
return f.read()
logger.info("Local SQL file not found. Downloading from GitHub...")
response = requests.get(CHINOOK_SQL_URL, timeout=60)
response.raise_for_status()
sql_script = response.text
try:
with open(LOCAL_SQL_PATH, "w", encoding="utf-8") as f:
f.write(sql_script)
logger.info(f"Cached SQL script to {LOCAL_SQL_PATH}")
except Exception as e:
logger.warning(f"Could not cache SQL script locally: {e}")
return sql_script
def _create_engine():
sql_script = _load_sql_script()
connection = sqlite3.connect(":memory:", check_same_thread=False)
connection.executescript(sql_script)
logger.info("Chinook database loaded successfully into memory.")
return create_engine(
"sqlite://",
creator=lambda: connection,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
def get_engine():
global _engine
if _engine is None:
_engine = _create_engine()
return _engine
def get_db() -> SQLDatabase:
global _db
if _db is None:
_db = SQLDatabase(get_engine())
return _db
def run_query_safe(query: str, params: dict = None) -> str:
engine = get_engine()
try:
with engine.connect() as conn:
if params:
result = conn.execute(text(query), params)
else:
result = conn.execute(text(query))
rows = result.fetchall()
columns = result.keys()
if not rows:
return "[]"
results_list = [dict(zip(columns, row)) for row in rows]
return json.dumps(results_list, default=str)
except Exception as e:
logger.error(f"Query error: {e} | query={query} | params={params}")
raise
def normalize_phone(phone: str) -> str:
if not phone:
return ""
phone = phone.strip()
if phone.startswith("+"):
return "+" + re.sub(r"[^\d]", "", phone[1:])
return re.sub(r"[^\d]", "", phone)
def verify_database() -> dict:
try:
db = get_db()
tables = db.get_usable_table_names()
result = db.run("SELECT COUNT(*) FROM Customer;")
logger.info(f"Database verification OK. Customer count query returned: {result}")
return {"status": "healthy", "tables": len(tables)}
except Exception as e:
logger.error(f"Database verification failed: {e}")
return {"status": "unhealthy", "error": str(e)}