|
|
|
|
|
|
| import os
|
| import torch
|
| from tqdm import tqdm
|
| from FastChemTokenizerHF import FastChemTokenizerSelfies
|
| from ChemQ3MTP import ChemQ3MTP, CurriculumManager
|
|
|
|
|
| def main():
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"π Using device: {device}")
|
|
|
|
|
| tokenizer = FastChemTokenizerSelfies.from_pretrained("../selftok_core")
|
|
|
|
|
| model = ChemQ3MTP.from_pretrained("./model_step_7000")
|
| model.tokenizer = tokenizer
|
| model.to(device)
|
|
|
|
|
| print("\nπ― Phase 2: RL Fine-tuning with PPO + Curriculum Learning")
|
| model.set_mtp_training(False)
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
|
| curriculum = CurriculumManager(start_len=10, max_len=35, step_increase=5, steps_per_level=70)
|
| baseline = None
|
| gamma = 0.95
|
|
|
|
|
| batch_size = 4
|
| dummy_input = tokenizer([tokenizer.bos_token] * batch_size, return_tensors="pt", padding=True)
|
| input_ids = dummy_input.input_ids.to(device)
|
|
|
|
|
| total_steps = 1000
|
| checkpoint_steps = {total_steps // 4, total_steps // 2, 3 * total_steps // 4, total_steps}
|
| checkpoint_dir = "./ppo_checkpoints_test"
|
| os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
|
| for step in tqdm(range(total_steps), desc="RL Training"):
|
| max_new_tokens = curriculum.get_max_new_tokens()
|
|
|
|
|
| with torch.no_grad():
|
| selfies_list, old_log_probs, _, old_action_probs = model.generate_with_logprobs(
|
| input_ids=input_ids,
|
| max_new_tokens=max_new_tokens,
|
| temperature=1.0,
|
| top_k=50,
|
| top_p=0.95,
|
| do_sample=True,
|
| return_probs=True
|
| )
|
| old_log_probs = old_log_probs.detach()
|
| old_action_probs = old_action_probs.detach()
|
|
|
|
|
| ppo_result = model.ppo_step(
|
| input_ids=input_ids,
|
| old_log_probs=old_log_probs,
|
| old_action_probs=old_action_probs,
|
| tokenizer=tokenizer,
|
| max_new_tokens=max_new_tokens,
|
|
|
|
|
| entropy_weight=0.01,
|
| clip_epsilon=0.2,
|
| baseline=baseline,
|
| reward_mode="sa",
|
| )
|
|
|
|
|
|
|
| loss = ppo_result['loss']
|
| optimizer.zero_grad(set_to_none=True)
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| optimizer.step()
|
|
|
|
|
| reward_tensor = torch.tensor(ppo_result['avg_reward'], device=device)
|
| baseline = reward_tensor if baseline is None else gamma * baseline + (1 - gamma) * reward_tensor
|
|
|
|
|
| curriculum.step()
|
|
|
|
|
| if (step + 1) in checkpoint_steps:
|
| checkpoint_path = os.path.join(checkpoint_dir, f"model_step_{step+1}")
|
| model.save_pretrained(checkpoint_path)
|
| tokenizer.save_pretrained(checkpoint_path)
|
| torch.save({
|
| 'step': step + 1,
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'baseline': baseline.item(),
|
| 'curriculum_state': {
|
| 'current_max_len': curriculum.current_max_len,
|
| 'step_counter': curriculum.step_counter
|
| }
|
| }, os.path.join(checkpoint_path, 'training_state.pt'))
|
| print(f"\nπΎ Checkpoint saved at step {step+1} -> {checkpoint_path}")
|
|
|
|
|
| if step % 50 == 0:
|
| log_line = (
|
| f"\n[RL Step {step}] "
|
| f"Loss={loss.item():.4f} | "
|
| f"Valid={ppo_result['validity_rate']:.3f} | "
|
| f"Lipinski={ppo_result['lipinski_score']:.3f} | "
|
| f"Reward={ppo_result['avg_reward']:.3f} | "
|
| f"Entropy={ppo_result['entropy']:.3f} | "
|
| f"EntropyW={ppo_result['entropy_weight']:.4f}"
|
| )
|
| if ppo_result.get("avg_sa_reward") is not None:
|
| log_line += f" | SA={ppo_result['avg_sa_reward']:.3f}"
|
| print(log_line)
|
|
|
| sample_selfies = ppo_result['generated_selfies'][0][:100]
|
| sample_smiles = ppo_result['generated_smiles'][0] or "Invalid"
|
| print(f" Sample SELFIES: {sample_selfies}")
|
| print(f" Sample SMILES: {sample_smiles}")
|
|
|
|
|
|
|
| print("π Training complete!")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|