|
|
|
|
|
|
|
|
| import sys
|
| import os
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
| import os
|
| import torch
|
| import numpy as np
|
| from tqdm import tqdm
|
| from FastChemTokenizerHF import FastChemTokenizerSelfies
|
| from ChemQ3MTP import ChemQ3MTPForCausalLM
|
| from ChemQ3MTP.rl_utils import CurriculumManager, AdaptiveKLController, batch_compute_rewards, compute_ppo_loss, compute_kl_divergence, compute_entropy_bonus, compute_kl_penalty
|
|
|
| def main():
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"🚀 Using device: {device}")
|
|
|
|
|
| tokenizer = FastChemTokenizerSelfies.from_pretrained("../selftok_core")
|
|
|
|
|
| model = ChemQ3MTPForCausalLM.from_pretrained("./checkpoint-1635")
|
| model.tokenizer = tokenizer
|
| model.to(device)
|
|
|
|
|
| print("\n🎯 Phase 2: RL Fine-tuning with PPO + Curriculum Learning")
|
| model.set_mtp_training(False)
|
|
|
|
|
| kl_controller = AdaptiveKLController(
|
| init_kl_coef=0.1,
|
| target_kl=0.01,
|
| horizon=100,
|
| max_kl_coef=100.0,
|
| ema_alpha=0.9,
|
| kl_penalty_cap=10.0
|
| )
|
| model.kl_controller = kl_controller
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
|
| curriculum = CurriculumManager(start_len=10, max_len=25, step_increase=5, steps_per_level=100)
|
| baseline = None
|
| gamma = 0.95
|
|
|
|
|
| batch_size = 4
|
| dummy_input = tokenizer([tokenizer.bos_token] * batch_size, return_tensors="pt", padding=True)
|
| input_ids = dummy_input.input_ids.to(device)
|
|
|
|
|
| total_steps = 2500
|
| checkpoint_steps = {total_steps // 4, total_steps // 2, 3 * total_steps // 4, total_steps}
|
| checkpoint_dir = "./ppo_checkpoints_test"
|
| os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
|
| for step in tqdm(range(total_steps), desc="RL Training"):
|
| global_step = step
|
| max_new_tokens = curriculum.get_max_new_tokens()
|
|
|
|
|
| with torch.no_grad():
|
| selfies_list, old_log_probs, _, old_action_probs = model.generate_with_logprobs(
|
| input_ids=input_ids,
|
| max_new_tokens=max_new_tokens,
|
| temperature=1.0,
|
| top_k=50,
|
| top_p=0.95,
|
| do_sample=True,
|
| return_probs=True
|
| )
|
| old_log_probs = old_log_probs.detach()
|
| old_action_probs = old_action_probs.detach()
|
|
|
|
|
| selfies_list, new_log_probs, token_ids, new_action_probs = model.generate_with_logprobs(
|
| input_ids=input_ids,
|
| max_new_tokens=max_new_tokens,
|
| temperature=1.0,
|
| top_k=50,
|
| top_p=0.95,
|
| do_sample=True,
|
| return_probs=True,
|
| tokenizer=tokenizer,
|
| )
|
|
|
|
|
| rewards_dict = batch_compute_rewards(
|
| selfies_list=selfies_list,
|
| reward_mode="chemq3",
|
| )
|
| rewards = rewards_dict["total_rewards"].to(device)
|
|
|
|
|
| ppo_loss, advantage = compute_ppo_loss(
|
| old_log_probs=old_log_probs,
|
| new_log_probs=new_log_probs,
|
| rewards=rewards,
|
| clip_epsilon=0.2,
|
| baseline=baseline
|
| )
|
|
|
|
|
|
|
|
|
| kl_div = compute_kl_divergence(old_action_probs, new_action_probs)
|
| kl_mean = kl_div.mean().item()
|
|
|
|
|
| kl_controller.update(kl_mean, n_steps=global_step)
|
| beta = kl_controller()
|
|
|
|
|
| kl_penalty, raw_kl_penalty, kl_mean_tensor = compute_kl_penalty(
|
| kl_div, beta, kl_controller.kl_penalty_cap
|
| )
|
|
|
|
|
| logs = {}
|
| logs["kl_mean"] = kl_mean_tensor.item()
|
| logs["kl_beta"] = beta
|
| logs["kl_penalty_raw"] = raw_kl_penalty.item()
|
| logs["kl_penalty_clipped"] = kl_penalty.item()
|
|
|
|
|
|
|
| entropy_per_example = compute_entropy_bonus(new_action_probs)
|
| entropy = entropy_per_example.mean()
|
|
|
|
|
| adaptive_entropy_weight = model.entropy_controller.update_entropy_weight(entropy.item())
|
| entropy_bonus = adaptive_entropy_weight * entropy
|
|
|
|
|
| total_policy_loss = ppo_loss + kl_penalty
|
| total_loss = total_policy_loss - entropy_bonus
|
|
|
|
|
| reg_loss = 1e-7 * sum(p.pow(2).sum() for p in model.parameters())
|
| total_loss = total_loss + reg_loss
|
|
|
|
|
| optimizer.zero_grad(set_to_none=True)
|
| total_loss.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| optimizer.step()
|
|
|
|
|
| reward_tensor = rewards.mean()
|
| baseline = reward_tensor if baseline is None else gamma * baseline + (1 - gamma) * reward_tensor
|
|
|
|
|
| curriculum.step()
|
|
|
|
|
| if (step + 1) in checkpoint_steps:
|
| checkpoint_path = os.path.join(checkpoint_dir, f"model_step_{step+1}")
|
| model.save_pretrained(checkpoint_path)
|
| tokenizer.save_pretrained(checkpoint_path)
|
| torch.save({
|
| 'step': step + 1,
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'baseline': baseline.item(),
|
| 'curriculum_state': {
|
| 'current_max_len': curriculum.current_max_len,
|
| 'step_counter': curriculum.step_counter
|
| }
|
| }, os.path.join(checkpoint_path, 'training_state.pt'))
|
| print(f"\n💾 Checkpoint saved at step {step+1} -> {checkpoint_path}")
|
|
|
|
|
| if step % 50 == 0:
|
|
|
| validity_count = 0
|
| for selfies in selfies_list[:10]:
|
| from ChemQ3MTP.rl_utils import selfies_to_smiles
|
| smiles = selfies_to_smiles(selfies)
|
| if smiles and smiles != "":
|
| validity_count += 1
|
| validity_rate = validity_count / max(1, min(10, len(selfies_list)))
|
|
|
|
|
| lipinski_scores = []
|
| for selfies in selfies_list[:10]:
|
| from rdkit import Chem
|
| smiles = selfies_to_smiles(selfies)
|
| mol = Chem.MolFromSmiles(smiles) if smiles else None
|
| if mol:
|
| from rdkit.Chem import Lipinski, Descriptors
|
| mw = Descriptors.MolWt(mol)
|
| logp = Descriptors.MolLogP(mol)
|
| hbd = Lipinski.NumHDonors(mol)
|
| hba = Lipinski.NumHAcceptors(mol)
|
| rules = [250 < mw <= 500, logp <= 5, hbd <= 5, hba <= 10]
|
| lipinski_score = sum(rules) / 4.0
|
| lipinski_scores.append(lipinski_score)
|
| lipinski_score = np.mean(lipinski_scores) if lipinski_scores else 0.0
|
|
|
|
|
| avg_sa_reward = rewards_dict.get("sa_rewards", rewards).mean().item() if "sa_rewards" in rewards_dict else rewards.mean().item()
|
|
|
| log_line = (
|
| f"\n[RL Step {step}] "
|
| f"Loss={total_loss.item():.4f} | "
|
| f"Valid={validity_rate:.3f} | "
|
| f"Lipinski={lipinski_score:.3f} | "
|
| f"Reward={rewards.mean().item():.3f} | "
|
| f"Entropy={entropy.item():.3f} | "
|
| f"EntropyW={adaptive_entropy_weight:.4f} | "
|
| f"KL_Beta={beta:.4f} | "
|
| f"KL_Mean={kl_mean:.4f}"
|
| )
|
| if avg_sa_reward is not None:
|
| log_line += f" | SA={avg_sa_reward:.3f}"
|
| print(log_line)
|
|
|
|
|
| sample_selfies = selfies_list[0][:100]
|
| sample_smiles = selfies_to_smiles(selfies_list[0]) or "Invalid"
|
| print(f" Sample SELFIES: {sample_selfies}")
|
| print(f" Sample SMILES: {sample_smiles}")
|
|
|
| print("🎉 Training complete!")
|
|
|
| if __name__ == "__main__":
|
| main() |