| import os |
| import re |
|
|
| import pathlib |
| from argparse import ArgumentParser |
| from typing import List, Dict, Optional |
| from dataclasses import dataclass, field |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader, Dataset |
| from transformers import get_cosine_schedule_with_warmup, AutoTokenizer |
|
|
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| AutoModelForMaskedLM, |
| AutoProcessor, |
| ) |
|
|
| from datasets import load_dataset, DatasetDict |
|
|
| from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training |
| from transformers import BitsAndBytesConfig |
|
|
| import pytorch_lightning as pl |
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
| from pytorch_lightning.loggers import WandbLogger |
|
|
| from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config |
| |
|
|
| from bioreason.models.dna_llm import DNALLMModel |
| from bioreason.dna_modules import NucleotideDNAModule |
| from bioreason.models.dl.processing_dl import DLProcessor |
| from bioreason.trainer import DNALLMGRPOTrainer, DNALLMGRPOConfig |
| from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer |
| register_evo2_tokenizer() |
|
|
| |
| from transformers import TrainerCallback, TrainerState, TrainerControl |
| from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
|
| class SaveWithPyTorchCallback(TrainerCallback): |
| """Custom callback to save models with PyTorch's native save mechanism instead of safetensors""" |
| def on_save(self, args, state, control, **kwargs): |
| |
| checkpoint_folder = os.path.join( |
| args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" |
| ) |
| os.makedirs(checkpoint_folder, exist_ok=True) |
| |
| |
| checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin") |
| model = kwargs.get("model") |
| |
| |
| unwrapped_model = model.module if hasattr(model, "module") else model |
| |
| |
| torch.save(unwrapped_model.state_dict(), checkpoint_path) |
| |
| |
| |
| if hasattr(unwrapped_model, "text_model"): |
| if hasattr(unwrapped_model.text_model, "config"): |
| unwrapped_model.text_model.config.save_pretrained(checkpoint_folder) |
| |
| elif hasattr(unwrapped_model.text_model, "base_model") and hasattr(unwrapped_model.text_model.base_model, "config"): |
| unwrapped_model.text_model.base_model.config.save_pretrained(checkpoint_folder) |
| |
| |
| print(f"Saved model checkpoint to {checkpoint_folder}") |
| lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k] |
| print(f"Checkpoint contains {len(lora_params)} LoRA parameters") |
| |
| |
| control.should_save = False |
| return control |
|
|
| def _get_target_modules(model: DNALLMModel): |
| |
| target_modules = [] |
|
|
| |
| seen_names = set() |
| for name, module in model.text.named_modules(): |
| if isinstance(module, torch.nn.Linear): |
| names = name.split(".") |
| target_name = names[-1] |
|
|
| |
| if target_name != "lm_head" and target_name not in seen_names: |
| target_modules.append(target_name) |
| seen_names.add(target_name) |
|
|
| |
| attention_patterns = [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "out_proj", |
| "query", |
| "key", |
| "value", |
| ] |
| for pattern in attention_patterns: |
| if pattern not in seen_names: |
| target_modules.append(pattern) |
|
|
| |
| return list(target_modules) |
|
|
|
|
| def extract_xml_answer(text: str) -> str: |
| |
| |
| answer = text.split("</think>")[-1] |
| return answer.strip() |
|
|
| def extract_hash_answer(text: str) -> str | None: |
| if "####" not in text: |
| return None |
| return text.split("####")[1].strip() |
|
|
| def get_kegg_questions() -> Dataset: |
| data = load_dataset('wanglab/kegg', 'default') |
| example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] |
| num_dna_sequences = 2 |
|
|
| data = data.map(lambda x: { |
| 'prompt': [ |
| |
| { |
| 'role': 'user', |
| 'content': [ |
| *({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)), |
| {'type': 'text', 'text': x['question']}, |
| ], |
| }, |
| ], |
| 'dna_sequences': [x['reference_sequence'], x['variant_sequence']], |
| 'answer': x['answer'], |
| }) |
|
|
| return data |
|
|
| |
| def get_gsm8k_questions(question_prompt: str) -> Dataset: |
| data = load_dataset('openai/gsm8k', 'main') |
|
|
| example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] |
| data = data.map(lambda x: { |
| 'prompt': [ |
|
|
| { |
| 'role': 'user', |
| 'content': [ |
| *({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))), |
| {'type': 'text', 'text': 'Give me a short introduction to large language model.'} |
| ] |
| }, |
| ], |
| 'dna_sequences': [dna for dna in example_dna_sequences], |
| 'answer': extract_hash_answer(x['answer']), |
| }) |
| |
| return data |
|
|
| def get_gsm8k_questions_old(question_prompt: str) -> Dataset: |
| data = load_dataset('openai/gsm8k', 'main') |
|
|
| example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] |
| data = data.map(lambda x: { |
| 'prompt': [ |
| { |
| 'role': 'user', |
| 'content': [ |
| *({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))), |
| {'type': 'text', 'text': question_prompt.format(Question=x['question'])} |
| ] |
| }, |
| ], |
| 'dna_sequences': [dna for dna in example_dna_sequences], |
| 'answer': extract_hash_answer(x['answer']), |
| }) |
| |
| return data |
|
|
| |
| def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: |
| responses = [completion[0]['content'] for completion in completions] |
| q = prompts[0][-1]['content'] |
| extracted_responses = [extract_xml_answer(r) for r in responses] |
| |
| print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") |
| return [2.0 if a.lower() in r.lower() else 0.0 for r, a in zip(extracted_responses, answer[0])] |
|
|
| def less_than_4_reward_func(completions, **kwargs) -> list[float]: |
| responses = [completion[0]['content'] for completion in completions] |
| extracted_responses = [extract_xml_answer(r) for r in responses] |
| return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses] |
|
|
| def strict_format_reward_func(completions, **kwargs) -> list[float]: |
| """Reward function that checks if the completion has a specific format.""" |
| pattern = r"^<think>\n.*?\n</think>\n.*?\n$" |
| responses = [completion[0]["content"] for completion in completions] |
| matches = [re.match(pattern, r) for r in responses] |
| return [0.5 if match else 0.0 for match in matches] |
|
|
| def soft_format_reward_func(completions, **kwargs) -> list[float]: |
| """Reward function that checks if the completion has a specific format.""" |
| pattern = r"<think>.*?</think>\s*.*?" |
| responses = [completion[0]["content"] for completion in completions] |
| matches = [re.match(pattern, r) for r in responses] |
| return [0.5 if match else 0.0 for match in matches] |
|
|
| def count_xml(text) -> float: |
| count = 0.0 |
| if text.count("<think>\n") == 1: |
| count += 0.125 |
| if text.count("\n</think>\n") == 1: |
| count += 0.125 |
| return count |
|
|
| def xmlcount_reward_func(completions, **kwargs) -> list[float]: |
| contents = [completion[0]["content"] for completion in completions] |
| return [count_xml(c) for c in contents] |
|
|
| |
| def make_conversation(example): |
| return { |
| "prompt": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": example["problem"]}, |
| ], |
| } |
|
|
| def make_conversation_image(example): |
| return { |
| "prompt": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image"}, |
| ], |
| }, |
| ], |
| } |
|
|
| @dataclass |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| class GRPOModelConfig(ModelConfig): |
|
|
| model_name_or_path: str = field(default="Qwen/Qwen3-0.6B", metadata={"help": "Model checkpoint for LLM weights initialization."}) |
| protein_model_name_or_path: str = field(default="esm2_t33_650M_UR50D", metadata={"help": "Model checkpoint for ESM-2 protein weights initialization."}) |
| cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."}) |
| max_length_text: int = field(default=800, metadata={"help": "Maximum length of text sequences."}) |
| max_length_protein: int = field(default=800, metadata={"help": "Maximum length of protein sequences (number of amino acids)."}) |
| sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."}) |
| lora_r: int = field(default=32, metadata={"help": "LoRA R value."}) |
| lora_alpha: int = field(default=64, metadata={"help": "LoRA alpha."}) |
| lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout."}) |
| lora_modules_to_save: Optional[list[str]] = field( |
| default_factory=lambda: ["embed_tokens", "lm_head"], |
| metadata={"help": "Model layers to unfreeze & train with LoRA."}, |
| ) |
| |
| |
| freeze_protein_model: bool = field(default=True, metadata={"help": "Whether to freeze the ESM-2 protein model during training."}) |
| num_query_tokens: int = field(default=32, metadata={"help": "The number of query tokens used by the Q-Former to summarize protein features. These tokens will be injected into the LLM input."}) |
| |
| projector_hidden_size: int = field(default=1280, metadata={"help": "Hidden size of the projector layer. It should match the ESM-2's output hidden size."}) |
| projector_output_size: int = field(default=1024, metadata={"help": "Output size of the projector layer. It should match the LLM's hidden size."}) |
| |
| |
| freeze_projector: bool = field(default=False, metadata={"help": "Whether to freeze the projector layer during training."}) |
|
|
| @dataclass |
| class GRPOScriptArguments(ScriptArguments): |
| """ |
| Script arguments for the GRPO training script. |
| """ |
| dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."}) |
| data_file_paths: str = field( |
| default=None, |
| metadata={"help": "Paths to data files, separated by ':'"}, |
| ) |
| arrow_cache_dir: str = field( |
| default=None, |
| metadata={"help": "Path to arrow cache directory"}, |
| ) |
| val_split_ratio: float = field( |
| default=0.0, |
| metadata={"help": "Ratio of validation split, default 0.0"}, |
| ) |
| reward_funcs: list[str] = field( |
| |
| default_factory=lambda: ["xmlcount", "soft_format", "strict_format", "less_than_4", "correctness"], |
| |
| metadata={"help": "List of reward functions. Possible values: 'accuracy', 'xmlcount', 'soft_format', 'strict_format', 'less_than_4', 'correctness'"}, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| reward_funcs_registry = { |
| |
| |
| "xmlcount": xmlcount_reward_func, |
| "soft_format": soft_format_reward_func, |
| "strict_format": strict_format_reward_func, |
| "less_than_4": less_than_4_reward_func, |
| "correctness": correctness_reward_func, |
| } |
|
|
| def get_vlm_module(model_name_or_path): |
| if any(mini_name in model_name_or_path.lower() for mini_name in ["qwen", "smol"]): |
| return NucleotideDNAModule |
| else: |
| raise ValueError(f"Unsupported model: {model_name_or_path}") |
| |
| def _get_target_modules(model): |
| |
| target_modules = [] |
|
|
| |
| seen_names = set() |
| for name, module in model.text_model.named_modules(): |
| if isinstance(module, torch.nn.Linear): |
| names = name.split(".") |
| target_name = names[-1] |
|
|
| |
| if target_name != "lm_head" and target_name not in seen_names: |
| target_modules.append(target_name) |
| seen_names.add(target_name) |
|
|
| |
| attention_patterns = [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "out_proj", |
| "query", |
| "key", |
| "value", |
| ] |
| for pattern in attention_patterns: |
| if pattern not in seen_names: |
| target_modules.append(pattern) |
|
|
| |
| return list(target_modules) |
|
|
|
|
| def _prep_for_training(model, training_args, dna_model_finetune: bool = False) -> LoraConfig: |
| """ |
| Load and configure the DNALLMModel. |
| """ |
|
|
| |
| if dna_model_finetune: |
| pass |
| else: |
| for param in model.dna_model.parameters(): |
| param.requires_grad = False |
|
|
| target_modules = _get_target_modules(model) |
|
|
| lora_config = LoraConfig( |
| r=training_args.lora_r, |
| lora_alpha=training_args.lora_alpha, |
| lora_dropout=training_args.lora_dropout, |
| target_modules=target_modules, |
| init_lora_weights="gaussian", |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
|
|
| |
| model.text_model = prepare_model_for_kbit_training(model.text_model) |
| model.text_model = get_peft_model(model.text_model, lora_config) |
|
|
| |
| for param in model.dna_projection.parameters(): |
| param.requires_grad = True |
|
|
| return lora_config |
|
|
| def main(script_args, training_args, model_args): |
|
|
| print(training_args.output_dir) |
| |
| |
| torch.cuda.empty_cache() |
| torch.set_float32_matmul_precision("medium") |
|
|
| |
| |
| |
| |
|
|
| |
| model = DNALLMModel( |
| text_model_name=model_args.model_name_or_path, |
| dna_model_name=model_args.dna_model_name_or_path, |
| cache_dir=model_args.cache_dir, |
| max_length_text=model_args.max_length_text, |
| max_length_dna=model_args.max_length_dna, |
| text_model_finetune=True, |
| dna_model_finetune=not model_args.freeze_dna_modules, |
| debug=False, |
| ) |
|
|
| |
| if model_args.sft_checkpoint is not None: |
| print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}") |
| |
| |
| is_directory = os.path.isdir(model_args.sft_checkpoint) |
| |
| if is_directory: |
| |
| from peft import PeftModel |
| |
| |
| print("Loading as PEFT checkpoint directory") |
| model.text_model = PeftModel.from_pretrained( |
| model.text_model, |
| model_args.sft_checkpoint, |
| is_trainable=True |
| ) |
| |
| |
| print("Loaded LoRA adapters:", model.text_model.active_adapter) |
| |
| |
| print("Merging SFT LoRA weights into base model...") |
| model.text_model = model.text_model.merge_and_unload() |
| print("Successfully merged SFT knowledge into base model") |
| |
| else: |
| |
| print("Loading as PyTorch state dict file") |
| checkpoint = torch.load(model_args.sft_checkpoint) |
| |
| |
| def new_key(k): |
| if k.startswith("=model."): return k[6:] |
| elif k.startswith("_forward_module."): return k[len("_forward_module."):] |
| else: return k |
| |
| if "state_dict" in checkpoint: |
| magic = {new_key(k): v for k, v in checkpoint["state_dict"].items()} |
| elif "module" in checkpoint: |
| magic = {new_key(k): v for k, v in checkpoint["module"].items()} |
| elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()): |
| |
| print("Detected direct state dict format") |
| magic = {new_key(k): v for k, v in checkpoint.items()} |
| else: |
| raise ValueError(f"Unsupported checkpoint format: {model_args.sft_checkpoint}") |
| |
| |
| lora_prefix = False |
| for key in magic.keys(): |
| if "lora" in key: |
| lora_prefix = True |
| break |
| |
| if lora_prefix: |
| print("Detected LoRA weights in state dict") |
| |
| _prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules) |
| |
| |
| model_keys = set(model.state_dict().keys()) |
| checkpoint_keys = set(magic.keys()) |
| print(f"Model has {len(model_keys)} keys") |
| print(f"Checkpoint has {len(checkpoint_keys)} keys") |
| |
| |
| new_magic = {} |
| for k, v in magic.items(): |
| |
| if "base_model.model" in k and k not in model_keys: |
| new_k = k.replace("text_model.base_model.model", "text_model") |
| if new_k in model_keys: |
| new_magic[new_k] = v |
| continue |
| |
| |
| if k.startswith("text_model.") and k not in model_keys: |
| new_k = "text_model.base_model.model." + k[len("text_model."):] |
| if new_k in model_keys: |
| new_magic[new_k] = v |
| continue |
| |
| |
| new_magic[k] = v |
| |
| |
| magic = new_magic |
| print(f"After key mapping: {len(magic)} keys") |
| |
| |
| result = model.load_state_dict(magic, strict=False) |
| |
| if len(result.unexpected_keys) > 0: |
| print(f"Sample unexpected keys: {result.unexpected_keys[:5]}") |
| if len(result.missing_keys) > 0: |
| print(f"Sample missing keys: {result.missing_keys[:5]}") |
| |
| print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys") |
| else: |
| print("Standard weights detected - remapping keys") |
| |
| magic = {k.replace("text_model", "text_model.base_model.model"): v for k, v in magic.items()} |
| magic = {k.replace("dna_model", "dna_model"): v for k, v in magic.items()} |
| |
| |
| for key in list(magic.keys()): |
| if 'lm_head.weight' in key: |
| magic[key] = magic[key].clone() |
| |
| |
| result = model.load_state_dict(magic, strict=False) |
| print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys") |
| |
| |
| _prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules) |
| else: |
| |
| _prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules) |
|
|
| |
| reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] |
| |
| |
| |
| |
| |
| |
| |
| print("reward_funcs:", reward_funcs) |
|
|
| vlm_module_cls = get_vlm_module(model_args.model_name_or_path) |
| print("using vlm module:", vlm_module_cls.__name__) |
| question_prompt = vlm_module_cls.get_question_template() |
|
|
|
|
| dataset = get_kegg_questions() |
| |
| |
|
|
| print(dataset) |
|
|
| |
|
|
| |
| custom_save_callback = SaveWithPyTorchCallback() |
|
|
| |
| trainer = DNALLMGRPOTrainer( |
| model=model, |
| reward_funcs=reward_funcs, |
| args=training_args, |
| dna_module=vlm_module_cls(), |
| train_dataset=dataset['train'], |
| eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None, |
| peft_config=get_peft_config(model_args), |
| attn_implementation=model_args.attn_implementation, |
| torch_dtype=model_args.torch_dtype, |
| callbacks=[custom_save_callback], |
| ) |
|
|
| |
| training_args.save_safetensors = False |
|
|
| |
| |
| |
| |
| |
|
|
| |
| trainer.train() |
|
|
|
|
| if __name__ == "__main__": |
| |
| print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
| parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, GRPOModelConfig)) |
| script_args, training_args, model_args = parser.parse_args_and_config() |
| |
| |
| training_args.save_safetensors = False |
|
|
| main(script_args, training_args, model_args) |
| |
| |
| |
|
|
| |
|
|