nas / BioReason_new /train_contrastive.py
yuccaaa's picture
Add files using upload-large-folder tool
349aa7a verified
# import os
# import time
# from argparse import ArgumentParser
# from functools import partial
# import torch
# import wandb
# from datasets import load_dataset
# from torch.utils.data import DataLoader
# import pytorch_lightning as pl
# from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
# from pytorch_lightning.loggers import WandbLogger
# from pytorch_lightning.strategies import DeepSpeedStrategy
# from bioreason.models.protein_llm import ProteinLLMModel
# from bioreason.models.contrast_trainer import (
# ContrastiveTrainer,
# ContrastiveTrainingArguments,
# protein_text_collate_fn,
# )
# from bioreason.dataset.protein import format_protein_contrastive
# def main(args):
# """
# Main function for contrastive pre-training of Protein-LLM.
# This script trains the QFormer projection layer to align protein and text representations
# using contrastive learning before fine-tuning on downstream tasks.
# """
# # Set random seed
# pl.seed_everything(args.seed)
# torch.cuda.empty_cache()
# torch.set_float32_matmul_precision("medium")
# # Initialize wandb
# if args.use_wandb:
# wandb.init(
# project=args.wandb_project,
# entity=args.wandb_entity,
# name=f"contrastive-{args.text_model_name.split('/')[-1]}-{time.strftime('%Y%m%d-%H%M%S')}",
# config=vars(args)
# )
# print("Loading model...")
# # Load the Protein-LLM model
# model = ProteinLLMModel(
# text_model_name=args.text_model_name,
# protein_model_name=args.protein_model_name,
# qformer_model_name=args.qformer_model_name,
# cache_dir=args.cache_dir,
# max_length_protein=args.max_length_protein,
# max_length_text=args.max_length_text,
# text_model_finetune=False, # Don't fine-tune during contrastive learning
# protein_model_finetune=False, # Don't fine-tune during contrastive learning
# num_query_tokens=args.num_query_tokens,
# )
# print("Loading datasets...")
# # Load datasets for contrastive learning
# train_dataset = load_dataset(args.dataset_name, split="train")
# eval_dataset = load_dataset(args.dataset_name, split="validation") if args.eval_dataset else None
# # Format datasets for contrastive learning
# train_dataset = train_dataset.map(format_protein_contrastive)
# if eval_dataset:
# eval_dataset = eval_dataset.map(format_protein_contrastive)
# # Filter out examples without protein sequences or descriptions
# train_dataset = train_dataset.filter(
# lambda x: x["protein_sequence"] and x["text_description"]
# and len(x["protein_sequence"].strip()) > 0
# and len(x["text_description"].strip()) > 0
# )
# if eval_dataset:
# eval_dataset = eval_dataset.filter(
# lambda x: x["protein_sequence"] and x["text_description"]
# and len(x["protein_sequence"].strip()) > 0
# and len(x["text_description"].strip()) > 0
# )
# print(f"Training dataset size: {len(train_dataset)}")
# if eval_dataset:
# print(f"Eval dataset size: {len(eval_dataset)}")
# # Setup training arguments for contrastive learning
# training_args = ContrastiveTrainingArguments(
# output_dir=args.output_dir,
# num_train_epochs=args.num_epochs,
# per_device_train_batch_size=args.batch_size,
# per_device_eval_batch_size=args.batch_size,
# learning_rate=args.learning_rate,
# weight_decay=args.weight_decay,
# temperature=args.temperature,
# freeze_protein_model=args.freeze_protein_model,
# freeze_text_model=args.freeze_text_model,
# protein_weight=args.protein_weight,
# text_weight=args.text_weight,
# max_length_protein=args.max_length_protein,
# max_length_text=args.max_length_text,
# logging_steps=args.logging_steps,
# evaluation_strategy="steps" if eval_dataset else "no",
# eval_steps=args.eval_steps if eval_dataset else None,
# save_steps=args.save_steps,
# save_total_limit=args.save_total_limit,
# load_best_model_at_end=True if eval_dataset else False,
# metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None,
# greater_is_better=True,
# report_to=["wandb"] if args.use_wandb else [],
# warmup_steps=args.warmup_steps,
# gradient_accumulation_steps=args.gradient_accumulation_steps,
# fp16=args.fp16,
# bf16=args.bf16,
# dataloader_num_workers=args.num_workers,
# remove_unused_columns=False,
# seed=args.seed,
# )
# print("Initializing trainer...")
# # Initialize the contrastive trainer
# trainer = ContrastiveTrainer(
# model=model,
# args=training_args,
# train_dataset=train_dataset,
# eval_dataset=eval_dataset,
# data_collator=protein_text_collate_fn,
# )
# print("Starting contrastive training...")
# # Train the model
# trainer.train()
# print("Saving final model...")
# # Save the final model
# trainer.save_model()
# # Save only the projection layer weights for later use
# projection_path = os.path.join(args.output_dir, "protein_projection.pt")
# torch.save(model.protein_projection.state_dict(), projection_path)
# print(f"Saved protein projection weights to: {projection_path}")
# # Final evaluation
# if eval_dataset:
# print("Running final evaluation...")
# eval_results = trainer.evaluate()
# print(f"Final evaluation results: {eval_results}")
# if args.use_wandb:
# wandb.log({"final_eval": eval_results})
# print("Contrastive training completed!")
# if args.use_wandb:
# wandb.finish()
# return trainer
# if __name__ == "__main__":
# parser = ArgumentParser(description="Contrastive pre-training for Protein-LLM")
# # Model configuration
# parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B",
# help="Name or path to the text model")
# parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D",
# help="Name or path to the protein model")
# parser.add_argument("--qformer_model_name", type=str,
# default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
# help="Name or path to the QFormer model")
# parser.add_argument("--cache_dir", type=str, default="/model-weights",
# help="Directory to cache downloaded models")
# parser.add_argument("--num_query_tokens", type=int, default=32,
# help="Number of query tokens in QFormer")
# # Dataset configuration
# parser.add_argument("--dataset_name", type=str, default="wanglab/protein_descriptions",
# help="Name of the dataset for contrastive learning")
# parser.add_argument("--eval_dataset", action="store_true",
# help="Whether to use evaluation dataset")
# # Training configuration
# parser.add_argument("--output_dir", type=str, default="./contrastive_outputs",
# help="Output directory for model and logs")
# parser.add_argument("--num_epochs", type=int, default=10,
# help="Number of training epochs")
# parser.add_argument("--batch_size", type=int, default=32,
# help="Batch size per device")
# parser.add_argument("--learning_rate", type=float, default=1e-4,
# help="Learning rate")
# parser.add_argument("--weight_decay", type=float, default=0.01,
# help="Weight decay")
# parser.add_argument("--warmup_steps", type=int, default=1000,
# help="Number of warmup steps")
# parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
# help="Gradient accumulation steps")
# # Contrastive learning specific
# parser.add_argument("--temperature", type=float, default=0.07,
# help="Temperature for contrastive loss")
# parser.add_argument("--freeze_protein_model", action="store_true", default=True,
# help="Freeze protein model during training")
# parser.add_argument("--freeze_text_model", action="store_true", default=True,
# help="Freeze text model during training")
# parser.add_argument("--protein_weight", type=float, default=1.0,
# help="Weight for protein features in contrastive loss")
# parser.add_argument("--text_weight", type=float, default=1.0,
# help="Weight for text features in contrastive loss")
# # Data configuration
# parser.add_argument("--max_length_protein", type=int, default=1024,
# help="Maximum length for protein sequences")
# parser.add_argument("--max_length_text", type=int, default=512,
# help="Maximum length for text sequences")
# parser.add_argument("--num_workers", type=int, default=4,
# help="Number of data loading workers")
# # Logging and evaluation
# parser.add_argument("--logging_steps", type=int, default=100,
# help="Number of steps between logging")
# parser.add_argument("--eval_steps", type=int, default=500,
# help="Number of steps between evaluations")
# parser.add_argument("--save_steps", type=int, default=1000,
# help="Number of steps between saving checkpoints")
# parser.add_argument("--save_total_limit", type=int, default=3,
# help="Maximum number of checkpoints to keep")
# # Hardware configuration
# parser.add_argument("--fp16", action="store_true",
# help="Use FP16 precision")
# parser.add_argument("--bf16", action="store_true",
# help="Use BF16 precision")
# parser.add_argument("--seed", type=int, default=42,
# help="Random seed")
# # Wandb logging
# parser.add_argument("--use_wandb", action="store_true",
# help="Use Weights & Biases for logging")
# parser.add_argument("--wandb_project", type=str, default="protein-llm-contrastive",
# help="Wandb project name")
# parser.add_argument("--wandb_entity", type=str, default=None,
# help="Wandb entity name")
# args = parser.parse_args()
# # Create output directory
# os.makedirs(args.output_dir, exist_ok=True)
# # Run contrastive training
# trainer = main(args)
import os
import time
from argparse import ArgumentParser
from functools import partial
import torch
import wandb
from datasets import load_dataset
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DeepSpeedStrategy
from bioreason.models.protein_llm import ProteinLLMModel
from bioreason.trainer.contrast_trainer_new import (
ContrastiveTrainer,
ContrastiveTrainingArguments,
protein_text_collate_fn,
)
from bioreason.dataset.protein import format_protein_contrastive
def main(args):
"""
Main function for enhanced contrastive pre-training of Protein-LLM.
This script trains the QFormer projection layer to align protein and text representations
using enhanced contrastive learning with optional protein-text matching before fine-tuning.
"""
# Set random seed
pl.seed_everything(args.seed)
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("medium")
# Initialize wandb
if args.use_wandb:
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
name=f"enhanced-contrastive-{args.text_model_name.split('/')[-1]}-{time.strftime('%Y%m%d-%H%M%S')}",
config=vars(args)
)
print("Loading model...")
# Load the Protein-LLM model
model = ProteinLLMModel(
text_model_name=args.text_model_name,
protein_model_name=args.protein_model_name,
qformer_model_name=args.qformer_model_name,
cache_dir=args.cache_dir,
max_length_protein=args.max_length_protein,
max_length_text=args.max_length_text,
text_model_finetune=False, # Don't fine-tune during contrastive learning
protein_model_finetune=False, # Don't fine-tune during contrastive learning
num_query_tokens=args.num_query_tokens,
)
print("Loading datasets...")
# Load datasets for contrastive learning
train_dataset = load_dataset("json", data_files=args.train_dataset, split="train")
eval_dataset = load_dataset("json", data_files=args.valid_dataset, split="train") if args.eval_dataset else None
# Format datasets for contrastive learning
train_dataset = train_dataset.map(format_protein_contrastive)
if eval_dataset:
eval_dataset = eval_dataset.map(format_protein_contrastive)
# Filter out examples without protein sequences or descriptions
train_dataset = train_dataset.filter(
lambda x: x["protein"] and x["text"]
and len(x["protein"].strip()) > 0
and len(x["text"].strip()) > 0
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
lambda x: x["protein"] and x["text"]
and len(x["protein"].strip()) > 0
and len(x["text"].strip()) > 0
)
print(f"Training dataset size: {len(train_dataset)}")
if eval_dataset:
print(f"Eval dataset size: {len(eval_dataset)}")
# Setup enhanced training arguments for contrastive learning
training_args = ContrastiveTrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
temperature=args.temperature,
freeze_protein_model=args.freeze_protein_model,
freeze_text_model=args.freeze_text_model,
protein_weight=args.protein_weight,
text_weight=args.text_weight,
enable_ptm=args.enable_ptm,
ptm_weight=args.ptm_weight,
max_length_protein=args.max_length_protein,
max_length_text=args.max_length_text,
logging_steps=args.logging_steps,
eval_strategy="steps" if eval_dataset else "no",
eval_steps=args.eval_steps if eval_dataset else None,
save_steps=args.save_steps,
save_total_limit=args.save_total_limit,
load_best_model_at_end=True if eval_dataset else False,
metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None,
greater_is_better=True,
report_to=["wandb"] if args.use_wandb else [],
warmup_steps=args.warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
fp16=args.fp16,
bf16=args.bf16,
dataloader_num_workers=args.num_workers,
remove_unused_columns=False,
seed=args.seed,
# Distributed training settings
ddp_find_unused_parameters=False,
dataloader_pin_memory=True,
)
print("Initializing enhanced trainer...")
# Initialize the enhanced contrastive trainer
trainer = ContrastiveTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=protein_text_collate_fn,
)
print("Starting enhanced contrastive training...")
print(f"- Contrastive learning enabled")
print(f"- Protein-text matching: {'enabled' if args.enable_ptm else 'disabled'}")
print(f"- Temperature: {args.temperature}")
print(f"- PTM weight: {args.ptm_weight}")
# Train the model
trainer.train()
print("Saving final model...")
# Save the final model
trainer.save_model()
# Save projection layer weights and PTM head if enabled
projection_path = os.path.join(args.output_dir, "protein_projection.pt")
torch.save(model.protein_projection.state_dict(), projection_path)
print(f"Saved protein projection weights to: {projection_path}")
if args.enable_ptm and hasattr(trainer.contrastive_loss, 'ptm_head'):
ptm_head_path = os.path.join(args.output_dir, "ptm_head.pt")
torch.save(trainer.contrastive_loss.ptm_head.state_dict(), ptm_head_path)
print(f"Saved PTM head weights to: {ptm_head_path}")
# Final evaluation
if eval_dataset:
print("Running final evaluation...")
eval_results = trainer.evaluate()
print(f"Final evaluation results: {eval_results}")
if args.use_wandb:
wandb.log({"final_eval": eval_results})
print("Enhanced contrastive training completed!")
if args.use_wandb:
wandb.finish()
return trainer
if __name__ == "__main__":
parser = ArgumentParser(description="Enhanced contrastive pre-training for Protein-LLM")
# Model configuration
parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B",
help="Name or path to the text model")
parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D",
help="Name or path to the protein model")
parser.add_argument("--qformer_model_name", type=str,
default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
help="Name or path to the QFormer model")
parser.add_argument("--cache_dir", type=str, default="/model-weights",
help="Directory to cache downloaded models")
parser.add_argument("--num_query_tokens", type=int, default=32,
help="Number of query tokens in QFormer")
# Dataset configuration
parser.add_argument("--train_dataset", type=str, default="wanglab/protein_descriptions",
help="Name of the dataset for contrastive learning")
parser.add_argument("--valid_dataset", type=str, default="wanglab/protein_descriptions",
help="Name of the dataset for contrastive learning")
parser.add_argument("--eval_dataset", action="store_true",
help="Whether to use evaluation dataset")
# Training configuration
parser.add_argument("--output_dir", type=str, default="./enhanced_contrastive_outputs",
help="Output directory for model and logs")
parser.add_argument("--num_epochs", type=int, default=10,
help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size per device")
parser.add_argument("--learning_rate", type=float, default=1e-4,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.01,
help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=1000,
help="Number of warmup steps")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
help="Gradient accumulation steps")
# Enhanced contrastive learning specific
parser.add_argument("--temperature", type=float, default=0.07,
help="Temperature for contrastive loss")
parser.add_argument("--freeze_protein_model", action="store_true", default=True,
help="Freeze protein model during training")
parser.add_argument("--freeze_text_model", action="store_true", default=True,
help="Freeze text model during training")
parser.add_argument("--protein_weight", type=float, default=1.0,
help="Weight for protein features in contrastive loss")
parser.add_argument("--text_weight", type=float, default=1.0,
help="Weight for text features in contrastive loss")
# Protein-Text Matching (PTM) configuration
parser.add_argument("--enable_ptm", action="store_true", default=True,
help="Enable protein-text matching task")
parser.add_argument("--ptm_weight", type=float, default=1.0,
help="Weight for protein-text matching loss")
# Data configuration
parser.add_argument("--max_length_protein", type=int, default=1024,
help="Maximum length for protein sequences")
parser.add_argument("--max_length_text", type=int, default=512,
help="Maximum length for text sequences")
parser.add_argument("--num_workers", type=int, default=4,
help="Number of data loading workers")
# Logging and evaluation
parser.add_argument("--logging_steps", type=int, default=100,
help="Number of steps between logging")
parser.add_argument("--eval_steps", type=int, default=500,
help="Number of steps between evaluations")
parser.add_argument("--save_steps", type=int, default=1000,
help="Number of steps between saving checkpoints")
parser.add_argument("--save_total_limit", type=int, default=3,
help="Maximum number of checkpoints to keep")
# Hardware configuration
parser.add_argument("--fp16", action="store_true",
help="Use FP16 precision")
parser.add_argument("--bf16", action="store_true",
help="Use BF16 precision")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
# Wandb logging
parser.add_argument("--use_wandb", action="store_true",
help="Use Weights & Biases for logging")
parser.add_argument("--wandb_project", type=str, default="protein-llm-enhanced-contrastive",
help="Wandb project name")
parser.add_argument("--wandb_entity", type=str, default=None,
help="Wandb entity name")
args = parser.parse_args()
# Validate arguments
if args.enable_ptm and not hasattr(args, 'ptm_weight'):
args.ptm_weight = 1.0
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Print configuration
print("=" * 50)
print("Enhanced Contrastive Training Configuration:")
print("=" * 50)
print(f"Text model: {args.text_model_name}")
print(f"Protein model: {args.protein_model_name}")
print(f"QFormer model: {args.qformer_model_name}")
print(f"Dataset: {args.train_dataset}")
print(f"Output directory: {args.output_dir}")
print(f"Batch size: {args.batch_size}")
print(f"Learning rate: {args.learning_rate}")
print(f"Temperature: {args.temperature}")
print(f"Enable PTM: {args.enable_ptm}")
print(f"PTM weight: {args.ptm_weight}")
print(f"Number of epochs: {args.num_epochs}")
print(f"Query tokens: {args.num_query_tokens}")
print("=" * 50)
# Run enhanced contrastive training
trainer = main(args)