| # import os | |
| # import time | |
| # from argparse import ArgumentParser | |
| # from functools import partial | |
| # import torch | |
| # import wandb | |
| # from datasets import load_dataset | |
| # from torch.utils.data import DataLoader | |
| # import pytorch_lightning as pl | |
| # from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | |
| # from pytorch_lightning.loggers import WandbLogger | |
| # from pytorch_lightning.strategies import DeepSpeedStrategy | |
| # from bioreason.models.protein_llm import ProteinLLMModel | |
| # from bioreason.models.contrast_trainer import ( | |
| # ContrastiveTrainer, | |
| # ContrastiveTrainingArguments, | |
| # protein_text_collate_fn, | |
| # ) | |
| # from bioreason.dataset.protein import format_protein_contrastive | |
| # def main(args): | |
| # """ | |
| # Main function for contrastive pre-training of Protein-LLM. | |
| # This script trains the QFormer projection layer to align protein and text representations | |
| # using contrastive learning before fine-tuning on downstream tasks. | |
| # """ | |
| # # Set random seed | |
| # pl.seed_everything(args.seed) | |
| # torch.cuda.empty_cache() | |
| # torch.set_float32_matmul_precision("medium") | |
| # # Initialize wandb | |
| # if args.use_wandb: | |
| # wandb.init( | |
| # project=args.wandb_project, | |
| # entity=args.wandb_entity, | |
| # name=f"contrastive-{args.text_model_name.split('/')[-1]}-{time.strftime('%Y%m%d-%H%M%S')}", | |
| # config=vars(args) | |
| # ) | |
| # print("Loading model...") | |
| # # Load the Protein-LLM model | |
| # model = ProteinLLMModel( | |
| # text_model_name=args.text_model_name, | |
| # protein_model_name=args.protein_model_name, | |
| # qformer_model_name=args.qformer_model_name, | |
| # cache_dir=args.cache_dir, | |
| # max_length_protein=args.max_length_protein, | |
| # max_length_text=args.max_length_text, | |
| # text_model_finetune=False, # Don't fine-tune during contrastive learning | |
| # protein_model_finetune=False, # Don't fine-tune during contrastive learning | |
| # num_query_tokens=args.num_query_tokens, | |
| # ) | |
| # print("Loading datasets...") | |
| # # Load datasets for contrastive learning | |
| # train_dataset = load_dataset(args.dataset_name, split="train") | |
| # eval_dataset = load_dataset(args.dataset_name, split="validation") if args.eval_dataset else None | |
| # # Format datasets for contrastive learning | |
| # train_dataset = train_dataset.map(format_protein_contrastive) | |
| # if eval_dataset: | |
| # eval_dataset = eval_dataset.map(format_protein_contrastive) | |
| # # Filter out examples without protein sequences or descriptions | |
| # train_dataset = train_dataset.filter( | |
| # lambda x: x["protein_sequence"] and x["text_description"] | |
| # and len(x["protein_sequence"].strip()) > 0 | |
| # and len(x["text_description"].strip()) > 0 | |
| # ) | |
| # if eval_dataset: | |
| # eval_dataset = eval_dataset.filter( | |
| # lambda x: x["protein_sequence"] and x["text_description"] | |
| # and len(x["protein_sequence"].strip()) > 0 | |
| # and len(x["text_description"].strip()) > 0 | |
| # ) | |
| # print(f"Training dataset size: {len(train_dataset)}") | |
| # if eval_dataset: | |
| # print(f"Eval dataset size: {len(eval_dataset)}") | |
| # # Setup training arguments for contrastive learning | |
| # training_args = ContrastiveTrainingArguments( | |
| # output_dir=args.output_dir, | |
| # num_train_epochs=args.num_epochs, | |
| # per_device_train_batch_size=args.batch_size, | |
| # per_device_eval_batch_size=args.batch_size, | |
| # learning_rate=args.learning_rate, | |
| # weight_decay=args.weight_decay, | |
| # temperature=args.temperature, | |
| # freeze_protein_model=args.freeze_protein_model, | |
| # freeze_text_model=args.freeze_text_model, | |
| # protein_weight=args.protein_weight, | |
| # text_weight=args.text_weight, | |
| # max_length_protein=args.max_length_protein, | |
| # max_length_text=args.max_length_text, | |
| # logging_steps=args.logging_steps, | |
| # evaluation_strategy="steps" if eval_dataset else "no", | |
| # eval_steps=args.eval_steps if eval_dataset else None, | |
| # save_steps=args.save_steps, | |
| # save_total_limit=args.save_total_limit, | |
| # load_best_model_at_end=True if eval_dataset else False, | |
| # metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None, | |
| # greater_is_better=True, | |
| # report_to=["wandb"] if args.use_wandb else [], | |
| # warmup_steps=args.warmup_steps, | |
| # gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| # fp16=args.fp16, | |
| # bf16=args.bf16, | |
| # dataloader_num_workers=args.num_workers, | |
| # remove_unused_columns=False, | |
| # seed=args.seed, | |
| # ) | |
| # print("Initializing trainer...") | |
| # # Initialize the contrastive trainer | |
| # trainer = ContrastiveTrainer( | |
| # model=model, | |
| # args=training_args, | |
| # train_dataset=train_dataset, | |
| # eval_dataset=eval_dataset, | |
| # data_collator=protein_text_collate_fn, | |
| # ) | |
| # print("Starting contrastive training...") | |
| # # Train the model | |
| # trainer.train() | |
| # print("Saving final model...") | |
| # # Save the final model | |
| # trainer.save_model() | |
| # # Save only the projection layer weights for later use | |
| # projection_path = os.path.join(args.output_dir, "protein_projection.pt") | |
| # torch.save(model.protein_projection.state_dict(), projection_path) | |
| # print(f"Saved protein projection weights to: {projection_path}") | |
| # # Final evaluation | |
| # if eval_dataset: | |
| # print("Running final evaluation...") | |
| # eval_results = trainer.evaluate() | |
| # print(f"Final evaluation results: {eval_results}") | |
| # if args.use_wandb: | |
| # wandb.log({"final_eval": eval_results}) | |
| # print("Contrastive training completed!") | |
| # if args.use_wandb: | |
| # wandb.finish() | |
| # return trainer | |
| # if __name__ == "__main__": | |
| # parser = ArgumentParser(description="Contrastive pre-training for Protein-LLM") | |
| # # Model configuration | |
| # parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B", | |
| # help="Name or path to the text model") | |
| # parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D", | |
| # help="Name or path to the protein model") | |
| # parser.add_argument("--qformer_model_name", type=str, | |
| # default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext", | |
| # help="Name or path to the QFormer model") | |
| # parser.add_argument("--cache_dir", type=str, default="/model-weights", | |
| # help="Directory to cache downloaded models") | |
| # parser.add_argument("--num_query_tokens", type=int, default=32, | |
| # help="Number of query tokens in QFormer") | |
| # # Dataset configuration | |
| # parser.add_argument("--dataset_name", type=str, default="wanglab/protein_descriptions", | |
| # help="Name of the dataset for contrastive learning") | |
| # parser.add_argument("--eval_dataset", action="store_true", | |
| # help="Whether to use evaluation dataset") | |
| # # Training configuration | |
| # parser.add_argument("--output_dir", type=str, default="./contrastive_outputs", | |
| # help="Output directory for model and logs") | |
| # parser.add_argument("--num_epochs", type=int, default=10, | |
| # help="Number of training epochs") | |
| # parser.add_argument("--batch_size", type=int, default=32, | |
| # help="Batch size per device") | |
| # parser.add_argument("--learning_rate", type=float, default=1e-4, | |
| # help="Learning rate") | |
| # parser.add_argument("--weight_decay", type=float, default=0.01, | |
| # help="Weight decay") | |
| # parser.add_argument("--warmup_steps", type=int, default=1000, | |
| # help="Number of warmup steps") | |
| # parser.add_argument("--gradient_accumulation_steps", type=int, default=1, | |
| # help="Gradient accumulation steps") | |
| # # Contrastive learning specific | |
| # parser.add_argument("--temperature", type=float, default=0.07, | |
| # help="Temperature for contrastive loss") | |
| # parser.add_argument("--freeze_protein_model", action="store_true", default=True, | |
| # help="Freeze protein model during training") | |
| # parser.add_argument("--freeze_text_model", action="store_true", default=True, | |
| # help="Freeze text model during training") | |
| # parser.add_argument("--protein_weight", type=float, default=1.0, | |
| # help="Weight for protein features in contrastive loss") | |
| # parser.add_argument("--text_weight", type=float, default=1.0, | |
| # help="Weight for text features in contrastive loss") | |
| # # Data configuration | |
| # parser.add_argument("--max_length_protein", type=int, default=1024, | |
| # help="Maximum length for protein sequences") | |
| # parser.add_argument("--max_length_text", type=int, default=512, | |
| # help="Maximum length for text sequences") | |
| # parser.add_argument("--num_workers", type=int, default=4, | |
| # help="Number of data loading workers") | |
| # # Logging and evaluation | |
| # parser.add_argument("--logging_steps", type=int, default=100, | |
| # help="Number of steps between logging") | |
| # parser.add_argument("--eval_steps", type=int, default=500, | |
| # help="Number of steps between evaluations") | |
| # parser.add_argument("--save_steps", type=int, default=1000, | |
| # help="Number of steps between saving checkpoints") | |
| # parser.add_argument("--save_total_limit", type=int, default=3, | |
| # help="Maximum number of checkpoints to keep") | |
| # # Hardware configuration | |
| # parser.add_argument("--fp16", action="store_true", | |
| # help="Use FP16 precision") | |
| # parser.add_argument("--bf16", action="store_true", | |
| # help="Use BF16 precision") | |
| # parser.add_argument("--seed", type=int, default=42, | |
| # help="Random seed") | |
| # # Wandb logging | |
| # parser.add_argument("--use_wandb", action="store_true", | |
| # help="Use Weights & Biases for logging") | |
| # parser.add_argument("--wandb_project", type=str, default="protein-llm-contrastive", | |
| # help="Wandb project name") | |
| # parser.add_argument("--wandb_entity", type=str, default=None, | |
| # help="Wandb entity name") | |
| # args = parser.parse_args() | |
| # # Create output directory | |
| # os.makedirs(args.output_dir, exist_ok=True) | |
| # # Run contrastive training | |
| # trainer = main(args) | |
| import os | |
| import time | |
| from argparse import ArgumentParser | |
| from functools import partial | |
| import torch | |
| import wandb | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | |
| from pytorch_lightning.loggers import WandbLogger | |
| from pytorch_lightning.strategies import DeepSpeedStrategy | |
| from bioreason.models.protein_llm import ProteinLLMModel | |
| from bioreason.trainer.contrast_trainer_new import ( | |
| ContrastiveTrainer, | |
| ContrastiveTrainingArguments, | |
| protein_text_collate_fn, | |
| ) | |
| from bioreason.dataset.protein import format_protein_contrastive | |
| def main(args): | |
| """ | |
| Main function for enhanced contrastive pre-training of Protein-LLM. | |
| This script trains the QFormer projection layer to align protein and text representations | |
| using enhanced contrastive learning with optional protein-text matching before fine-tuning. | |
| """ | |
| # Set random seed | |
| pl.seed_everything(args.seed) | |
| torch.cuda.empty_cache() | |
| torch.set_float32_matmul_precision("medium") | |
| # Initialize wandb | |
| if args.use_wandb: | |
| wandb.init( | |
| project=args.wandb_project, | |
| entity=args.wandb_entity, | |
| name=f"enhanced-contrastive-{args.text_model_name.split('/')[-1]}-{time.strftime('%Y%m%d-%H%M%S')}", | |
| config=vars(args) | |
| ) | |
| print("Loading model...") | |
| # Load the Protein-LLM model | |
| model = ProteinLLMModel( | |
| text_model_name=args.text_model_name, | |
| protein_model_name=args.protein_model_name, | |
| qformer_model_name=args.qformer_model_name, | |
| cache_dir=args.cache_dir, | |
| max_length_protein=args.max_length_protein, | |
| max_length_text=args.max_length_text, | |
| text_model_finetune=False, # Don't fine-tune during contrastive learning | |
| protein_model_finetune=False, # Don't fine-tune during contrastive learning | |
| num_query_tokens=args.num_query_tokens, | |
| ) | |
| print("Loading datasets...") | |
| # Load datasets for contrastive learning | |
| train_dataset = load_dataset("json", data_files=args.train_dataset, split="train") | |
| eval_dataset = load_dataset("json", data_files=args.valid_dataset, split="train") if args.eval_dataset else None | |
| # Format datasets for contrastive learning | |
| train_dataset = train_dataset.map(format_protein_contrastive) | |
| if eval_dataset: | |
| eval_dataset = eval_dataset.map(format_protein_contrastive) | |
| # Filter out examples without protein sequences or descriptions | |
| train_dataset = train_dataset.filter( | |
| lambda x: x["protein"] and x["text"] | |
| and len(x["protein"].strip()) > 0 | |
| and len(x["text"].strip()) > 0 | |
| ) | |
| if eval_dataset: | |
| eval_dataset = eval_dataset.filter( | |
| lambda x: x["protein"] and x["text"] | |
| and len(x["protein"].strip()) > 0 | |
| and len(x["text"].strip()) > 0 | |
| ) | |
| print(f"Training dataset size: {len(train_dataset)}") | |
| if eval_dataset: | |
| print(f"Eval dataset size: {len(eval_dataset)}") | |
| # Setup enhanced training arguments for contrastive learning | |
| training_args = ContrastiveTrainingArguments( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.num_epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size, | |
| learning_rate=args.learning_rate, | |
| weight_decay=args.weight_decay, | |
| temperature=args.temperature, | |
| freeze_protein_model=args.freeze_protein_model, | |
| freeze_text_model=args.freeze_text_model, | |
| protein_weight=args.protein_weight, | |
| text_weight=args.text_weight, | |
| enable_ptm=args.enable_ptm, | |
| ptm_weight=args.ptm_weight, | |
| max_length_protein=args.max_length_protein, | |
| max_length_text=args.max_length_text, | |
| logging_steps=args.logging_steps, | |
| eval_strategy="steps" if eval_dataset else "no", | |
| eval_steps=args.eval_steps if eval_dataset else None, | |
| save_steps=args.save_steps, | |
| save_total_limit=args.save_total_limit, | |
| load_best_model_at_end=True if eval_dataset else False, | |
| metric_for_best_model="eval_avg_recall_at_1" if eval_dataset else None, | |
| greater_is_better=True, | |
| report_to=["wandb"] if args.use_wandb else [], | |
| warmup_steps=args.warmup_steps, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| fp16=args.fp16, | |
| bf16=args.bf16, | |
| dataloader_num_workers=args.num_workers, | |
| remove_unused_columns=False, | |
| seed=args.seed, | |
| # Distributed training settings | |
| ddp_find_unused_parameters=False, | |
| dataloader_pin_memory=True, | |
| ) | |
| print("Initializing enhanced trainer...") | |
| # Initialize the enhanced contrastive trainer | |
| trainer = ContrastiveTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| data_collator=protein_text_collate_fn, | |
| ) | |
| print("Starting enhanced contrastive training...") | |
| print(f"- Contrastive learning enabled") | |
| print(f"- Protein-text matching: {'enabled' if args.enable_ptm else 'disabled'}") | |
| print(f"- Temperature: {args.temperature}") | |
| print(f"- PTM weight: {args.ptm_weight}") | |
| # Train the model | |
| trainer.train() | |
| print("Saving final model...") | |
| # Save the final model | |
| trainer.save_model() | |
| # Save projection layer weights and PTM head if enabled | |
| projection_path = os.path.join(args.output_dir, "protein_projection.pt") | |
| torch.save(model.protein_projection.state_dict(), projection_path) | |
| print(f"Saved protein projection weights to: {projection_path}") | |
| if args.enable_ptm and hasattr(trainer.contrastive_loss, 'ptm_head'): | |
| ptm_head_path = os.path.join(args.output_dir, "ptm_head.pt") | |
| torch.save(trainer.contrastive_loss.ptm_head.state_dict(), ptm_head_path) | |
| print(f"Saved PTM head weights to: {ptm_head_path}") | |
| # Final evaluation | |
| if eval_dataset: | |
| print("Running final evaluation...") | |
| eval_results = trainer.evaluate() | |
| print(f"Final evaluation results: {eval_results}") | |
| if args.use_wandb: | |
| wandb.log({"final_eval": eval_results}) | |
| print("Enhanced contrastive training completed!") | |
| if args.use_wandb: | |
| wandb.finish() | |
| return trainer | |
| if __name__ == "__main__": | |
| parser = ArgumentParser(description="Enhanced contrastive pre-training for Protein-LLM") | |
| # Model configuration | |
| parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B", | |
| help="Name or path to the text model") | |
| parser.add_argument("--protein_model_name", type=str, default="facebook/esm2_t6_8M_UR50D", | |
| help="Name or path to the protein model") | |
| parser.add_argument("--qformer_model_name", type=str, | |
| default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext", | |
| help="Name or path to the QFormer model") | |
| parser.add_argument("--cache_dir", type=str, default="/model-weights", | |
| help="Directory to cache downloaded models") | |
| parser.add_argument("--num_query_tokens", type=int, default=32, | |
| help="Number of query tokens in QFormer") | |
| # Dataset configuration | |
| parser.add_argument("--train_dataset", type=str, default="wanglab/protein_descriptions", | |
| help="Name of the dataset for contrastive learning") | |
| parser.add_argument("--valid_dataset", type=str, default="wanglab/protein_descriptions", | |
| help="Name of the dataset for contrastive learning") | |
| parser.add_argument("--eval_dataset", action="store_true", | |
| help="Whether to use evaluation dataset") | |
| # Training configuration | |
| parser.add_argument("--output_dir", type=str, default="./enhanced_contrastive_outputs", | |
| help="Output directory for model and logs") | |
| parser.add_argument("--num_epochs", type=int, default=10, | |
| help="Number of training epochs") | |
| parser.add_argument("--batch_size", type=int, default=32, | |
| help="Batch size per device") | |
| parser.add_argument("--learning_rate", type=float, default=1e-4, | |
| help="Learning rate") | |
| parser.add_argument("--weight_decay", type=float, default=0.01, | |
| help="Weight decay") | |
| parser.add_argument("--warmup_steps", type=int, default=1000, | |
| help="Number of warmup steps") | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1, | |
| help="Gradient accumulation steps") | |
| # Enhanced contrastive learning specific | |
| parser.add_argument("--temperature", type=float, default=0.07, | |
| help="Temperature for contrastive loss") | |
| parser.add_argument("--freeze_protein_model", action="store_true", default=True, | |
| help="Freeze protein model during training") | |
| parser.add_argument("--freeze_text_model", action="store_true", default=True, | |
| help="Freeze text model during training") | |
| parser.add_argument("--protein_weight", type=float, default=1.0, | |
| help="Weight for protein features in contrastive loss") | |
| parser.add_argument("--text_weight", type=float, default=1.0, | |
| help="Weight for text features in contrastive loss") | |
| # Protein-Text Matching (PTM) configuration | |
| parser.add_argument("--enable_ptm", action="store_true", default=True, | |
| help="Enable protein-text matching task") | |
| parser.add_argument("--ptm_weight", type=float, default=1.0, | |
| help="Weight for protein-text matching loss") | |
| # Data configuration | |
| parser.add_argument("--max_length_protein", type=int, default=1024, | |
| help="Maximum length for protein sequences") | |
| parser.add_argument("--max_length_text", type=int, default=512, | |
| help="Maximum length for text sequences") | |
| parser.add_argument("--num_workers", type=int, default=4, | |
| help="Number of data loading workers") | |
| # Logging and evaluation | |
| parser.add_argument("--logging_steps", type=int, default=100, | |
| help="Number of steps between logging") | |
| parser.add_argument("--eval_steps", type=int, default=500, | |
| help="Number of steps between evaluations") | |
| parser.add_argument("--save_steps", type=int, default=1000, | |
| help="Number of steps between saving checkpoints") | |
| parser.add_argument("--save_total_limit", type=int, default=3, | |
| help="Maximum number of checkpoints to keep") | |
| # Hardware configuration | |
| parser.add_argument("--fp16", action="store_true", | |
| help="Use FP16 precision") | |
| parser.add_argument("--bf16", action="store_true", | |
| help="Use BF16 precision") | |
| parser.add_argument("--seed", type=int, default=42, | |
| help="Random seed") | |
| # Wandb logging | |
| parser.add_argument("--use_wandb", action="store_true", | |
| help="Use Weights & Biases for logging") | |
| parser.add_argument("--wandb_project", type=str, default="protein-llm-enhanced-contrastive", | |
| help="Wandb project name") | |
| parser.add_argument("--wandb_entity", type=str, default=None, | |
| help="Wandb entity name") | |
| args = parser.parse_args() | |
| # Validate arguments | |
| if args.enable_ptm and not hasattr(args, 'ptm_weight'): | |
| args.ptm_weight = 1.0 | |
| # Create output directory | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Print configuration | |
| print("=" * 50) | |
| print("Enhanced Contrastive Training Configuration:") | |
| print("=" * 50) | |
| print(f"Text model: {args.text_model_name}") | |
| print(f"Protein model: {args.protein_model_name}") | |
| print(f"QFormer model: {args.qformer_model_name}") | |
| print(f"Dataset: {args.train_dataset}") | |
| print(f"Output directory: {args.output_dir}") | |
| print(f"Batch size: {args.batch_size}") | |
| print(f"Learning rate: {args.learning_rate}") | |
| print(f"Temperature: {args.temperature}") | |
| print(f"Enable PTM: {args.enable_ptm}") | |
| print(f"PTM weight: {args.ptm_weight}") | |
| print(f"Number of epochs: {args.num_epochs}") | |
| print(f"Query tokens: {args.num_query_tokens}") | |
| print("=" * 50) | |
| # Run enhanced contrastive training | |
| trainer = main(args) |