text2sql-demo / src /train_rl.py
tjhalanigrid's picture
Add src folder
dc59b01
# =========================================================
# RLHF TRAINING FOR TEXT2SQL (STABLE PPO VERSION)
# =========================================================
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from peft import PeftModel
import os, sys, sqlite3, re, random
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from execution_reward import execution_reward, extract_tables, extract_columns
try:
import sqlparse # gate PPO updates on parsable SQL only
except Exception: # pragma: no cover
sqlparse = None
# ======================================================
# DEVICE
# ======================================================
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ======================================================
# TRAINING SETTINGS
# ======================================================
NUM_EPOCHS = 5
LOG_EVERY = 20
USE_SCHEMA = True
SCHEMA_WARMUP_EPOCHS = 0
MAX_SCHEMA_CHARS = 1500
MAX_OUTPUT_TOKENS = 80
ROLLOUTS_PER_EPOCH = 2048
# ======================================================
# PATHS
# ======================================================
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 🎯 FIXED: Save ONLY the best model to this exact path
RL_MODEL_PATH = os.path.join(PROJECT_ROOT, "checkpoints", "rlhf_t5_best")
output_dir = RL_MODEL_PATH
DB_ROOT = os.path.join(PROJECT_ROOT, "data/database")
# 🎯 Updated to point to our newly trained t5-small SFT model
ADAPTER_PATH = os.path.join(PROJECT_ROOT, "checkpoints/sft_t5")
FALLBACK_ADAPTER_PATH = os.path.join(PROJECT_ROOT, "models/t5_spider_sft_lora")
FALLBACK_ADAPTER_PATH_2 = os.path.join(PROJECT_ROOT, "outputs/sft_text2sql")
# 🎯 ENSURING t5-small is used
BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
# ======================================================
# LOAD MODEL (LoRA)
# ======================================================
print("Loading base:", BASE_MODEL)
if not os.path.isdir(ADAPTER_PATH):
if os.path.isdir(FALLBACK_ADAPTER_PATH):
ADAPTER_PATH = FALLBACK_ADAPTER_PATH
elif os.path.isdir(FALLBACK_ADAPTER_PATH_2):
ADAPTER_PATH = FALLBACK_ADAPTER_PATH_2
print("Loading adapters:", ADAPTER_PATH)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device)
model.pretrained_model = PeftModel.from_pretrained(model.pretrained_model, ADAPTER_PATH)
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(BASE_MODEL).to(device)
ref_model.pretrained_model = PeftModel.from_pretrained(ref_model.pretrained_model, ADAPTER_PATH)
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad_(False)
# Freeze base transformer weights; train LoRA adapters + value head.
for name, p in model.named_parameters():
# Train value head
if name.startswith("v_head"):
p.requires_grad = True
# Train LoRA adapters (policy learning!)
elif "lora_" in name:
p.requires_grad = True
# Freeze base model
else:
p.requires_grad = False
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable}/{total} ({100*trainable/total:.2f}%)")
model.config.use_cache = False
ref_model.config.use_cache = False
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
# ======================================================
# DATASET
# ======================================================
print("Loading Spider subset...")
random.seed(0)
# Train on a small, stable curriculum of DBs first.
TRAIN_DBS = [
"flight_1",
"student_assessment",
"store_1",
"bike_1",
"book_2",
"chinook_1",
]
dataset = load_dataset("spider", split="train")
_TRAIN_DBS_SET = set(TRAIN_DBS)
dataset = dataset.filter(lambda x: x["db_id"] in _TRAIN_DBS_SET)
dataset = dataset.select(range(min(800, len(dataset))))
print("Using RLHF DBs:", TRAIN_DBS)
print("Filtered size:", len(dataset))
total_steps = ROLLOUTS_PER_EPOCH
# ======================================================
# DB UTILITIES
# ======================================================
def get_db_path(db_id):
return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
def get_db_schema(db_path):
schema_text = ""
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
for table in tables:
table_name = table[0]
columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
col_names = [col[1] for col in columns]
schema_text += f"{table_name}({', '.join(col_names)}) "
conn.close()
except:
pass
return schema_text
# ======================================================
# PROMPT
# ======================================================
PREFIX = "translate English to SQL:"
def trim_schema(schema: str, max_chars: int = 1200) -> str:
if schema is None:
return ""
schema = str(schema)
if len(schema) <= max_chars:
return schema
return schema[:max_chars]
def build_prompt(question: str, schema: str, use_schema: bool) -> str:
if not use_schema:
return f"{PREFIX}\n\nQuestion:\n{question}\n\nSQL:"
schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
return f"{PREFIX}\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:"
def encode_prompt(question, schema, use_schema):
# Never truncate the question; only truncate schema tokens if needed.
if not use_schema:
prompt = build_prompt(question, schema, use_schema=False)
return tokenizer(prompt, return_tensors="pt", truncation=True).input_ids[0].to(device)
schema = trim_schema(schema, max_chars=MAX_SCHEMA_CHARS)
prefix_schema = f"{PREFIX}\n\nSchema:\n"
mid = "\n\nQuestion:\n"
suffix = f"{question}\n\nSQL:"
prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
schema_ids = tokenizer.encode(schema, add_special_tokens=False)
mid_ids = tokenizer.encode(mid, add_special_tokens=False)
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
max_len = getattr(tokenizer, "model_max_length", 512)
eos_id = tokenizer.eos_token_id
max_without_eos = max_len - (1 if eos_id is not None else 0)
# Ensure the question+SQL suffix always fits; truncate schema first.
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
if fixed_len > max_without_eos:
# Extremely rare; clip the suffix (question) only if unavoidable.
keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
suffix_ids = suffix_ids[:keep]
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
remaining_for_schema = max_without_eos - fixed_len
if remaining_for_schema < 0:
remaining_for_schema = 0
schema_ids = schema_ids[:remaining_for_schema]
ids = prefix_ids + schema_ids + mid_ids + suffix_ids
ids = ids[:max_without_eos]
if eos_id is not None:
ids = ids + [eos_id]
return torch.tensor(ids, dtype=torch.long).to(device)
# ======================================================
# SQL CONSTRAINED DECODING
# ======================================================
SQL_KEYWORDS = [
"select", "from", "where", "join", "inner", "left", "right",
"full", "outer", "on", "group", "by", "order", "having",
"limit", "distinct", "as", "and", "or", "not", "in", "is",
"null", "like", "between", "asc", "desc", "union",
"intersect", "except",
]
SQL_OPERATORS = ["*", ",", ".", "(", ")", "=", "<", ">", "!", "+", "-", "/", "%", "_"]
def _piece_token_str(tok: str) -> str:
# T5 SentencePiece: "▁" marks a leading space; strip it for char checks.
return tok.lstrip("▁")
def _precompute_always_allowed_token_ids():
vocab_size = len(tokenizer)
allowed = set()
# Always allow special tokens.
for tid in [tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.unk_token_id]:
if tid is not None and tid >= 0:
allowed.add(int(tid))
# Allow whitespace/newlines in case they exist as pieces.
for s in [" ", "\n", "\t"]:
allowed.update(tokenizer.encode(s, add_special_tokens=False))
# Allow operators/punctuation/numeric pieces broadly.
op_chars = set("".join(SQL_OPERATORS))
for tid in range(vocab_size):
tok = tokenizer.convert_ids_to_tokens(tid)
if not isinstance(tok, str) or not tok:
continue
piece = _piece_token_str(tok)
if not piece:
continue
if all((ch in op_chars) for ch in piece):
allowed.add(tid)
continue
if piece.isdigit():
allowed.add(tid)
continue
# Common numeric fragments like "1", "00", etc.
if all(ch.isdigit() for ch in piece):
allowed.add(tid)
# Allow keyword pieces.
for kw in SQL_KEYWORDS:
for variant in {kw, kw.upper(), kw.capitalize()}:
allowed.update(tokenizer.encode(" " + variant, add_special_tokens=False))
allowed.update(tokenizer.encode(variant, add_special_tokens=False))
return allowed
ALWAYS_ALLOWED_TOKEN_IDS = _precompute_always_allowed_token_ids()
def _schema_allowed_token_ids(table_names, column_names):
allowed = set(ALWAYS_ALLOWED_TOKEN_IDS)
def _add_identifier(name: str):
if not name:
return
# Add whole identifier and common splits.
variants = {name, name.lower(), name.upper()}
parts = re.split(r"[_\s]+", name)
variants.update({p for p in parts if p})
for v in variants:
allowed.update(tokenizer.encode(" " + v, add_special_tokens=False))
allowed.update(tokenizer.encode(v, add_special_tokens=False))
for t in table_names:
_add_identifier(t)
for c in column_names:
_add_identifier(c)
return allowed
class SQLVocabularyLogitsProcessor(LogitsProcessor):
def __init__(self, allowed_token_ids):
self.allowed_token_ids = {int(i) for i in allowed_token_ids if int(i) >= 0}
self._bias = None
self._bias_vocab_size = None
def _get_bias(self, scores: torch.Tensor) -> torch.Tensor:
vocab_size = int(scores.shape[-1])
if (
self._bias is None
or self._bias.device != scores.device
or self._bias.dtype != scores.dtype
or self._bias_vocab_size != vocab_size
):
bias = torch.full((vocab_size,), float("-inf"), device=scores.device, dtype=scores.dtype)
for tid in self.allowed_token_ids:
if tid < vocab_size:
bias[tid] = 0.0
self._bias = bias
self._bias_vocab_size = vocab_size
return self._bias
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
return scores + self._get_bias(scores)
_DB_VOCAB_CACHE = {}
def get_db_tables_columns(db_path: str):
if db_path in _DB_VOCAB_CACHE:
return _DB_VOCAB_CACHE[db_path]
tables, cols = [], []
try:
conn = sqlite3.connect(db_path)
cur = conn.cursor()
for (tname,) in cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
).fetchall():
if not tname:
continue
tables.append(tname)
try:
for row in cur.execute(f'PRAGMA table_info("{tname}")').fetchall():
if row and isinstance(row[1], str):
cols.append(row[1])
except Exception:
continue
conn.close()
except Exception:
pass
_DB_VOCAB_CACHE[db_path] = (tables, cols)
return tables, cols
# ======================================================
# PPO CONFIG (stable learning)
# ======================================================
ppo_config = PPOConfig(
learning_rate=2e-5, # was 1e-6 → model could not move
batch_size=8, # better gradient estimate
mini_batch_size=2,
gradient_accumulation_steps=2, # stable updates on small data
ppo_epochs=1,
# --- KL control (MOST IMPORTANT FIX) ---
init_kl_coef=0.05, # reduce punishment
target_kl=0.15, # relax constraint to avoid skipped updates
adap_kl_ctrl=True,
# --- stability ---
cliprange=0.25,
cliprange_value=0.25,
whiten_rewards=True,
kl_penalty="kl",
max_grad_norm=1.0,
)
trainer = PPOTrainer(
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
)
optimizer = trainer.optimizer
# Provide `.device` attribute for the supervised anchor helper.
try:
model.device = torch.device(device)
except Exception:
pass
# ======================================================
# GENERATION (schema-constrained decoding)
# ======================================================
generation_kwargs = dict(
max_new_tokens=64, # 128 causes garbage SQL loops
do_sample=True,
temperature=0.9, # encourage exploration
top_p=0.95,
top_k=100,
repetition_penalty=1.1, # prevents SELECT SELECT SELECT
no_repeat_ngram_size=3,
num_beams=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# ======================================================
# TRAIN LOOP
# ======================================================
print("Starting RL training 🚀")
query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], []
query_text_buffer = []
best_reward = -999999
best_epoch = -1
def _is_parsable_sql(sql: str) -> bool:
s = (sql or "").strip()
if not s:
return False
up = s.upper()
if "SELECT" not in up or "FROM" not in up:
return False
if sqlparse is None:
return True
try:
return bool(sqlparse.parse(s))
except Exception:
return False
def _pad_2d(seqs, pad_id: int):
max_len = max(int(s.numel()) for s in seqs)
out = torch.full((len(seqs), max_len), int(pad_id), dtype=torch.long, device=device)
attn = torch.zeros((len(seqs), max_len), dtype=torch.long, device=device)
for i, s in enumerate(seqs):
n = int(s.numel())
out[i, :n] = s.to(device)
attn[i, :n] = 1
return out, attn
def _shift_right(labels: torch.Tensor, start_id: int) -> torch.Tensor:
dec = labels.clone()
dec[:, 1:] = labels[:, :-1]
dec[:, 0] = int(start_id)
return dec
def safe_get_kl(stats):
if not isinstance(stats, dict):
return None
for k in stats.keys():
if "kl" in str(k).lower():
v = stats[k]
try:
return float(v.item() if hasattr(v, "item") else v)
except Exception:
return None
return None
def supervised_anchor_step(model, tokenizer, queries, gold_sqls, weight=0.05):
model.train()
total_loss = 0.0
for q, gold in zip(queries, gold_sqls):
enc = tokenizer(q, return_tensors="pt", truncation=True).to(model.device)
dec = tokenizer(text_target=gold, return_tensors="pt", truncation=True)
labels = dec.input_ids.to(model.device)
# teacher forcing shift
decoder_input_ids = labels[:, :-1].contiguous()
target_ids = labels[:, 1:].contiguous()
outputs = model(
input_ids=enc.input_ids,
attention_mask=enc.attention_mask,
decoder_input_ids=decoder_input_ids,
)
logits = outputs[0]
vocab_size = logits.size(-1)
loss = F.cross_entropy(
logits.view(-1, vocab_size),
target_ids.view(-1),
ignore_index=tokenizer.pad_token_id,
)
(loss * weight).backward()
total_loss += loss.item()
return total_loss
@torch.no_grad()
def _estimate_policy_entropy(query_tensors, response_tensors) -> torch.Tensor:
"""
Returns per-sample average token entropy of the policy on the sampled response tokens.
Used as a small bonus to reduce repetition collapse.
"""
pad_id = int(tokenizer.pad_token_id)
enc_ids, enc_attn = _pad_2d(query_tensors, pad_id)
dec_ids, dec_attn = _pad_2d(response_tensors, pad_id)
start_id = int(getattr(model.pretrained_model.config, "decoder_start_token_id", pad_id))
dec_inp = _shift_right(dec_ids, start_id)
out = model.pretrained_model(
input_ids=enc_ids,
attention_mask=enc_attn,
decoder_input_ids=dec_inp,
use_cache=False,
)
logp = torch.log_softmax(out.logits, dim=-1)
p = torch.exp(logp)
ent = -(p * logp).sum(dim=-1) # [B, T]
# average only over non-pad positions of the sampled response
denom = dec_attn.sum(dim=-1).clamp_min(1)
return (ent * dec_attn).sum(dim=-1) / denom # [B]
def _repeat_penalty(response_tensor: torch.Tensor) -> float:
"""
Penalize repetition to avoid 'SELECT SELECT SELECT' collapse.
Simple heuristic: consecutive duplicate token ratio + low-unique-token ratio.
"""
ids = response_tensor.detach().tolist()
n = len(ids)
if n <= 1:
return 0.0
consec_dup = 0
for i in range(1, n):
if ids[i] == ids[i - 1]:
consec_dup += 1
unique_ratio = len(set(ids)) / n
consec_ratio = consec_dup / (n - 1)
# Higher penalty when low unique + high consecutive duplicates
return float(0.5 * consec_ratio + 0.5 * (1.0 - unique_ratio))
def _supervised_anchor_step(query_tensors, gold_sql_texts, weight: float = 0.05) -> None:
"""
Small teacher-forcing step on gold SQL to anchor grammar during PPO.
Runs only if PPOTrainer exposes (accelerator, optimizer).
"""
if not gold_sql_texts:
return
accelerator = getattr(trainer, "accelerator", None)
optimizer = getattr(trainer, "optimizer", None)
if accelerator is None or optimizer is None:
return
pad_id = int(tokenizer.pad_token_id)
enc_ids, enc_attn = _pad_2d(query_tensors, pad_id)
# Tokenize gold SQL targets (decoder side)
gold_ids = []
for s in gold_sql_texts:
g = (s or "").strip()
if not g:
g = "SELECT 1"
ids = tokenizer.encode(g, add_special_tokens=False)[:256]
if tokenizer.eos_token_id is not None:
ids = ids + [int(tokenizer.eos_token_id)]
gold_ids.append(torch.tensor(ids, dtype=torch.long))
dec_ids, dec_attn = _pad_2d(gold_ids, pad_id)
labels = dec_ids.clone()
labels[dec_attn == 0] = -100
# PEFT model forward supports labels -> returns loss
out = model.pretrained_model(
input_ids=enc_ids,
attention_mask=enc_attn,
labels=labels,
use_cache=False,
)
loss = out.loss * float(weight)
optimizer.zero_grad(set_to_none=True) if hasattr(optimizer, "zero_grad") else None
accelerator.backward(loss)
optimizer.step()
def _curriculum_allows(gold_sql: str, epoch_num: int) -> bool:
gold_up = (gold_sql or "").upper()
has_join = "JOIN" in gold_up
has_set_op = any(op in gold_up for op in ["UNION", "INTERSECT", "EXCEPT"])
tables = extract_tables(gold_sql)
single_table = len(tables) <= 1 and (not has_join)
# Epoch 1: only single-table, no joins/set-ops.
if epoch_num == 1:
return single_table and (not has_set_op)
# Epoch 2: allow joins, but still avoid set-ops.
if epoch_num == 2:
return (single_table or has_join) and (not has_set_op)
# Epoch 3+: full dataset.
return True
for epoch in range(1, NUM_EPOCHS + 1):
use_schema_this_epoch = USE_SCHEMA and (epoch > SCHEMA_WARMUP_EPOCHS)
epoch_reward_sum = 0
negative_rewards = 0
partial_rewards = 0
correct_rewards = 0
total_considered = 0
valid_sql_count = 0
exec_correct_count = 0
table_overlap_sum = 0.0
column_overlap_sum = 0.0
kl_values = []
for step in range(1, total_steps + 1):
example = dataset[random.randrange(len(dataset))]
question = example["question"]
gold_sql = example["query"]
db_id = example["db_id"]
db_path = get_db_path(db_id)
# NOTE: sampling-with-replacement provides more rollouts per epoch.
schema = get_db_schema(db_path)
question_text = build_prompt(question, schema, use_schema_this_epoch)
query_tensor = encode_prompt(question, schema, use_schema_this_epoch)
# ----- generate -----
table_names, column_names = get_db_tables_columns(db_path)
allowed_ids = _schema_allowed_token_ids(table_names, column_names)
logits_processor = LogitsProcessorList([SQLVocabularyLogitsProcessor(allowed_ids)])
response = trainer.generate([query_tensor], logits_processor=logits_processor, **generation_kwargs)[0]
response_tensor = response.squeeze(0)[:MAX_OUTPUT_TOKENS]
pred_sql = tokenizer.decode(response_tensor.cpu(), skip_special_tokens=True)
total_considered += 1
# PPO must optimize ONLY when SQL parses successfully.
if not _is_parsable_sql(pred_sql):
negative_rewards += 1
continue
# Reject generations shorter than 6 tokens.
if int(response_tensor.numel()) < 6:
negative_rewards += 1
continue
# ----- reward -----
reward_value = execution_reward(pred_sql, db_path, gold_sql)
# SQL validity gate: if invalid/unparsable -> reward_value is None -> skip PPO entirely.
if reward_value is None:
if step % 100 == 0:
ratio = valid_sql_count / max(total_considered, 1)
print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})")
if ratio < 0.15:
print("MODEL COLLAPSING")
continue
# Clip rewards to [-1, 1]
reward_value = float(max(-1.0, min(1.0, reward_value)))
# Penalize repetition in decoded output (token-level heuristic).
reward_value = float(max(-1.0, min(1.0, reward_value - 0.2 * _repeat_penalty(response_tensor))))
# Keep rewards on CPU for normalization; move to device only for trainer.step().
reward_tensor = torch.tensor(reward_value, dtype=torch.float32)
epoch_reward_sum += reward_value
# ----- metrics -----
# "Valid sample" means reward is not None (parsable SQL).
valid_sql_count += 1
pred_tables = extract_tables(pred_sql)
gold_tables = extract_tables(gold_sql)
pred_cols = extract_columns(pred_sql)
gold_cols = extract_columns(gold_sql)
if len(gold_tables) > 0:
table_overlap_sum += len(pred_tables & gold_tables) / max(len(gold_tables), 1)
if len(gold_cols) > 0:
column_overlap_sum += len(pred_cols & gold_cols) / max(len(gold_cols), 1)
# execution_reward returns 1.0 for correct execution result.
if reward_value >= 1.0:
exec_correct_count += 1
if reward_value <= -1.0:
negative_rewards += 1
elif reward_value >= 1.0:
correct_rewards += 1
else:
partial_rewards += 1
# Train only on informative samples:
# - invalid SQL already skipped (reward is None)
# - very small magnitude signal skipped
if abs(reward_value) < 0.1:
continue
query_buffer.append(query_tensor)
response_buffer.append(response_tensor)
reward_buffer.append(reward_tensor)
gold_buffer.append(gold_sql)
query_text_buffer.append(question_text)
# ----- PPO update -----
if len(query_buffer) == ppo_config.batch_size:
# move rewards to device
reward_buffer = [r.to(device) for r in reward_buffer]
# run PPO step
stats = trainer.step(query_buffer, response_buffer, reward_buffer)
# log KL safely (no control logic)
kl = safe_get_kl(stats)
if kl is not None:
kl_values.append(kl)
# --- supervised anchor to prevent grammar collapse ---
supervised_anchor_step(model, tokenizer, query_text_buffer, gold_buffer, weight=0.05)
optimizer.step()
optimizer.zero_grad()
# reset buffers
query_buffer, response_buffer, reward_buffer, gold_buffer = [], [], [], []
query_text_buffer = []
# ----- learning ratio logging -----
if step % 100 == 0:
ratio = valid_sql_count / max(total_considered, 1)
print(f"\nLearning ratio: {valid_sql_count}/{total_considered} ({ratio:.3f})")
if ratio < 0.15:
print("MODEL COLLAPSING")
# Increase KL coefficient dynamically when valid_sql_rate drops.
try:
if hasattr(trainer, "kl_ctl") and hasattr(trainer.kl_ctl, "value"):
trainer.kl_ctl.value *= 1.5
print(f"Increasing KL coef -> {trainer.kl_ctl.value:.4f}")
except Exception:
pass
# ----- logging -----
if step % LOG_EVERY == 0:
avg_reward = epoch_reward_sum / step
print("\n---------------------------")
print(f"Epoch {epoch}/{NUM_EPOCHS} | Step {step}/{total_steps} | Avg Reward {avg_reward:.3f}")
print("DB:", db_id)
print("Q:", question)
print("SQL:", pred_sql)
print("Reward:", reward_value)
# epoch stats
print(f"\nEpoch {epoch} stats:")
print("negative:", negative_rewards)
print("partial:", partial_rewards)
print("correct:", correct_rewards)
denom = max(total_considered, 1)
print("\nEpoch metrics:")
print(f"execution_accuracy: {exec_correct_count/denom:.3f}")
print(f"valid_sql_rate: {valid_sql_count/denom:.3f}")
print(f"table_match_rate: {table_overlap_sum/denom:.3f}")
print(f"column_match_rate: {column_overlap_sum/denom:.3f}")
print(f"avg_reward: {epoch_reward_sum/max(denom,1):.3f}")
if kl_values:
avg_kl = sum(kl_values) / max(len(kl_values), 1)
print(f"avg_kl: {avg_kl:.3f}")
if avg_kl < -8:
print("\nKL collapse guard triggered (avg_kl < -8). Stopping early.")
break
# 🎯 FIXED: Removed the code that saved intermediate checkpoints at the end of each epoch
# Only save if this epoch is the best one so far
epoch_avg_reward = epoch_reward_sum / max(denom, 1)
if epoch_avg_reward > best_reward:
best_reward = epoch_avg_reward
best_epoch = epoch
print(f"\nNew best model at epoch {epoch} with reward {best_reward:.4f}")
# 🎯 FIXED: Save directly to checkpoints/rlhf_t5_best, overwriting if needed
os.makedirs(output_dir, exist_ok=True)
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"\nTraining finished.")
print(f"Best epoch: {best_epoch}")
print(f"Best reward: {best_reward:.4f}")
print(f"Best model saved at: {output_dir}")