Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| from typing import Dict, Any, Optional | |
| import re | |
| class FineTunedModelLoader: | |
| """Loads and manages the fine-tuned Mistral-7B model.""" | |
| def __init__(self, | |
| base_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2", | |
| adapter_path: str = "mhdakmal80/Olist-SQL-Agent-Final", | |
| use_4bit: bool = True): | |
| """ | |
| Initialize the fine-tuned model. | |
| Args: | |
| base_model_name: HuggingFace model name | |
| adapter_path: Path to LoRA adapter weights | |
| use_4bit: Whether to use 4-bit quantization | |
| """ | |
| self.base_model_name = base_model_name | |
| self.adapter_path = adapter_path | |
| self.use_4bit = use_4bit | |
| print(" Loading fine-tuned model...") | |
| self.model, self.tokenizer = self._load_model() | |
| print(" Model loaded successfully!") | |
| def _load_model(self): | |
| """Load the base model and LoRA adapters.""" | |
| # Check if GPU is available | |
| has_gpu = torch.cuda.is_available() | |
| if not has_gpu: | |
| print(" ⚠️ No GPU detected - loading model on CPU (this will be slow)") | |
| print(" ⚠️ Disabling 4-bit quantization (requires GPU)") | |
| self.use_4bit = False # Force disable 4-bit on CPU | |
| # Configure 4-bit quantization only if GPU available | |
| if self.use_4bit and has_gpu: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=False, | |
| ) | |
| print(" ✅ Using 4-bit quantization (GPU)") | |
| else: | |
| bnb_config = None | |
| print(" ℹ️ Using float32 (CPU mode)") | |
| # Load base model | |
| print(f" Loading base model: {self.base_model_name}") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| self.base_model_name, | |
| quantization_config=bnb_config if (self.use_4bit and has_gpu) else None, | |
| torch_dtype=torch.float32 if not has_gpu else torch.bfloat16, # float32 for CPU | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, # Optimize CPU memory | |
| ) | |
| # Load tokenizer | |
| print(f" Loading tokenizer") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| self.base_model_name, | |
| trust_remote_code=True | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| # Load LoRA adapter | |
| print(f" Loading LoRA adapter from: {self.adapter_path}") | |
| model = PeftModel.from_pretrained(base_model, self.adapter_path) | |
| return model, tokenizer | |
| def generate_sql(self, question: str, schema: str) -> Dict[str, Any]: | |
| """ | |
| Generate SQL query from natural language question. | |
| Args: | |
| question: User's natural language question | |
| schema: Database schema as string | |
| Returns: | |
| Dictionary with 'sql', 'success', and 'error' keys | |
| """ | |
| # Format prompt | |
| prompt = f"""[INST]You are a SQL expert. Generate a valid SQLite query using ONLY the columns and tables listed below. | |
| Don't ever use columns that is not in the schema (this need to be followed strictly).Always try to come up the | |
| solution based on provided schema only. | |
| ### Available Tables and Columns: | |
| {schema} | |
| ### IMPORTANT: | |
| - Use ONLY the column names listed above | |
| - Do NOT invent column names | |
| - Do NOT use columns that don't exist | |
| ### Question: | |
| {question} | |
| ### Generate SQL using only the columns listed above: | |
| [/INST]```sql | |
| """ | |
| try: | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.1, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract SQL from response | |
| sql_query = self._extract_sql(generated_text, prompt) | |
| return { | |
| "sql": sql_query, | |
| "success": True, | |
| "error": None | |
| } | |
| except Exception as e: | |
| return { | |
| "sql": "", | |
| "success": False, | |
| "error": f"Model Error: {str(e)}" | |
| } | |
| def _extract_sql(self, generated_text: str, prompt: str) -> str: | |
| """ | |
| Extract SQL query from generated text. | |
| Args: | |
| generated_text: Full generated text from model | |
| prompt: Original prompt (to remove from output) | |
| Returns: | |
| Cleaned SQL query | |
| """ | |
| # Remove the prompt from the generated text | |
| sql = generated_text.replace(prompt, "").strip() | |
| # Try to extract SQL after "### SQL Query:" marker | |
| patterns = [ | |
| r"### SQL Query:\s*(.+?)(?:###|$)", | |
| r"```sql\s*(.+?)\s*```", | |
| r"SELECT\s+.+", | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, sql, re.IGNORECASE | re.DOTALL) | |
| if match: | |
| sql = match.group(1) if match.lastindex else match.group(0) | |
| break | |
| # Clean up | |
| sql = sql.replace("```sql", "").replace("```", "") | |
| sql = " ".join(sql.split()) # Remove extra whitespace | |
| sql = sql.strip() | |
| # Ensure it ends with semicolon | |
| if not sql.endswith(";"): | |
| sql += ";" | |
| return sql | |
| # Test function | |
| if __name__ == "__main__": | |
| # Quick test | |
| model_loader = FineTunedModelLoader() | |
| test_schema = """ | |
| Table: orders | |
| Columns: order_id, customer_id, order_status, order_purchase_timestamp | |
| """ | |
| result = model_loader.generate_sql( | |
| "How many orders are there?", | |
| test_schema | |
| ) | |
| print(f"\nSuccess: {result['success']}") | |
| print(f"SQL: {result['sql']}") | |
| if result['error']: | |
| print(f"Error: {result['error']}") | |