olist-text2sql / model_loader.py
mhdakmal80's picture
Upload 2 files
6a096d0 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from typing import Dict, Any, Optional
import re
class FineTunedModelLoader:
"""Loads and manages the fine-tuned Mistral-7B model."""
def __init__(self,
base_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
adapter_path: str = "mhdakmal80/Olist-SQL-Agent-Final",
use_4bit: bool = True):
"""
Initialize the fine-tuned model.
Args:
base_model_name: HuggingFace model name
adapter_path: Path to LoRA adapter weights
use_4bit: Whether to use 4-bit quantization
"""
self.base_model_name = base_model_name
self.adapter_path = adapter_path
self.use_4bit = use_4bit
print(" Loading fine-tuned model...")
self.model, self.tokenizer = self._load_model()
print(" Model loaded successfully!")
def _load_model(self):
"""Load the base model and LoRA adapters."""
# Check if GPU is available
has_gpu = torch.cuda.is_available()
if not has_gpu:
print(" ⚠️ No GPU detected - loading model on CPU (this will be slow)")
print(" ⚠️ Disabling 4-bit quantization (requires GPU)")
self.use_4bit = False # Force disable 4-bit on CPU
# Configure 4-bit quantization only if GPU available
if self.use_4bit and has_gpu:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
print(" ✅ Using 4-bit quantization (GPU)")
else:
bnb_config = None
print(" ℹ️ Using float32 (CPU mode)")
# Load base model
print(f" Loading base model: {self.base_model_name}")
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
quantization_config=bnb_config if (self.use_4bit and has_gpu) else None,
torch_dtype=torch.float32 if not has_gpu else torch.bfloat16, # float32 for CPU
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True, # Optimize CPU memory
)
# Load tokenizer
print(f" Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Load LoRA adapter
print(f" Loading LoRA adapter from: {self.adapter_path}")
model = PeftModel.from_pretrained(base_model, self.adapter_path)
return model, tokenizer
def generate_sql(self, question: str, schema: str) -> Dict[str, Any]:
"""
Generate SQL query from natural language question.
Args:
question: User's natural language question
schema: Database schema as string
Returns:
Dictionary with 'sql', 'success', and 'error' keys
"""
# Format prompt
prompt = f"""[INST]You are a SQL expert. Generate a valid SQLite query using ONLY the columns and tables listed below.
Don't ever use columns that is not in the schema (this need to be followed strictly).Always try to come up the
solution based on provided schema only.
### Available Tables and Columns:
{schema}
### IMPORTANT:
- Use ONLY the column names listed above
- Do NOT invent column names
- Do NOT use columns that don't exist
### Question:
{question}
### Generate SQL using only the columns listed above:
[/INST]```sql
"""
try:
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract SQL from response
sql_query = self._extract_sql(generated_text, prompt)
return {
"sql": sql_query,
"success": True,
"error": None
}
except Exception as e:
return {
"sql": "",
"success": False,
"error": f"Model Error: {str(e)}"
}
def _extract_sql(self, generated_text: str, prompt: str) -> str:
"""
Extract SQL query from generated text.
Args:
generated_text: Full generated text from model
prompt: Original prompt (to remove from output)
Returns:
Cleaned SQL query
"""
# Remove the prompt from the generated text
sql = generated_text.replace(prompt, "").strip()
# Try to extract SQL after "### SQL Query:" marker
patterns = [
r"### SQL Query:\s*(.+?)(?:###|$)",
r"```sql\s*(.+?)\s*```",
r"SELECT\s+.+",
]
for pattern in patterns:
match = re.search(pattern, sql, re.IGNORECASE | re.DOTALL)
if match:
sql = match.group(1) if match.lastindex else match.group(0)
break
# Clean up
sql = sql.replace("```sql", "").replace("```", "")
sql = " ".join(sql.split()) # Remove extra whitespace
sql = sql.strip()
# Ensure it ends with semicolon
if not sql.endswith(";"):
sql += ";"
return sql
# Test function
if __name__ == "__main__":
# Quick test
model_loader = FineTunedModelLoader()
test_schema = """
Table: orders
Columns: order_id, customer_id, order_status, order_purchase_timestamp
"""
result = model_loader.generate_sql(
"How many orders are there?",
test_schema
)
print(f"\nSuccess: {result['success']}")
print(f"SQL: {result['sql']}")
if result['error']:
print(f"Error: {result['error']}")