text2sql-demo / src /train_rl_lora.py
tjhalanigrid's picture
Add src folder
dc59b01
# ======================================
# RLHF Text2SQL — FINAL WORKING VERSION
# T5-small + LoRA + PPO + Execution Reward
# Single-sample stable training (Mac MPS safe)
# ======================================
from execution_reward import execution_reward
import os, gc, json, random, torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from trl import PPOTrainer, PPOConfig
from trl.models.modeling_value_head import AutoModelForSeq2SeqLMWithValueHead
from peft import LoraConfig, get_peft_model
# ---------------- SETTINGS ----------------
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
os.makedirs("rlhf_text2sql_lora", exist_ok=True)
# ---------------- MODEL ----------------
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# LoRA
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q","v"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_2_SEQ_LM",
)
base_model = get_peft_model(base_model, lora_config)
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(base_model).to(device)
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name).to(device)
model.config.use_cache = False
ref_model.config.use_cache = False
# ---------------- DATA ----------------
with open("data/train_spider.json") as f:
dataset = json.load(f)
def build_prompt(example):
return f"Translate to SQL: {example['question']}"
# ---------------- PPO ----------------
ppo_config = PPOConfig(
batch_size=1,
mini_batch_size=1,
learning_rate=2e-6,
target_kl=0.05,
adap_kl_ctrl=True,
init_kl_coef=0.2,
)
ppo_trainer = PPOTrainer(
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
)
# ---------------- GENERATION ----------------
def generate_sql(query_tensors):
# deterministic decoding = prevents NaN explosion
with torch.no_grad():
response_tensors = ppo_trainer.generate(
query_tensors,
max_new_tokens=64,
# 🔴 CRITICAL: disable sampling
do_sample=False,
# stable decoding
num_beams=1,
early_stopping=True,
# prevents invalid tokens
pad_token_id=tokenizer.eos_token_id,
)
# extra safety (important on MPS)
cleaned = []
for t in response_tensors:
t = torch.nan_to_num(t, nan=0, posinf=0, neginf=0)
cleaned.append(t)
return cleaned
# ---------------- TRAIN ----------------
MAX_STEPS = 1200
for step in range(MAX_STEPS):
# pick random Spider example
example = random.choice(dataset)
question = example["question"]
gold_sql = example["query"]
db_id = example["db_id"]
db_path = f"data/database/{db_id}/{db_id}.sqlite"
# tokenize
enc = tokenizer(build_prompt(example), return_tensors="pt")
query_tensor = enc.input_ids.to(device)
query_tensors = [query_tensor[0]]
# generate SQL
response_tensors = generate_sql(query_tensors)
pred_sql = tokenizer.decode(response_tensors[0], skip_special_tokens=True)
# -------- EXECUTION REWARD --------
reward = execution_reward(pred_sql, gold_sql, db_path)
reward_tensor = torch.tensor([reward], dtype=torch.float32).to(device)
# PPO update
stats = ppo_trainer.step(query_tensors, response_tensors, [reward_tensor])
# stabilize
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# cleanup
del query_tensor, response_tensors, reward_tensor
gc.collect()
if device == "mps":
torch.mps.empty_cache()
# log
if step % 20 == 0:
print(f"\nStep {step}/{MAX_STEPS}")
print("DB:", db_id)
print("Q:", question)
print("Pred:", pred_sql)
print("Gold:", gold_sql)
print("Reward:", reward)
# ---------------- SAVE ----------------
model.save_pretrained("rlhf_text2sql_lora")
tokenizer.save_pretrained("rlhf_text2sql_lora")
print("\nTraining complete — model saved!")