# import os # import time # from argparse import ArgumentParser # from functools import partial # import torch # import wandb # from datasets import load_dataset # from torch.utils.data import DataLoader # import pytorch_lightning as pl # from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # from pytorch_lightning.loggers import WandbLogger # from pytorch_lightning.strategies import DeepSpeedStrategy # from bioreason.models.protein_llm import ProteinLLMModel # from bioreason.models.contrast_trainer import ( # ContrastiveTrainer, # ContrastiveTrainingArguments, # protein_text_collate_fn, # ) # from bioreason.dataset.protein import format_protein_contrastive # def main(args): # """ # Main function for contrastive pre-training of Protein-LLM. # This script trains the QFormer projection layer to align protein and text representations # using contrastive learning before fine-tuning on downstream tasks. # """ # # Set random seed # pl.seed_everything(args.seed) # torch.cuda.empty_cache() # torch.set_float32_matmul_precision("medium") # # Initialize wandb # if args.use_wandb: # wandb.init( # project=args.wandb_project, # entity=args.wandb_entity, # name=f"contrastive-{args.text_model_name.split('/')[-1]}-{time.strftime('%Y%m%d-%H%M%S')}", # config=vars(args) # ) # print("Loading model...") # # Load the Protein-LLM model # model = ProteinLLMModel( # text_model_name=args.text_model_name, # protein_model_name=args.protein_model_name, # qformer_model_name=args.qformer_model_name, # cache_dir=args.cache_dir, # max_length_protein=args.max_length_protein, # max_length_text=args.max_length_text, # text_model_finetune=False, # Don't fine-tune during contrastive learning # protein_model_finetune=False, # Don't fine-tune during contrastive learning # num_query_tokens=args.num_query_tokens, # ) # print("Loading datasets...") # # Load datasets for contrastive learning # train_dataset = load_dataset(args.dataset_name, split="train") # eval_dataset = load_dataset(args.dataset_name, split="validation") if args.eval_dataset else None # # Format datasets for contrastive learning # train_dataset = train_dataset.map(format_protein_contrastive) # if eval_dataset: # eval_dataset = eval_dataset.map(format_protein_contrastive) # # Filter out examples without protein sequences or descriptions # train_dataset = train_dataset.filter( # lambda x: x["protein_sequence"] and x["text_description"] # and len(x["protein_sequence"].strip()) > 0 # and len(x["text_description"].strip()) > 0 # ) # if eval_dataset: # eval_dataset = eval_dataset.filter( # lambda x: x["protein_sequence"] and x["text_description"] # and len(x["protein_sequence"].strip()) > 0 # and len(x["text_description"].strip()) > 0 # ) # print(f"Training dataset size: {len(train_dataset)}") # if eval_dataset: # print(f"Eval dataset size: {len(eval_dataset)}") # # Setup training arguments for contrastive learning # training_args = ContrastiveTrainingArguments( # output_dir=args.output_dir, # num_train_epochs=args.num_epochs, # per_device_train_batch_size=args.batch_size, # per_device_eval_batch_size=args.batch_size, # learning_rate=args.learning_rate, # weight_decay=args.weight_decay, # temperature=args.temperature, # freeze_protein_model=args.freeze_protein_model, # freeze_text_model=args.freeze_text_model, # protein_weight=args.protein_weight, # text_weight=args.text_weight, # max_length_protein=args.max_length_protein, # max_length_text=args.max_length_text, # logging_steps=args.logging_steps, # evaluation_strategy="steps" if eval_dataset else "no", # eval_steps=args.eval_steps if eval_dataset else None, # save_steps=args.save_steps, # save_total_limit=args.save_total_limit, # load_best_model_at_end=True if eval_dataset else False, # metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None, # greater_is_better=True, # report_to=["wandb"] if args.use_wandb else [], # warmup_steps=args.warmup_steps, # gradient_accumulation_steps=args.gradient_accumulation_steps, # fp16=args.fp16, # bf16=args.bf16, # dataloader_num_workers=args.num_workers, # remove_unused_columns=False, # seed=args.seed, # ) # print("Initializing trainer...") # # Initialize the contrastive trainer # trainer = ContrastiveTrainer( # model=model, # args=training_args, # train_dataset=train_dataset, # eval_dataset=eval_dataset, # data_collator=protein_text_collate_fn, # ) # print("Starting contrastive training...") # # Train the model # trainer.train() # print("Saving final model...") # # Save the final model # trainer.save_model() # # Save only the projection layer weights for later use # projection_path = os.path.join(args.output_dir, "protein_projection.pt") # torch.save(model.protein_projection.state_dict(), projection_path) # print(f"Saved protein projection weights to: {projection_path}") # # Final evaluation # if eval_dataset: # print("Running final evaluation...") # eval_results = trainer.evaluate() # print(f"Final evaluation results: {eval_results}") # if args.use_wandb: # wandb.log({"final_eval": eval_results}) # print("Contrastive training completed!") # if args.use_wandb: # wandb.finish() # return trainer # if __name__ == "__main__": # parser = ArgumentParser(description="Contrastive pre-training for Protein-LLM") # # Model configuration # parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B", # help="Name or path to the text model") # parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D", # help="Name or path to the protein model") # parser.add_argument("--qformer_model_name", type=str, # default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext", # help="Name or path to the QFormer model") # parser.add_argument("--cache_dir", type=str, default="/model-weights", # help="Directory to cache downloaded models") # parser.add_argument("--num_query_tokens", type=int, default=32, # help="Number of query tokens in QFormer") # # Dataset configuration # parser.add_argument("--dataset_name", type=str, default="wanglab/protein_descriptions", # help="Name of the dataset for contrastive learning") # parser.add_argument("--eval_dataset", action="store_true", # help="Whether to use evaluation dataset") # # Training configuration # parser.add_argument("--output_dir", type=str, default="./contrastive_outputs", # help="Output directory for model and logs") # parser.add_argument("--num_epochs", type=int, default=10, # help="Number of training epochs") # parser.add_argument("--batch_size", type=int, default=32, # help="Batch size per device") # parser.add_argument("--learning_rate", type=float, default=1e-4, # help="Learning rate") # parser.add_argument("--weight_decay", type=float, default=0.01, # help="Weight decay") # parser.add_argument("--warmup_steps", type=int, default=1000, # help="Number of warmup steps") # parser.add_argument("--gradient_accumulation_steps", type=int, default=1, # help="Gradient accumulation steps") # # Contrastive learning specific # parser.add_argument("--temperature", type=float, default=0.07, # help="Temperature for contrastive loss") # parser.add_argument("--freeze_protein_model", action="store_true", default=True, # help="Freeze protein model during training") # parser.add_argument("--freeze_text_model", action="store_true", default=True, # help="Freeze text model during training") # parser.add_argument("--protein_weight", type=float, default=1.0, # help="Weight for protein features in contrastive loss") # parser.add_argument("--text_weight", type=float, default=1.0, # help="Weight for text features in contrastive loss") # # Data configuration # parser.add_argument("--max_length_protein", type=int, default=1024, # help="Maximum length for protein sequences") # parser.add_argument("--max_length_text", type=int, default=512, # help="Maximum length for text sequences") # parser.add_argument("--num_workers", type=int, default=4, # help="Number of data loading workers") # # Logging and evaluation # parser.add_argument("--logging_steps", type=int, default=100, # help="Number of steps between logging") # parser.add_argument("--eval_steps", type=int, default=500, # help="Number of steps between evaluations") # parser.add_argument("--save_steps", type=int, default=1000, # help="Number of steps between saving checkpoints") # parser.add_argument("--save_total_limit", type=int, default=3, # help="Maximum number of checkpoints to keep") # # Hardware configuration # parser.add_argument("--fp16", action="store_true", # help="Use FP16 precision") # parser.add_argument("--bf16", action="store_true", # help="Use BF16 precision") # parser.add_argument("--seed", type=int, default=42, # help="Random seed") # # Wandb logging # parser.add_argument("--use_wandb", action="store_true", # help="Use Weights & Biases for logging") # parser.add_argument("--wandb_project", type=str, default="protein-llm-contrastive", # help="Wandb project name") # parser.add_argument("--wandb_entity", type=str, default=None, # help="Wandb entity name") # args = parser.parse_args() # # Create output directory # os.makedirs(args.output_dir, exist_ok=True) # # Run contrastive training # trainer = main(args) import os import time from argparse import ArgumentParser from functools import partial import torch import wandb from datasets import load_dataset from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.strategies import DeepSpeedStrategy from bioreason.models.protein_llm import ProteinLLMModel from bioreason.trainer.contrast_trainer_new import ( ContrastiveTrainer, ContrastiveTrainingArguments, protein_text_collate_fn, ) from bioreason.dataset.protein import format_protein_contrastive def main(args): """ Main function for enhanced contrastive pre-training of Protein-LLM. This script trains the QFormer projection layer to align protein and text representations using enhanced contrastive learning with optional protein-text matching before fine-tuning. """ # Set random seed pl.seed_everything(args.seed) torch.cuda.empty_cache() torch.set_float32_matmul_precision("medium") # Initialize wandb if args.use_wandb: wandb.init( project=args.wandb_project, entity=args.wandb_entity, name=f"enhanced-contrastive-{args.text_model_name.split('/')[-1]}-{time.strftime('%Y%m%d-%H%M%S')}", config=vars(args) ) print("Loading model...") # Load the Protein-LLM model model = ProteinLLMModel( text_model_name=args.text_model_name, protein_model_name=args.protein_model_name, qformer_model_name=args.qformer_model_name, cache_dir=args.cache_dir, max_length_protein=args.max_length_protein, max_length_text=args.max_length_text, text_model_finetune=False, # Don't fine-tune during contrastive learning protein_model_finetune=False, # Don't fine-tune during contrastive learning num_query_tokens=args.num_query_tokens, ) print("Loading datasets...") # Load datasets for contrastive learning train_dataset = load_dataset("json", data_files=args.train_dataset, split="train") eval_dataset = load_dataset("json", data_files=args.valid_dataset, split="train") if args.eval_dataset else None # Format datasets for contrastive learning train_dataset = train_dataset.map(format_protein_contrastive) if eval_dataset: eval_dataset = eval_dataset.map(format_protein_contrastive) # Filter out examples without protein sequences or descriptions train_dataset = train_dataset.filter( lambda x: x["protein"] and x["text"] and len(x["protein"].strip()) > 0 and len(x["text"].strip()) > 0 ) if eval_dataset: eval_dataset = eval_dataset.filter( lambda x: x["protein"] and x["text"] and len(x["protein"].strip()) > 0 and len(x["text"].strip()) > 0 ) print(f"Training dataset size: {len(train_dataset)}") if eval_dataset: print(f"Eval dataset size: {len(eval_dataset)}") # Setup enhanced training arguments for contrastive learning training_args = ContrastiveTrainingArguments( output_dir=args.output_dir, num_train_epochs=args.num_epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, learning_rate=args.learning_rate, weight_decay=args.weight_decay, temperature=args.temperature, freeze_protein_model=args.freeze_protein_model, freeze_text_model=args.freeze_text_model, protein_weight=args.protein_weight, text_weight=args.text_weight, enable_ptm=args.enable_ptm, ptm_weight=args.ptm_weight, max_length_protein=args.max_length_protein, max_length_text=args.max_length_text, logging_steps=args.logging_steps, eval_strategy="steps" if eval_dataset else "no", eval_steps=args.eval_steps if eval_dataset else None, save_steps=args.save_steps, save_total_limit=args.save_total_limit, load_best_model_at_end=True if eval_dataset else False, metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None, greater_is_better=True, report_to=["wandb"] if args.use_wandb else [], warmup_steps=args.warmup_steps, gradient_accumulation_steps=args.gradient_accumulation_steps, fp16=args.fp16, bf16=args.bf16, dataloader_num_workers=args.num_workers, remove_unused_columns=False, seed=args.seed, # Distributed training settings ddp_find_unused_parameters=False, dataloader_pin_memory=True, ) print("Initializing enhanced trainer...") # Initialize the enhanced contrastive trainer trainer = ContrastiveTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=protein_text_collate_fn, ) print("Starting enhanced contrastive training...") print(f"- Contrastive learning enabled") print(f"- Protein-text matching: {'enabled' if args.enable_ptm else 'disabled'}") print(f"- Temperature: {args.temperature}") print(f"- PTM weight: {args.ptm_weight}") # Train the model trainer.train() print("Saving final model...") # Save the final model trainer.save_model() # Save projection layer weights and PTM head if enabled projection_path = os.path.join(args.output_dir, "protein_projection.pt") torch.save(model.protein_projection.state_dict(), projection_path) print(f"Saved protein projection weights to: {projection_path}") if args.enable_ptm and hasattr(trainer.contrastive_loss, 'ptm_head'): ptm_head_path = os.path.join(args.output_dir, "ptm_head.pt") torch.save(trainer.contrastive_loss.ptm_head.state_dict(), ptm_head_path) print(f"Saved PTM head weights to: {ptm_head_path}") # Final evaluation if eval_dataset: print("Running final evaluation...") eval_results = trainer.evaluate() print(f"Final evaluation results: {eval_results}") if args.use_wandb: wandb.log({"final_eval": eval_results}) print("Enhanced contrastive training completed!") if args.use_wandb: wandb.finish() return trainer if __name__ == "__main__": parser = ArgumentParser(description="Enhanced contrastive pre-training for Protein-LLM") # Model configuration parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B", help="Name or path to the text model") parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D", help="Name or path to the protein model") parser.add_argument("--qformer_model_name", type=str, default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext", help="Name or path to the QFormer model") parser.add_argument("--cache_dir", type=str, default="/model-weights", help="Directory to cache downloaded models") parser.add_argument("--num_query_tokens", type=int, default=32, help="Number of query tokens in QFormer") # Dataset configuration parser.add_argument("--train_dataset", type=str, default="wanglab/protein_descriptions", help="Name of the dataset for contrastive learning") parser.add_argument("--valid_dataset", type=str, default="wanglab/protein_descriptions", help="Name of the dataset for contrastive learning") parser.add_argument("--eval_dataset", action="store_true", help="Whether to use evaluation dataset") # Training configuration parser.add_argument("--output_dir", type=str, default="./enhanced_contrastive_outputs", help="Output directory for model and logs") parser.add_argument("--num_epochs", type=int, default=10, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=32, help="Batch size per device") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") parser.add_argument("--warmup_steps", type=int, default=1000, help="Number of warmup steps") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") # Enhanced contrastive learning specific parser.add_argument("--temperature", type=float, default=0.07, help="Temperature for contrastive loss") parser.add_argument("--freeze_protein_model", action="store_true", default=True, help="Freeze protein model during training") parser.add_argument("--freeze_text_model", action="store_true", default=True, help="Freeze text model during training") parser.add_argument("--protein_weight", type=float, default=1.0, help="Weight for protein features in contrastive loss") parser.add_argument("--text_weight", type=float, default=1.0, help="Weight for text features in contrastive loss") # Protein-Text Matching (PTM) configuration parser.add_argument("--enable_ptm", action="store_true", default=True, help="Enable protein-text matching task") parser.add_argument("--ptm_weight", type=float, default=1.0, help="Weight for protein-text matching loss") # Data configuration parser.add_argument("--max_length_protein", type=int, default=1024, help="Maximum length for protein sequences") parser.add_argument("--max_length_text", type=int, default=512, help="Maximum length for text sequences") parser.add_argument("--num_workers", type=int, default=4, help="Number of data loading workers") # Logging and evaluation parser.add_argument("--logging_steps", type=int, default=100, help="Number of steps between logging") parser.add_argument("--eval_steps", type=int, default=500, help="Number of steps between evaluations") parser.add_argument("--save_steps", type=int, default=1000, help="Number of steps between saving checkpoints") parser.add_argument("--save_total_limit", type=int, default=3, help="Maximum number of checkpoints to keep") # Hardware configuration parser.add_argument("--fp16", action="store_true", help="Use FP16 precision") parser.add_argument("--bf16", action="store_true", help="Use BF16 precision") parser.add_argument("--seed", type=int, default=42, help="Random seed") # Wandb logging parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases for logging") parser.add_argument("--wandb_project", type=str, default="protein-llm-enhanced-contrastive", help="Wandb project name") parser.add_argument("--wandb_entity", type=str, default=None, help="Wandb entity name") args = parser.parse_args() # Validate arguments if args.enable_ptm and not hasattr(args, 'ptm_weight'): args.ptm_weight = 1.0 # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Print configuration print("=" * 50) print("Enhanced Contrastive Training Configuration:") print("=" * 50) print(f"Text model: {args.text_model_name}") print(f"Protein model: {args.protein_model_name}") print(f"QFormer model: {args.qformer_model_name}") print(f"Dataset: {args.train_dataset}") print(f"Output directory: {args.output_dir}") print(f"Batch size: {args.batch_size}") print(f"Learning rate: {args.learning_rate}") print(f"Temperature: {args.temperature}") print(f"Enable PTM: {args.enable_ptm}") print(f"PTM weight: {args.ptm_weight}") print(f"Number of epochs: {args.num_epochs}") print(f"Query tokens: {args.num_query_tokens}") print("=" * 50) # Run enhanced contrastive training trainer = main(args)