import os import re import pathlib from argparse import ArgumentParser from typing import List, Dict, Optional from dataclasses import dataclass, field import torch from torch import nn import torch.nn.functional as F from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from transformers import get_cosine_schedule_with_warmup, AutoTokenizer from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM, AutoProcessor, ) from datasets import load_dataset, DatasetDict from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training from transformers import BitsAndBytesConfig import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.loggers import WandbLogger from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config #from unsloth import FastLanguageModel, is_bfloat16_supported from bioreason.models.dna_llm import DNALLMModel from bioreason.models.protein_llm import ProteinLLMModel from bioreason.dna_modules import NucleotideDNAModule from bioreason.models.dl.processing_dl import DLProcessor from bioreason.trainer import DNALLMGRPOTrainer, DNALLMGRPOConfig from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer register_evo2_tokenizer() # Custom TrainerCallback to override the saving mechanism from transformers import TrainerCallback, TrainerState, TrainerControl from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR class SaveWithPyTorchCallback(TrainerCallback): """Custom callback to save models with PyTorch's native save mechanism instead of safetensors""" def on_save(self, args, state, control, **kwargs): # Get the checkpoint folder checkpoint_folder = os.path.join( args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" ) os.makedirs(checkpoint_folder, exist_ok=True) # Save with PyTorch instead of safetensors checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin") model = kwargs.get("model") # Get model unwrapped from accelerator etc. unwrapped_model = model.module if hasattr(model, "module") else model # Save using PyTorch directly torch.save(unwrapped_model.state_dict(), checkpoint_path) # DNALLMModel doesn't have a direct config attribute, so we need to save # the configs of its sub-models if hasattr(unwrapped_model, "text_model"): if hasattr(unwrapped_model.text_model, "config"): unwrapped_model.text_model.config.save_pretrained(checkpoint_folder) # Handle PEFT models which might have base_model elif hasattr(unwrapped_model.text_model, "base_model") and hasattr(unwrapped_model.text_model.base_model, "config"): unwrapped_model.text_model.base_model.config.save_pretrained(checkpoint_folder) # Print info about what's being saved print(f"Saved model checkpoint to {checkpoint_folder}") lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k] print(f"Checkpoint contains {len(lora_params)} LoRA parameters") # Signal that we've saved control.should_save = False return control def _get_target_modules(model: ProteinLLMModel): # Apply LoRA to all linear layers in the text model target_modules = [] # Get all unique linear layer names seen_names = set() for name, module in model.text.named_modules(): if isinstance(module, torch.nn.Linear): names = name.split(".") target_name = names[-1] # Use the last part of the name # Skip output head but include all other linear layers if target_name != "lm_head" and target_name not in seen_names: target_modules.append(target_name) seen_names.add(target_name) # Add attention-specific layers commonly found in transformers attention_patterns = [ "q_proj", "k_proj", "v_proj", "o_proj", "out_proj", "query", "key", "value", "gate_proj", "up_proj", "down_proj", ] for pattern in attention_patterns: if pattern not in seen_names: target_modules.append(pattern) # Return all unique layer names to apply LoRA to all layers return list(target_modules) def extract_xml_answer(text: str) -> str: # answer = text.split("")[-1] # answer = answer.split("")[0] answer = text.split("")[-1] return answer.strip() def extract_hash_answer(text: str) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() def get_kegg_questions() -> Dataset: data = load_dataset('wanglab/kegg', 'default') # type: ignore # 修改为蛋白质序列示例 example_protein_sequences = ["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRISSKLLERGKTHYPPHTMVGTGVLVTKMRVAGQEPDVQGPHAGIVVQGAGDAPVVVKPVVEMLNRMVVVVSGSAAPVVVNNNNNGAAAAAAA", "MSQVQVQVQNQALNTLVKQLGRVLLQGKGRPPLQGFRIIEQNGGDSPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP"] num_protein_sequences = 2 data = data.map(lambda x: { # type: ignore 'prompt': [ { 'role': 'user', 'content': [ *({'type': 'protein', 'text': None} for _ in range(num_protein_sequences)), {'type': 'text', 'text': x['question']}, ], }, ], 'protein_sequences': [example_protein_sequences[0], example_protein_sequences[1]], # 使用蛋白质序列 'answer': x['answer'], }) # type: ignore return data # uncomment middle messages for 1-shot prompting # Reward functions def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] # extracted_responses = [r.lower().replace("answer:", "").strip() for r in extracted_responses] print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0 if a.lower() in r.lower() else 0.0 for r, a in zip(extracted_responses, answer[0])] def less_than_4_reward_func(completions, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses] def strict_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"^\n.*?\n\n.*?\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def soft_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r".*?\s*.*?" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def count_xml(text) -> float: count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 return count def xmlcount_reward_func(completions, **kwargs) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents] #补充奖励函数 def repeatness_reward(s: str): """计算文本重复度,返回值越高表示重复度越低""" def ranks(l): index = {v: i for i, v in enumerate(sorted(set(l)))} return [index[v] for v in l] def suffixArray(s): line = ranks(s) n, k, ans, sa = len(s), 1, line, [0] * len(s) while k < n - 1: line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1))) ans, k = line, k << 1 for i, k in enumerate(ans): sa[k] = i return ans, sa def lcp(arr, suffixArr, inv_suff): n, ans, k = len(arr), [0] * len(arr), 0 for i in range(n): if inv_suff[i] == n - 1: k = 0 continue j = suffixArr[inv_suff[i] + 1] while i + k < n and j + k < n and arr[i + k] == arr[j + k]: k += 1 ans[inv_suff[i]] = k if k > 0: k -= 1 return ans arr = [ord(i) for i in s] n = len(arr) if n <= 1: return 0 c, sa = suffixArray(arr) cnt = sum(lcp(arr, sa, c)) return 1 - cnt * 2 / (n * (n + 1)) def format_reward(predict_str: str) -> float: """ 格式奖励函数,严格要求输出格式为: ...... 中间不能有多余内容 """ pattern = r'^.*?\s*\s*.*?\s*$' return 1.0 if re.fullmatch(pattern, predict_str.strip(), re.DOTALL) else 0.0 def acc_reward(predict_str: str, ground_truth) -> float: """ 准确率奖励函数 要求中内容与ground_truth完全一致 """ match = re.search(r'\s*([^<]*?)\s*', predict_str) if not match: return 0.0 answer_content = match.group(1).strip() # 处理不同类型的ground_truth if isinstance(ground_truth, str): return 1.0 if answer_content == ground_truth else 0.0 elif isinstance(ground_truth, (int, float)): try: # 尝试将答案转换为数字进行比较 return 1.0 if float(answer_content) == float(ground_truth) else 0.0 except ValueError: # 如果转换失败,尝试字符串比较 return 1.0 if answer_content == str(ground_truth) else 0.0 else: # 其他类型,转换为字符串比较 return 1.0 if answer_content == str(ground_truth) else 0.0 # 包装函数以适配现有的奖励函数接口 def repeatness_reward_func(completions, **kwargs) -> list[float]: """重复度奖励函数包装器""" responses = [completion[0]['content'] for completion in completions] return [repeatness_reward(r) for r in responses] def format_reward_func(completions, **kwargs) -> list[float]: """格式奖励函数包装器""" responses = [completion[0]['content'] for completion in completions] return [format_reward(r) for r in responses] def acc_reward_func(prompts, completions, answer, **kwargs) -> list[float]: """准确率奖励函数包装器""" responses = [completion[0]['content'] for completion in completions] # 调试信息 print(f"DEBUG acc_reward_func - answer type: {type(answer)}, answer: {answer}") # 根据现有代码的模式,answer可能是一个嵌套结构 try: if isinstance(answer, list) and len(answer) > 0: # 如果answer[0]是一个列表,说明是批次数据 if isinstance(answer[0], list): ground_truths = answer[0] else: # 如果answer[0]是单个值,为所有响应使用相同的真实答案 ground_truths = [answer[0]] * len(responses) else: # 如果answer不是期望的格式,返回全0 print(f"DEBUG: Unexpected answer format, returning zeros") return [0.0] * len(responses) except (IndexError, TypeError) as e: print(f"DEBUG: Error processing answer: {e}, returning zeros") return [0.0] * len(responses) print(f"DEBUG: ground_truths: {ground_truths}") # 确保responses和ground_truths长度一致 rewards = [] for i, response in enumerate(responses): if i < len(ground_truths): reward = acc_reward(response, ground_truths[i]) print(f"DEBUG: response {i}: '{response[:100]}...', ground_truth: '{ground_truths[i]}', reward: {reward}") else: # 如果ground_truths不够长,使用第一个值 reward = acc_reward(response, ground_truths[0] if ground_truths else "") print(f"DEBUG: response {i} (fallback): reward: {reward}") rewards.append(reward) return rewards # # Format into conversation def make_conversation(example): return { "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": example["problem"]}, ], } def make_conversation_image(example): return { "prompt": [ { "role": "user", "content": [ {"type": "image"}, ], }, ], } @dataclass class GRPOModelConfig(ModelConfig): model_name_or_path: str = field(default="Qwen/Qwen3-0.6B", metadata={"help": "Model checkpoint for LLM weights initialization."}) protein_model_name_or_path: str = field(default="esm2_t33_650M_UR50D", metadata={"help": "Model checkpoint for ESM-2 protein weights initialization."}) cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."}) max_length_text: int = field(default=800, metadata={"help": "Maximum length of text sequences."}) max_length_protein: int = field(default=800, metadata={"help": "Maximum length of protein sequences (number of amino acids)."}) sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."}) lora_r: int = field(default=32, metadata={"help": "LoRA R value."}) lora_alpha: int = field(default=64, metadata={"help": "LoRA alpha."}) lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout."}) lora_modules_to_save: Optional[list[str]] = field( default_factory=lambda: ["embed_tokens", "lm_head"], metadata={"help": "Model layers to unfreeze & train with LoRA."}, ) # Updated: Renamed `freeze_dna_modules` to `freeze_protein_model` freeze_protein_model: bool = field(default=True, metadata={"help": "Whether to freeze the ESM-2 protein model during training."}) num_query_tokens: int = field(default=32, metadata={"help": "Number of query tokens for QFormer."}) qformer_num_layers: int = field(default=6, metadata={"help": "Number of layers in QFormer."}) qformer_num_heads: int = field(default=8, metadata={"help": "Number of attention heads in QFormer."}) qformer_dropout: float = field(default=0.1, metadata={"help": "Dropout rate for QFormer."}) @dataclass class GRPOScriptArguments(ScriptArguments): """ Script arguments for the GRPO training script. """ dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."}) data_file_paths: str = field( default=None, metadata={"help": "Paths to data files, separated by ':'"}, ) arrow_cache_dir: str = field( default=None, metadata={"help": "Path to arrow cache directory"}, ) val_split_ratio: float = field( default=0.0, metadata={"help": "Ratio of validation split, default 0.0"}, ) reward_funcs: list[str] = field( # 更新默认奖励函数列表,包含新的三个函数 default_factory=lambda: ["repeatness", "format", "acc", "xmlcount", "soft_format"], metadata={"help": "List of reward functions. Possible values: 'repeatness', 'format', 'acc', 'xmlcount', 'soft_format', 'strict_format', 'less_than_4', 'correctness'"}, ) reward_funcs_registry = { # "accuracy": accuracy_reward, # "format": format_reward, "repeatness": repeatness_reward_func, "format": format_reward_func, "acc": acc_reward_func, "xmlcount": xmlcount_reward_func, "soft_format": soft_format_reward_func, "strict_format": strict_format_reward_func, "less_than_4": less_than_4_reward_func, "correctness": correctness_reward_func, } def get_vlm_module(model_name_or_path): if any(mini_name in model_name_or_path.lower() for mini_name in ["qwen", "smol"]): # 如果你有专门的蛋白质模块,使用它 try: from bioreason.protein_modules import ProteinModule return ProteinModule except ImportError: # 如果没有专门的蛋白质模块,检查DNAModule是否兼容 print("Warning: Using NucleotideDNAModule for protein processing. Consider creating a dedicated ProteinModule.") return NucleotideDNAModule else: raise ValueError(f"Unsupported model: {model_name_or_path}") def _prep_for_training(model: ProteinLLMModel, model_args, protein_model_finetune: bool = False) -> LoraConfig: """ 准备ProteinLLMModel进行训练。 """ # Freeze protein encoder parameters if not finetuning if not protein_model_finetune: for param in model.protein_model.parameters(): param.requires_grad = False print("Frozen protein model parameters") else: print("Protein model parameters will be finetuned") # Get target modules for LoRA target_modules = _get_target_modules(model) print(f"LoRA target modules: {target_modules}") lora_config = LoraConfig( r=model_args.lora_r, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout, target_modules=target_modules, init_lora_weights="gaussian", bias="none", task_type="CAUSAL_LM", ) # Prepare text model for training model.text_model = prepare_model_for_kbit_training(model.text_model) model.text_model = get_peft_model(model.text_model, lora_config) # Make QFormer projection layer trainable for param in model.protein_projection.parameters(): param.requires_grad = True print("QFormer projection layer set as trainable") # Print trainable parameters info trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)") return lora_config ###################################################################### ###################################################################### def main(script_args, training_args, model_args): print(training_args.output_dir) #pl.seed_everything(args.seed) # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" torch.cuda.empty_cache() torch.set_float32_matmul_precision("medium") # Initialize model # Load tokenizer for target text # tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) # tokenizer.pad_token = tokenizer.eos_token # Load model # model = ProteinLLMModel( # text_model_name=model_args.model_name_or_path, # dna_model_name=model_args.dna_model_name_or_path, # cache_dir=model_args.cache_dir, # max_length_text=model_args.max_length_text, # max_length_dna=model_args.max_length_dna, # text_model_finetune=True, # dna_model_finetune=not model_args.freeze_dna_modules, # debug=False, # ) print("Initializing ProteinLLMModel...") model = ProteinLLMModel( text_model_name=model_args.model_name_or_path, protein_model_name=model_args.protein_model_name_or_path, biomedbert_model_name=getattr(model_args, 'biomedbert_model_name', "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"), cache_dir=model_args.cache_dir, max_length_text=model_args.max_length_text, max_length_protein=model_args.max_length_protein, text_model_finetune=True, protein_model_finetune=not model_args.freeze_protein_modules, biomedbert_finetune=getattr(model_args, 'biomedbert_finetune', True), # 新增:控制BiomedBERT微调 # Q-Former相关参数(简化了,因为直接使用BiomedBERT) qformer_num_query_tokens=getattr(model_args, 'qformer_num_query_tokens', 8), # 重命名为qformer_num_query_tokens ) # load checkpoint if model_args.sft_checkpoint is not None: print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}") # Determine if it's a directory (PEFT format) or file (PyTorch state dict) is_directory = os.path.isdir(model_args.sft_checkpoint) if is_directory: # It's a PEFT checkpoint directory - load properly with PEFT from peft import PeftModel # First initialize the text model with PEFT print("Loading as PEFT checkpoint directory") model.text_model = PeftModel.from_pretrained( model.text_model, model_args.sft_checkpoint, is_trainable=True ) # Verify loaded adapters print("Loaded LoRA adapters:", model.text_model.active_adapter) # Optional: Merge weights into base model print("Merging SFT LoRA weights into base model...") model.text_model = model.text_model.merge_and_unload() print("Successfully merged SFT knowledge into base model") else: # It's a PyTorch state dict file print("Loading as PyTorch state dict file") checkpoint = torch.load(model_args.sft_checkpoint, map_location="cpu") # replace model.text_model with text_model for all in state dict def new_key(k): if k.startswith("=model."): return k[6:] elif k.startswith("_forward_module."): return k[len("_forward_module."):] else: return k if "state_dict" in checkpoint: magic = {new_key(k): v for k, v in checkpoint["state_dict"].items()} elif "module" in checkpoint: magic = {new_key(k): v for k, v in checkpoint["module"].items()} elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()): # Direct state dict - the checkpoint itself is the state dict print("Detected direct state dict format") magic = {new_key(k): v for k, v in checkpoint.items()} else: raise ValueError(f"Unsupported checkpoint format: {model_args.sft_checkpoint}") # Handle prefix mapping for different model architectures lora_prefix = any("lora" in key for key in state_dict.keys()) if lora_prefix: print("Detected LoRA weights in state dict") # First prepare model for LoRA training _prep_for_training(model, model_args, protein_model_finetune=model_args.freeze_protein_modules) # Print diagnostic info model_keys = set(model.state_dict().keys()) checkpoint_keys = set(state_dict.keys()) print(f"Model has {len(model_keys)} keys") print(f"Checkpoint has {len(checkpoint_keys)} keys") # Intelligent key mapping for different prefixes new_state_dict = {} for k, v in state_dict.items(): # Handle different common prefix patterns if "base_model.model" in k and k not in model_keys: new_k = k.replace("text_model.base_model.model", "text_model") if new_k in model_keys: new_state_dict[new_k] = v continue # Try removing/adding prefixes if k.startswith("text_model.") and k not in model_keys: new_k = "text_model.base_model.model." + k[len("text_model."):] if new_k in model_keys: new_state_dict[new_k] = v continue # Keep original key new_state_dict[k] = v state_dict = new_state_dict print(f"After key mapping: {len(state_dict)} keys") # Load state dict with missing/unexpected keys allowed result = model.load_state_dict(state_dict, strict=False) if len(result.unexpected_keys) > 0: print(f"Sample unexpected keys: {result.unexpected_keys[:5]}") if len(result.missing_keys) > 0: print(f"Sample missing keys: {result.missing_keys[:5]}") print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys") else: print("Standard weights detected - loading before LoRA setup") # Handle shared memory issue for embedding weights for key in list(state_dict.keys()): if 'lm_head.weight' in key: state_dict[key] = state_dict[key].clone() # Load weights before setting up LoRA result = model.load_state_dict(state_dict, strict=False) print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys") # Now prepare for LoRA training _prep_for_training(model, model_args, protein_model_finetune=model_args.freeze_protein_modules) else: # No checkpoint, just prepare for training _prep_for_training(model, model_args, protein_model_finetune=not model_args.freeze_protein_model) # Get reward functions reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] # reward_funcs = [ # xmlcount_reward_func, # soft_format_reward_func, # strict_format_reward_func, # int_reward_func, # correctness_reward_func, # ] print("reward_funcs:", [func.__name__ for func in reward_funcs]) vlm_module_cls = get_vlm_module(model_args.model_name_or_path) print("using vlm module:", vlm_module_cls.__name__) question_prompt = vlm_module_cls.get_question_template() dataset = get_kegg_questions() #dataset = get_gsm8k_questions(question_prompt) print(dataset) #print('ITEM ONE OF THE DATASET', dataset['train'][0]) # Custom callback to handle saving with PyTorch's native mechanism custom_save_callback = SaveWithPyTorchCallback() # Initialize the GRPO trainer with custom callback trainer = DNALLMGRPOTrainer( model=model, reward_funcs=reward_funcs, args=training_args, dna_module=vlm_module_cls(), train_dataset=dataset['train'], eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), attn_implementation=model_args.attn_implementation, torch_dtype=model_args.torch_dtype, callbacks=[custom_save_callback], # Add our custom callback ) # Set the trainer to save in PyTorch format instead of safetensors training_args.save_safetensors = False # Train and push the model to the Hub # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): # trainer.train(resume_from_checkpoint=True) # else: # trainer.train() # Train and push the model to the Hub trainer.train() if __name__ == "__main__": # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, GRPOModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() # Ensure we use PyTorch's save mechanism instead of safetensors training_args.save_safetensors = False main(script_args, training_args, model_args) # parser.add_argument("--wandb_project", type=str, default="dna-text-finetune") # parser.add_argument("--wandb_entity", type=str, default="adibvafa") # args = parser.parse_args()