File size: 5,783 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import sys
import json
import torch
import hashlib
from pathlib import Path
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys

# --- PATCH FOR TRANSFORMERS VERSION MISMATCH ---
try:
    import transformers.activations
    if not hasattr(transformers.activations, "PytorchGELUTanh"):
        # Mapping the old name to the new existing one
        transformers.activations.PytorchGELUTanh = transformers.activations.GELUActivation
except ImportError:
    pass
# ------------------------------------------------------

import os
import json
import torch
# ... baaki ke saare purane imports

# Force script to use only the 2 free GPUs (e.g., 0 and 7)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,7"

PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from data_factory.schemas import SCHEMA_CONTEXT

# AWQ model is 4x smaller and much faster
MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct-AWQ"
INPUT_FILE = "llm_hybrid_templates.json"
OUTPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
VARIATIONS_PER_SQL = 20
BATCH_SIZE = 64  # AWQ allows much larger batches!

SYSTEM_PROMPT = "You are an expert SQL analyst. Write a single SELECT query that answers the question. Output ONLY the SQL query — no markdown, no explanation, no backticks."

EXPANSION_PROMPT = """
You are an expert linguist and NL2SQL data augmentor. I have a SQLite database schema and a complex SQL query.
Generate exactly {count} completely different natural language questions that this exact SQL query answers.

RULES:
- Personas: Executive (direct), Non-tech (wordy), Analyst (technical), Curious (investigative).
- Structure: Completely change sentence flow.
- No direct column/table names.

DATABASE SCHEMA:
{schema}

SQL QUERY:
{sql}

OUTPUT FORMAT:
Return ONLY a valid JSON array of objects: [{{"persona": "...", "question": "..."}}]
"""

def extract_json_array(raw_text):
    text = raw_text.strip()
    start = text.find("[")
    end = text.rfind("]")
    if start != -1 and end != -1:
        return text[start:end+1]
    return "[]"

def get_hash(text):
    return hashlib.md5(text.lower().strip().encode('utf-8')).hexdigest()

def main():
    if not os.path.exists(INPUT_FILE):
        print(f"Error: {INPUT_FILE} not found.")
        sys.exit(1)
        
    with open(INPUT_FILE, "r") as f:
        base_templates = json.load(f)
        
    print(f"🚀 Loading {MODEL_NAME} on 2 GPUs...")
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token

    # Model loading (AWQ version automatically handles quantization)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        torch_dtype=torch.float16, # AWQ models use float16/bfloat16 for weights
        low_cpu_mem_usage=True
    )
    
    seen_hashes = set()
    total_saved = 0
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, "r") as f:
            for line in f:
                total_saved += 1 # Quick count
    
    pbar = tqdm(total=len(base_templates) * VARIATIONS_PER_SQL, initial=total_saved)
    
    # Batch processing
    for i in range(0, len(base_templates), BATCH_SIZE):
        batch = base_templates[i:i + BATCH_SIZE]
        prompts = []
        
        for temp in batch:
            msg = [
                {"role": "system", "content": "You output only JSON arrays."},
                {"role": "user", "content": EXPANSION_PROMPT.format(count=VARIATIONS_PER_SQL, schema=SCHEMA_CONTEXT[temp['domain']], sql=temp['sql'])}
            ]
            prompts.append(tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True))

        inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
        
        try:
            with torch.no_grad():
                # Increased speed: AWQ handles large batches efficiently
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=2048,
                    temperature=0.5,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id
                )
            
            responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
            
            with open(OUTPUT_FILE, "a", encoding="utf-8") as out_file:
                for idx, resp in enumerate(responses):
                    questions_data = json.loads(extract_json_array(resp))
                    sql = batch[idx]["sql"]
                    domain = batch[idx]["domain"]
                    
                    for item in questions_data:
                        q = item.get("question", "")
                        if len(q) > 10:
                            q_hash = get_hash(q + sql)
                            if q_hash not in seen_hashes:
                                seen_hashes.add(q_hash)
                                record = {
                                    "prompt": [
                                        {"role": "system", "content": SYSTEM_PROMPT},
                                        {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {q}"}
                                    ],
                                    "sql": sql
                                }
                                out_file.write(json.dumps(record, ensure_ascii=False) + "\n")
                                total_saved += 1
                                pbar.update(1)
                out_file.flush()
        except Exception as e:
            print(f"Batch failed: {e}")
            continue

    pbar.close()

if __name__ == "__main__":
    main()