import os import re import pathlib from argparse import ArgumentParser from typing import List, Dict, Optional from dataclasses import dataclass, field import torch from torch import nn import torch.nn.functional as F from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from transformers import get_cosine_schedule_with_warmup, AutoTokenizer from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM, AutoProcessor, ) from datasets import load_dataset, DatasetDict from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training from transformers import BitsAndBytesConfig import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.loggers import WandbLogger from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config # Import BLIP2 modules from model.blip2_stage2 import Blip2Stage2 from blip2_dna_module import Blip2DNAModule from blip2_grpo_trainer import Blip2GRPOTrainer from bioreason.trainer import DNALLMGRPOConfig # Custom TrainerCallback to override the saving mechanism from transformers import TrainerCallback, TrainerState, TrainerControl from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from prompt_templates import prompt_templates class SaveWithPyTorchCallback(TrainerCallback): """Custom callback to save models with PyTorch's native save mechanism instead of safetensors""" def on_save(self, args, state, control, **kwargs): # Get the checkpoint folder checkpoint_folder = os.path.join( args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" ) os.makedirs(checkpoint_folder, exist_ok=True) # Save with PyTorch instead of safetensors checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin") model = kwargs.get("model") # Get model unwrapped from accelerator etc. unwrapped_model = model.module if hasattr(model, "module") else model # Save using PyTorch directly torch.save(unwrapped_model.state_dict(), checkpoint_path) # For BLIP2, save the config from the LLM component if hasattr(unwrapped_model, "blip2") and hasattr(unwrapped_model.blip2, "llm_model"): if hasattr(unwrapped_model.blip2.llm_model, "config"): unwrapped_model.blip2.llm_model.config.save_pretrained(checkpoint_folder) elif hasattr(unwrapped_model.blip2.llm_model, "base_model") and hasattr(unwrapped_model.blip2.llm_model.base_model, "config"): unwrapped_model.blip2.llm_model.base_model.config.save_pretrained(checkpoint_folder) # Print info about what's being saved print(f"Saved model checkpoint to {checkpoint_folder}") lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k] print(f"Checkpoint contains {len(lora_params)} LoRA parameters") # Signal that we've saved control.should_save = False return control def extract_xml_answer(text: str) -> str: """提取answer标签中的内容,如果没有则返回think标签后的内容""" # 首先尝试提取answer标签 answer_match = re.search(r"(.*?)", text, re.DOTALL) if answer_match: return answer_match.group(1).strip() # 如果没有answer标签,尝试提取think标签后的内容 think_split = text.split("") if len(think_split) > 1: return think_split[-1].strip() # 如果都没有,返回原文 return text.strip() def extract_classification_answer(text: str) -> str: """专门用于提取分类答案的函数""" # 提取answer标签中的内容 answer_match = re.search(r"(.*?)", text, re.DOTALL) if answer_match: answer_content = answer_match.group(1).strip() # 查找分类相关的模式 classification_patterns = [ r"[Cc]lassification:\s*(\d+)", r"[Cc]lass:\s*(\d+)", r"[Ll]abel:\s*(\d+)", r"[Pp]rediction:\s*(\d+)", r"(\d+)", # 任何数字 ] for pattern in classification_patterns: match = re.search(pattern, answer_content) if match: return match.group(1) return answer_content return extract_xml_answer(text) def extract_hash_answer(text: str) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() def get_kegg_questions() -> Dataset: """保留原有的KEGG数据集加载函数作为fallback""" try: data = load_dataset('wanglab/kegg', 'default') # type: ignore example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] num_dna_sequences = 2 data = data.map(lambda x: { # type: ignore 'prompt': [ { 'role': 'user', 'content': [ *({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)), {'type': 'text', 'text': x['question']}, ], }, ], 'dna_sequences': [x['reference_sequence'], x['variant_sequence']], 'answer': x['answer'], }) # type: ignore return data except Exception as e: print(f"Failed to load KEGG dataset: {e}") # 返回一个空的数据集结构 from datasets import Dataset empty_data = { 'prompt': [], 'dna_sequences': [], 'answer': [] } dataset = Dataset.from_dict(empty_data) return {'train': dataset, 'val': dataset} def get_protein_classification_data(data_path: str = None, prompt_template: str = None) -> Dataset: """ 加载蛋白质分类数据集 数据格式:name,aa_seq,label,location,unique_id,pdb_hash """ import pandas as pd from datasets import Dataset if data_path is None: # 如果没有提供路径,使用默认的kegg数据集作为fallback return get_kegg_questions() # 读取CSV数据 if data_path.endswith('.csv'): df = pd.read_csv(data_path) else: # 假设是其他格式,可以扩展 raise ValueError(f"Unsupported file format: {data_path}") # 默认prompt模板 if prompt_template is None: prompt_template = """ Please analyze the following protein sequence and predict its classification. Protein sequence: {aa_seq} Question: What is the classification of this protein sequence? Please provide your reasoning in tags and your final answer in tags. """ # 数据转换 def process_example(row): # 构建prompt prompt_text = prompt_template.format( aa_seq=row['aa_seq'], name=row.get('name', ''), location=row.get('location', ''), unique_id=row.get('unique_id', ''), ) return { 'prompt': [ { 'role': 'user', 'content': [ {'type': 'protein', 'text': None}, # 蛋白质序列占位符 {'type': 'text', 'text': prompt_text}, ], }, ], 'dna_sequences': [row['aa_seq']], # 使用aa_seq作为"dna_sequences" 'answer': str(row['label']), # label作为答案 'metadata': { 'name': row.get('name', ''), 'location': row.get('location', ''), 'unique_id': row.get('unique_id', ''), 'pdb_hash': row.get('pdb_hash', ''), } } # 转换所有数据 processed_data = [] for _, row in df.iterrows(): processed_data.append(process_example(row)) # 创建数据集 dataset = Dataset.from_list(processed_data) # 划分训练集和验证集 if len(dataset) > 100: # 如果数据足够大,进行划分 dataset = dataset.train_test_split(test_size=0.1, seed=42) else: # 数据较小时,复制训练集作为验证集 dataset = { 'train': dataset, 'val': dataset.select(range(min(10, len(dataset)))) # 选择前10个作为验证 } return dataset def get_custom_protein_data_with_prompts(data_path: str = None, prompt_templates: Dict[str, str] = None) -> Dataset: """ 更灵活的蛋白质数据加载函数,支持多种prompt模板 """ import pandas as pd from datasets import Dataset import random if data_path is None: return get_kegg_questions() # 读取数据 df = pd.read_csv(data_path) def process_example(row, template_name=None): # 随机选择或指定template if template_name is None: template_name = random.choice(list(prompt_templates.keys())) template = prompt_templates[template_name] # 格式化prompt prompt_text = template.format( aa_seq=row['aa_seq'][:500] + "..." if len(row['aa_seq']) > 500 else row['aa_seq'], # 截断长序列 label=row['label'], name=row.get('name', ''), location=row.get('location', ''), ) return { 'prompt': [ { 'role': 'user', 'content': [ {'type': 'protein', 'text': None}, {'type': 'text', 'text': prompt_text.split('')[0]}, # prompt前半部分 ], }, ], 'dna_sequences': [row['aa_seq']], # 完整序列用于模型处理 'answer': str(row['label']), 'template_used': template_name, 'metadata': { 'name': row.get('name', ''), 'location': row.get('location', ''), 'unique_id': row.get('unique_id', ''), 'pdb_hash': row.get('pdb_hash', ''), 'full_prompt': prompt_text, } } # 处理数据 processed_data = [] print("template_name") print(script_args.template_name) for _, row in df.iterrows(): processed_data.append(process_example(row,script_args.template_name)) dataset = Dataset.from_list(processed_data) # 数据集划分 if len(dataset) > 50: dataset = dataset.train_test_split(test_size=0.1, seed=42) else: dataset = { 'train': dataset, 'val': dataset.select(range(min(5, len(dataset)))) } return dataset def get_gsm8k_questions(question_prompt: str) -> Dataset: data = load_dataset('openai/gsm8k', 'main') # type: ignore example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] data = data.map(lambda x: { # type: ignore 'prompt': [ { 'role': 'user', 'content': [ *({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))), {'type': 'text', 'text': 'Give me a short introduction to large language model.'} ] }, ], 'dna_sequences': [dna for dna in example_dna_sequences], 'answer': extract_hash_answer(x['answer']), }) # type: ignore return data # type: ignore # Reward functions def format_correct_reward_func(completions, **kwargs) -> list[float]: """ 奖励函数:检查格式是否正确 要求:包含 ...... 标签 """ responses = [completion[0]["content"] for completion in completions] rewards = [] for response in responses: score = 0.0 # 检查是否有think标签 if "" in response and "" in response: score += 0.5 # 检查是否有answer标签 if "" in response and "" in response: score += 0.5 # 检查标签的顺序是否正确 think_start = response.find("") think_end = response.find("") answer_start = response.find("") answer_end = response.find("") if (think_start != -1 and think_end != -1 and answer_start != -1 and answer_end != -1 and think_start < think_end < answer_start < answer_end): score += 0.5 # 格式完全正确的额外奖励 rewards.append(score) return rewards def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]: """ 奖励函数:检查答案准确率 适配蛋白质分类任务 """ responses = [completion[0]['content'] for completion in completions] rewards = [] for i, response in enumerate(responses): # 提取answer标签中的内容 answer_match = re.search(r"(.*?)", response, re.DOTALL) if answer_match: extracted_answer = answer_match.group(1).strip() else: extracted_answer = response.strip() # 获取正确答案 if isinstance(answer, list) and len(answer) > i: correct_answer = str(answer[i]).strip() elif isinstance(answer, list) and len(answer) > 0: correct_answer = str(answer[0]).strip() else: correct_answer = str(answer).strip() # 计算准确率奖励 # 对于分类任务,检查数字或类别匹配 extracted_clean = re.sub(r'[^\w\d]', '', extracted_answer.lower()) correct_clean = re.sub(r'[^\w\d]', '', correct_answer.lower()) if correct_clean in extracted_clean or extracted_clean == correct_clean: rewards.append(1.0) # 完全匹配 elif any(word in extracted_clean for word in correct_clean.split()): rewards.append(0.5) # 部分匹配 else: rewards.append(0.0) # 不匹配 return rewards def classification_specific_reward_func(prompts, completions, answer, **kwargs) -> list[float]: """ 针对蛋白质分类任务的专门奖励函数 """ responses = [completion[0]['content'] for completion in completions] rewards = [] for i, response in enumerate(responses): score = 0.0 # 提取答案 answer_match = re.search(r"(.*?)", response, re.DOTALL) if answer_match: extracted_answer = answer_match.group(1).strip() else: extracted_answer = response.strip() # 获取正确答案 if isinstance(answer, list) and len(answer) > i: correct_answer = str(answer[i]).strip() elif isinstance(answer, list) and len(answer) > 0: correct_answer = str(answer[0]).strip() else: correct_answer = str(answer).strip() # 检查是否包含分类关键词 classification_keywords = ['classification', 'class', 'category', 'type', 'function', 'family'] if any(keyword in extracted_answer.lower() for keyword in classification_keywords): score += 0.2 # 检查数字匹配(对于数字标签) if correct_answer.isdigit(): if correct_answer in extracted_answer: score += 0.8 # 检查数字临近性 try: extracted_numbers = re.findall(r'\d+', extracted_answer) if extracted_numbers: closest_num = min(extracted_numbers, key=lambda x: abs(int(x) - int(correct_answer))) if abs(int(closest_num) - int(correct_answer)) <= 1: score += 0.4 except: pass else: # 文本标签匹配 if correct_answer.lower() in extracted_answer.lower(): score += 0.8 # 检查是否有推理过程 if "" in response and "" in response: think_content = re.search(r"(.*?)", response, re.DOTALL) if think_content and len(think_content.group(1).strip()) > 20: score += 0.2 rewards.append(min(score, 1.0)) # 确保不超过1.0 return rewards def repetition_penalty_reward_func(completions, **kwargs) -> list[float]: """ 奖励函数:检查重复率(越低越好) 计算文本中重复词汇的比例,重复率越低奖励越高 """ responses = [completion[0]["content"] for completion in completions] rewards = [] for response in responses: # 提取answer部分的文本 answer_match = re.search(r"(.*?)", response, re.DOTALL) if answer_match: text_to_analyze = answer_match.group(1).strip() else: text_to_analyze = response.strip() # 分词并计算重复率 words = text_to_analyze.lower().split() if len(words) == 0: rewards.append(0.0) continue # 计算词汇重复率 unique_words = set(words) repetition_rate = 1.0 - (len(unique_words) / len(words)) # 计算句子重复率 sentences = [s.strip() for s in text_to_analyze.split('.') if s.strip()] if len(sentences) > 1: unique_sentences = set(sentences) sentence_repetition_rate = 1.0 - (len(unique_sentences) / len(sentences)) else: sentence_repetition_rate = 0.0 # 综合重复率 overall_repetition = (repetition_rate + sentence_repetition_rate) / 2 # 重复率越低,奖励越高 reward = max(0.0, 1.0 - overall_repetition * 2) # 乘以2让惩罚更明显 rewards.append(reward) return rewards def combined_reward_func(prompts, completions, answer, format_weight=0.3, accuracy_weight=0.5, repetition_weight=0.2, **kwargs) -> list[float]: """ 组合奖励函数:格式+准确率+重复率的加权组合 """ format_rewards = format_correct_reward_func(completions, **kwargs) accuracy_rewards = accuracy_reward_func(prompts, completions, answer, **kwargs) repetition_rewards = repetition_penalty_reward_func(completions, **kwargs) # 确保权重总和为1 total_weight = format_weight + accuracy_weight + repetition_weight if total_weight != 1.0: format_weight /= total_weight accuracy_weight /= total_weight repetition_weight /= total_weight print(f"Normalized weights - Format: {format_weight:.3f}, Accuracy: {accuracy_weight:.3f}, Repetition: {repetition_weight:.3f}") combined_rewards = [] for f_reward, a_reward, r_reward in zip(format_rewards, accuracy_rewards, repetition_rewards): combined = (format_weight * f_reward + accuracy_weight * a_reward + repetition_weight * r_reward) combined_rewards.append(combined) return combined_rewards # 保留一些原有的奖励函数作为备选 def less_than_4_reward_func(completions, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses] def strict_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"^\n.*?\n\n.*?\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def xmlcount_reward_func(completions, **kwargs) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents] def count_xml(text) -> float: count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 return count @dataclass class Blip2ModelConfig(ModelConfig): # BLIP2 specific configuration model_name_or_path: str = field(default="blip2-model", metadata={"help": "Model checkpoint for weights initialization."}) # BLIP2 Architecture parameters bert_name: str = field(default="/path/to/bert", metadata={"help": "BERT model for Q-former"}) num_query_token: int = field(default=8, metadata={"help": "Number of query tokens"}) cross_attention_freq: int = field(default=2, metadata={"help": "Cross attention frequency"}) plm_model: str = field(default="facebook/esm2_t30_150M_UR50D", metadata={"help": "Protein language model"}) plm_tune: str = field(default="freeze", metadata={"help": "PLM tuning strategy"}) llm_name: str = field(default="facebook/galactica-1.3b", metadata={"help": "Language model name"}) llm_tune: str = field(default="lora", metadata={"help": "LLM tuning strategy"}) qformer_tune: str = field(default="train", metadata={"help": "Q-former tuning strategy"}) peft_dir: str = field(default="", metadata={"help": "PEFT directory"}) # LoRA parameters lora_r: int = field(default=8, metadata={"help": "LoRA rank"}) lora_alpha: int = field(default=16, metadata={"help": "LoRA alpha"}) lora_dropout: float = field(default=0.1, metadata={"help": "LoRA dropout"}) # Training parameters enbale_gradient_checkpointing: bool = field(default=False, metadata={"help": "Enable gradient checkpointing"}) enable_flash: bool = field(default=False, metadata={"help": "Enable flash attention"}) # Other parameters cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."}) sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."}) freeze_dna_modules: bool = field(default=False, metadata={"help": "Freeze DNA/protein modules"}) @dataclass class GRPOScriptArguments(ScriptArguments): """ Script arguments for the GRPO training script with BLIP2. """ dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."}) data_file_paths: str = field( default=None, metadata={"help": "Path to protein classification CSV file (format: name,aa_seq,label,location,unique_id,pdb_hash)"}, ) arrow_cache_dir: str = field( default=None, metadata={"help": "Path to arrow cache directory"}, ) val_split_ratio: float = field( default=0.1, metadata={"help": "Ratio of validation split, default 0.1"}, ) reward_funcs: list[str] = field( # 选项1:使用组合奖励函数(推荐) default_factory=lambda: ["combined"], # 选项2:使用分离的奖励函数 # default_factory=lambda: ["format_correct", "accuracy", "repetition_penalty"], # 选项3:使用蛋白质分类专用奖励 # default_factory=lambda: ["format_correct", "classification_specific", "repetition_penalty"], metadata={"help": "List of reward functions. Available: 'combined', 'format_correct', 'accuracy', 'classification_specific', 'repetition_penalty', 'xmlcount', 'strict_format', 'less_than_4'"}, ) # 奖励函数权重配置 format_weight: float = field( default=0.3, metadata={"help": "Weight for format correctness reward (used in combined reward)"} ) accuracy_weight: float = field( default=0.5, metadata={"help": "Weight for accuracy reward (used in combined reward)"} ) repetition_weight: float = field( default=0.2, metadata={"help": "Weight for repetition penalty reward (used in combined reward)"} ) # 数据处理参数 template_name: str = field( default="classification", metadata={"help": "Prompt template to use: 'classification', 'function_prediction', 'location_prediction'"} ) max_seq_length: int = field( default=1000, metadata={"help": "Maximum protein sequence length for display in prompt"} ) use_custom_prompts: bool = field( default=True, metadata={"help": "Whether to use custom protein-specific prompts"} ) reward_funcs_registry = { # 新的三合一奖励函数 "combined": combined_reward_func, # 格式+准确率+重复率组合 # 分离的奖励函数 "format_correct": format_correct_reward_func, # 格式正确性 "accuracy": accuracy_reward_func, # 准确率 "repetition_penalty": repetition_penalty_reward_func, # 重复率惩罚 "classification_specific": classification_specific_reward_func, # 蛋白质分类专用 # 原有的奖励函数(保留作为备选) "xmlcount": xmlcount_reward_func, "strict_format": strict_format_reward_func, "less_than_4": less_than_4_reward_func, } def get_vlm_module(model_name_or_path): # Always use BLIP2 module for this implementation return Blip2DNAModule def create_blip2_args_from_config(model_args): """Create BLIP2 args from model config""" # Convert model config to the format expected by BLIP2 blip2_args = { 'bert_name': model_args.bert_name, 'num_query_token': model_args.num_query_token, 'cross_attention_freq': model_args.cross_attention_freq, 'plm_model': model_args.plm_model, 'plm_tune': model_args.plm_tune, 'llm_name': model_args.llm_name, 'llm_tune': model_args.llm_tune, 'qformer_tune': model_args.qformer_tune, 'peft_dir': model_args.peft_dir, 'lora_r': model_args.lora_r, 'lora_alpha': model_args.lora_alpha, 'lora_dropout': model_args.lora_dropout, 'enbale_gradient_checkpointing': model_args.enbale_gradient_checkpointing, 'enable_flash': model_args.enable_flash, } return blip2_args def _prep_for_training(model, training_args): """ Prepare BLIP2 model for training with LoRA. """ # The BLIP2 model should handle its own LoRA setup # This is mainly for any additional preparation needed target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"] lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, lora_dropout=training_args.lora_dropout, target_modules=target_modules, init_lora_weights="gaussian", bias="none", task_type="CAUSAL_LM", ) return lora_config def main(script_args, training_args, model_args): print(training_args.output_dir) torch.cuda.empty_cache() torch.set_float32_matmul_precision("medium") # Create BLIP2 model blip2_args = create_blip2_args_from_config(model_args) model = Blip2Stage2(blip2_args) # Load checkpoint if specified if model_args.sft_checkpoint is not None: print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}") model = Blip2Stage2.load_from_checkpoint(model_args.sft_checkpoint, strict=False, args=blip2_args, map_location='cpu') # if os.path.isdir(model_args.sft_checkpoint): # # Load Lightning checkpoint # checkpoint = torch.load(os.path.join(model_args.sft_checkpoint, "last.ckpt"), map_location='cpu') # model.load_state_dict(checkpoint['state_dict'], strict=False) # print("Loaded Lightning checkpoint") # else: # # Load PyTorch state dict # checkpoint = torch.load(model_args.sft_checkpoint, map_location='cpu') # if "state_dict" in checkpoint: # state_dict = checkpoint["state_dict"] # else: # state_dict = checkpoint # # Remove module prefix if present # state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # result = model.load_state_dict(state_dict, strict=False) # print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys") # Get reward functions with weights reward_funcs = [] for func_name in script_args.reward_funcs: if func_name == "combined": # 为组合奖励函数传递权重参数 def weighted_combined_reward(prompts, completions, answer, **kwargs): return combined_reward_func( prompts, completions, answer, format_weight=script_args.format_weight, accuracy_weight=script_args.accuracy_weight, repetition_weight=script_args.repetition_weight, **kwargs ) reward_funcs.append(weighted_combined_reward) else: reward_funcs.append(reward_funcs_registry[func_name]) print("reward_funcs:", [func.__name__ if hasattr(func, '__name__') else 'weighted_combined_reward' for func in reward_funcs]) print(f"Reward weights - Format: {script_args.format_weight}, Accuracy: {script_args.accuracy_weight}, Repetition: {script_args.repetition_weight}") vlm_module_cls = get_vlm_module(model_args.model_name_or_path) print("using vlm module:", vlm_module_cls.__name__) question_prompt = vlm_module_cls.get_question_template() # Load dataset based on data source if script_args.data_file_paths and script_args.use_custom_prompts: print(f"Loading custom protein data from: {script_args.data_file_paths}") dataset = get_custom_protein_data_with_prompts( data_path=script_args.data_file_paths, prompt_templates=prompt_templates, template_name=script_args.template_name ) elif script_args.data_file_paths: print(f"Loading protein data from: {script_args.data_file_paths}") dataset = get_protein_classification_data( data_path=script_args.data_file_paths ) else: print("Using default KEGG dataset") dataset = get_kegg_questions() print("Dataset loaded:") print(f"Train size: {len(dataset['train'])}") print(f"Val size: {len(dataset.get('val', []))}") # 打印数据样例 if len(dataset['train']) > 0: print("\nSample data:") sample = dataset['train'][0] print(f"Prompt type: {type(sample.get('prompt', 'Unknown'))}") print(f"DNA sequences count: {len(sample.get('dna_sequences', []))}") print(f"Answer: {sample.get('answer', 'N/A')}") if 'metadata' in sample: print(f"Metadata: {sample['metadata']}") print(f"First 100 chars of sequence: {sample.get('dna_sequences', [''])[0][:100]}...") # Custom callback to handle saving with PyTorch's native mechanism custom_save_callback = SaveWithPyTorchCallback() # Initialize the BLIP2 GRPO trainer trainer = Blip2GRPOTrainer( model=model, reward_funcs=reward_funcs, args=training_args, dna_module=vlm_module_cls(), train_dataset=dataset['train'], eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), attn_implementation=getattr(model_args, 'attn_implementation', 'flash_attention_2'), torch_dtype=getattr(model_args, 'torch_dtype', 'bfloat16'), callbacks=[custom_save_callback], ) # Set the trainer to save in PyTorch format instead of safetensors training_args.save_safetensors = False # Train the model trainer.train() if __name__ == "__main__": print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, Blip2ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() # Ensure we use PyTorch's save mechanism instead of safetensors training_args.save_safetensors = False main(script_args, training_args, model_args) # 使用示例: """ 使用你的蛋白质数据进行训练: 1. 准备CSV文件,格式:name,aa_seq,label,location,unique_id,pdb_hash 2. 运行训练: python blip2_reason.py \ --data_file_paths /path/to/your/protein_data.csv \ --reward_funcs combined \ --format_weight 0.2 \ --accuracy_weight 0.6 \ --repetition_weight 0.2 \ --use_custom_prompts \ --prompt_template classification \ --max_seq_length 1000 \ --output_dir ./output \ --per_device_train_batch_size 4 \ --num_train_epochs 3 \ --learning_rate 1e-5 3. 或者使用分离的奖励函数: python blip2_reason.py \ --data_file_paths /path/to/your/protein_data.csv \ --reward_funcs format_correct classification_specific repetition_penalty \ --use_custom_prompts \ --prompt_template function_prediction 数据格式示例: P0DM40,MLRVVVESASINPPLSTTPKAFVTVYFRDMMKRTRVEEGHDPIWNETLIWHLWNQPLENDSFLKVILQDSVSKKKERFIGLATVPLKRLAQRPKEVMFVRDLILLNHSMKPTNCTVTLHVAQIYDQDTEMTGNEELLGSTVNEVTQKKLMVSGLPMHRALASKPQHFQVRVKVFEARQLLGNNIKPVVKVNIADQQHLTRIKMGNNPFFNEIFFQNFHEVPAKFFEENISIEVVDSAASRSKAEIGRFQTDIGFIYHSPGHTLLRKWLGLCQRNKTTSGVRGYLKVTICALGVGDQALVDQKLPYEQNTRVQIFKSKEVPVSLAYLQFFIYCAEDLHFGTHKSATPVLEVELIGDKLRTKPQNPSDNPIWNQILTFQIQLPCLSSYIKFRVMDCSKYKCQDEIGSASLCLSQISSTGEEIQGMYSGFLPCFGPSFLTLRGGKKPPFRTSEEGTCIMDAVQHGLAYRGRIFVEIVTKIKSQQDSVMKDLSQEVTQVEMQYYRQKYGLCVIFLSCTMMPKFKDLIQFEVSMGHYGNKTDPNYKPLVSTTQYSPVIYDGTTYHYVPWYNTKPVVAVTSNWEDVGFRMNCLNLLHITRDRLKTNLDILKSIRNPRDPALLQQWEKLLKELQEDCRRPLPCMTDQPRANSLDRNKWQLRSQLLQQLAQMAKEAKPVNMVGTAKEWLHRLNAVIPEPQESLPDVLIWLMSRQQRVAYARVPAHTVLFSPAGPLSSGKFCGKIQNILLQYPEGEGQDTFPASLRVCMWLGNVKYSKNLKLLQQGSMVVYAETYENQAKTRDDWGQQGLYHCPNFSDVMGRKALPKTDFKAPPGWHWKDDWVVEPQRRLLLDIDINKSQVLEEVYENQLRNATGAWVPAAIPNTDVNGQPVEALENVKCPQGWHFKKNWIVKLNHAVDSEGWEYGVGIPPSGLPQIWNSVEKTYHSCRRRRWVRVRFRNHKELGQERSQEQETLSFLQMQDLSEEGKEGWEYGTFDSRFHLDPQPTSRFRRRCWHRQLAPNKDRGVASIFLLEGSLAVEQKDQPRKEMEKTRSWQPWKDLRHTPEDPRIPTTPFIYYILNKPHYYQLFCYIYQARNLMYNQILTFQEPFIQVVFLNHSLCTQTLRSSAAPTWSQSIIFQHLLLFEDPKDTRENPPLVVLELWQHDSRGNKILWGRSMWPPVVWLGLQDWVFTPLRWHPLVRELGEEEGEILASCELILETQKLKELHPPILSIPCKDGIYLLPKNIQPTMKMMAIEIMAWGLRNMTKVRYPQLLLECGGESLKTEPISNFQENPNFPTSTFFFTVFMPLEETHAQPLVVKVVDNQEYGQQIVVGQANIDFLQPYFCDPWSLNYTTVKLPTLSVKKPDTFLDFVYKKFWFDSSKDEEVYEEEVDWWSKLFWATGDADKSLNYNHKSYHTLKVYDCELEAVLTFKGLQDFCQTFKLYQEKPKVDSPVVGEFKGLFRIYPFPEDPEAPKPPRQFSAWPEIEDFPQMCLVRVYLIRAINLQPQDYNGLCDPYVILKLGQTKLGSRDSYYPNTLDPIFGMMYELTCNIPLEKDLEIQLFDFDLITADDEIGSTVIDLENRLLSGFGARCGLSKSYCKSGPFKWRDQMTPSYLLYRYAKQKGLPPPVFDLEGDSLYYNGETFKLQSFESAPPTYKHLGPKKERLALYILNTQGLVPEHVETRTLHSNSQPGIDQGKIQMWVDIFPKMLGPPGPQVNISPRKPKRYQLRCIIWSTAEVDLVQETFSKEKMSDIYVKGWLFGLEEDTQKTDVHYHSLTGEATFNWRFIFTMDYLTTERACVQSQKDYIWSLDPTSTKFPARLMIQIWDNDFFSPDDFLGVLELDLSDMPLPAQNIKQCSLKMMETDSKWPFTPQKRISLFKKTNVTGWWPCQVLDGDKWRLSGKVKMTLEMLSEREALIRPAGRGQSEPNQFPMLHPPERNDSFLLWYQSPIKNFCYAVCKRYRSKIICLVVTLVIGFILLNFVYSAPSYFAMNWIKPQLRLSSPIKIVNLIGTVNTSNINSSILTMEGSTYHASHVFPEAPAP,0,M,af67d99c09f74ea8af5004cc2906bbc5,d55cbc3d94bd9668d97a678b4a04176a """