Spaces:
Runtime error
Runtime error
File size: 7,108 Bytes
d60cb1f 6a096d0 d60cb1f 6a096d0 d60cb1f 6a096d0 d60cb1f 6a096d0 d60cb1f 6a096d0 d60cb1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from typing import Dict, Any, Optional
import re
class FineTunedModelLoader:
"""Loads and manages the fine-tuned Mistral-7B model."""
def __init__(self,
base_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
adapter_path: str = "mhdakmal80/Olist-SQL-Agent-Final",
use_4bit: bool = True):
"""
Initialize the fine-tuned model.
Args:
base_model_name: HuggingFace model name
adapter_path: Path to LoRA adapter weights
use_4bit: Whether to use 4-bit quantization
"""
self.base_model_name = base_model_name
self.adapter_path = adapter_path
self.use_4bit = use_4bit
print(" Loading fine-tuned model...")
self.model, self.tokenizer = self._load_model()
print(" Model loaded successfully!")
def _load_model(self):
"""Load the base model and LoRA adapters."""
# Check if GPU is available
has_gpu = torch.cuda.is_available()
if not has_gpu:
print(" ⚠️ No GPU detected - loading model on CPU (this will be slow)")
print(" ⚠️ Disabling 4-bit quantization (requires GPU)")
self.use_4bit = False # Force disable 4-bit on CPU
# Configure 4-bit quantization only if GPU available
if self.use_4bit and has_gpu:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
print(" ✅ Using 4-bit quantization (GPU)")
else:
bnb_config = None
print(" ℹ️ Using float32 (CPU mode)")
# Load base model
print(f" Loading base model: {self.base_model_name}")
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
quantization_config=bnb_config if (self.use_4bit and has_gpu) else None,
torch_dtype=torch.float32 if not has_gpu else torch.bfloat16, # float32 for CPU
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True, # Optimize CPU memory
)
# Load tokenizer
print(f" Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Load LoRA adapter
print(f" Loading LoRA adapter from: {self.adapter_path}")
model = PeftModel.from_pretrained(base_model, self.adapter_path)
return model, tokenizer
def generate_sql(self, question: str, schema: str) -> Dict[str, Any]:
"""
Generate SQL query from natural language question.
Args:
question: User's natural language question
schema: Database schema as string
Returns:
Dictionary with 'sql', 'success', and 'error' keys
"""
# Format prompt
prompt = f"""[INST]You are a SQL expert. Generate a valid SQLite query using ONLY the columns and tables listed below.
Don't ever use columns that is not in the schema (this need to be followed strictly).Always try to come up the
solution based on provided schema only.
### Available Tables and Columns:
{schema}
### IMPORTANT:
- Use ONLY the column names listed above
- Do NOT invent column names
- Do NOT use columns that don't exist
### Question:
{question}
### Generate SQL using only the columns listed above:
[/INST]```sql
"""
try:
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract SQL from response
sql_query = self._extract_sql(generated_text, prompt)
return {
"sql": sql_query,
"success": True,
"error": None
}
except Exception as e:
return {
"sql": "",
"success": False,
"error": f"Model Error: {str(e)}"
}
def _extract_sql(self, generated_text: str, prompt: str) -> str:
"""
Extract SQL query from generated text.
Args:
generated_text: Full generated text from model
prompt: Original prompt (to remove from output)
Returns:
Cleaned SQL query
"""
# Remove the prompt from the generated text
sql = generated_text.replace(prompt, "").strip()
# Try to extract SQL after "### SQL Query:" marker
patterns = [
r"### SQL Query:\s*(.+?)(?:###|$)",
r"```sql\s*(.+?)\s*```",
r"SELECT\s+.+",
]
for pattern in patterns:
match = re.search(pattern, sql, re.IGNORECASE | re.DOTALL)
if match:
sql = match.group(1) if match.lastindex else match.group(0)
break
# Clean up
sql = sql.replace("```sql", "").replace("```", "")
sql = " ".join(sql.split()) # Remove extra whitespace
sql = sql.strip()
# Ensure it ends with semicolon
if not sql.endswith(";"):
sql += ";"
return sql
# Test function
if __name__ == "__main__":
# Quick test
model_loader = FineTunedModelLoader()
test_schema = """
Table: orders
Columns: order_id, customer_id, order_status, order_purchase_timestamp
"""
result = model_loader.generate_sql(
"How many orders are there?",
test_schema
)
print(f"\nSuccess: {result['success']}")
print(f"SQL: {result['sql']}")
if result['error']:
print(f"Error: {result['error']}")
|