| | import os |
| | import re |
| | import pathlib |
| | from argparse import ArgumentParser |
| | from typing import List, Dict, Optional |
| | from dataclasses import dataclass, field |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | from torch.optim import AdamW |
| | from torch.utils.data import DataLoader, Dataset |
| | from transformers import get_cosine_schedule_with_warmup, AutoTokenizer |
| |
|
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | AutoModelForMaskedLM, |
| | AutoProcessor, |
| | ) |
| |
|
| | from datasets import load_dataset, DatasetDict |
| |
|
| | from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training |
| | from transformers import BitsAndBytesConfig |
| |
|
| | import pytorch_lightning as pl |
| | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
| | from pytorch_lightning.loggers import WandbLogger |
| |
|
| | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config |
| |
|
| | |
| | from model.blip2_stage2 import Blip2Stage2 |
| | from blip2_dna_module import Blip2DNAModule |
| | from blip2_grpo_trainer import Blip2GRPOTrainer |
| | from bioreason.trainer import DNALLMGRPOConfig |
| |
|
| | |
| | from transformers import TrainerCallback, TrainerState, TrainerControl |
| | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
| |
|
| | from prompt_templates import prompt_templates |
| |
|
| | class SaveWithPyTorchCallback(TrainerCallback): |
| | """Custom callback to save models with PyTorch's native save mechanism instead of safetensors""" |
| | def on_save(self, args, state, control, **kwargs): |
| | |
| | checkpoint_folder = os.path.join( |
| | args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" |
| | ) |
| | os.makedirs(checkpoint_folder, exist_ok=True) |
| | |
| | |
| | checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin") |
| | model = kwargs.get("model") |
| | |
| | |
| | unwrapped_model = model.module if hasattr(model, "module") else model |
| | |
| | |
| | torch.save(unwrapped_model.state_dict(), checkpoint_path) |
| | |
| | |
| | if hasattr(unwrapped_model, "blip2") and hasattr(unwrapped_model.blip2, "llm_model"): |
| | if hasattr(unwrapped_model.blip2.llm_model, "config"): |
| | unwrapped_model.blip2.llm_model.config.save_pretrained(checkpoint_folder) |
| | elif hasattr(unwrapped_model.blip2.llm_model, "base_model") and hasattr(unwrapped_model.blip2.llm_model.base_model, "config"): |
| | unwrapped_model.blip2.llm_model.base_model.config.save_pretrained(checkpoint_folder) |
| | |
| | |
| | print(f"Saved model checkpoint to {checkpoint_folder}") |
| | lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k] |
| | print(f"Checkpoint contains {len(lora_params)} LoRA parameters") |
| | |
| | |
| | control.should_save = False |
| | return control |
| |
|
| | def extract_xml_answer(text: str) -> str: |
| | """提取answer标签中的内容,如果没有则返回think标签后的内容""" |
| | |
| | answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL) |
| | if answer_match: |
| | return answer_match.group(1).strip() |
| | |
| | |
| | think_split = text.split("</think>") |
| | if len(think_split) > 1: |
| | return think_split[-1].strip() |
| | |
| | |
| | return text.strip() |
| |
|
| | def extract_classification_answer(text: str) -> str: |
| | """专门用于提取分类答案的函数""" |
| | |
| | answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL) |
| | if answer_match: |
| | answer_content = answer_match.group(1).strip() |
| | |
| | |
| | classification_patterns = [ |
| | r"[Cc]lassification:\s*(\d+)", |
| | r"[Cc]lass:\s*(\d+)", |
| | r"[Ll]abel:\s*(\d+)", |
| | r"[Pp]rediction:\s*(\d+)", |
| | r"(\d+)", |
| | ] |
| | |
| | for pattern in classification_patterns: |
| | match = re.search(pattern, answer_content) |
| | if match: |
| | return match.group(1) |
| | |
| | return answer_content |
| | |
| | return extract_xml_answer(text) |
| |
|
| | def extract_hash_answer(text: str) -> str | None: |
| | if "####" not in text: |
| | return None |
| | return text.split("####")[1].strip() |
| |
|
| | def get_kegg_questions() -> Dataset: |
| | """保留原有的KEGG数据集加载函数作为fallback""" |
| | try: |
| | data = load_dataset('wanglab/kegg', 'default') |
| | example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] |
| | num_dna_sequences = 2 |
| |
|
| | data = data.map(lambda x: { |
| | 'prompt': [ |
| | { |
| | 'role': 'user', |
| | 'content': [ |
| | *({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)), |
| | {'type': 'text', 'text': x['question']}, |
| | ], |
| | }, |
| | ], |
| | 'dna_sequences': [x['reference_sequence'], x['variant_sequence']], |
| | 'answer': x['answer'], |
| | }) |
| |
|
| | return data |
| | except Exception as e: |
| | print(f"Failed to load KEGG dataset: {e}") |
| | |
| | from datasets import Dataset |
| | empty_data = { |
| | 'prompt': [], |
| | 'dna_sequences': [], |
| | 'answer': [] |
| | } |
| | dataset = Dataset.from_dict(empty_data) |
| | return {'train': dataset, 'val': dataset} |
| |
|
| | def get_protein_classification_data(data_path: str = None, prompt_template: str = None) -> Dataset: |
| | """ |
| | 加载蛋白质分类数据集 |
| | 数据格式:name,aa_seq,label,location,unique_id,pdb_hash |
| | """ |
| | import pandas as pd |
| | from datasets import Dataset |
| | |
| | if data_path is None: |
| | |
| | return get_kegg_questions() |
| | |
| | |
| | if data_path.endswith('.csv'): |
| | df = pd.read_csv(data_path) |
| | else: |
| | |
| | raise ValueError(f"Unsupported file format: {data_path}") |
| | |
| | |
| | if prompt_template is None: |
| | prompt_template = """ |
| | Please analyze the following protein sequence and predict its classification. |
| | |
| | Protein sequence: <protein>{aa_seq}</protein> |
| | |
| | Question: What is the classification of this protein sequence? |
| | |
| | Please provide your reasoning in <think></think> tags and your final answer in <answer></answer> tags. |
| | """ |
| | |
| | |
| | def process_example(row): |
| | |
| | prompt_text = prompt_template.format( |
| | aa_seq=row['aa_seq'], |
| | name=row.get('name', ''), |
| | location=row.get('location', ''), |
| | unique_id=row.get('unique_id', ''), |
| | ) |
| | |
| | return { |
| | 'prompt': [ |
| | { |
| | 'role': 'user', |
| | 'content': [ |
| | {'type': 'protein', 'text': None}, |
| | {'type': 'text', 'text': prompt_text}, |
| | ], |
| | }, |
| | ], |
| | 'dna_sequences': [row['aa_seq']], |
| | 'answer': str(row['label']), |
| | 'metadata': { |
| | 'name': row.get('name', ''), |
| | 'location': row.get('location', ''), |
| | 'unique_id': row.get('unique_id', ''), |
| | 'pdb_hash': row.get('pdb_hash', ''), |
| | } |
| | } |
| | |
| | |
| | processed_data = [] |
| | for _, row in df.iterrows(): |
| | processed_data.append(process_example(row)) |
| | |
| | |
| | dataset = Dataset.from_list(processed_data) |
| | |
| | |
| | if len(dataset) > 100: |
| | dataset = dataset.train_test_split(test_size=0.1, seed=42) |
| | else: |
| | |
| | dataset = { |
| | 'train': dataset, |
| | 'val': dataset.select(range(min(10, len(dataset)))) |
| | } |
| | |
| | return dataset |
| |
|
| | def get_custom_protein_data_with_prompts(data_path: str = None, |
| | prompt_templates: Dict[str, str] = None) -> Dataset: |
| | """ |
| | 更灵活的蛋白质数据加载函数,支持多种prompt模板 |
| | """ |
| | import pandas as pd |
| | from datasets import Dataset |
| | import random |
| | |
| | if data_path is None: |
| | return get_kegg_questions() |
| | |
| | |
| | df = pd.read_csv(data_path) |
| | |
| | def process_example(row, template_name=None): |
| | |
| | if template_name is None: |
| | template_name = random.choice(list(prompt_templates.keys())) |
| | |
| | template = prompt_templates[template_name] |
| | |
| | |
| | prompt_text = template.format( |
| | aa_seq=row['aa_seq'][:500] + "..." if len(row['aa_seq']) > 500 else row['aa_seq'], |
| | label=row['label'], |
| | name=row.get('name', ''), |
| | location=row.get('location', ''), |
| | ) |
| | |
| | return { |
| | 'prompt': [ |
| | { |
| | 'role': 'user', |
| | 'content': [ |
| | {'type': 'protein', 'text': None}, |
| | {'type': 'text', 'text': prompt_text.split('<protein>')[0]}, |
| | ], |
| | }, |
| | ], |
| | 'dna_sequences': [row['aa_seq']], |
| | 'answer': str(row['label']), |
| | 'template_used': template_name, |
| | 'metadata': { |
| | 'name': row.get('name', ''), |
| | 'location': row.get('location', ''), |
| | 'unique_id': row.get('unique_id', ''), |
| | 'pdb_hash': row.get('pdb_hash', ''), |
| | 'full_prompt': prompt_text, |
| | } |
| | } |
| | |
| | |
| | processed_data = [] |
| | print("template_name") |
| | print(script_args.template_name) |
| | for _, row in df.iterrows(): |
| | processed_data.append(process_example(row,script_args.template_name)) |
| | |
| | dataset = Dataset.from_list(processed_data) |
| | |
| | |
| | if len(dataset) > 50: |
| | dataset = dataset.train_test_split(test_size=0.1, seed=42) |
| | else: |
| | dataset = { |
| | 'train': dataset, |
| | 'val': dataset.select(range(min(5, len(dataset)))) |
| | } |
| | |
| | return dataset |
| |
|
| | def get_gsm8k_questions(question_prompt: str) -> Dataset: |
| | data = load_dataset('openai/gsm8k', 'main') |
| |
|
| | example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] |
| | data = data.map(lambda x: { |
| | 'prompt': [ |
| | { |
| | 'role': 'user', |
| | 'content': [ |
| | *({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))), |
| | {'type': 'text', 'text': 'Give me a short introduction to large language model.'} |
| | ] |
| | }, |
| | ], |
| | 'dna_sequences': [dna for dna in example_dna_sequences], |
| | 'answer': extract_hash_answer(x['answer']), |
| | }) |
| | |
| | return data |
| |
|
| | |
| | def format_correct_reward_func(completions, **kwargs) -> list[float]: |
| | """ |
| | 奖励函数:检查格式是否正确 |
| | 要求:包含 <think>...</think> 和 <answer>...</answer> 标签 |
| | """ |
| | responses = [completion[0]["content"] for completion in completions] |
| | rewards = [] |
| | |
| | for response in responses: |
| | score = 0.0 |
| | |
| | |
| | if "<think>" in response and "</think>" in response: |
| | score += 0.5 |
| | |
| | |
| | if "<answer>" in response and "</answer>" in response: |
| | score += 0.5 |
| | |
| | |
| | think_start = response.find("<think>") |
| | think_end = response.find("</think>") |
| | answer_start = response.find("<answer>") |
| | answer_end = response.find("</answer>") |
| | |
| | if (think_start != -1 and think_end != -1 and |
| | answer_start != -1 and answer_end != -1 and |
| | think_start < think_end < answer_start < answer_end): |
| | score += 0.5 |
| | |
| | rewards.append(score) |
| | |
| | return rewards |
| |
|
| | def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]: |
| | """ |
| | 奖励函数:检查答案准确率 |
| | 适配蛋白质分类任务 |
| | """ |
| | responses = [completion[0]['content'] for completion in completions] |
| | rewards = [] |
| | |
| | for i, response in enumerate(responses): |
| | |
| | answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL) |
| | if answer_match: |
| | extracted_answer = answer_match.group(1).strip() |
| | else: |
| | extracted_answer = response.strip() |
| | |
| | |
| | if isinstance(answer, list) and len(answer) > i: |
| | correct_answer = str(answer[i]).strip() |
| | elif isinstance(answer, list) and len(answer) > 0: |
| | correct_answer = str(answer[0]).strip() |
| | else: |
| | correct_answer = str(answer).strip() |
| | |
| | |
| | |
| | extracted_clean = re.sub(r'[^\w\d]', '', extracted_answer.lower()) |
| | correct_clean = re.sub(r'[^\w\d]', '', correct_answer.lower()) |
| | |
| | if correct_clean in extracted_clean or extracted_clean == correct_clean: |
| | rewards.append(1.0) |
| | elif any(word in extracted_clean for word in correct_clean.split()): |
| | rewards.append(0.5) |
| | else: |
| | rewards.append(0.0) |
| | |
| | return rewards |
| |
|
| | def classification_specific_reward_func(prompts, completions, answer, **kwargs) -> list[float]: |
| | """ |
| | 针对蛋白质分类任务的专门奖励函数 |
| | """ |
| | responses = [completion[0]['content'] for completion in completions] |
| | rewards = [] |
| | |
| | for i, response in enumerate(responses): |
| | score = 0.0 |
| | |
| | |
| | answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL) |
| | if answer_match: |
| | extracted_answer = answer_match.group(1).strip() |
| | else: |
| | extracted_answer = response.strip() |
| | |
| | |
| | if isinstance(answer, list) and len(answer) > i: |
| | correct_answer = str(answer[i]).strip() |
| | elif isinstance(answer, list) and len(answer) > 0: |
| | correct_answer = str(answer[0]).strip() |
| | else: |
| | correct_answer = str(answer).strip() |
| | |
| | |
| | classification_keywords = ['classification', 'class', 'category', 'type', 'function', 'family'] |
| | if any(keyword in extracted_answer.lower() for keyword in classification_keywords): |
| | score += 0.2 |
| | |
| | |
| | if correct_answer.isdigit(): |
| | if correct_answer in extracted_answer: |
| | score += 0.8 |
| | |
| | try: |
| | extracted_numbers = re.findall(r'\d+', extracted_answer) |
| | if extracted_numbers: |
| | closest_num = min(extracted_numbers, key=lambda x: abs(int(x) - int(correct_answer))) |
| | if abs(int(closest_num) - int(correct_answer)) <= 1: |
| | score += 0.4 |
| | except: |
| | pass |
| | else: |
| | |
| | if correct_answer.lower() in extracted_answer.lower(): |
| | score += 0.8 |
| | |
| | |
| | if "<think>" in response and "</think>" in response: |
| | think_content = re.search(r"<think>(.*?)</think>", response, re.DOTALL) |
| | if think_content and len(think_content.group(1).strip()) > 20: |
| | score += 0.2 |
| | |
| | rewards.append(min(score, 1.0)) |
| | |
| | return rewards |
| |
|
| | def repetition_penalty_reward_func(completions, **kwargs) -> list[float]: |
| | """ |
| | 奖励函数:检查重复率(越低越好) |
| | 计算文本中重复词汇的比例,重复率越低奖励越高 |
| | """ |
| | responses = [completion[0]["content"] for completion in completions] |
| | rewards = [] |
| | |
| | for response in responses: |
| | |
| | answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL) |
| | if answer_match: |
| | text_to_analyze = answer_match.group(1).strip() |
| | else: |
| | text_to_analyze = response.strip() |
| | |
| | |
| | words = text_to_analyze.lower().split() |
| | |
| | if len(words) == 0: |
| | rewards.append(0.0) |
| | continue |
| | |
| | |
| | unique_words = set(words) |
| | repetition_rate = 1.0 - (len(unique_words) / len(words)) |
| | |
| | |
| | sentences = [s.strip() for s in text_to_analyze.split('.') if s.strip()] |
| | if len(sentences) > 1: |
| | unique_sentences = set(sentences) |
| | sentence_repetition_rate = 1.0 - (len(unique_sentences) / len(sentences)) |
| | else: |
| | sentence_repetition_rate = 0.0 |
| | |
| | |
| | overall_repetition = (repetition_rate + sentence_repetition_rate) / 2 |
| | |
| | |
| | reward = max(0.0, 1.0 - overall_repetition * 2) |
| | rewards.append(reward) |
| | |
| | return rewards |
| |
|
| | def combined_reward_func(prompts, completions, answer, |
| | format_weight=0.3, accuracy_weight=0.5, repetition_weight=0.2, |
| | **kwargs) -> list[float]: |
| | """ |
| | 组合奖励函数:格式+准确率+重复率的加权组合 |
| | """ |
| | format_rewards = format_correct_reward_func(completions, **kwargs) |
| | accuracy_rewards = accuracy_reward_func(prompts, completions, answer, **kwargs) |
| | repetition_rewards = repetition_penalty_reward_func(completions, **kwargs) |
| | |
| | |
| | total_weight = format_weight + accuracy_weight + repetition_weight |
| | if total_weight != 1.0: |
| | format_weight /= total_weight |
| | accuracy_weight /= total_weight |
| | repetition_weight /= total_weight |
| | print(f"Normalized weights - Format: {format_weight:.3f}, Accuracy: {accuracy_weight:.3f}, Repetition: {repetition_weight:.3f}") |
| | |
| | combined_rewards = [] |
| | for f_reward, a_reward, r_reward in zip(format_rewards, accuracy_rewards, repetition_rewards): |
| | combined = (format_weight * f_reward + |
| | accuracy_weight * a_reward + |
| | repetition_weight * r_reward) |
| | combined_rewards.append(combined) |
| | |
| | return combined_rewards |
| |
|
| | |
| | def less_than_4_reward_func(completions, **kwargs) -> list[float]: |
| | responses = [completion[0]['content'] for completion in completions] |
| | extracted_responses = [extract_xml_answer(r) for r in responses] |
| | return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses] |
| |
|
| | def strict_format_reward_func(completions, **kwargs) -> list[float]: |
| | """Reward function that checks if the completion has a specific format.""" |
| | pattern = r"^<think>\n.*?\n</think>\n.*?\n$" |
| | responses = [completion[0]["content"] for completion in completions] |
| | matches = [re.match(pattern, r) for r in responses] |
| | return [0.5 if match else 0.0 for match in matches] |
| |
|
| | def xmlcount_reward_func(completions, **kwargs) -> list[float]: |
| | contents = [completion[0]["content"] for completion in completions] |
| | return [count_xml(c) for c in contents] |
| |
|
| | def count_xml(text) -> float: |
| | count = 0.0 |
| | if text.count("<think>\n") == 1: |
| | count += 0.125 |
| | if text.count("\n</think>\n") == 1: |
| | count += 0.125 |
| | return count |
| |
|
| | @dataclass |
| | class Blip2ModelConfig(ModelConfig): |
| | |
| | model_name_or_path: str = field(default="blip2-model", metadata={"help": "Model checkpoint for weights initialization."}) |
| | |
| | |
| | bert_name: str = field(default="/path/to/bert", metadata={"help": "BERT model for Q-former"}) |
| | num_query_token: int = field(default=8, metadata={"help": "Number of query tokens"}) |
| | cross_attention_freq: int = field(default=2, metadata={"help": "Cross attention frequency"}) |
| | plm_model: str = field(default="facebook/esm2_t30_150M_UR50D", metadata={"help": "Protein language model"}) |
| | plm_tune: str = field(default="freeze", metadata={"help": "PLM tuning strategy"}) |
| | llm_name: str = field(default="facebook/galactica-1.3b", metadata={"help": "Language model name"}) |
| | llm_tune: str = field(default="lora", metadata={"help": "LLM tuning strategy"}) |
| | qformer_tune: str = field(default="train", metadata={"help": "Q-former tuning strategy"}) |
| | peft_dir: str = field(default="", metadata={"help": "PEFT directory"}) |
| | |
| | |
| | lora_r: int = field(default=8, metadata={"help": "LoRA rank"}) |
| | lora_alpha: int = field(default=16, metadata={"help": "LoRA alpha"}) |
| | lora_dropout: float = field(default=0.1, metadata={"help": "LoRA dropout"}) |
| | |
| | |
| | enbale_gradient_checkpointing: bool = field(default=False, metadata={"help": "Enable gradient checkpointing"}) |
| | enable_flash: bool = field(default=False, metadata={"help": "Enable flash attention"}) |
| | |
| | |
| | cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."}) |
| | sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."}) |
| | freeze_dna_modules: bool = field(default=False, metadata={"help": "Freeze DNA/protein modules"}) |
| |
|
| | @dataclass |
| | class GRPOScriptArguments(ScriptArguments): |
| | """ |
| | Script arguments for the GRPO training script with BLIP2. |
| | """ |
| | dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."}) |
| | data_file_paths: str = field( |
| | default=None, |
| | metadata={"help": "Path to protein classification CSV file (format: name,aa_seq,label,location,unique_id,pdb_hash)"}, |
| | ) |
| | arrow_cache_dir: str = field( |
| | default=None, |
| | metadata={"help": "Path to arrow cache directory"}, |
| | ) |
| | val_split_ratio: float = field( |
| | default=0.1, |
| | metadata={"help": "Ratio of validation split, default 0.1"}, |
| | ) |
| | reward_funcs: list[str] = field( |
| | |
| | default_factory=lambda: ["combined"], |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | metadata={"help": "List of reward functions. Available: 'combined', 'format_correct', 'accuracy', 'classification_specific', 'repetition_penalty', 'xmlcount', 'strict_format', 'less_than_4'"}, |
| | ) |
| | |
| | |
| | format_weight: float = field( |
| | default=0.3, |
| | metadata={"help": "Weight for format correctness reward (used in combined reward)"} |
| | ) |
| | accuracy_weight: float = field( |
| | default=0.5, |
| | metadata={"help": "Weight for accuracy reward (used in combined reward)"} |
| | ) |
| | repetition_weight: float = field( |
| | default=0.2, |
| | metadata={"help": "Weight for repetition penalty reward (used in combined reward)"} |
| | ) |
| | |
| | |
| | template_name: str = field( |
| | default="classification", |
| | metadata={"help": "Prompt template to use: 'classification', 'function_prediction', 'location_prediction'"} |
| | ) |
| | max_seq_length: int = field( |
| | default=1000, |
| | metadata={"help": "Maximum protein sequence length for display in prompt"} |
| | ) |
| | use_custom_prompts: bool = field( |
| | default=True, |
| | metadata={"help": "Whether to use custom protein-specific prompts"} |
| | ) |
| |
|
| | reward_funcs_registry = { |
| | |
| | "combined": combined_reward_func, |
| | |
| | |
| | "format_correct": format_correct_reward_func, |
| | "accuracy": accuracy_reward_func, |
| | "repetition_penalty": repetition_penalty_reward_func, |
| | "classification_specific": classification_specific_reward_func, |
| | |
| | |
| | "xmlcount": xmlcount_reward_func, |
| | "strict_format": strict_format_reward_func, |
| | "less_than_4": less_than_4_reward_func, |
| | } |
| |
|
| | def get_vlm_module(model_name_or_path): |
| | |
| | return Blip2DNAModule |
| |
|
| | def create_blip2_args_from_config(model_args): |
| | """Create BLIP2 args from model config""" |
| | |
| | blip2_args = { |
| | 'bert_name': model_args.bert_name, |
| | 'num_query_token': model_args.num_query_token, |
| | 'cross_attention_freq': model_args.cross_attention_freq, |
| | 'plm_model': model_args.plm_model, |
| | 'plm_tune': model_args.plm_tune, |
| | 'llm_name': model_args.llm_name, |
| | 'llm_tune': model_args.llm_tune, |
| | 'qformer_tune': model_args.qformer_tune, |
| | 'peft_dir': model_args.peft_dir, |
| | 'lora_r': model_args.lora_r, |
| | 'lora_alpha': model_args.lora_alpha, |
| | 'lora_dropout': model_args.lora_dropout, |
| | 'enbale_gradient_checkpointing': model_args.enbale_gradient_checkpointing, |
| | 'enable_flash': model_args.enable_flash, |
| | } |
| | return blip2_args |
| |
|
| | def _prep_for_training(model, training_args): |
| | """ |
| | Prepare BLIP2 model for training with LoRA. |
| | """ |
| | |
| | |
| | |
| | target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"] |
| | |
| | lora_config = LoraConfig( |
| | r=training_args.lora_r, |
| | lora_alpha=training_args.lora_alpha, |
| | lora_dropout=training_args.lora_dropout, |
| | target_modules=target_modules, |
| | init_lora_weights="gaussian", |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | ) |
| | |
| | return lora_config |
| |
|
| | def main(script_args, training_args, model_args): |
| | print(training_args.output_dir) |
| | torch.cuda.empty_cache() |
| | torch.set_float32_matmul_precision("medium") |
| |
|
| | |
| | blip2_args = create_blip2_args_from_config(model_args) |
| | model = Blip2Stage2(blip2_args) |
| |
|
| | |
| | if model_args.sft_checkpoint is not None: |
| | print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}") |
| | model = Blip2Stage2.load_from_checkpoint(model_args.sft_checkpoint, strict=False, args=blip2_args, map_location='cpu') |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | reward_funcs = [] |
| | for func_name in script_args.reward_funcs: |
| | if func_name == "combined": |
| | |
| | def weighted_combined_reward(prompts, completions, answer, **kwargs): |
| | return combined_reward_func( |
| | prompts, completions, answer, |
| | format_weight=script_args.format_weight, |
| | accuracy_weight=script_args.accuracy_weight, |
| | repetition_weight=script_args.repetition_weight, |
| | **kwargs |
| | ) |
| | reward_funcs.append(weighted_combined_reward) |
| | else: |
| | reward_funcs.append(reward_funcs_registry[func_name]) |
| | |
| | print("reward_funcs:", [func.__name__ if hasattr(func, '__name__') else 'weighted_combined_reward' for func in reward_funcs]) |
| | print(f"Reward weights - Format: {script_args.format_weight}, Accuracy: {script_args.accuracy_weight}, Repetition: {script_args.repetition_weight}") |
| |
|
| | vlm_module_cls = get_vlm_module(model_args.model_name_or_path) |
| | print("using vlm module:", vlm_module_cls.__name__) |
| | question_prompt = vlm_module_cls.get_question_template() |
| |
|
| | |
| | if script_args.data_file_paths and script_args.use_custom_prompts: |
| | print(f"Loading custom protein data from: {script_args.data_file_paths}") |
| | |
| | |
| | dataset = get_custom_protein_data_with_prompts( |
| | data_path=script_args.data_file_paths, |
| | prompt_templates=prompt_templates, |
| | template_name=script_args.template_name |
| | ) |
| | elif script_args.data_file_paths: |
| | print(f"Loading protein data from: {script_args.data_file_paths}") |
| | dataset = get_protein_classification_data( |
| | data_path=script_args.data_file_paths |
| | ) |
| | else: |
| | print("Using default KEGG dataset") |
| | dataset = get_kegg_questions() |
| |
|
| | print("Dataset loaded:") |
| | print(f"Train size: {len(dataset['train'])}") |
| | print(f"Val size: {len(dataset.get('val', []))}") |
| | |
| | |
| | if len(dataset['train']) > 0: |
| | print("\nSample data:") |
| | sample = dataset['train'][0] |
| | print(f"Prompt type: {type(sample.get('prompt', 'Unknown'))}") |
| | print(f"DNA sequences count: {len(sample.get('dna_sequences', []))}") |
| | print(f"Answer: {sample.get('answer', 'N/A')}") |
| | if 'metadata' in sample: |
| | print(f"Metadata: {sample['metadata']}") |
| | print(f"First 100 chars of sequence: {sample.get('dna_sequences', [''])[0][:100]}...") |
| |
|
| |
|
| | |
| | custom_save_callback = SaveWithPyTorchCallback() |
| |
|
| | |
| | trainer = Blip2GRPOTrainer( |
| | model=model, |
| | reward_funcs=reward_funcs, |
| | args=training_args, |
| | dna_module=vlm_module_cls(), |
| | train_dataset=dataset['train'], |
| | eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None, |
| | peft_config=get_peft_config(model_args), |
| | attn_implementation=getattr(model_args, 'attn_implementation', 'flash_attention_2'), |
| | torch_dtype=getattr(model_args, 'torch_dtype', 'bfloat16'), |
| | callbacks=[custom_save_callback], |
| | ) |
| |
|
| | |
| | training_args.save_safetensors = False |
| |
|
| | |
| | trainer.train() |
| |
|
| | if __name__ == "__main__": |
| | print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
| | parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, Blip2ModelConfig)) |
| | script_args, training_args, model_args = parser.parse_args_and_config() |
| | |
| | |
| | training_args.save_safetensors = False |
| |
|
| | main(script_args, training_args, model_args) |
| |
|
| | |
| | """ |
| | 使用你的蛋白质数据进行训练: |
| | |
| | 1. 准备CSV文件,格式:name,aa_seq,label,location,unique_id,pdb_hash |
| | |
| | 2. 运行训练: |
| | python blip2_reason.py \ |
| | --data_file_paths /path/to/your/protein_data.csv \ |
| | --reward_funcs combined \ |
| | --format_weight 0.2 \ |
| | --accuracy_weight 0.6 \ |
| | --repetition_weight 0.2 \ |
| | --use_custom_prompts \ |
| | --prompt_template classification \ |
| | --max_seq_length 1000 \ |
| | --output_dir ./output \ |
| | --per_device_train_batch_size 4 \ |
| | --num_train_epochs 3 \ |
| | --learning_rate 1e-5 |
| | |
| | 3. 或者使用分离的奖励函数: |
| | python blip2_reason.py \ |
| | --data_file_paths /path/to/your/protein_data.csv \ |
| | --reward_funcs format_correct classification_specific repetition_penalty \ |
| | --use_custom_prompts \ |
| | --prompt_template function_prediction |
| | |
| | 数据格式示例: |
| | P0DM40,MLRVVVESASINPPLSTTPKAFVTVYFRDMMKRTRVEEGHDPIWNETLIWHLWNQPLENDSFLKVILQDSVSKKKERFIGLATVPLKRLAQRPKEVMFVRDLILLNHSMKPTNCTVTLHVAQIYDQDTEMTGNEELLGSTVNEVTQKKLMVSGLPMHRALASKPQHFQVRVKVFEARQLLGNNIKPVVKVNIADQQHLTRIKMGNNPFFNEIFFQNFHEVPAKFFEENISIEVVDSAASRSKAEIGRFQTDIGFIYHSPGHTLLRKWLGLCQRNKTTSGVRGYLKVTICALGVGDQALVDQKLPYEQNTRVQIFKSKEVPVSLAYLQFFIYCAEDLHFGTHKSATPVLEVELIGDKLRTKPQNPSDNPIWNQILTFQIQLPCLSSYIKFRVMDCSKYKCQDEIGSASLCLSQISSTGEEIQGMYSGFLPCFGPSFLTLRGGKKPPFRTSEEGTCIMDAVQHGLAYRGRIFVEIVTKIKSQQDSVMKDLSQEVTQVEMQYYRQKYGLCVIFLSCTMMPKFKDLIQFEVSMGHYGNKTDPNYKPLVSTTQYSPVIYDGTTYHYVPWYNTKPVVAVTSNWEDVGFRMNCLNLLHITRDRLKTNLDILKSIRNPRDPALLQQWEKLLKELQEDCRRPLPCMTDQPRANSLDRNKWQLRSQLLQQLAQMAKEAKPVNMVGTAKEWLHRLNAVIPEPQESLPDVLIWLMSRQQRVAYARVPAHTVLFSPAGPLSSGKFCGKIQNILLQYPEGEGQDTFPASLRVCMWLGNVKYSKNLKLLQQGSMVVYAETYENQAKTRDDWGQQGLYHCPNFSDVMGRKALPKTDFKAPPGWHWKDDWVVEPQRRLLLDIDINKSQVLEEVYENQLRNATGAWVPAAIPNTDVNGQPVEALENVKCPQGWHFKKNWIVKLNHAVDSEGWEYGVGIPPSGLPQIWNSVEKTYHSCRRRRWVRVRFRNHKELGQERSQEQETLSFLQMQDLSEEGKEGWEYGTFDSRFHLDPQPTSRFRRRCWHRQLAPNKDRGVASIFLLEGSLAVEQKDQPRKEMEKTRSWQPWKDLRHTPEDPRIPTTPFIYYILNKPHYYQLFCYIYQARNLMYNQILTFQEPFIQVVFLNHSLCTQTLRSSAAPTWSQSIIFQHLLLFEDPKDTRENPPLVVLELWQHDSRGNKILWGRSMWPPVVWLGLQDWVFTPLRWHPLVRELGEEEGEILASCELILETQKLKELHPPILSIPCKDGIYLLPKNIQPTMKMMAIEIMAWGLRNMTKVRYPQLLLECGGESLKTEPISNFQENPNFPTSTFFFTVFMPLEETHAQPLVVKVVDNQEYGQQIVVGQANIDFLQPYFCDPWSLNYTTVKLPTLSVKKPDTFLDFVYKKFWFDSSKDEEVYEEEVDWWSKLFWATGDADKSLNYNHKSYHTLKVYDCELEAVLTFKGLQDFCQTFKLYQEKPKVDSPVVGEFKGLFRIYPFPEDPEAPKPPRQFSAWPEIEDFPQMCLVRVYLIRAINLQPQDYNGLCDPYVILKLGQTKLGSRDSYYPNTLDPIFGMMYELTCNIPLEKDLEIQLFDFDLITADDEIGSTVIDLENRLLSGFGARCGLSKSYCKSGPFKWRDQMTPSYLLYRYAKQKGLPPPVFDLEGDSLYYNGETFKLQSFESAPPTYKHLGPKKERLALYILNTQGLVPEHVETRTLHSNSQPGIDQGKIQMWVDIFPKMLGPPGPQVNISPRKPKRYQLRCIIWSTAEVDLVQETFSKEKMSDIYVKGWLFGLEEDTQKTDVHYHSLTGEATFNWRFIFTMDYLTTERACVQSQKDYIWSLDPTSTKFPARLMIQIWDNDFFSPDDFLGVLELDLSDMPLPAQNIKQCSLKMMETDSKWPFTPQKRISLFKKTNVTGWWPCQVLDGDKWRLSGKVKMTLEMLSEREALIRPAGRGQSEPNQFPMLHPPERNDSFLLWYQSPIKNFCYAVCKRYRSKIICLVVTLVIGFILLNFVYSAPSYFAMNWIKPQLRLSSPIKIVNLIGTVNTSNINSSILTMEGSTYHASHVFPEAPAP,0,M,af67d99c09f74ea8af5004cc2906bbc5,d55cbc3d94bd9668d97a678b4a04176a |
| | """ |