import os
import re
import pathlib
from argparse import ArgumentParser
from typing import List, Dict, Optional
from dataclasses import dataclass, field
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoProcessor,
)
from datasets import load_dataset, DatasetDict
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
# Import BLIP2 modules
from model.blip2_stage2 import Blip2Stage2
from blip2_dna_module import Blip2DNAModule
from blip2_grpo_trainer import Blip2GRPOTrainer
from bioreason.trainer import DNALLMGRPOConfig
# Custom TrainerCallback to override the saving mechanism
from transformers import TrainerCallback, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from prompt_templates import prompt_templates
class SaveWithPyTorchCallback(TrainerCallback):
"""Custom callback to save models with PyTorch's native save mechanism instead of safetensors"""
def on_save(self, args, state, control, **kwargs):
# Get the checkpoint folder
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
os.makedirs(checkpoint_folder, exist_ok=True)
# Save with PyTorch instead of safetensors
checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
model = kwargs.get("model")
# Get model unwrapped from accelerator etc.
unwrapped_model = model.module if hasattr(model, "module") else model
# Save using PyTorch directly
torch.save(unwrapped_model.state_dict(), checkpoint_path)
# For BLIP2, save the config from the LLM component
if hasattr(unwrapped_model, "blip2") and hasattr(unwrapped_model.blip2, "llm_model"):
if hasattr(unwrapped_model.blip2.llm_model, "config"):
unwrapped_model.blip2.llm_model.config.save_pretrained(checkpoint_folder)
elif hasattr(unwrapped_model.blip2.llm_model, "base_model") and hasattr(unwrapped_model.blip2.llm_model.base_model, "config"):
unwrapped_model.blip2.llm_model.base_model.config.save_pretrained(checkpoint_folder)
# Print info about what's being saved
print(f"Saved model checkpoint to {checkpoint_folder}")
lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k]
print(f"Checkpoint contains {len(lora_params)} LoRA parameters")
# Signal that we've saved
control.should_save = False
return control
def extract_xml_answer(text: str) -> str:
"""提取answer标签中的内容,如果没有则返回think标签后的内容"""
# 首先尝试提取answer标签
answer_match = re.search(r"(.*?)", text, re.DOTALL)
if answer_match:
return answer_match.group(1).strip()
# 如果没有answer标签,尝试提取think标签后的内容
think_split = text.split("")
if len(think_split) > 1:
return think_split[-1].strip()
# 如果都没有,返回原文
return text.strip()
def extract_classification_answer(text: str) -> str:
"""专门用于提取分类答案的函数"""
# 提取answer标签中的内容
answer_match = re.search(r"(.*?)", text, re.DOTALL)
if answer_match:
answer_content = answer_match.group(1).strip()
# 查找分类相关的模式
classification_patterns = [
r"[Cc]lassification:\s*(\d+)",
r"[Cc]lass:\s*(\d+)",
r"[Ll]abel:\s*(\d+)",
r"[Pp]rediction:\s*(\d+)",
r"(\d+)", # 任何数字
]
for pattern in classification_patterns:
match = re.search(pattern, answer_content)
if match:
return match.group(1)
return answer_content
return extract_xml_answer(text)
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
def get_kegg_questions() -> Dataset:
"""保留原有的KEGG数据集加载函数作为fallback"""
try:
data = load_dataset('wanglab/kegg', 'default') # type: ignore
example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
num_dna_sequences = 2
data = data.map(lambda x: { # type: ignore
'prompt': [
{
'role': 'user',
'content': [
*({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)),
{'type': 'text', 'text': x['question']},
],
},
],
'dna_sequences': [x['reference_sequence'], x['variant_sequence']],
'answer': x['answer'],
}) # type: ignore
return data
except Exception as e:
print(f"Failed to load KEGG dataset: {e}")
# 返回一个空的数据集结构
from datasets import Dataset
empty_data = {
'prompt': [],
'dna_sequences': [],
'answer': []
}
dataset = Dataset.from_dict(empty_data)
return {'train': dataset, 'val': dataset}
def get_protein_classification_data(data_path: str = None, prompt_template: str = None) -> Dataset:
"""
加载蛋白质分类数据集
数据格式:name,aa_seq,label,location,unique_id,pdb_hash
"""
import pandas as pd
from datasets import Dataset
if data_path is None:
# 如果没有提供路径,使用默认的kegg数据集作为fallback
return get_kegg_questions()
# 读取CSV数据
if data_path.endswith('.csv'):
df = pd.read_csv(data_path)
else:
# 假设是其他格式,可以扩展
raise ValueError(f"Unsupported file format: {data_path}")
# 默认prompt模板
if prompt_template is None:
prompt_template = """
Please analyze the following protein sequence and predict its classification.
Protein sequence: {aa_seq}
Question: What is the classification of this protein sequence?
Please provide your reasoning in tags and your final answer in tags.
"""
# 数据转换
def process_example(row):
# 构建prompt
prompt_text = prompt_template.format(
aa_seq=row['aa_seq'],
name=row.get('name', ''),
location=row.get('location', ''),
unique_id=row.get('unique_id', ''),
)
return {
'prompt': [
{
'role': 'user',
'content': [
{'type': 'protein', 'text': None}, # 蛋白质序列占位符
{'type': 'text', 'text': prompt_text},
],
},
],
'dna_sequences': [row['aa_seq']], # 使用aa_seq作为"dna_sequences"
'answer': str(row['label']), # label作为答案
'metadata': {
'name': row.get('name', ''),
'location': row.get('location', ''),
'unique_id': row.get('unique_id', ''),
'pdb_hash': row.get('pdb_hash', ''),
}
}
# 转换所有数据
processed_data = []
for _, row in df.iterrows():
processed_data.append(process_example(row))
# 创建数据集
dataset = Dataset.from_list(processed_data)
# 划分训练集和验证集
if len(dataset) > 100: # 如果数据足够大,进行划分
dataset = dataset.train_test_split(test_size=0.1, seed=42)
else:
# 数据较小时,复制训练集作为验证集
dataset = {
'train': dataset,
'val': dataset.select(range(min(10, len(dataset)))) # 选择前10个作为验证
}
return dataset
def get_custom_protein_data_with_prompts(data_path: str = None,
prompt_templates: Dict[str, str] = None) -> Dataset:
"""
更灵活的蛋白质数据加载函数,支持多种prompt模板
"""
import pandas as pd
from datasets import Dataset
import random
if data_path is None:
return get_kegg_questions()
# 读取数据
df = pd.read_csv(data_path)
def process_example(row, template_name=None):
# 随机选择或指定template
if template_name is None:
template_name = random.choice(list(prompt_templates.keys()))
template = prompt_templates[template_name]
# 格式化prompt
prompt_text = template.format(
aa_seq=row['aa_seq'][:500] + "..." if len(row['aa_seq']) > 500 else row['aa_seq'], # 截断长序列
label=row['label'],
name=row.get('name', ''),
location=row.get('location', ''),
)
return {
'prompt': [
{
'role': 'user',
'content': [
{'type': 'protein', 'text': None},
{'type': 'text', 'text': prompt_text.split('')[0]}, # prompt前半部分
],
},
],
'dna_sequences': [row['aa_seq']], # 完整序列用于模型处理
'answer': str(row['label']),
'template_used': template_name,
'metadata': {
'name': row.get('name', ''),
'location': row.get('location', ''),
'unique_id': row.get('unique_id', ''),
'pdb_hash': row.get('pdb_hash', ''),
'full_prompt': prompt_text,
}
}
# 处理数据
processed_data = []
print("template_name")
print(script_args.template_name)
for _, row in df.iterrows():
processed_data.append(process_example(row,script_args.template_name))
dataset = Dataset.from_list(processed_data)
# 数据集划分
if len(dataset) > 50:
dataset = dataset.train_test_split(test_size=0.1, seed=42)
else:
dataset = {
'train': dataset,
'val': dataset.select(range(min(5, len(dataset))))
}
return dataset
def get_gsm8k_questions(question_prompt: str) -> Dataset:
data = load_dataset('openai/gsm8k', 'main') # type: ignore
example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
data = data.map(lambda x: { # type: ignore
'prompt': [
{
'role': 'user',
'content': [
*({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))),
{'type': 'text', 'text': 'Give me a short introduction to large language model.'}
]
},
],
'dna_sequences': [dna for dna in example_dna_sequences],
'answer': extract_hash_answer(x['answer']),
}) # type: ignore
return data # type: ignore
# Reward functions
def format_correct_reward_func(completions, **kwargs) -> list[float]:
"""
奖励函数:检查格式是否正确
要求:包含 ... 和 ... 标签
"""
responses = [completion[0]["content"] for completion in completions]
rewards = []
for response in responses:
score = 0.0
# 检查是否有think标签
if "" in response and "" in response:
score += 0.5
# 检查是否有answer标签
if "" in response and "" in response:
score += 0.5
# 检查标签的顺序是否正确
think_start = response.find("")
think_end = response.find("")
answer_start = response.find("")
answer_end = response.find("")
if (think_start != -1 and think_end != -1 and
answer_start != -1 and answer_end != -1 and
think_start < think_end < answer_start < answer_end):
score += 0.5 # 格式完全正确的额外奖励
rewards.append(score)
return rewards
def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""
奖励函数:检查答案准确率
适配蛋白质分类任务
"""
responses = [completion[0]['content'] for completion in completions]
rewards = []
for i, response in enumerate(responses):
# 提取answer标签中的内容
answer_match = re.search(r"(.*?)", response, re.DOTALL)
if answer_match:
extracted_answer = answer_match.group(1).strip()
else:
extracted_answer = response.strip()
# 获取正确答案
if isinstance(answer, list) and len(answer) > i:
correct_answer = str(answer[i]).strip()
elif isinstance(answer, list) and len(answer) > 0:
correct_answer = str(answer[0]).strip()
else:
correct_answer = str(answer).strip()
# 计算准确率奖励
# 对于分类任务,检查数字或类别匹配
extracted_clean = re.sub(r'[^\w\d]', '', extracted_answer.lower())
correct_clean = re.sub(r'[^\w\d]', '', correct_answer.lower())
if correct_clean in extracted_clean or extracted_clean == correct_clean:
rewards.append(1.0) # 完全匹配
elif any(word in extracted_clean for word in correct_clean.split()):
rewards.append(0.5) # 部分匹配
else:
rewards.append(0.0) # 不匹配
return rewards
def classification_specific_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""
针对蛋白质分类任务的专门奖励函数
"""
responses = [completion[0]['content'] for completion in completions]
rewards = []
for i, response in enumerate(responses):
score = 0.0
# 提取答案
answer_match = re.search(r"(.*?)", response, re.DOTALL)
if answer_match:
extracted_answer = answer_match.group(1).strip()
else:
extracted_answer = response.strip()
# 获取正确答案
if isinstance(answer, list) and len(answer) > i:
correct_answer = str(answer[i]).strip()
elif isinstance(answer, list) and len(answer) > 0:
correct_answer = str(answer[0]).strip()
else:
correct_answer = str(answer).strip()
# 检查是否包含分类关键词
classification_keywords = ['classification', 'class', 'category', 'type', 'function', 'family']
if any(keyword in extracted_answer.lower() for keyword in classification_keywords):
score += 0.2
# 检查数字匹配(对于数字标签)
if correct_answer.isdigit():
if correct_answer in extracted_answer:
score += 0.8
# 检查数字临近性
try:
extracted_numbers = re.findall(r'\d+', extracted_answer)
if extracted_numbers:
closest_num = min(extracted_numbers, key=lambda x: abs(int(x) - int(correct_answer)))
if abs(int(closest_num) - int(correct_answer)) <= 1:
score += 0.4
except:
pass
else:
# 文本标签匹配
if correct_answer.lower() in extracted_answer.lower():
score += 0.8
# 检查是否有推理过程
if "" in response and "" in response:
think_content = re.search(r"(.*?)", response, re.DOTALL)
if think_content and len(think_content.group(1).strip()) > 20:
score += 0.2
rewards.append(min(score, 1.0)) # 确保不超过1.0
return rewards
def repetition_penalty_reward_func(completions, **kwargs) -> list[float]:
"""
奖励函数:检查重复率(越低越好)
计算文本中重复词汇的比例,重复率越低奖励越高
"""
responses = [completion[0]["content"] for completion in completions]
rewards = []
for response in responses:
# 提取answer部分的文本
answer_match = re.search(r"(.*?)", response, re.DOTALL)
if answer_match:
text_to_analyze = answer_match.group(1).strip()
else:
text_to_analyze = response.strip()
# 分词并计算重复率
words = text_to_analyze.lower().split()
if len(words) == 0:
rewards.append(0.0)
continue
# 计算词汇重复率
unique_words = set(words)
repetition_rate = 1.0 - (len(unique_words) / len(words))
# 计算句子重复率
sentences = [s.strip() for s in text_to_analyze.split('.') if s.strip()]
if len(sentences) > 1:
unique_sentences = set(sentences)
sentence_repetition_rate = 1.0 - (len(unique_sentences) / len(sentences))
else:
sentence_repetition_rate = 0.0
# 综合重复率
overall_repetition = (repetition_rate + sentence_repetition_rate) / 2
# 重复率越低,奖励越高
reward = max(0.0, 1.0 - overall_repetition * 2) # 乘以2让惩罚更明显
rewards.append(reward)
return rewards
def combined_reward_func(prompts, completions, answer,
format_weight=0.3, accuracy_weight=0.5, repetition_weight=0.2,
**kwargs) -> list[float]:
"""
组合奖励函数:格式+准确率+重复率的加权组合
"""
format_rewards = format_correct_reward_func(completions, **kwargs)
accuracy_rewards = accuracy_reward_func(prompts, completions, answer, **kwargs)
repetition_rewards = repetition_penalty_reward_func(completions, **kwargs)
# 确保权重总和为1
total_weight = format_weight + accuracy_weight + repetition_weight
if total_weight != 1.0:
format_weight /= total_weight
accuracy_weight /= total_weight
repetition_weight /= total_weight
print(f"Normalized weights - Format: {format_weight:.3f}, Accuracy: {accuracy_weight:.3f}, Repetition: {repetition_weight:.3f}")
combined_rewards = []
for f_reward, a_reward, r_reward in zip(format_rewards, accuracy_rewards, repetition_rewards):
combined = (format_weight * f_reward +
accuracy_weight * a_reward +
repetition_weight * r_reward)
combined_rewards.append(combined)
return combined_rewards
# 保留一些原有的奖励函数作为备选
def less_than_4_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^\n.*?\n\n.*?\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
def count_xml(text) -> float:
count = 0.0
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
return count
@dataclass
class Blip2ModelConfig(ModelConfig):
# BLIP2 specific configuration
model_name_or_path: str = field(default="blip2-model", metadata={"help": "Model checkpoint for weights initialization."})
# BLIP2 Architecture parameters
bert_name: str = field(default="/path/to/bert", metadata={"help": "BERT model for Q-former"})
num_query_token: int = field(default=8, metadata={"help": "Number of query tokens"})
cross_attention_freq: int = field(default=2, metadata={"help": "Cross attention frequency"})
plm_model: str = field(default="facebook/esm2_t30_150M_UR50D", metadata={"help": "Protein language model"})
plm_tune: str = field(default="freeze", metadata={"help": "PLM tuning strategy"})
llm_name: str = field(default="facebook/galactica-1.3b", metadata={"help": "Language model name"})
llm_tune: str = field(default="lora", metadata={"help": "LLM tuning strategy"})
qformer_tune: str = field(default="train", metadata={"help": "Q-former tuning strategy"})
peft_dir: str = field(default="", metadata={"help": "PEFT directory"})
# LoRA parameters
lora_r: int = field(default=8, metadata={"help": "LoRA rank"})
lora_alpha: int = field(default=16, metadata={"help": "LoRA alpha"})
lora_dropout: float = field(default=0.1, metadata={"help": "LoRA dropout"})
# Training parameters
enbale_gradient_checkpointing: bool = field(default=False, metadata={"help": "Enable gradient checkpointing"})
enable_flash: bool = field(default=False, metadata={"help": "Enable flash attention"})
# Other parameters
cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."})
sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."})
freeze_dna_modules: bool = field(default=False, metadata={"help": "Freeze DNA/protein modules"})
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script with BLIP2.
"""
dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."})
data_file_paths: str = field(
default=None,
metadata={"help": "Path to protein classification CSV file (format: name,aa_seq,label,location,unique_id,pdb_hash)"},
)
arrow_cache_dir: str = field(
default=None,
metadata={"help": "Path to arrow cache directory"},
)
val_split_ratio: float = field(
default=0.1,
metadata={"help": "Ratio of validation split, default 0.1"},
)
reward_funcs: list[str] = field(
# 选项1:使用组合奖励函数(推荐)
default_factory=lambda: ["combined"],
# 选项2:使用分离的奖励函数
# default_factory=lambda: ["format_correct", "accuracy", "repetition_penalty"],
# 选项3:使用蛋白质分类专用奖励
# default_factory=lambda: ["format_correct", "classification_specific", "repetition_penalty"],
metadata={"help": "List of reward functions. Available: 'combined', 'format_correct', 'accuracy', 'classification_specific', 'repetition_penalty', 'xmlcount', 'strict_format', 'less_than_4'"},
)
# 奖励函数权重配置
format_weight: float = field(
default=0.3,
metadata={"help": "Weight for format correctness reward (used in combined reward)"}
)
accuracy_weight: float = field(
default=0.5,
metadata={"help": "Weight for accuracy reward (used in combined reward)"}
)
repetition_weight: float = field(
default=0.2,
metadata={"help": "Weight for repetition penalty reward (used in combined reward)"}
)
# 数据处理参数
template_name: str = field(
default="classification",
metadata={"help": "Prompt template to use: 'classification', 'function_prediction', 'location_prediction'"}
)
max_seq_length: int = field(
default=1000,
metadata={"help": "Maximum protein sequence length for display in prompt"}
)
use_custom_prompts: bool = field(
default=True,
metadata={"help": "Whether to use custom protein-specific prompts"}
)
reward_funcs_registry = {
# 新的三合一奖励函数
"combined": combined_reward_func, # 格式+准确率+重复率组合
# 分离的奖励函数
"format_correct": format_correct_reward_func, # 格式正确性
"accuracy": accuracy_reward_func, # 准确率
"repetition_penalty": repetition_penalty_reward_func, # 重复率惩罚
"classification_specific": classification_specific_reward_func, # 蛋白质分类专用
# 原有的奖励函数(保留作为备选)
"xmlcount": xmlcount_reward_func,
"strict_format": strict_format_reward_func,
"less_than_4": less_than_4_reward_func,
}
def get_vlm_module(model_name_or_path):
# Always use BLIP2 module for this implementation
return Blip2DNAModule
def create_blip2_args_from_config(model_args):
"""Create BLIP2 args from model config"""
# Convert model config to the format expected by BLIP2
blip2_args = {
'bert_name': model_args.bert_name,
'num_query_token': model_args.num_query_token,
'cross_attention_freq': model_args.cross_attention_freq,
'plm_model': model_args.plm_model,
'plm_tune': model_args.plm_tune,
'llm_name': model_args.llm_name,
'llm_tune': model_args.llm_tune,
'qformer_tune': model_args.qformer_tune,
'peft_dir': model_args.peft_dir,
'lora_r': model_args.lora_r,
'lora_alpha': model_args.lora_alpha,
'lora_dropout': model_args.lora_dropout,
'enbale_gradient_checkpointing': model_args.enbale_gradient_checkpointing,
'enable_flash': model_args.enable_flash,
}
return blip2_args
def _prep_for_training(model, training_args):
"""
Prepare BLIP2 model for training with LoRA.
"""
# The BLIP2 model should handle its own LoRA setup
# This is mainly for any additional preparation needed
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
target_modules=target_modules,
init_lora_weights="gaussian",
bias="none",
task_type="CAUSAL_LM",
)
return lora_config
def main(script_args, training_args, model_args):
print(training_args.output_dir)
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("medium")
# Create BLIP2 model
blip2_args = create_blip2_args_from_config(model_args)
model = Blip2Stage2(blip2_args)
# Load checkpoint if specified
if model_args.sft_checkpoint is not None:
print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}")
model = Blip2Stage2.load_from_checkpoint(model_args.sft_checkpoint, strict=False, args=blip2_args, map_location='cpu')
# if os.path.isdir(model_args.sft_checkpoint):
# # Load Lightning checkpoint
# checkpoint = torch.load(os.path.join(model_args.sft_checkpoint, "last.ckpt"), map_location='cpu')
# model.load_state_dict(checkpoint['state_dict'], strict=False)
# print("Loaded Lightning checkpoint")
# else:
# # Load PyTorch state dict
# checkpoint = torch.load(model_args.sft_checkpoint, map_location='cpu')
# if "state_dict" in checkpoint:
# state_dict = checkpoint["state_dict"]
# else:
# state_dict = checkpoint
# # Remove module prefix if present
# state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# result = model.load_state_dict(state_dict, strict=False)
# print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
# Get reward functions with weights
reward_funcs = []
for func_name in script_args.reward_funcs:
if func_name == "combined":
# 为组合奖励函数传递权重参数
def weighted_combined_reward(prompts, completions, answer, **kwargs):
return combined_reward_func(
prompts, completions, answer,
format_weight=script_args.format_weight,
accuracy_weight=script_args.accuracy_weight,
repetition_weight=script_args.repetition_weight,
**kwargs
)
reward_funcs.append(weighted_combined_reward)
else:
reward_funcs.append(reward_funcs_registry[func_name])
print("reward_funcs:", [func.__name__ if hasattr(func, '__name__') else 'weighted_combined_reward' for func in reward_funcs])
print(f"Reward weights - Format: {script_args.format_weight}, Accuracy: {script_args.accuracy_weight}, Repetition: {script_args.repetition_weight}")
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
print("using vlm module:", vlm_module_cls.__name__)
question_prompt = vlm_module_cls.get_question_template()
# Load dataset based on data source
if script_args.data_file_paths and script_args.use_custom_prompts:
print(f"Loading custom protein data from: {script_args.data_file_paths}")
dataset = get_custom_protein_data_with_prompts(
data_path=script_args.data_file_paths,
prompt_templates=prompt_templates,
template_name=script_args.template_name
)
elif script_args.data_file_paths:
print(f"Loading protein data from: {script_args.data_file_paths}")
dataset = get_protein_classification_data(
data_path=script_args.data_file_paths
)
else:
print("Using default KEGG dataset")
dataset = get_kegg_questions()
print("Dataset loaded:")
print(f"Train size: {len(dataset['train'])}")
print(f"Val size: {len(dataset.get('val', []))}")
# 打印数据样例
if len(dataset['train']) > 0:
print("\nSample data:")
sample = dataset['train'][0]
print(f"Prompt type: {type(sample.get('prompt', 'Unknown'))}")
print(f"DNA sequences count: {len(sample.get('dna_sequences', []))}")
print(f"Answer: {sample.get('answer', 'N/A')}")
if 'metadata' in sample:
print(f"Metadata: {sample['metadata']}")
print(f"First 100 chars of sequence: {sample.get('dna_sequences', [''])[0][:100]}...")
# Custom callback to handle saving with PyTorch's native mechanism
custom_save_callback = SaveWithPyTorchCallback()
# Initialize the BLIP2 GRPO trainer
trainer = Blip2GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=training_args,
dna_module=vlm_module_cls(),
train_dataset=dataset['train'],
eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
attn_implementation=getattr(model_args, 'attn_implementation', 'flash_attention_2'),
torch_dtype=getattr(model_args, 'torch_dtype', 'bfloat16'),
callbacks=[custom_save_callback],
)
# Set the trainer to save in PyTorch format instead of safetensors
training_args.save_safetensors = False
# Train the model
trainer.train()
if __name__ == "__main__":
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, Blip2ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
# Ensure we use PyTorch's save mechanism instead of safetensors
training_args.save_safetensors = False
main(script_args, training_args, model_args)
# 使用示例:
"""
使用你的蛋白质数据进行训练:
1. 准备CSV文件,格式:name,aa_seq,label,location,unique_id,pdb_hash
2. 运行训练:
python blip2_reason.py \
--data_file_paths /path/to/your/protein_data.csv \
--reward_funcs combined \
--format_weight 0.2 \
--accuracy_weight 0.6 \
--repetition_weight 0.2 \
--use_custom_prompts \
--prompt_template classification \
--max_seq_length 1000 \
--output_dir ./output \
--per_device_train_batch_size 4 \
--num_train_epochs 3 \
--learning_rate 1e-5
3. 或者使用分离的奖励函数:
python blip2_reason.py \
--data_file_paths /path/to/your/protein_data.csv \
--reward_funcs format_correct classification_specific repetition_penalty \
--use_custom_prompts \
--prompt_template function_prediction
数据格式示例:
P0DM40,MLRVVVESASINPPLSTTPKAFVTVYFRDMMKRTRVEEGHDPIWNETLIWHLWNQPLENDSFLKVILQDSVSKKKERFIGLATVPLKRLAQRPKEVMFVRDLILLNHSMKPTNCTVTLHVAQIYDQDTEMTGNEELLGSTVNEVTQKKLMVSGLPMHRALASKPQHFQVRVKVFEARQLLGNNIKPVVKVNIADQQHLTRIKMGNNPFFNEIFFQNFHEVPAKFFEENISIEVVDSAASRSKAEIGRFQTDIGFIYHSPGHTLLRKWLGLCQRNKTTSGVRGYLKVTICALGVGDQALVDQKLPYEQNTRVQIFKSKEVPVSLAYLQFFIYCAEDLHFGTHKSATPVLEVELIGDKLRTKPQNPSDNPIWNQILTFQIQLPCLSSYIKFRVMDCSKYKCQDEIGSASLCLSQISSTGEEIQGMYSGFLPCFGPSFLTLRGGKKPPFRTSEEGTCIMDAVQHGLAYRGRIFVEIVTKIKSQQDSVMKDLSQEVTQVEMQYYRQKYGLCVIFLSCTMMPKFKDLIQFEVSMGHYGNKTDPNYKPLVSTTQYSPVIYDGTTYHYVPWYNTKPVVAVTSNWEDVGFRMNCLNLLHITRDRLKTNLDILKSIRNPRDPALLQQWEKLLKELQEDCRRPLPCMTDQPRANSLDRNKWQLRSQLLQQLAQMAKEAKPVNMVGTAKEWLHRLNAVIPEPQESLPDVLIWLMSRQQRVAYARVPAHTVLFSPAGPLSSGKFCGKIQNILLQYPEGEGQDTFPASLRVCMWLGNVKYSKNLKLLQQGSMVVYAETYENQAKTRDDWGQQGLYHCPNFSDVMGRKALPKTDFKAPPGWHWKDDWVVEPQRRLLLDIDINKSQVLEEVYENQLRNATGAWVPAAIPNTDVNGQPVEALENVKCPQGWHFKKNWIVKLNHAVDSEGWEYGVGIPPSGLPQIWNSVEKTYHSCRRRRWVRVRFRNHKELGQERSQEQETLSFLQMQDLSEEGKEGWEYGTFDSRFHLDPQPTSRFRRRCWHRQLAPNKDRGVASIFLLEGSLAVEQKDQPRKEMEKTRSWQPWKDLRHTPEDPRIPTTPFIYYILNKPHYYQLFCYIYQARNLMYNQILTFQEPFIQVVFLNHSLCTQTLRSSAAPTWSQSIIFQHLLLFEDPKDTRENPPLVVLELWQHDSRGNKILWGRSMWPPVVWLGLQDWVFTPLRWHPLVRELGEEEGEILASCELILETQKLKELHPPILSIPCKDGIYLLPKNIQPTMKMMAIEIMAWGLRNMTKVRYPQLLLECGGESLKTEPISNFQENPNFPTSTFFFTVFMPLEETHAQPLVVKVVDNQEYGQQIVVGQANIDFLQPYFCDPWSLNYTTVKLPTLSVKKPDTFLDFVYKKFWFDSSKDEEVYEEEVDWWSKLFWATGDADKSLNYNHKSYHTLKVYDCELEAVLTFKGLQDFCQTFKLYQEKPKVDSPVVGEFKGLFRIYPFPEDPEAPKPPRQFSAWPEIEDFPQMCLVRVYLIRAINLQPQDYNGLCDPYVILKLGQTKLGSRDSYYPNTLDPIFGMMYELTCNIPLEKDLEIQLFDFDLITADDEIGSTVIDLENRLLSGFGARCGLSKSYCKSGPFKWRDQMTPSYLLYRYAKQKGLPPPVFDLEGDSLYYNGETFKLQSFESAPPTYKHLGPKKERLALYILNTQGLVPEHVETRTLHSNSQPGIDQGKIQMWVDIFPKMLGPPGPQVNISPRKPKRYQLRCIIWSTAEVDLVQETFSKEKMSDIYVKGWLFGLEEDTQKTDVHYHSLTGEATFNWRFIFTMDYLTTERACVQSQKDYIWSLDPTSTKFPARLMIQIWDNDFFSPDDFLGVLELDLSDMPLPAQNIKQCSLKMMETDSKWPFTPQKRISLFKKTNVTGWWPCQVLDGDKWRLSGKVKMTLEMLSEREALIRPAGRGQSEPNQFPMLHPPERNDSFLLWYQSPIKNFCYAVCKRYRSKIICLVVTLVIGFILLNFVYSAPSYFAMNWIKPQLRLSSPIKIVNLIGTVNTSNINSSILTMEGSTYHASHVFPEAPAP,0,M,af67d99c09f74ea8af5004cc2906bbc5,d55cbc3d94bd9668d97a678b4a04176a
"""