phi3-text-to-sql-studio / src /inference.py
Bhuvandesai's picture
Fix CPU thread oversubscription: cap n_threads to 2 (cpu-basic vCPUs) for faster generation
4853f12 verified
Raw
History Blame Contribute Delete
18.2 kB
import os
import re
import sqlite3
import time
import pandas as pd
from dataset import SCHEMA_PROMPT
# NOTE: torch / transformers / peft are imported lazily inside the CUDA-only code
# paths. The CPU (HF Space) image serves via llama.cpp and does NOT install them,
# so importing them at module top would crash the Space on startup.
class TextToSQLInference:
def __init__(self, base_model_id="microsoft/Phi-3-mini-4k-instruct", adapter_path="Bhuvandesai/phi3-text-to-sql-adapter"):
self.base_model_id = base_model_id
self.adapter_path = adapter_path
self.db_path = "data/company_sales.db"
self.tokenizer = None
self.model = None
self.llm = None # llama_cpp.Llama instance (CPU/GGUF backend only)
self.backend = None # "transformers" (CUDA) or "llama_cpp" (CPU)
# GGUF source for CPU serving (override via env on the Space)
self.gguf_repo_id = os.environ.get("GGUF_REPO_ID", "Bhuvandesai/phi3-text-to-sql-gguf")
self.gguf_filename = os.environ.get("GGUF_FILENAME", "phi3-text-to-sql-Q4_K_M.gguf")
# Support both local paths and HF Hub repo IDs
self.has_adapter = self._adapter_available(adapter_path)
# Load environment. Detect CUDA without hard-importing torch, so the CPU
# Space (which has no torch installed) doesn't crash here.
self.device = self._detect_device()
self._load_model_and_tokenizer()
def _detect_device(self):
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"
def _adapter_available(self, path):
"""Returns True if adapter exists locally or as a HF Hub repo."""
if os.path.exists(path):
return True
try:
from huggingface_hub import repo_exists
return repo_exists(path, repo_type="model")
except Exception:
return False
def _load_model_and_tokenizer(self):
# Two backends, chosen by hardware:
# - CUDA (laptop/dev): transformers + 4-bit bitsandbytes + PEFT adapter
# - CPU (HF Space): llama.cpp + pre-merged GGUF (Q4_K_M)
# The CPU path never imports torch-side quantization or PEFT, and llama_cpp
# is imported lazily so the CUDA path doesn't need it installed.
if self.device == "cuda":
self.backend = "transformers"
self._load_transformers_cuda()
else:
self.backend = "llama_cpp"
self._load_llama_cpp_cpu()
def _load_transformers_cuda(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
print(f"Loading tokenizer: {self.base_model_id}...")
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id, trust_remote_code=False)
self.tokenizer.pad_token = self.tokenizer.eos_token
print("CUDA detected — configuring 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
print("Loading base model with 4-bit quantization...")
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=False,
attn_implementation="eager"
)
if self.has_adapter:
print(f"Loading fine-tuned adapter from: {self.adapter_path}...")
self.model = PeftModel.from_pretrained(
self.model,
self.adapter_path,
adapter_name="default"
)
print("Adapter loaded successfully!")
else:
print("WARNING: No adapter found. Will run base model only.")
def _load_llama_cpp_cpu(self):
# GGUF already contains the merged fine-tuned weights, so there is no
# separate adapter to load at runtime.
self.has_adapter = True
print("CPU detected — using llama.cpp + GGUF (Q4_K_M) backend.")
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
# Allow a local GGUF (committed under models/) to take precedence; otherwise pull from Hub.
local_candidate = os.path.join("models", self.gguf_filename)
if os.path.exists(local_candidate):
model_path = local_candidate
print(f"Loading local GGUF: {model_path}")
else:
print(f"Downloading GGUF {self.gguf_filename} from {self.gguf_repo_id}...")
model_path = hf_hub_download(repo_id=self.gguf_repo_id, filename=self.gguf_filename)
print(f"GGUF cached at: {model_path}")
# cpu-basic has 2 vCPUs, but os.cpu_count() reports the host's many cores.
# Using that many threads oversubscribes the 2 vCPUs and makes generation
# dramatically slower. Cap to the real vCPU count (override via N_THREADS).
n_threads = int(os.environ.get("N_THREADS", "2"))
self.llm = Llama(
model_path=model_path,
n_ctx=2048, # schema prompt (~300 tok) + question + 150 generated
n_threads=n_threads,
n_batch=512,
logits_all=False,
verbose=False,
)
print(f"llama.cpp model loaded (n_threads={n_threads}).")
def clean_sql(self, raw_output):
"""Cleans and extracts SQL query from model output, stripping backticks and markdown formatting."""
# Print for raw debugging
print(f"Raw model response: {repr(raw_output)}")
# Find assistant text if model outputted the prompt too
if "<|assistant|>" in raw_output:
raw_output = raw_output.split("<|assistant|>")[-1]
# Clean markdown wrappers (e.g. ```sql ... ```)
cleaned = re.sub(r"```sql\s*", "", raw_output, flags=re.IGNORECASE)
cleaned = re.sub(r"```", "", cleaned)
# Remove whitespace, system characters and ChatML tokens
cleaned = re.sub(r"<\|.*?\|>", "", cleaned).strip()
# Pull the SQL query - matches up to first semicolon
sql_match = re.search(r"(SELECT\s+.*?;)", cleaned, re.DOTALL | re.IGNORECASE)
if sql_match:
return sql_match.group(1).strip()
# If no semicolon but looks like a query, return cleaned string
if "select" in cleaned.lower():
return cleaned.strip()
return cleaned.strip()
def execute_query(self, sql_query):
"""Executes query on SQLite and returns results as pandas DataFrame."""
if not sql_query or "select" not in sql_query.lower():
return None, "Error: Invalid or blank SQL query generated."
if not os.path.exists(self.db_path):
return None, f"Error: Database file '{self.db_path}' not found. Run database.py first."
try:
conn = sqlite3.connect(self.db_path)
# Run query with pandas to easily get columns and visual formatting
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df, None
except Exception as e:
return None, str(e)
def generate_sql(self, question, use_adapter=True):
"""Dispatches to the active backend. Output contract is identical for both."""
if self.backend == "llama_cpp":
return self._generate_llama(question, use_adapter=use_adapter)
return self._generate_transformers(question, use_adapter=use_adapter)
def _generate_llama(self, question, use_adapter=True):
"""CPU/GGUF generation via llama.cpp. The GGUF is the merged fine-tuned model."""
if not use_adapter:
# GGUF has the fine-tuned weights baked in; there is no base toggle on CPU.
# The base comparison is served by generate_sql_base_via_api (HF API) instead.
# query_pipeline catches this and returns a clean error rather than crashing.
raise RuntimeError(
"Local base-model comparison is unavailable in CPU/GGUF mode; "
"use the HF Inference API path for base comparison."
)
prompt = f"{SCHEMA_PROMPT}\n\nQuestion: {question}"
print("--- Generating with Fine-tuned GGUF (llama.cpp) ---")
start_time = time.perf_counter()
result = self.llm.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=150,
temperature=0.0, # greedy / deterministic, matches do_sample=False
)
latency_sec = time.perf_counter() - start_time
raw_output = result["choices"][0]["message"]["content"]
usage = result.get("usage", {}) or {}
tokens_generated = usage.get("completion_tokens", 0) or 0
tokens_per_second = tokens_generated / latency_sec if latency_sec > 0 else 0.0
sql_query = self.clean_sql(raw_output)
metrics = {
"latency_sec": latency_sec,
"tokens_generated": tokens_generated,
"tokens_per_second": tokens_per_second
}
return sql_query, metrics
def _generate_transformers(self, question, use_adapter=True):
"""Generates SQL query from a natural language question."""
import torch
prompt = f"{SCHEMA_PROMPT}\n\nQuestion: {question}"
# Standard Phi-3 ChatML template formatting
messages = [
{"role": "user", "content": prompt}
]
# Build prompt using chat template
chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.device)
input_token_count = inputs.input_ids.shape[1]
# Toggle fine-tuned adapter weights or use raw base model
if self.has_adapter:
if use_adapter:
self.model.set_adapter("default")
print("--- Generating with Fine-tuned Adapter Active ---")
else:
# Disable adapter weights to get base model performance
print("--- Generating with Base Model (Adapter Disabled) ---")
# Run generation with latency tracking
start_time = time.perf_counter()
with torch.no_grad():
if self.has_adapter and not use_adapter:
# PEFT context manager to temporarily disable adapter weights
with self.model.disable_adapter():
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
do_sample=False, # Greedy decoding for stable code output
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id
)
else:
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
do_sample=False,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id
)
end_time = time.perf_counter()
# Calculate generation metrics
latency_sec = end_time - start_time
output_token_count = outputs.shape[1]
tokens_generated = max(0, output_token_count - input_token_count)
tokens_per_second = tokens_generated / latency_sec if latency_sec > 0 else 0.0
# Decode and clean
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
# Extract assistant response only
assistant_marker = "<|assistant|>\n"
if assistant_marker in decoded:
response = decoded.split(assistant_marker)[-1]
else:
# Fallback if chat template behaves differently
response = decoded[inputs.input_ids.shape[1]:]
sql_query = self.clean_sql(response)
metrics = {
"latency_sec": latency_sec,
"tokens_generated": tokens_generated,
"tokens_per_second": tokens_per_second
}
return sql_query, metrics
def generate_sql_base_via_api(self, question):
"""Calls HF Inference API for base model comparison. Runs on HF's GPU, not local hardware."""
from huggingface_hub import InferenceClient
token = os.environ.get("HF_TOKEN")
client = InferenceClient(model=self.base_model_id, token=token)
prompt = f"{SCHEMA_PROMPT}\n\nQuestion: {question}"
start = time.perf_counter()
response = client.chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=150,
)
elapsed = time.perf_counter() - start
raw = response.choices[0].message.content
sql = self.clean_sql(raw)
usage = getattr(response, "usage", None)
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
metrics = {
"latency_sec": elapsed,
"tokens_generated": completion_tokens,
"tokens_per_second": round(completion_tokens / elapsed, 1) if elapsed > 0 and completion_tokens else 0
}
return sql, metrics
def query_pipeline(self, question, use_adapter=True):
"""Full pipeline: Question -> Generated SQL -> Database Execution -> Results with metrics."""
try:
if not use_adapter:
# Base model comparison: delegate to HF Inference API (their GPU, not our CPU)
try:
sql_query, gen_metrics = self.generate_sql_base_via_api(question)
print(f"Base model via HF API returned: {sql_query}")
except Exception as api_err:
print(f"HF API failed ({api_err}), falling back to local base inference...")
sql_query, gen_metrics = self.generate_sql(question, use_adapter=False)
else:
sql_query, gen_metrics = self.generate_sql(question, use_adapter=True)
# Execute query on SQLite and time it
start_exec = time.perf_counter()
df, error = self.execute_query(sql_query)
end_exec = time.perf_counter()
exec_latency_sec = end_exec - start_exec
# Convert pandas DataFrame to HTML table or JSON for the web UI
results = None
if df is not None:
results = df.to_dict(orient="records")
columns = list(df.columns)
else:
columns = []
return {
"question": question,
"sql": sql_query,
"success": error is None,
"error": error,
"columns": columns,
"results": results,
"metrics": {
"generation_time_ms": round(gen_metrics["latency_sec"] * 1000, 1),
"tokens_generated": gen_metrics["tokens_generated"],
"tokens_per_second": round(gen_metrics["tokens_per_second"], 1),
"db_exec_time_ms": round(exec_latency_sec * 1000, 1)
}
}
except Exception as e:
return {
"question": question,
"sql": "Generation Failed",
"success": False,
"error": f"Inference pipeline crash: {str(e)}",
"columns": [],
"results": None,
"metrics": {
"generation_time_ms": 0.0,
"tokens_generated": 0,
"tokens_per_second": 0.0,
"db_exec_time_ms": 0.0
}
}
if __name__ == "__main__":
# Test script in terminal if run directly
print("Testing Text-to-SQL Inference Pipeline...")
# Check if database exists, create if not
if not os.path.exists("data/company_sales.db"):
from database import setup_database
setup_database()
# Check if datasets exist, create if not
if not os.path.exists("data/train_dataset.jsonl"):
from dataset import generate_datasets
generate_datasets()
# Run test pipeline (this will attempt to load the adapter if present)
pipeline = TextToSQLInference()
test_question = "What is the total revenue generated from all sales?"
print(f"\nQuestion: {test_question}")
# Run with adapter (falls back to base if adapter isn't trained yet)
result = pipeline.query_pipeline(test_question, use_adapter=True)
print("\n[Adapter Enabled] Pipeline Result:")
print("Generated SQL:", result["sql"])
print("Success:", result["success"])
if result["success"]:
print("Results (First 3):", result["results"][:3])
else:
print("Error:", result["error"])
if pipeline.has_adapter:
# Run without adapter to compare
result_base = pipeline.query_pipeline(test_question, use_adapter=False)
print("\n[Adapter Disabled - Base Model] Pipeline Result:")
print("Generated SQL:", result_base["sql"])
print("Success:", result_base["success"])
if result_base["success"]:
print("Results (First 3):", result_base["results"][:3])
else:
print("Error:", result_base["error"])