Feature Extraction
PEFT
Safetensors
protein
protein-language-model
embeddings
lora
llm2vec
progen2
bidirectional
Instructions to use ratishsp/progen2-base-bidirectional-llm2vec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use ratishsp/progen2-base-bidirectional-llm2vec with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python | |
| """LLM2Vec-style bidirectional adaptation of ProGen2 — training entrypoint. | |
| Launch: | |
| srun torchrun --standalone --nproc_per_node=4 pretrain.py [args] | |
| Single-GPU / smoke also works without torchrun (falls back to rank 0). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| from transformers import AutoTokenizer, get_cosine_schedule_with_warmup | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from src.bidir_progen import load_bidir_progen # noqa: E402 | |
| from src.data import ProteinSeqDataset, MNTPCollator, CleanCollator, load_sequences # noqa: E402 | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--model-name", default="hugohrban/progen2-base") | |
| p.add_argument("--objective", default="joint", choices=["mntp", "simcse", "joint"]) | |
| p.add_argument("--output-dir", required=True) | |
| # two-stage (LLM2Vec): SimCSE stage resumes the MNTP adapter and enables dropout | |
| p.add_argument("--init-adapter", default=None, | |
| help="LoRA adapter dir to resume (SimCSE stage starts from MNTP)") | |
| p.add_argument("--simcse-dropout", type=float, default=None, | |
| help="force all dropout to this prob (SimCSE positive-pair augmentation)") | |
| # data | |
| p.add_argument("--hf-dataset", default=None, help="HF dataset id of protein seqs") | |
| p.add_argument("--hf-config", default=None) | |
| p.add_argument("--text-column", default="sequence") | |
| p.add_argument("--num-sequences", type=int, default=2000) | |
| p.add_argument("--max-length", type=int, default=256) | |
| p.add_argument("--mlm-probability", type=float, default=0.15) | |
| # lora | |
| p.add_argument("--lora-r", type=int, default=16) | |
| p.add_argument("--lora-alpha", type=int, default=32) | |
| p.add_argument("--lora-dropout", type=float, default=0.05) | |
| # contrastive | |
| p.add_argument("--simcse-weight", type=float, default=0.1) | |
| p.add_argument("--temperature", type=float, default=0.05) | |
| # optim | |
| p.add_argument("--per-device-batch-size", type=int, default=8) | |
| p.add_argument("--gradient-accumulation-steps", type=int, default=1) | |
| p.add_argument("--lr", type=float, default=1e-4) | |
| p.add_argument("--weight-decay", type=float, default=0.01) | |
| p.add_argument("--warmup-steps", type=int, default=10) | |
| p.add_argument("--max-steps", type=int, default=100) | |
| p.add_argument("--logging-steps", type=int, default=1) | |
| p.add_argument("--save-steps", type=int, default=100000) | |
| p.add_argument("--seed", type=int, default=0) | |
| return p.parse_args() | |
| def setup_dist(): | |
| if "RANK" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1: | |
| dist.init_process_group("nccl") | |
| rank = dist.get_rank() | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| torch.cuda.set_device(local_rank) | |
| return rank, local_rank, dist.get_world_size(), True | |
| local_rank = 0 | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(0) | |
| return 0, local_rank, 1, False | |
| def is_main(rank): | |
| return rank == 0 | |
| def log(rank, msg): | |
| if is_main(rank): | |
| print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) | |
| def main(): | |
| args = parse_args() | |
| torch.manual_seed(args.seed) | |
| rank, local_rank, world_size, distributed = setup_dist() | |
| device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| log(rank, f"world_size={world_size} device={device} dtype={dtype} objective={args.objective}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.convert_ids_to_tokens(0) | |
| model, info = load_bidir_progen( | |
| args.model_name, args.objective, args.lora_r, args.lora_alpha, | |
| args.lora_dropout, args.simcse_weight, args.temperature, dtype=dtype, | |
| init_adapter=args.init_adapter, attn_dropout=args.simcse_dropout, | |
| ) | |
| log(rank, f"bidirectional patch: {info['patched_layers']} layers; " | |
| f"lora targets: {info['lora_targets']}; " | |
| f"resumed_adapter={info['resumed_adapter']}; dropout_set={info['dropout_set']}") | |
| model.to(device) | |
| if is_main(rank): | |
| model.model.print_trainable_parameters() | |
| seqs = load_sequences(args.num_sequences, args.hf_dataset, args.hf_config, | |
| args.text_column, seed=args.seed) | |
| log(rank, f"loaded {len(seqs)} sequences (e.g. len={len(seqs[0])})") | |
| dataset = ProteinSeqDataset(seqs, tokenizer, max_length=args.max_length) | |
| # SimCSE stage = clean (unmasked) input; MNTP/joint = BERT-style masking. | |
| collator = (CleanCollator(tokenizer) if args.objective == "simcse" | |
| else MNTPCollator(tokenizer, mlm_probability=args.mlm_probability)) | |
| sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, | |
| shuffle=True, seed=args.seed) if distributed else None | |
| loader = DataLoader(dataset, batch_size=args.per_device_batch_size, | |
| sampler=sampler, shuffle=sampler is None, | |
| collate_fn=collator, drop_last=True) | |
| if distributed: | |
| model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) | |
| core = model.module if distributed else model | |
| optim = torch.optim.AdamW( | |
| [p for p in model.parameters() if p.requires_grad], | |
| lr=args.lr, weight_decay=args.weight_decay, | |
| ) | |
| sched = get_cosine_schedule_with_warmup(optim, args.warmup_steps, args.max_steps) | |
| model.train() | |
| step = 0 | |
| t0 = time.time() | |
| data_iter = iter(loader) | |
| epoch = 0 | |
| while step < args.max_steps: | |
| optim.zero_grad(set_to_none=True) | |
| accum_logs = {} | |
| for micro in range(args.gradient_accumulation_steps): | |
| try: | |
| batch = next(data_iter) | |
| except StopIteration: | |
| epoch += 1 | |
| if sampler is not None: | |
| sampler.set_epoch(epoch) | |
| data_iter = iter(loader) | |
| batch = next(data_iter) | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| out = model(**batch) | |
| loss = out["loss"] | |
| (loss / args.gradient_accumulation_steps).backward() | |
| for k, v in out["logs"].items(): | |
| accum_logs[k] = accum_logs.get(k, 0.0) + v.item() | |
| torch.nn.utils.clip_grad_norm_( | |
| [p for p in model.parameters() if p.requires_grad], 1.0) | |
| optim.step() | |
| sched.step() | |
| step += 1 | |
| if step % args.logging_steps == 0: | |
| parts = " ".join(f"{k}={v/args.gradient_accumulation_steps:.4f}" | |
| for k, v in accum_logs.items()) | |
| sps = step / (time.time() - t0) | |
| log(rank, f"step {step}/{args.max_steps} loss={loss.item():.4f} {parts} " | |
| f"lr={sched.get_last_lr()[0]:.2e} {sps:.2f} step/s") | |
| if is_main(rank): | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| core.model.save_pretrained(args.output_dir) # saves LoRA adapter | |
| tokenizer.save_pretrained(args.output_dir) | |
| log(rank, f"saved adapter + tokenizer to {args.output_dir}") | |
| if distributed: | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() | |