File size: 7,108 Bytes
d60cb1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a096d0
 
 
 
 
 
 
 
 
 
d60cb1f
 
 
 
 
 
6a096d0
d60cb1f
 
6a096d0
d60cb1f
 
 
 
 
6a096d0
 
d60cb1f
 
6a096d0
d60cb1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from typing import Dict, Any, Optional
import re


class FineTunedModelLoader:
    """Loads and manages the fine-tuned Mistral-7B model."""
    
    def __init__(self, 

                 base_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",

                 adapter_path: str = "mhdakmal80/Olist-SQL-Agent-Final",

                 use_4bit: bool = True):
        """

        Initialize the fine-tuned model.

        

        Args:

            base_model_name: HuggingFace model name

            adapter_path: Path to LoRA adapter weights

            use_4bit: Whether to use 4-bit quantization

        """
        self.base_model_name = base_model_name
        self.adapter_path = adapter_path
        self.use_4bit = use_4bit
        
        print(" Loading fine-tuned model...")
        self.model, self.tokenizer = self._load_model()
        print(" Model loaded successfully!")
    
    def _load_model(self):
        """Load the base model and LoRA adapters."""
        
        # Check if GPU is available
        has_gpu = torch.cuda.is_available()
        
        if not has_gpu:
            print("  ⚠️ No GPU detected - loading model on CPU (this will be slow)")
            print("  ⚠️ Disabling 4-bit quantization (requires GPU)")
            self.use_4bit = False  # Force disable 4-bit on CPU
        
        # Configure 4-bit quantization only if GPU available
        if self.use_4bit and has_gpu:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=False,
            )
            print("  ✅ Using 4-bit quantization (GPU)")
        else:
            bnb_config = None
            print("  ℹ️ Using float32 (CPU mode)")
        
        # Load base model
        print(f"  Loading base model: {self.base_model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(
            self.base_model_name,
            quantization_config=bnb_config if (self.use_4bit and has_gpu) else None,
            torch_dtype=torch.float32 if not has_gpu else torch.bfloat16,  # float32 for CPU
            device_map="auto",
            trust_remote_code=True,
            low_cpu_mem_usage=True,  # Optimize CPU memory
        )
        
        # Load tokenizer
        print(f"  Loading tokenizer")
        tokenizer = AutoTokenizer.from_pretrained(
            self.base_model_name,
            trust_remote_code=True
        )
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right"
        
        # Load LoRA adapter
        print(f"  Loading LoRA adapter from: {self.adapter_path}")
        model = PeftModel.from_pretrained(base_model, self.adapter_path)
        
        return model, tokenizer
    
    def generate_sql(self, question: str, schema: str) -> Dict[str, Any]:
        """

        Generate SQL query from natural language question.

        

        Args:

            question: User's natural language question

            schema: Database schema as string

            

        Returns:

            Dictionary with 'sql', 'success', and 'error' keys

        """
        # Format prompt 
        prompt = f"""[INST]You are a SQL expert. Generate a valid SQLite query using ONLY the columns and tables listed below.

Don't ever use columns that is not in the schema (this need to be followed strictly).Always try to come up the

solution based on provided schema only.



### Available Tables and Columns:



{schema}



### IMPORTANT:

- Use ONLY the column names listed above

- Do NOT invent column names

- Do NOT use columns that don't exist



### Question:

{question}



### Generate SQL using only the columns listed above:

[/INST]```sql

"""
        
        try:
            # Tokenize
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=512
            )
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            
            # Generate
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.1,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            
            # Decode
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract SQL from response
            sql_query = self._extract_sql(generated_text, prompt)
            
            return {
                "sql": sql_query,
                "success": True,
                "error": None
            }
            
        except Exception as e:
            return {
                "sql": "",
                "success": False,
                "error": f"Model Error: {str(e)}"
            }
    
    def _extract_sql(self, generated_text: str, prompt: str) -> str:
        """

        Extract SQL query from generated text.

        

        Args:

            generated_text: Full generated text from model

            prompt: Original prompt (to remove from output)

            

        Returns:

            Cleaned SQL query

        """
        # Remove the prompt from the generated text
        sql = generated_text.replace(prompt, "").strip()
        
        # Try to extract SQL after "### SQL Query:" marker
        patterns = [
            r"### SQL Query:\s*(.+?)(?:###|$)",
            r"```sql\s*(.+?)\s*```",
            r"SELECT\s+.+",
        ]
        
        for pattern in patterns:
            match = re.search(pattern, sql, re.IGNORECASE | re.DOTALL)
            if match:
                sql = match.group(1) if match.lastindex else match.group(0)
                break
        
        # Clean up
        sql = sql.replace("```sql", "").replace("```", "")
        sql = " ".join(sql.split())  # Remove extra whitespace
        sql = sql.strip()
        
        # Ensure it ends with semicolon
        if not sql.endswith(";"):
            sql += ";"
        
        return sql


# Test function
if __name__ == "__main__":
    # Quick test
    model_loader = FineTunedModelLoader()
    
    test_schema = """

    Table: orders

    Columns: order_id, customer_id, order_status, order_purchase_timestamp

    """
    
    result = model_loader.generate_sql(
        "How many orders are there?",
        test_schema
    )
    
    print(f"\nSuccess: {result['success']}")
    print(f"SQL: {result['sql']}")
    if result['error']:
        print(f"Error: {result['error']}")