Spaces:
Sleeping
Sleeping
Fix CPU thread oversubscription: cap n_threads to 2 (cpu-basic vCPUs) for faster generation
4853f12 verified | import os | |
| import re | |
| import sqlite3 | |
| import time | |
| import pandas as pd | |
| from dataset import SCHEMA_PROMPT | |
| # NOTE: torch / transformers / peft are imported lazily inside the CUDA-only code | |
| # paths. The CPU (HF Space) image serves via llama.cpp and does NOT install them, | |
| # so importing them at module top would crash the Space on startup. | |
| class TextToSQLInference: | |
| def __init__(self, base_model_id="microsoft/Phi-3-mini-4k-instruct", adapter_path="Bhuvandesai/phi3-text-to-sql-adapter"): | |
| self.base_model_id = base_model_id | |
| self.adapter_path = adapter_path | |
| self.db_path = "data/company_sales.db" | |
| self.tokenizer = None | |
| self.model = None | |
| self.llm = None # llama_cpp.Llama instance (CPU/GGUF backend only) | |
| self.backend = None # "transformers" (CUDA) or "llama_cpp" (CPU) | |
| # GGUF source for CPU serving (override via env on the Space) | |
| self.gguf_repo_id = os.environ.get("GGUF_REPO_ID", "Bhuvandesai/phi3-text-to-sql-gguf") | |
| self.gguf_filename = os.environ.get("GGUF_FILENAME", "phi3-text-to-sql-Q4_K_M.gguf") | |
| # Support both local paths and HF Hub repo IDs | |
| self.has_adapter = self._adapter_available(adapter_path) | |
| # Load environment. Detect CUDA without hard-importing torch, so the CPU | |
| # Space (which has no torch installed) doesn't crash here. | |
| self.device = self._detect_device() | |
| self._load_model_and_tokenizer() | |
| def _detect_device(self): | |
| try: | |
| import torch | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| except ImportError: | |
| return "cpu" | |
| def _adapter_available(self, path): | |
| """Returns True if adapter exists locally or as a HF Hub repo.""" | |
| if os.path.exists(path): | |
| return True | |
| try: | |
| from huggingface_hub import repo_exists | |
| return repo_exists(path, repo_type="model") | |
| except Exception: | |
| return False | |
| def _load_model_and_tokenizer(self): | |
| # Two backends, chosen by hardware: | |
| # - CUDA (laptop/dev): transformers + 4-bit bitsandbytes + PEFT adapter | |
| # - CPU (HF Space): llama.cpp + pre-merged GGUF (Q4_K_M) | |
| # The CPU path never imports torch-side quantization or PEFT, and llama_cpp | |
| # is imported lazily so the CUDA path doesn't need it installed. | |
| if self.device == "cuda": | |
| self.backend = "transformers" | |
| self._load_transformers_cuda() | |
| else: | |
| self.backend = "llama_cpp" | |
| self._load_llama_cpp_cpu() | |
| def _load_transformers_cuda(self): | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| print(f"Loading tokenizer: {self.base_model_id}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id, trust_remote_code=False) | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| print("CUDA detected — configuring 4-bit quantization...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| print("Loading base model with 4-bit quantization...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.base_model_id, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=False, | |
| attn_implementation="eager" | |
| ) | |
| if self.has_adapter: | |
| print(f"Loading fine-tuned adapter from: {self.adapter_path}...") | |
| self.model = PeftModel.from_pretrained( | |
| self.model, | |
| self.adapter_path, | |
| adapter_name="default" | |
| ) | |
| print("Adapter loaded successfully!") | |
| else: | |
| print("WARNING: No adapter found. Will run base model only.") | |
| def _load_llama_cpp_cpu(self): | |
| # GGUF already contains the merged fine-tuned weights, so there is no | |
| # separate adapter to load at runtime. | |
| self.has_adapter = True | |
| print("CPU detected — using llama.cpp + GGUF (Q4_K_M) backend.") | |
| from llama_cpp import Llama | |
| from huggingface_hub import hf_hub_download | |
| # Allow a local GGUF (committed under models/) to take precedence; otherwise pull from Hub. | |
| local_candidate = os.path.join("models", self.gguf_filename) | |
| if os.path.exists(local_candidate): | |
| model_path = local_candidate | |
| print(f"Loading local GGUF: {model_path}") | |
| else: | |
| print(f"Downloading GGUF {self.gguf_filename} from {self.gguf_repo_id}...") | |
| model_path = hf_hub_download(repo_id=self.gguf_repo_id, filename=self.gguf_filename) | |
| print(f"GGUF cached at: {model_path}") | |
| # cpu-basic has 2 vCPUs, but os.cpu_count() reports the host's many cores. | |
| # Using that many threads oversubscribes the 2 vCPUs and makes generation | |
| # dramatically slower. Cap to the real vCPU count (override via N_THREADS). | |
| n_threads = int(os.environ.get("N_THREADS", "2")) | |
| self.llm = Llama( | |
| model_path=model_path, | |
| n_ctx=2048, # schema prompt (~300 tok) + question + 150 generated | |
| n_threads=n_threads, | |
| n_batch=512, | |
| logits_all=False, | |
| verbose=False, | |
| ) | |
| print(f"llama.cpp model loaded (n_threads={n_threads}).") | |
| def clean_sql(self, raw_output): | |
| """Cleans and extracts SQL query from model output, stripping backticks and markdown formatting.""" | |
| # Print for raw debugging | |
| print(f"Raw model response: {repr(raw_output)}") | |
| # Find assistant text if model outputted the prompt too | |
| if "<|assistant|>" in raw_output: | |
| raw_output = raw_output.split("<|assistant|>")[-1] | |
| # Clean markdown wrappers (e.g. ```sql ... ```) | |
| cleaned = re.sub(r"```sql\s*", "", raw_output, flags=re.IGNORECASE) | |
| cleaned = re.sub(r"```", "", cleaned) | |
| # Remove whitespace, system characters and ChatML tokens | |
| cleaned = re.sub(r"<\|.*?\|>", "", cleaned).strip() | |
| # Pull the SQL query - matches up to first semicolon | |
| sql_match = re.search(r"(SELECT\s+.*?;)", cleaned, re.DOTALL | re.IGNORECASE) | |
| if sql_match: | |
| return sql_match.group(1).strip() | |
| # If no semicolon but looks like a query, return cleaned string | |
| if "select" in cleaned.lower(): | |
| return cleaned.strip() | |
| return cleaned.strip() | |
| def execute_query(self, sql_query): | |
| """Executes query on SQLite and returns results as pandas DataFrame.""" | |
| if not sql_query or "select" not in sql_query.lower(): | |
| return None, "Error: Invalid or blank SQL query generated." | |
| if not os.path.exists(self.db_path): | |
| return None, f"Error: Database file '{self.db_path}' not found. Run database.py first." | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| # Run query with pandas to easily get columns and visual formatting | |
| df = pd.read_sql_query(sql_query, conn) | |
| conn.close() | |
| return df, None | |
| except Exception as e: | |
| return None, str(e) | |
| def generate_sql(self, question, use_adapter=True): | |
| """Dispatches to the active backend. Output contract is identical for both.""" | |
| if self.backend == "llama_cpp": | |
| return self._generate_llama(question, use_adapter=use_adapter) | |
| return self._generate_transformers(question, use_adapter=use_adapter) | |
| def _generate_llama(self, question, use_adapter=True): | |
| """CPU/GGUF generation via llama.cpp. The GGUF is the merged fine-tuned model.""" | |
| if not use_adapter: | |
| # GGUF has the fine-tuned weights baked in; there is no base toggle on CPU. | |
| # The base comparison is served by generate_sql_base_via_api (HF API) instead. | |
| # query_pipeline catches this and returns a clean error rather than crashing. | |
| raise RuntimeError( | |
| "Local base-model comparison is unavailable in CPU/GGUF mode; " | |
| "use the HF Inference API path for base comparison." | |
| ) | |
| prompt = f"{SCHEMA_PROMPT}\n\nQuestion: {question}" | |
| print("--- Generating with Fine-tuned GGUF (llama.cpp) ---") | |
| start_time = time.perf_counter() | |
| result = self.llm.create_chat_completion( | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=150, | |
| temperature=0.0, # greedy / deterministic, matches do_sample=False | |
| ) | |
| latency_sec = time.perf_counter() - start_time | |
| raw_output = result["choices"][0]["message"]["content"] | |
| usage = result.get("usage", {}) or {} | |
| tokens_generated = usage.get("completion_tokens", 0) or 0 | |
| tokens_per_second = tokens_generated / latency_sec if latency_sec > 0 else 0.0 | |
| sql_query = self.clean_sql(raw_output) | |
| metrics = { | |
| "latency_sec": latency_sec, | |
| "tokens_generated": tokens_generated, | |
| "tokens_per_second": tokens_per_second | |
| } | |
| return sql_query, metrics | |
| def _generate_transformers(self, question, use_adapter=True): | |
| """Generates SQL query from a natural language question.""" | |
| import torch | |
| prompt = f"{SCHEMA_PROMPT}\n\nQuestion: {question}" | |
| # Standard Phi-3 ChatML template formatting | |
| messages = [ | |
| {"role": "user", "content": prompt} | |
| ] | |
| # Build prompt using chat template | |
| chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.device) | |
| input_token_count = inputs.input_ids.shape[1] | |
| # Toggle fine-tuned adapter weights or use raw base model | |
| if self.has_adapter: | |
| if use_adapter: | |
| self.model.set_adapter("default") | |
| print("--- Generating with Fine-tuned Adapter Active ---") | |
| else: | |
| # Disable adapter weights to get base model performance | |
| print("--- Generating with Base Model (Adapter Disabled) ---") | |
| # Run generation with latency tracking | |
| start_time = time.perf_counter() | |
| with torch.no_grad(): | |
| if self.has_adapter and not use_adapter: | |
| # PEFT context manager to temporarily disable adapter weights | |
| with self.model.disable_adapter(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| do_sample=False, # Greedy decoding for stable code output | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| else: | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| do_sample=False, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| end_time = time.perf_counter() | |
| # Calculate generation metrics | |
| latency_sec = end_time - start_time | |
| output_token_count = outputs.shape[1] | |
| tokens_generated = max(0, output_token_count - input_token_count) | |
| tokens_per_second = tokens_generated / latency_sec if latency_sec > 0 else 0.0 | |
| # Decode and clean | |
| decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| # Extract assistant response only | |
| assistant_marker = "<|assistant|>\n" | |
| if assistant_marker in decoded: | |
| response = decoded.split(assistant_marker)[-1] | |
| else: | |
| # Fallback if chat template behaves differently | |
| response = decoded[inputs.input_ids.shape[1]:] | |
| sql_query = self.clean_sql(response) | |
| metrics = { | |
| "latency_sec": latency_sec, | |
| "tokens_generated": tokens_generated, | |
| "tokens_per_second": tokens_per_second | |
| } | |
| return sql_query, metrics | |
| def generate_sql_base_via_api(self, question): | |
| """Calls HF Inference API for base model comparison. Runs on HF's GPU, not local hardware.""" | |
| from huggingface_hub import InferenceClient | |
| token = os.environ.get("HF_TOKEN") | |
| client = InferenceClient(model=self.base_model_id, token=token) | |
| prompt = f"{SCHEMA_PROMPT}\n\nQuestion: {question}" | |
| start = time.perf_counter() | |
| response = client.chat_completion( | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=150, | |
| ) | |
| elapsed = time.perf_counter() - start | |
| raw = response.choices[0].message.content | |
| sql = self.clean_sql(raw) | |
| usage = getattr(response, "usage", None) | |
| completion_tokens = getattr(usage, "completion_tokens", 0) or 0 | |
| metrics = { | |
| "latency_sec": elapsed, | |
| "tokens_generated": completion_tokens, | |
| "tokens_per_second": round(completion_tokens / elapsed, 1) if elapsed > 0 and completion_tokens else 0 | |
| } | |
| return sql, metrics | |
| def query_pipeline(self, question, use_adapter=True): | |
| """Full pipeline: Question -> Generated SQL -> Database Execution -> Results with metrics.""" | |
| try: | |
| if not use_adapter: | |
| # Base model comparison: delegate to HF Inference API (their GPU, not our CPU) | |
| try: | |
| sql_query, gen_metrics = self.generate_sql_base_via_api(question) | |
| print(f"Base model via HF API returned: {sql_query}") | |
| except Exception as api_err: | |
| print(f"HF API failed ({api_err}), falling back to local base inference...") | |
| sql_query, gen_metrics = self.generate_sql(question, use_adapter=False) | |
| else: | |
| sql_query, gen_metrics = self.generate_sql(question, use_adapter=True) | |
| # Execute query on SQLite and time it | |
| start_exec = time.perf_counter() | |
| df, error = self.execute_query(sql_query) | |
| end_exec = time.perf_counter() | |
| exec_latency_sec = end_exec - start_exec | |
| # Convert pandas DataFrame to HTML table or JSON for the web UI | |
| results = None | |
| if df is not None: | |
| results = df.to_dict(orient="records") | |
| columns = list(df.columns) | |
| else: | |
| columns = [] | |
| return { | |
| "question": question, | |
| "sql": sql_query, | |
| "success": error is None, | |
| "error": error, | |
| "columns": columns, | |
| "results": results, | |
| "metrics": { | |
| "generation_time_ms": round(gen_metrics["latency_sec"] * 1000, 1), | |
| "tokens_generated": gen_metrics["tokens_generated"], | |
| "tokens_per_second": round(gen_metrics["tokens_per_second"], 1), | |
| "db_exec_time_ms": round(exec_latency_sec * 1000, 1) | |
| } | |
| } | |
| except Exception as e: | |
| return { | |
| "question": question, | |
| "sql": "Generation Failed", | |
| "success": False, | |
| "error": f"Inference pipeline crash: {str(e)}", | |
| "columns": [], | |
| "results": None, | |
| "metrics": { | |
| "generation_time_ms": 0.0, | |
| "tokens_generated": 0, | |
| "tokens_per_second": 0.0, | |
| "db_exec_time_ms": 0.0 | |
| } | |
| } | |
| if __name__ == "__main__": | |
| # Test script in terminal if run directly | |
| print("Testing Text-to-SQL Inference Pipeline...") | |
| # Check if database exists, create if not | |
| if not os.path.exists("data/company_sales.db"): | |
| from database import setup_database | |
| setup_database() | |
| # Check if datasets exist, create if not | |
| if not os.path.exists("data/train_dataset.jsonl"): | |
| from dataset import generate_datasets | |
| generate_datasets() | |
| # Run test pipeline (this will attempt to load the adapter if present) | |
| pipeline = TextToSQLInference() | |
| test_question = "What is the total revenue generated from all sales?" | |
| print(f"\nQuestion: {test_question}") | |
| # Run with adapter (falls back to base if adapter isn't trained yet) | |
| result = pipeline.query_pipeline(test_question, use_adapter=True) | |
| print("\n[Adapter Enabled] Pipeline Result:") | |
| print("Generated SQL:", result["sql"]) | |
| print("Success:", result["success"]) | |
| if result["success"]: | |
| print("Results (First 3):", result["results"][:3]) | |
| else: | |
| print("Error:", result["error"]) | |
| if pipeline.has_adapter: | |
| # Run without adapter to compare | |
| result_base = pipeline.query_pipeline(test_question, use_adapter=False) | |
| print("\n[Adapter Disabled - Base Model] Pipeline Result:") | |
| print("Generated SQL:", result_base["sql"]) | |
| print("Success:", result_base["success"]) | |
| if result_base["success"]: | |
| print("Results (First 3):", result_base["results"][:3]) | |
| else: | |
| print("Error:", result_base["error"]) | |