Spaces:
Sleeping
Sleeping
| # ====================================== | |
| # RLHF Text2SQL — FINAL WORKING VERSION | |
| # T5-small + LoRA + PPO + Execution Reward | |
| # Single-sample stable training (Mac MPS safe) | |
| # ====================================== | |
| from execution_reward import execution_reward | |
| import os, gc, json, random, torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from trl import PPOTrainer, PPOConfig | |
| from trl.models.modeling_value_head import AutoModelForSeq2SeqLMWithValueHead | |
| from peft import LoraConfig, get_peft_model | |
| # ---------------- SETTINGS ---------------- | |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| print("Using device:", device) | |
| os.makedirs("rlhf_text2sql_lora", exist_ok=True) | |
| # ---------------- MODEL ---------------- | |
| model_name = "google/flan-t5-small" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| # LoRA | |
| lora_config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| target_modules=["q","v"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="SEQ_2_SEQ_LM", | |
| ) | |
| base_model = get_peft_model(base_model, lora_config) | |
| model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(base_model).to(device) | |
| ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_name).to(device) | |
| model.config.use_cache = False | |
| ref_model.config.use_cache = False | |
| # ---------------- DATA ---------------- | |
| with open("data/train_spider.json") as f: | |
| dataset = json.load(f) | |
| def build_prompt(example): | |
| return f"Translate to SQL: {example['question']}" | |
| # ---------------- PPO ---------------- | |
| ppo_config = PPOConfig( | |
| batch_size=1, | |
| mini_batch_size=1, | |
| learning_rate=2e-6, | |
| target_kl=0.05, | |
| adap_kl_ctrl=True, | |
| init_kl_coef=0.2, | |
| ) | |
| ppo_trainer = PPOTrainer( | |
| config=ppo_config, | |
| model=model, | |
| ref_model=ref_model, | |
| tokenizer=tokenizer, | |
| ) | |
| # ---------------- GENERATION ---------------- | |
| def generate_sql(query_tensors): | |
| # deterministic decoding = prevents NaN explosion | |
| with torch.no_grad(): | |
| response_tensors = ppo_trainer.generate( | |
| query_tensors, | |
| max_new_tokens=64, | |
| # 🔴 CRITICAL: disable sampling | |
| do_sample=False, | |
| # stable decoding | |
| num_beams=1, | |
| early_stopping=True, | |
| # prevents invalid tokens | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # extra safety (important on MPS) | |
| cleaned = [] | |
| for t in response_tensors: | |
| t = torch.nan_to_num(t, nan=0, posinf=0, neginf=0) | |
| cleaned.append(t) | |
| return cleaned | |
| # ---------------- TRAIN ---------------- | |
| MAX_STEPS = 1200 | |
| for step in range(MAX_STEPS): | |
| # pick random Spider example | |
| example = random.choice(dataset) | |
| question = example["question"] | |
| gold_sql = example["query"] | |
| db_id = example["db_id"] | |
| db_path = f"data/database/{db_id}/{db_id}.sqlite" | |
| # tokenize | |
| enc = tokenizer(build_prompt(example), return_tensors="pt") | |
| query_tensor = enc.input_ids.to(device) | |
| query_tensors = [query_tensor[0]] | |
| # generate SQL | |
| response_tensors = generate_sql(query_tensors) | |
| pred_sql = tokenizer.decode(response_tensors[0], skip_special_tokens=True) | |
| # -------- EXECUTION REWARD -------- | |
| reward = execution_reward(pred_sql, gold_sql, db_path) | |
| reward_tensor = torch.tensor([reward], dtype=torch.float32).to(device) | |
| # PPO update | |
| stats = ppo_trainer.step(query_tensors, response_tensors, [reward_tensor]) | |
| # stabilize | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| # cleanup | |
| del query_tensor, response_tensors, reward_tensor | |
| gc.collect() | |
| if device == "mps": | |
| torch.mps.empty_cache() | |
| # log | |
| if step % 20 == 0: | |
| print(f"\nStep {step}/{MAX_STEPS}") | |
| print("DB:", db_id) | |
| print("Q:", question) | |
| print("Pred:", pred_sql) | |
| print("Gold:", gold_sql) | |
| print("Reward:", reward) | |
| # ---------------- SAVE ---------------- | |
| model.save_pretrained("rlhf_text2sql_lora") | |
| tokenizer.save_pretrained("rlhf_text2sql_lora") | |
| print("\nTraining complete — model saved!") |