Spaces:
Running
Running
File size: 5,783 Bytes
a39d8ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import os
import sys
import json
import torch
import hashlib
from pathlib import Path
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
# --- PATCH FOR TRANSFORMERS VERSION MISMATCH ---
try:
import transformers.activations
if not hasattr(transformers.activations, "PytorchGELUTanh"):
# Mapping the old name to the new existing one
transformers.activations.PytorchGELUTanh = transformers.activations.GELUActivation
except ImportError:
pass
# ------------------------------------------------------
import os
import json
import torch
# ... baaki ke saare purane imports
# Force script to use only the 2 free GPUs (e.g., 0 and 7)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,7"
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from data_factory.schemas import SCHEMA_CONTEXT
# AWQ model is 4x smaller and much faster
MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct-AWQ"
INPUT_FILE = "llm_hybrid_templates.json"
OUTPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
VARIATIONS_PER_SQL = 20
BATCH_SIZE = 64 # AWQ allows much larger batches!
SYSTEM_PROMPT = "You are an expert SQL analyst. Write a single SELECT query that answers the question. Output ONLY the SQL query — no markdown, no explanation, no backticks."
EXPANSION_PROMPT = """
You are an expert linguist and NL2SQL data augmentor. I have a SQLite database schema and a complex SQL query.
Generate exactly {count} completely different natural language questions that this exact SQL query answers.
RULES:
- Personas: Executive (direct), Non-tech (wordy), Analyst (technical), Curious (investigative).
- Structure: Completely change sentence flow.
- No direct column/table names.
DATABASE SCHEMA:
{schema}
SQL QUERY:
{sql}
OUTPUT FORMAT:
Return ONLY a valid JSON array of objects: [{{"persona": "...", "question": "..."}}]
"""
def extract_json_array(raw_text):
text = raw_text.strip()
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1:
return text[start:end+1]
return "[]"
def get_hash(text):
return hashlib.md5(text.lower().strip().encode('utf-8')).hexdigest()
def main():
if not os.path.exists(INPUT_FILE):
print(f"Error: {INPUT_FILE} not found.")
sys.exit(1)
with open(INPUT_FILE, "r") as f:
base_templates = json.load(f)
print(f"🚀 Loading {MODEL_NAME} on 2 GPUs...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
# Model loading (AWQ version automatically handles quantization)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.float16, # AWQ models use float16/bfloat16 for weights
low_cpu_mem_usage=True
)
seen_hashes = set()
total_saved = 0
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
for line in f:
total_saved += 1 # Quick count
pbar = tqdm(total=len(base_templates) * VARIATIONS_PER_SQL, initial=total_saved)
# Batch processing
for i in range(0, len(base_templates), BATCH_SIZE):
batch = base_templates[i:i + BATCH_SIZE]
prompts = []
for temp in batch:
msg = [
{"role": "system", "content": "You output only JSON arrays."},
{"role": "user", "content": EXPANSION_PROMPT.format(count=VARIATIONS_PER_SQL, schema=SCHEMA_CONTEXT[temp['domain']], sql=temp['sql'])}
]
prompts.append(tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True))
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
try:
with torch.no_grad():
# Increased speed: AWQ handles large batches efficiently
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.5,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
with open(OUTPUT_FILE, "a", encoding="utf-8") as out_file:
for idx, resp in enumerate(responses):
questions_data = json.loads(extract_json_array(resp))
sql = batch[idx]["sql"]
domain = batch[idx]["domain"]
for item in questions_data:
q = item.get("question", "")
if len(q) > 10:
q_hash = get_hash(q + sql)
if q_hash not in seen_hashes:
seen_hashes.add(q_hash)
record = {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {q}"}
],
"sql": sql
}
out_file.write(json.dumps(record, ensure_ascii=False) + "\n")
total_saved += 1
pbar.update(1)
out_file.flush()
except Exception as e:
print(f"Batch failed: {e}")
continue
pbar.close()
if __name__ == "__main__":
main() |