Spaces:
Running
Running
| # ========================================================= | |
| # RLHF TRAINING FOR TEXT2SQL (STABLE PPO VERSION) | |
| # ========================================================= | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList | |
| from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead | |
| from peft import PeftModel | |
| import os, sys, sqlite3, re, random | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from execution_reward import execution_reward, extract_tables, extract_columns | |
| try: | |
| import sqlparse # gate PPO updates on parsable SQL only | |
| except Exception: # pragma: no cover | |
| sqlparse = None | |
| # ====================================================== | |
| # DEVICE | |
| # ====================================================== | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using device:", device) | |
| # ====================================================== | |
| # TRAINING SETTINGS | |
| # ====================================================== | |
| NUM_EPOCHS = 5 | |
| LOG_EVERY = 20 | |
| USE_SCHEMA = True | |
| SCHEMA_WARMUP_EPOCHS = 0 | |
| MAX_SCHEMA_CHARS = 1500 | |
| MAX_OUTPUT_TOKENS = 80 | |
| ROLLOUTS_PER_EPOCH = 2048 | |
| # ====================================================== | |
| # PATHS | |
| # ====================================================== | |
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| # 🎯 FIXED: Save ONLY the best model to this exact path | |
| RL_MODEL_PATH = os.path.join(PROJECT_ROOT, "checkpoints", "rlhf_t5_best") | |
| output_dir = RL_MODEL_PATH | |
| DB_ROOT = os.path.join(PROJECT_ROOT, "data/database") | |
| # 🎯 Updated to point to our newly trained t5-small SFT model | |
| ADAPTER_PATH = os.path.join(PROJECT_ROOT, "checkpoints/sft_t5") | |
| FALLBACK_ADAPTER_PATH = os.path.join(PROJECT_ROOT, "models/t5_spider_sft_lora") | |
| FALLBACK_ADAPTER_PATH_2 = os.path.join(PROJECT_ROOT, "outputs/sft_text2sql") | |
| # 🎯 ENSURING t5-small is used | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small") | |
| # ====================================================== | |
| # LOAD MODEL (LoRA) | |
| # ====================================================== | |
| print("Loading base:", BASE_MODEL) | |
| if not os.path.isdir(ADAPTER_PATH): | |
| if os.path.isdir(FALLBACK_ADAPTER_PATH): | |
| ADAPTER_PATH = FALLBACK_ADAPTER_PATH | |
| elif os.path.isdir(FALLBACK_ADAPTER_PATH_2): | |
| ADAPTER_PATH = FALLBACK_ADAPTER_PATH_2 | |
| print("Loading adapters:", ADAPTER_PATH) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device) | |
| model.pretrained_model = PeftModel.from_pretrained(model.pretrained_model, ADAPTER_PATH) | |
| ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device) | |
| ref_model.pretrained_model = PeftModel.from_pretrained(ref_model.pretrained_model, ADAPTER_PATH) | |
| ref_model.eval() | |
| for p in ref_model.parameters(): | |
| p.requires_grad_(False) | |
| # Freeze base transformer weights; train LoRA adapters + value head. | |
| for name, p in model.named_parameters(): | |
| # Train value head | |
| if name.startswith("v_head"): | |
| p.requires_grad = True | |
| # Train LoRA adapters (policy learning!) | |
| elif "lora_" in name: | |
| p.requires_grad = True | |
| # Freeze base model | |
| else: | |
| p.requires_grad = False | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in model.parameters()) | |
| print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)") | |
| model.config.use_cache = False | |
| ref_model.config.use_cache = False | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ====================================================== | |
| # DATASET | |
| # ====================================================== | |
| print("Loading Spider subset...") | |
| random.seed(0) | |
| # Train on a small, stable curriculum of DBs first. | |
| TRAIN_DBS = [ | |
| "flight_1", | |
| "student_assessment", | |
| "store_1", | |
| "bike_1", | |
| "book_2", | |
| "chinook_1", | |
| ] | |
| dataset = load_dataset("spider", split="train") | |
| _TRAIN_DBS_SET = set(TRAIN_DBS) | |
| dataset = dataset.filter(lambda x: x["db_id"] in _TRAIN_DBS_SET) | |
| dataset = dataset.select(range(min(800, len(dataset)))) | |
| print("Using RLHF DBs:", TRAIN_DBS) | |
| print("Filtered size:", len(dataset)) | |
| total_steps = ROLLOUTS_PER_EPOCH | |
| # ====================================================== | |
| # DB UTILITIES | |
| # ====================================================== | |
| def get_db_path(db_id): | |
| return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite") | |
| def get_db_schema(db_path): | |
| schema_text = "" | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| tables = cursor.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table';" | |
| ).fetchall() | |
| for table in tables: | |
| table_name = table[0] | |
| columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall() | |
| col_names = [col[1] for col in columns] | |
| schema_text += f"{table_name}({', '.join(col_names)}) " | |
| conn.close() | |
| except: | |
| pass | |
| return schema_text | |
| # ====================================================== | |
| # PROMPT | |
| # ====================================================== | |
| PREFIX = "translate English to SQL:" | |
| def trim_schema(schema: str, max_chars: int = 1200) -> str: | |
| if schema is None: | |
| return "" | |
| schema = str(schema) | |
| if len(schema) <= max_chars: | |
| return schema | |
| return schema[:max_chars] | |
| def build_prompt(question: str, schema: str, use_schema: bool) -> str: | |
| if not use_schema: | |
| return f"{PREFIX}\n\nQuestion:\n{question}\n\nSQL:" | |
| schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS) | |
| return f"{PREFIX}\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:" | |
| def encode_prompt(question, schema, use_schema): | |
| # Never truncate the question; only truncate schema tokens if needed. | |
| if not use_schema: | |
| prompt = build_prompt(question, schema, use_schema=False) | |
| return tokenizer(prompt, return_tensors="pt", truncation=True).input_ids[0].to(device) | |
| schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS) | |
| prefix_schema = f"{PREFIX}\n\nSchema:\n" | |
| mid = "\n\nQuestion:\n" | |
| suffix = f"{question}\n\nSQL:" | |
| prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False) | |
| schema_ids = tokenizer.encode(schema, add_special_tokens=False) | |
| mid_ids = tokenizer.encode(mid, add_special_tokens=False) | |
| suffix_ids = tokenizer.encode(suffix, add_special_tokens=False) | |
| max_len = getattr(tokenizer, "model_max_length", 512) | |
| eos_id = tokenizer.eos_token_id | |
| max_without_eos = max_len - (1 if eos_id is not None else 0) | |
| # Ensure the question+SQL suffix always fits; truncate schema first. | |
| fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids) | |
| if fixed_len > max_without_eos: | |
| # Extremely rare; clip the suffix (question) only if unavoidable. | |
| keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids))) | |
| suffix_ids = suffix_ids[:keep] | |
| fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids) | |
| remaining_for_schema = max_without_eos - fixed_len | |
| if remaining_for_schema < 0: | |
| remaining_for_schema = 0 | |
| schema_ids = schema_ids[:remaining_for_schema] | |
| ids = prefix_ids + schema_ids + mid_ids + suffix_ids | |
| ids = ids[:max_without_eos] | |
| if eos_id is not None: | |
| ids = ids + [eos_id] | |
| return torch.tensor(ids, dtype=torch.long).to(device) | |
| # ====================================================== | |
| # SQL CONSTRAINED DECODING | |
| # ====================================================== | |
| SQL_KEYWORDS = [ | |
| "select", "from", "where", "join", "inner", "left", "right", | |
| "full", "outer", "on", "group", "by", "order", "having", | |
| "limit", "distinct", "as", "and", "or", "not", "in", "is", | |
| "null", "like", "between", "asc", "desc", "union", | |
| "intersect", "except", | |
| ] | |
| SQL_OPERATORS = ["*", ",", ".", "(", ")", "=", "<", ">", "!", "+", "-", "/", "%", "_"] | |
| def _piece_token_str(tok: str) -> str: | |
| # T5 SentencePiece: "▁" marks a leading space; strip it for char checks. | |
| return tok.lstrip("▁") | |
| def _precompute_always_allowed_token_ids(): | |
| vocab_size = len(tokenizer) | |
| allowed = set() | |
| # Always allow special tokens. | |
| for tid in [tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.unk_token_id]: | |
| if tid is not None and tid >= 0: | |
| allowed.add(int(tid)) | |
| # Allow whitespace/newlines in case they exist as pieces. | |
| for s in [" ", "\n", "\t"]: | |
| allowed.update(tokenizer.encode(s, add_special_tokens=False)) | |
| # Allow operators/punctuation/numeric pieces broadly. | |
| op_chars = set("".join(SQL_OPERATORS)) | |
| for tid in range(vocab_size): | |
| tok = tokenizer.convert_ids_to_tokens(tid) | |
| if not isinstance(tok, str) or not tok: | |
| continue | |
| piece = _piece_token_str(tok) | |
| if not piece: | |
| continue | |
| if all((ch in op_chars) for ch in piece): | |
| allowed.add(tid) | |
| continue | |
| if piece.isdigit(): | |
| allowed.add(tid) | |
| continue | |
| # Common numeric fragments like "1", "00", etc. | |
| if all(ch.isdigit() for ch in piece): | |
| allowed.add(tid) | |
| # Allow keyword pieces. | |
| for kw in SQL_KEYWORDS: | |
| for variant in {kw, kw.upper(), kw.capitalize()}: | |
| allowed.update(tokenizer.encode(" " + variant, add_special_tokens=False)) | |
| allowed.update(tokenizer.encode(variant, add_special_tokens=False)) | |
| return allowed | |
| ALWAYS_ALLOWED_TOKEN_IDS = _precompute_always_allowed_token_ids() | |
| def _schema_allowed_token_ids(table_names, column_names): | |
| allowed = set(ALWAYS_ALLOWED_TOKEN_IDS) | |
| def _add_identifier(name: str): | |
| if not name: | |
| return | |
| # Add whole identifier and common splits. | |
| variants = {name, name.lower(), name.upper()} | |
| parts = re.split(r"[_\s]+", name) | |
| variants.update({p for p in parts if p}) | |
| for v in variants: | |
| allowed.update(tokenizer.encode(" " + v, add_special_tokens=False)) | |
| allowed.update(tokenizer.encode(v, add_special_tokens=False)) | |
| for t in table_names: | |
| _add_identifier(t) | |
| for c in column_names: | |
| _add_identifier(c) | |
| return allowed | |
| class SQLVocabularyLogitsProcessor(LogitsProcessor): | |
| def __init__(self, allowed_token_ids): | |
| self.allowed_token_ids = {int(i) for i in allowed_token_ids if int(i) >= 0} | |
| self._bias = None | |
| self._bias_vocab_size = None | |
| def _get_bias(self, scores: torch.Tensor) -> torch.Tensor: | |
| vocab_size = int(scores.shape[-1]) | |
| if ( | |
| self._bias is None | |
| or self._bias.device != scores.device | |
| or self._bias.dtype != scores.dtype | |
| or self._bias_vocab_size != vocab_size | |
| ): | |
| bias = torch.full((vocab_size,), float("-inf"), device=scores.device, dtype=scores.dtype) | |
| for tid in self.allowed_token_ids: | |
| if tid < vocab_size: | |
| bias[tid] = 0.0 | |
| self._bias = bias | |
| self._bias_vocab_size = vocab_size | |
| return self._bias | |
| def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: | |
| return scores + self._get_bias(scores) | |
| _DB_VOCAB_CACHE = {} | |
| def get_db_tables_columns(db_path: str): | |
| if db_path in _DB_VOCAB_CACHE: | |
| return _DB_VOCAB_CACHE[db_path] | |
| tables, cols = [], [] | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cur = conn.cursor() | |
| for (tname,) in cur.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" | |
| ).fetchall(): | |
| if not tname: | |
| continue | |
| tables.append(tname) | |
| try: | |
| for row in cur.execute(f'PRAGMA table_info("{tname}")').fetchall(): | |
| if row and isinstance(row[1], str): | |
| cols.append(row[1]) | |
| except Exception: | |
| continue | |
| conn.close() | |
| except Exception: | |
| pass | |
| _DB_VOCAB_CACHE[db_path] = (tables, cols) | |
| return tables, cols | |
| # ====================================================== | |
| # PPO CONFIG (stable learning) | |
| # ====================================================== | |
| ppo_config = PPOConfig( | |
| learning_rate=2e-5, # was 1e-6 → model could not move | |
| batch_size=8, # better gradient estimate | |
| mini_batch_size=2, | |
| gradient_accumulation_steps=2, # stable updates on small data | |
| ppo_epochs=1, | |
| # --- KL control (MOST IMPORTANT FIX) --- | |
| init_kl_coef=0.05, # reduce punishment | |
| target_kl=0.15, # relax constraint to avoid skipped updates | |
| adap_kl_ctrl=True, | |
| # --- stability --- | |
| cliprange=0.25, | |
| cliprange_value=0.25, | |
| whiten_rewards=True, | |
| kl_penalty="kl", | |
| max_grad_norm=1.0, | |
| ) | |
| trainer = PPOTrainer( | |
| config=ppo_config, | |
| model=model, | |
| ref_model=ref_model, | |
| tokenizer=tokenizer, | |
| ) | |
| optimizer = trainer.optimizer | |
| # Provide `.device` attribute for the supervised anchor helper. | |
| try: | |
| model.device = torch.device(device) | |
| except Exception: | |
| pass | |
| # ====================================================== | |
| # GENERATION (schema-constrained decoding) | |
| # ====================================================== | |
| generation_kwargs = dict( | |
| max_new_tokens=64, # 128 causes garbage SQL loops | |
| do_sample=True, | |
| temperature=0.9, # encourage exploration | |
| top_p=0.95, | |
| top_k=100, | |
| repetition_penalty=1.1, # prevents SELECT SELECT SELECT | |
| no_repeat_ngram_size=3, | |
| num_beams=1, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # ====================================================== | |
| # TRAIN LOOP | |
| # ====================================================== | |
| print("Starting RL training 🚀") | |
| query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], [] | |
| query_text_buffer = [] | |
| best_reward = -999999 | |
| best_epoch = -1 | |
| def _is_parsable_sql(sql: str) -> bool: | |
| s = (sql or "").strip() | |
| if not s: | |
| return False | |
| up = s.upper() | |
| if "SELECT" not in up or "FROM" not in up: | |
| return False | |
| if sqlparse is None: | |
| return True | |
| try: | |
| return bool(sqlparse.parse(s)) | |
| except Exception: | |
| return False | |
| def _pad_2d(seqs, pad_id: int): | |
| max_len = max(int(s.numel()) for s in seqs) | |
| out = torch.full((len(seqs), max_len), int(pad_id), dtype=torch.long, device=device) | |
| attn = torch.zeros((len(seqs), max_len), dtype=torch.long, device=device) | |
| for i, s in enumerate(seqs): | |
| n = int(s.numel()) | |
| out[i, :n] = s.to(device) | |
| attn[i, :n] = 1 | |
| return out, attn | |
| def _shift_right(labels: torch.Tensor, start_id: int) -> torch.Tensor: | |
| dec = labels.clone() | |
| dec[:, 1:] = labels[:, :-1] | |
| dec[:, 0] = int(start_id) | |
| return dec | |
| def safe_get_kl(stats): | |
| if not isinstance(stats, dict): | |
| return None | |
| for k in stats.keys(): | |
| if "kl" in str(k).lower(): | |
| v = stats[k] | |
| try: | |
| return float(v.item() if hasattr(v, "item") else v) | |
| except Exception: | |
| return None | |
| return None | |
| def supervised_anchor_step(model, tokenizer, queries, gold_sqls, weight=0.05): | |
| model.train() | |
| total_loss = 0.0 | |
| for q, gold in zip(queries, gold_sqls): | |
| enc = tokenizer(q, return_tensors="pt", truncation=True).to(model.device) | |
| dec = tokenizer(text_target=gold, return_tensors="pt", truncation=True) | |
| labels = dec.input_ids.to(model.device) | |
| # teacher forcing shift | |
| decoder_input_ids = labels[:, :-1].contiguous() | |
| target_ids = labels[:, 1:].contiguous() | |
| outputs = model( | |
| input_ids=enc.input_ids, | |
| attention_mask=enc.attention_mask, | |
| decoder_input_ids=decoder_input_ids, | |
| ) | |
| logits = outputs[0] | |
| vocab_size = logits.size(-1) | |
| loss = F.cross_entropy( | |
| logits.view(-1, vocab_size), | |
| target_ids.view(-1), | |
| ignore_index=tokenizer.pad_token_id, | |
| ) | |
| (loss * weight).backward() | |
| total_loss += loss.item() | |
| return total_loss | |
| def _estimate_policy_entropy(query_tensors, response_tensors) -> torch.Tensor: | |
| """ | |
| Returns per-sample average token entropy of the policy on the sampled response tokens. | |
| Used as a small bonus to reduce repetition collapse. | |
| """ | |
| pad_id = int(tokenizer.pad_token_id) | |
| enc_ids, enc_attn = _pad_2d(query_tensors, pad_id) | |
| dec_ids, dec_attn = _pad_2d(response_tensors, pad_id) | |
| start_id = int(getattr(model.pretrained_model.config, "decoder_start_token_id", pad_id)) | |
| dec_inp = _shift_right(dec_ids, start_id) | |
| out = model.pretrained_model( | |
| input_ids=enc_ids, | |
| attention_mask=enc_attn, | |
| decoder_input_ids=dec_inp, | |
| use_cache=False, | |
| ) | |
| logp = torch.log_softmax(out.logits, dim=-1) | |
| p = torch.exp(logp) | |
| ent = -(p * logp).sum(dim=-1) # [B, T] | |
| # average only over non-pad positions of the sampled response | |
| denom = dec_attn.sum(dim=-1).clamp_min(1) | |
| return (ent * dec_attn).sum(dim=-1) / denom # [B] | |
| def _repeat_penalty(response_tensor: torch.Tensor) -> float: | |
| """ | |
| Penalize repetition to avoid 'SELECT SELECT SELECT' collapse. | |
| Simple heuristic: consecutive duplicate token ratio + low-unique-token ratio. | |
| """ | |
| ids = response_tensor.detach().tolist() | |
| n = len(ids) | |
| if n <= 1: | |
| return 0.0 | |
| consec_dup = 0 | |
| for i in range(1, n): | |
| if ids[i] == ids[i - 1]: | |
| consec_dup += 1 | |
| unique_ratio = len(set(ids)) / n | |
| consec_ratio = consec_dup / (n - 1) | |
| # Higher penalty when low unique + high consecutive duplicates | |
| return float(0.5 * consec_ratio + 0.5 * (1.0 - unique_ratio)) | |
| def _supervised_anchor_step(query_tensors, gold_sql_texts, weight: float = 0.05) -> None: | |
| """ | |
| Small teacher-forcing step on gold SQL to anchor grammar during PPO. | |
| Runs only if PPOTrainer exposes (accelerator, optimizer). | |
| """ | |
| if not gold_sql_texts: | |
| return | |
| accelerator = getattr(trainer, "accelerator", None) | |
| optimizer = getattr(trainer, "optimizer", None) | |
| if accelerator is None or optimizer is None: | |
| return | |
| pad_id = int(tokenizer.pad_token_id) | |
| enc_ids, enc_attn = _pad_2d(query_tensors, pad_id) | |
| # Tokenize gold SQL targets (decoder side) | |
| gold_ids = [] | |
| for s in gold_sql_texts: | |
| g = (s or "").strip() | |
| if not g: | |
| g = "SELECT 1" | |
| ids = tokenizer.encode(g, add_special_tokens=False)[:256] | |
| if tokenizer.eos_token_id is not None: | |
| ids = ids + [int(tokenizer.eos_token_id)] | |
| gold_ids.append(torch.tensor(ids, dtype=torch.long)) | |
| dec_ids, dec_attn = _pad_2d(gold_ids, pad_id) | |
| labels = dec_ids.clone() | |
| labels[dec_attn == 0] = -100 | |
| # PEFT model forward supports labels -> returns loss | |
| out = model.pretrained_model( | |
| input_ids=enc_ids, | |
| attention_mask=enc_attn, | |
| labels=labels, | |
| use_cache=False, | |
| ) | |
| loss = out.loss * float(weight) | |
| optimizer.zero_grad(set_to_none=True) if hasattr(optimizer, "zero_grad") else None | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| def _curriculum_allows(gold_sql: str, epoch_num: int) -> bool: | |
| gold_up = (gold_sql or "").upper() | |
| has_join = "JOIN" in gold_up | |
| has_set_op = any(op in gold_up for op in ["UNION", "INTERSECT", "EXCEPT"]) | |
| tables = extract_tables(gold_sql) | |
| single_table = len(tables) <= 1 and (not has_join) | |
| # Epoch 1: only single-table, no joins/set-ops. | |
| if epoch_num == 1: | |
| return single_table and (not has_set_op) | |
| # Epoch 2: allow joins, but still avoid set-ops. | |
| if epoch_num == 2: | |
| return (single_table or has_join) and (not has_set_op) | |
| # Epoch 3+: full dataset. | |
| return True | |
| for epoch in range(1, NUM_EPOCHS + 1): | |
| use_schema_this_epoch = USE_SCHEMA and (epoch > SCHEMA_WARMUP_EPOCHS) | |
| epoch_reward_sum = 0 | |
| negative_rewards = 0 | |
| partial_rewards = 0 | |
| correct_rewards = 0 | |
| total_considered = 0 | |
| valid_sql_count = 0 | |
| exec_correct_count = 0 | |
| table_overlap_sum = 0.0 | |
| column_overlap_sum = 0.0 | |
| kl_values = [] | |
| for step in range(1, total_steps + 1): | |
| example = dataset[random.randrange(len(dataset))] | |
| question = example["question"] | |
| gold_sql = example["query"] | |
| db_id = example["db_id"] | |
| db_path = get_db_path(db_id) | |
| # NOTE: sampling-with-replacement provides more rollouts per epoch. | |
| schema = get_db_schema(db_path) | |
| question_text = build_prompt(question, schema, use_schema_this_epoch) | |
| query_tensor = encode_prompt(question, schema, use_schema_this_epoch) | |
| # ----- generate ----- | |
| table_names, column_names = get_db_tables_columns(db_path) | |
| allowed_ids = _schema_allowed_token_ids(table_names, column_names) | |
| logits_processor = LogitsProcessorList([SQLVocabularyLogitsProcessor(allowed_ids)]) | |
| response = trainer.generate([query_tensor], logits_processor=logits_processor, **generation_kwargs)[0] | |
| response_tensor = response.squeeze(0)[:MAX_OUTPUT_TOKENS] | |
| pred_sql = tokenizer.decode(response_tensor.cpu(), skip_special_tokens=True) | |
| total_considered += 1 | |
| # PPO must optimize ONLY when SQL parses successfully. | |
| if not _is_parsable_sql(pred_sql): | |
| negative_rewards += 1 | |
| continue | |
| # Reject generations shorter than 6 tokens. | |
| if int(response_tensor.numel()) < 6: | |
| negative_rewards += 1 | |
| continue | |
| # ----- reward ----- | |
| reward_value = execution_reward(pred_sql, db_path, gold_sql) | |
| # SQL validity gate: if invalid/unparsable -> reward_value is None -> skip PPO entirely. | |
| if reward_value is None: | |
| if step % 100 == 0: | |
| ratio = valid_sql_count / max(total_considered, 1) | |
| print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})") | |
| if ratio < 0.15: | |
| print("MODEL COLLAPSING") | |
| continue | |
| # Clip rewards to [-1, 1] | |
| reward_value = float(max(-1.0, min(1.0, reward_value))) | |
| # Penalize repetition in decoded output (token-level heuristic). | |
| reward_value = float(max(-1.0, min(1.0, reward_value - 0.2 * _repeat_penalty(response_tensor)))) | |
| # Keep rewards on CPU for normalization; move to device only for trainer.step(). | |
| reward_tensor = torch.tensor(reward_value, dtype=torch.float32) | |
| epoch_reward_sum += reward_value | |
| # ----- metrics ----- | |
| # "Valid sample" means reward is not None (parsable SQL). | |
| valid_sql_count += 1 | |
| pred_tables = extract_tables(pred_sql) | |
| gold_tables = extract_tables(gold_sql) | |
| pred_cols = extract_columns(pred_sql) | |
| gold_cols = extract_columns(gold_sql) | |
| if len(gold_tables) > 0: | |
| table_overlap_sum += len(pred_tables & gold_tables) / max(len(gold_tables), 1) | |
| if len(gold_cols) > 0: | |
| column_overlap_sum += len(pred_cols & gold_cols) / max(len(gold_cols), 1) | |
| # execution_reward returns 1.0 for correct execution result. | |
| if reward_value >= 1.0: | |
| exec_correct_count += 1 | |
| if reward_value <= -1.0: | |
| negative_rewards += 1 | |
| elif reward_value >= 1.0: | |
| correct_rewards += 1 | |
| else: | |
| partial_rewards += 1 | |
| # Train only on informative samples: | |
| # - invalid SQL already skipped (reward is None) | |
| # - very small magnitude signal skipped | |
| if abs(reward_value) < 0.1: | |
| continue | |
| query_buffer.append(query_tensor) | |
| response_buffer.append(response_tensor) | |
| reward_buffer.append(reward_tensor) | |
| gold_buffer.append(gold_sql) | |
| query_text_buffer.append(question_text) | |
| # ----- PPO update ----- | |
| if len(query_buffer) == ppo_config.batch_size: | |
| # move rewards to device | |
| reward_buffer = [r.to(device) for r in reward_buffer] | |
| # run PPO step | |
| stats = trainer.step(query_buffer, response_buffer, reward_buffer) | |
| # log KL safely (no control logic) | |
| kl = safe_get_kl(stats) | |
| if kl is not None: | |
| kl_values.append(kl) | |
| # --- supervised anchor to prevent grammar collapse --- | |
| supervised_anchor_step(model, tokenizer, query_text_buffer, gold_buffer, weight=0.05) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # reset buffers | |
| query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], [] | |
| query_text_buffer = [] | |
| # ----- learning ratio logging ----- | |
| if step % 100 == 0: | |
| ratio = valid_sql_count / max(total_considered, 1) | |
| print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})") | |
| if ratio < 0.15: | |
| print("MODEL COLLAPSING") | |
| # Increase KL coefficient dynamically when valid_sql_rate drops. | |
| try: | |
| if hasattr(trainer, "kl_ctl") and hasattr(trainer.kl_ctl, "value"): | |
| trainer.kl_ctl.value *= 1.5 | |
| print(f"Increasing KL coef -> {trainer.kl_ctl.value:.4f}") | |
| except Exception: | |
| pass | |
| # ----- logging ----- | |
| if step % LOG_EVERY == 0: | |
| avg_reward = epoch_reward_sum / step | |
| print("\n---------------------------") | |
| print(f"Epoch {epoch}/{NUM_EPOCHS} | Step {step}/{total_steps} | Avg Reward {avg_reward:.3f}") | |
| print("DB:", db_id) | |
| print("Q:", question) | |
| print("SQL:", pred_sql) | |
| print("Reward:", reward_value) | |
| # epoch stats | |
| print(f"\nEpoch {epoch} stats:") | |
| print("negative:", negative_rewards) | |
| print("partial:", partial_rewards) | |
| print("correct:", correct_rewards) | |
| denom = max(total_considered, 1) | |
| print("\nEpoch metrics:") | |
| print(f"execution_accuracy: {exec_correct_count/denom:.3f}") | |
| print(f"valid_sql_rate: {valid_sql_count/denom:.3f}") | |
| print(f"table_match_rate: {table_overlap_sum/denom:.3f}") | |
| print(f"column_match_rate: {column_overlap_sum/denom:.3f}") | |
| print(f"avg_reward: {epoch_reward_sum/max(denom,1):.3f}") | |
| if kl_values: | |
| avg_kl = sum(kl_values) / max(len(kl_values), 1) | |
| print(f"avg_kl: {avg_kl:.3f}") | |
| if avg_kl < -8: | |
| print("\nKL collapse guard triggered (avg_kl < -8). Stopping early.") | |
| break | |
| # 🎯 FIXED: Removed the code that saved intermediate checkpoints at the end of each epoch | |
| # Only save if this epoch is the best one so far | |
| epoch_avg_reward = epoch_reward_sum / max(denom, 1) | |
| if epoch_avg_reward > best_reward: | |
| best_reward = epoch_avg_reward | |
| best_epoch = epoch | |
| print(f"\nNew best model at epoch {epoch} with reward {best_reward:.4f}") | |
| # 🎯 FIXED: Save directly to checkpoints/rlhf_t5_best, overwriting if needed | |
| os.makedirs(output_dir, exist_ok=True) | |
| trainer.model.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| print(f"\nTraining finished.") | |
| print(f"Best epoch: {best_epoch}") | |
| print(f"Best reward: {best_reward:.4f}") | |
| print(f"Best model saved at: {output_dir}") |