Spaces:
Sleeping
Sleeping
File size: 4,262 Bytes
dc59b01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | # ======================================
# RLHF Text2SQL — FINAL WORKING VERSION
# T5-small + LoRA + PPO + Execution Reward
# Single-sample stable training (Mac MPS safe)
# ======================================
from execution_reward import execution_reward
import os, gc, json, random, torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from trl import PPOTrainer, PPOConfig
from trl.models.modeling_value_head import AutoModelForSeq2SeqLMWithValueHead
from peft import LoraConfig, get_peft_model
# ---------------- SETTINGS ----------------
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
os.makedirs("rlhf_text2sql_lora", exist_ok=True)
# ---------------- MODEL ----------------
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# LoRA
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q","v"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_2_SEQ_LM",
)
base_model = get_peft_model(base_model, lora_config)
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(base_model).to(device)
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name).to(device)
model.config.use_cache = False
ref_model.config.use_cache = False
# ---------------- DATA ----------------
with open("data/train_spider.json") as f:
dataset = json.load(f)
def build_prompt(example):
return f"Translate to SQL: {example['question']}"
# ---------------- PPO ----------------
ppo_config = PPOConfig(
batch_size=1,
mini_batch_size=1,
learning_rate=2e-6,
target_kl=0.05,
adap_kl_ctrl=True,
init_kl_coef=0.2,
)
ppo_trainer = PPOTrainer(
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
)
# ---------------- GENERATION ----------------
def generate_sql(query_tensors):
# deterministic decoding = prevents NaN explosion
with torch.no_grad():
response_tensors = ppo_trainer.generate(
query_tensors,
max_new_tokens=64,
# 🔴 CRITICAL: disable sampling
do_sample=False,
# stable decoding
num_beams=1,
early_stopping=True,
# prevents invalid tokens
pad_token_id=tokenizer.eos_token_id,
)
# extra safety (important on MPS)
cleaned = []
for t in response_tensors:
t = torch.nan_to_num(t, nan=0, posinf=0, neginf=0)
cleaned.append(t)
return cleaned
# ---------------- TRAIN ----------------
MAX_STEPS = 1200
for step in range(MAX_STEPS):
# pick random Spider example
example = random.choice(dataset)
question = example["question"]
gold_sql = example["query"]
db_id = example["db_id"]
db_path = f"data/database/{db_id}/{db_id}.sqlite"
# tokenize
enc = tokenizer(build_prompt(example), return_tensors="pt")
query_tensor = enc.input_ids.to(device)
query_tensors = [query_tensor[0]]
# generate SQL
response_tensors = generate_sql(query_tensors)
pred_sql = tokenizer.decode(response_tensors[0], skip_special_tokens=True)
# -------- EXECUTION REWARD --------
reward = execution_reward(pred_sql, gold_sql, db_path)
reward_tensor = torch.tensor([reward], dtype=torch.float32).to(device)
# PPO update
stats = ppo_trainer.step(query_tensors, response_tensors, [reward_tensor])
# stabilize
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# cleanup
del query_tensor, response_tensors, reward_tensor
gc.collect()
if device == "mps":
torch.mps.empty_cache()
# log
if step % 20 == 0:
print(f"\nStep {step}/{MAX_STEPS}")
print("DB:", db_id)
print("Q:", question)
print("Pred:", pred_sql)
print("Gold:", gold_sql)
print("Reward:", reward)
# ---------------- SAVE ----------------
model.save_pretrained("rlhf_text2sql_lora")
tokenizer.save_pretrained("rlhf_text2sql_lora")
print("\nTraining complete — model saved!") |