import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel from typing import Dict, Any, Optional import re class FineTunedModelLoader: """Loads and manages the fine-tuned Mistral-7B model.""" def __init__(self, base_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2", adapter_path: str = "mhdakmal80/Olist-SQL-Agent-Final", use_4bit: bool = True): """ Initialize the fine-tuned model. Args: base_model_name: HuggingFace model name adapter_path: Path to LoRA adapter weights use_4bit: Whether to use 4-bit quantization """ self.base_model_name = base_model_name self.adapter_path = adapter_path self.use_4bit = use_4bit print(" Loading fine-tuned model...") self.model, self.tokenizer = self._load_model() print(" Model loaded successfully!") def _load_model(self): """Load the base model and LoRA adapters.""" # Check if GPU is available has_gpu = torch.cuda.is_available() if not has_gpu: print(" ⚠️ No GPU detected - loading model on CPU (this will be slow)") print(" ⚠️ Disabling 4-bit quantization (requires GPU)") self.use_4bit = False # Force disable 4-bit on CPU # Configure 4-bit quantization only if GPU available if self.use_4bit and has_gpu: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=False, ) print(" ✅ Using 4-bit quantization (GPU)") else: bnb_config = None print(" ℹ️ Using float32 (CPU mode)") # Load base model print(f" Loading base model: {self.base_model_name}") base_model = AutoModelForCausalLM.from_pretrained( self.base_model_name, quantization_config=bnb_config if (self.use_4bit and has_gpu) else None, torch_dtype=torch.float32 if not has_gpu else torch.bfloat16, # float32 for CPU device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, # Optimize CPU memory ) # Load tokenizer print(f" Loading tokenizer") tokenizer = AutoTokenizer.from_pretrained( self.base_model_name, trust_remote_code=True ) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Load LoRA adapter print(f" Loading LoRA adapter from: {self.adapter_path}") model = PeftModel.from_pretrained(base_model, self.adapter_path) return model, tokenizer def generate_sql(self, question: str, schema: str) -> Dict[str, Any]: """ Generate SQL query from natural language question. Args: question: User's natural language question schema: Database schema as string Returns: Dictionary with 'sql', 'success', and 'error' keys """ # Format prompt prompt = f"""[INST]You are a SQL expert. Generate a valid SQLite query using ONLY the columns and tables listed below. Don't ever use columns that is not in the schema (this need to be followed strictly).Always try to come up the solution based on provided schema only. ### Available Tables and Columns: {schema} ### IMPORTANT: - Use ONLY the column names listed above - Do NOT invent column names - Do NOT use columns that don't exist ### Question: {question} ### Generate SQL using only the columns listed above: [/INST]```sql """ try: # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, temperature=0.1, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Decode generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract SQL from response sql_query = self._extract_sql(generated_text, prompt) return { "sql": sql_query, "success": True, "error": None } except Exception as e: return { "sql": "", "success": False, "error": f"Model Error: {str(e)}" } def _extract_sql(self, generated_text: str, prompt: str) -> str: """ Extract SQL query from generated text. Args: generated_text: Full generated text from model prompt: Original prompt (to remove from output) Returns: Cleaned SQL query """ # Remove the prompt from the generated text sql = generated_text.replace(prompt, "").strip() # Try to extract SQL after "### SQL Query:" marker patterns = [ r"### SQL Query:\s*(.+?)(?:###|$)", r"```sql\s*(.+?)\s*```", r"SELECT\s+.+", ] for pattern in patterns: match = re.search(pattern, sql, re.IGNORECASE | re.DOTALL) if match: sql = match.group(1) if match.lastindex else match.group(0) break # Clean up sql = sql.replace("```sql", "").replace("```", "") sql = " ".join(sql.split()) # Remove extra whitespace sql = sql.strip() # Ensure it ends with semicolon if not sql.endswith(";"): sql += ";" return sql # Test function if __name__ == "__main__": # Quick test model_loader = FineTunedModelLoader() test_schema = """ Table: orders Columns: order_id, customer_id, order_status, order_purchase_timestamp """ result = model_loader.generate_sql( "How many orders are there?", test_schema ) print(f"\nSuccess: {result['success']}") print(f"SQL: {result['sql']}") if result['error']: print(f"Error: {result['error']}")