import os
import re
import pathlib
from argparse import ArgumentParser
from typing import List, Dict, Optional
from dataclasses import dataclass, field
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoProcessor,
)
from datasets import load_dataset, DatasetDict
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
# Import BLIP2 modules
from model.blip2_stage2 import Blip2Stage2
from blip2_dna_module import Blip2DNAModule
from blip2_grpo_trainer import Blip2GRPOTrainer
from bioreason.trainer import DNALLMGRPOConfig
# Custom TrainerCallback to override the saving mechanism
from transformers import TrainerCallback, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SaveWithPyTorchCallback(TrainerCallback):
"""Custom callback to save models with PyTorch's native save mechanism instead of safetensors"""
def on_save(self, args, state, control, **kwargs):
# Get the checkpoint folder
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
os.makedirs(checkpoint_folder, exist_ok=True)
# Save with PyTorch instead of safetensors
checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
model = kwargs.get("model")
# Get model unwrapped from accelerator etc.
unwrapped_model = model.module if hasattr(model, "module") else model
# Save using PyTorch directly
torch.save(unwrapped_model.state_dict(), checkpoint_path)
# For BLIP2, save the config from the LLM component
if hasattr(unwrapped_model, "blip2") and hasattr(unwrapped_model.blip2, "llm_model"):
if hasattr(unwrapped_model.blip2.llm_model, "config"):
unwrapped_model.blip2.llm_model.config.save_pretrained(checkpoint_folder)
elif hasattr(unwrapped_model.blip2.llm_model, "base_model") and hasattr(unwrapped_model.blip2.llm_model.base_model, "config"):
unwrapped_model.blip2.llm_model.base_model.config.save_pretrained(checkpoint_folder)
# Print info about what's being saved
print(f"Saved model checkpoint to {checkpoint_folder}")
lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k]
print(f"Checkpoint contains {len(lora_params)} LoRA parameters")
# Signal that we've saved
control.should_save = False
return control
def extract_xml_answer(text: str) -> str:
answer = text.split("")[-1]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
def get_kegg_questions() -> Dataset:
data = load_dataset('wanglab/kegg', 'default') # type: ignore
example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
num_dna_sequences = 2
data = data.map(lambda x: { # type: ignore
'prompt': [
{
'role': 'user',
'content': [
*({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)),
{'type': 'text', 'text': x['question']},
],
},
],
'dna_sequences': [x['reference_sequence'], x['variant_sequence']],
'answer': x['answer'],
}) # type: ignore
return data
def get_gsm8k_questions(question_prompt: str) -> Dataset:
data = load_dataset('openai/gsm8k', 'main') # type: ignore
example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
data = data.map(lambda x: { # type: ignore
'prompt': [
{
'role': 'user',
'content': [
*({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))),
{'type': 'text', 'text': 'Give me a short introduction to large language model.'}
]
},
],
'dna_sequences': [dna for dna in example_dna_sequences],
'answer': extract_hash_answer(x['answer']),
}) # type: ignore
return data # type: ignore
# Reward functions
def format_correct_reward_func(completions, **kwargs) -> list[float]:
"""
奖励函数:检查格式是否正确
要求:包含 ... 和 ... 标签
"""
responses = [completion[0]["content"] for completion in completions]
rewards = []
for response in responses:
score = 0.0
# 检查是否有think标签
if "" in response and "" in response:
score += 0.5
# 检查是否有answer标签
if "" in response and "" in response:
score += 0.5
# 检查标签的顺序是否正确
think_start = response.find("")
think_end = response.find("")
answer_start = response.find("")
answer_end = response.find("")
if (think_start != -1 and think_end != -1 and
answer_start != -1 and answer_end != -1 and
think_start < think_end < answer_start < answer_end):
score += 0.5 # 格式完全正确的额外奖励
rewards.append(score)
return rewards
def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""
奖励函数:检查答案准确率
"""
responses = [completion[0]['content'] for completion in completions]
rewards = []
for i, response in enumerate(responses):
# 提取answer标签中的内容
answer_match = re.search(r"(.*?)", response, re.DOTALL)
if answer_match:
extracted_answer = answer_match.group(1).strip()
else:
extracted_answer = response.strip()
# 获取正确答案
if isinstance(answer, list) and len(answer) > i:
correct_answer = str(answer[i]).strip()
elif isinstance(answer, list) and len(answer) > 0:
correct_answer = str(answer[0]).strip()
else:
correct_answer = str(answer).strip()
# 计算准确率奖励
if correct_answer.lower() in extracted_answer.lower():
rewards.append(1.0) # 完全匹配
elif any(word in extracted_answer.lower() for word in correct_answer.lower().split()):
rewards.append(0.5) # 部分匹配
else:
rewards.append(0.0) # 不匹配
return rewards
def repetition_penalty_reward_func(completions, **kwargs) -> list[float]:
"""
奖励函数:检查重复率(越低越好)
计算文本中重复词汇的比例,重复率越低奖励越高
"""
responses = [completion[0]["content"] for completion in completions]
rewards = []
for response in responses:
# 提取answer部分的文本
answer_match = re.search(r"(.*?)", response, re.DOTALL)
if answer_match:
text_to_analyze = answer_match.group(1).strip()
else:
text_to_analyze = response.strip()
# 分词并计算重复率
words = text_to_analyze.lower().split()
if len(words) == 0:
rewards.append(0.0)
continue
# 计算词汇重复率
unique_words = set(words)
repetition_rate = 1.0 - (len(unique_words) / len(words))
# 计算句子重复率
sentences = [s.strip() for s in text_to_analyze.split('.') if s.strip()]
if len(sentences) > 1:
unique_sentences = set(sentences)
sentence_repetition_rate = 1.0 - (len(unique_sentences) / len(sentences))
else:
sentence_repetition_rate = 0.0
# 综合重复率
overall_repetition = (repetition_rate + sentence_repetition_rate) / 2
# 重复率越低,奖励越高
reward = max(0.0, 1.0 - overall_repetition * 2) # 乘以2让惩罚更明显
rewards.append(reward)
return rewards
def combined_reward_func(prompts, completions, answer,
format_weight=0.3, accuracy_weight=0.5, repetition_weight=0.2,
**kwargs) -> list[float]:
"""
组合奖励函数:格式+准确率+重复率的加权组合
"""
format_rewards = format_correct_reward_func(completions, **kwargs)
accuracy_rewards = accuracy_reward_func(prompts, completions, answer, **kwargs)
repetition_rewards = repetition_penalty_reward_func(completions, **kwargs)
# 确保权重总和为1
total_weight = format_weight + accuracy_weight + repetition_weight
if total_weight != 1.0:
format_weight /= total_weight
accuracy_weight /= total_weight
repetition_weight /= total_weight
print(f"Normalized weights - Format: {format_weight:.3f}, Accuracy: {accuracy_weight:.3f}, Repetition: {repetition_weight:.3f}")
combined_rewards = []
for f_reward, a_reward, r_reward in zip(format_rewards, accuracy_rewards, repetition_rewards):
combined = (format_weight * f_reward +
accuracy_weight * a_reward +
repetition_weight * r_reward)
combined_rewards.append(combined)
return combined_rewards
# 保留一些原有的奖励函数作为备选
def less_than_4_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^\n.*?\n\n.*?\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
def count_xml(text) -> float:
count = 0.0
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
return count
@dataclass
class Blip2ModelConfig(ModelConfig):
# BLIP2 specific configuration
model_name_or_path: str = field(default="blip2-model", metadata={"help": "Model checkpoint for weights initialization."})
# BLIP2 Architecture parameters
bert_name: str = field(default="/path/to/bert", metadata={"help": "BERT model for Q-former"})
num_query_token: int = field(default=32, metadata={"help": "Number of query tokens"})
cross_attention_freq: int = field(default=2, metadata={"help": "Cross attention frequency"})
plm_model: str = field(default="facebook/esm2_t30_150M_UR50D", metadata={"help": "Protein language model"})
plm_tune: str = field(default="freeze", metadata={"help": "PLM tuning strategy"})
llm_name: str = field(default="facebook/galactica-1.3b", metadata={"help": "Language model name"})
llm_tune: str = field(default="lora", metadata={"help": "LLM tuning strategy"})
qformer_tune: str = field(default="train", metadata={"help": "Q-former tuning strategy"})
peft_dir: str = field(default="", metadata={"help": "PEFT directory"})
# LoRA parameters
lora_r: int = field(default=8, metadata={"help": "LoRA rank"})
lora_alpha: int = field(default=16, metadata={"help": "LoRA alpha"})
lora_dropout: float = field(default=0.1, metadata={"help": "LoRA dropout"})
# Training parameters
enbale_gradient_checkpointing: bool = field(default=False, metadata={"help": "Enable gradient checkpointing"})
enable_flash: bool = field(default=False, metadata={"help": "Enable flash attention"})
# Other parameters
cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."})
sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."})
freeze_dna_modules: bool = field(default=False, metadata={"help": "Freeze DNA/protein modules"})
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script with BLIP2.
"""
dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."})
data_file_paths: str = field(
default=None,
metadata={"help": "Paths to data files, separated by ':'"},
)
arrow_cache_dir: str = field(
default=None,
metadata={"help": "Path to arrow cache directory"},
)
val_split_ratio: float = field(
default=0.0,
metadata={"help": "Ratio of validation split, default 0.0"},
)
reward_funcs: list[str] = field(
# 选项1:使用组合奖励函数(推荐)
default_factory=lambda: ["combined"],
# 选项2:使用分离的三个奖励函数
# default_factory=lambda: ["format_correct", "accuracy", "repetition_penalty"],
# 选项3:自定义组合
# default_factory=lambda: ["format_correct", "accuracy", "repetition_penalty", "xmlcount"],
metadata={"help": "List of reward functions. Available: 'combined', 'format_correct', 'accuracy', 'repetition_penalty', 'xmlcount', 'strict_format', 'less_than_4'"},
)
# 奖励函数权重配置
format_weight: float = field(
default=0.3,
metadata={"help": "Weight for format correctness reward (used in combined reward)"}
)
accuracy_weight: float = field(
default=0.5,
metadata={"help": "Weight for accuracy reward (used in combined reward)"}
)
repetition_weight: float = field(
default=0.2,
metadata={"help": "Weight for repetition penalty reward (used in combined reward)"}
)
reward_funcs_registry = {
# 新的三合一奖励函数
"combined": combined_reward_func, # 格式+准确率+重复率组合
# 分离的奖励函数
"format_correct": format_correct_reward_func, # 格式正确性
"accuracy": accuracy_reward_func, # 准确率
"repetition_penalty": repetition_penalty_reward_func, # 重复率惩罚
# 原有的奖励函数(保留作为备选)
"xmlcount": xmlcount_reward_func,
"strict_format": strict_format_reward_func,
"less_than_4": less_than_4_reward_func,
}
def get_vlm_module(model_name_or_path):
# Always use BLIP2 module for this implementation
return Blip2DNAModule
def create_blip2_args_from_config(model_args):
"""Create BLIP2 args from model config"""
# Convert model config to the format expected by BLIP2
blip2_args = {
'bert_name': model_args.bert_name,
'num_query_token': model_args.num_query_token,
'cross_attention_freq': model_args.cross_attention_freq,
'plm_model': model_args.plm_model,
'plm_tune': model_args.plm_tune,
'llm_name': model_args.llm_name,
'llm_tune': model_args.llm_tune,
'qformer_tune': model_args.qformer_tune,
'peft_dir': model_args.peft_dir,
'lora_r': model_args.lora_r,
'lora_alpha': model_args.lora_alpha,
'lora_dropout': model_args.lora_dropout,
'enbale_gradient_checkpointing': model_args.enbale_gradient_checkpointing,
'enable_flash': model_args.enable_flash,
}
return blip2_args
def _prep_for_training(model, training_args):
"""
Prepare BLIP2 model for training with LoRA.
"""
# The BLIP2 model should handle its own LoRA setup
# This is mainly for any additional preparation needed
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
target_modules=target_modules,
init_lora_weights="gaussian",
bias="none",
task_type="CAUSAL_LM",
)
return lora_config
def main(script_args, training_args, model_args):
print(training_args.output_dir)
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("medium")
# Create BLIP2 model
blip2_args = create_blip2_args_from_config(model_args)
model = Blip2Stage2(blip2_args)
# Load checkpoint if specified
if model_args.sft_checkpoint is not None:
print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}")
if os.path.isdir(model_args.sft_checkpoint):
# Load Lightning checkpoint
checkpoint = torch.load(os.path.join(model_args.sft_checkpoint, "last.ckpt"), map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
print("Loaded Lightning checkpoint")
else:
# Load PyTorch state dict
checkpoint = torch.load(model_args.sft_checkpoint, map_location='cpu')
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
# Remove module prefix if present
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
result = model.load_state_dict(state_dict, strict=False)
print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
# Get reward functions with weights
reward_funcs = []
for func_name in script_args.reward_funcs:
if func_name == "combined":
# 为组合奖励函数传递权重参数
def weighted_combined_reward(prompts, completions, answer, **kwargs):
return combined_reward_func(
prompts, completions, answer,
format_weight=script_args.format_weight,
accuracy_weight=script_args.accuracy_weight,
repetition_weight=script_args.repetition_weight,
**kwargs
)
reward_funcs.append(weighted_combined_reward)
else:
reward_funcs.append(reward_funcs_registry[func_name])
print("reward_funcs:", [func.__name__ if hasattr(func, '__name__') else 'weighted_combined_reward' for func in reward_funcs])
print(f"Reward weights - Format: {script_args.format_weight}, Accuracy: {script_args.accuracy_weight}, Repetition: {script_args.repetition_weight}")
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
print("using vlm module:", vlm_module_cls.__name__)
question_prompt = vlm_module_cls.get_question_template()
# Load dataset
dataset = get_kegg_questions()
print(dataset)
# Custom callback to handle saving with PyTorch's native mechanism
custom_save_callback = SaveWithPyTorchCallback()
# Initialize the BLIP2 GRPO trainer
trainer = Blip2GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=training_args,
dna_module=vlm_module_cls(),
train_dataset=dataset['train'],
eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
attn_implementation=getattr(model_args, 'attn_implementation', 'flash_attention_2'),
torch_dtype=getattr(model_args, 'torch_dtype', 'bfloat16'),
callbacks=[custom_save_callback],
)
# Set the trainer to save in PyTorch format instead of safetensors
training_args.save_safetensors = False
# Train the model
trainer.train()
if __name__ == "__main__":
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, Blip2ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
# Ensure we use PyTorch's save mechanism instead of safetensors
training_args.save_safetensors = False
main(script_args, training_args, model_args)