ProtTale-demo / model /blip2.py
Mulah's picture
Initial ProtTale Space
7c15d15
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torch.nn as nn
from lavis.models.base_model import BaseModel
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
from transformers import BertTokenizer, BitsAndBytesConfig
from transformers import EsmTokenizer, EsmModel
try:
from esm.models.esmc import ESMC
except Exception:
# some older installs expose it under esm.models.esmc.esmc
try:
from esm.models import esmc as _esmc_mod
ESMC = _esmc_mod.ESMC
except Exception as e:
raise ImportError(
"Cannot import ESMC. Make sure `pip install esm` succeeded "
"and esm>=2.x is installed. Original error: %r" % e
)
from esm.sdk.api import LogitsConfig
def get_gpu_memory(device=0):
# t = torch.cuda.get_device_properties(device).total_memory
# r = torch.cuda.memory_reserved(device)
# a = torch.cuda.memory_allocated(device)
# f = r-a # free inside reserved
free, total = torch.cuda.mem_get_info(device)
free = free / (1024 ** 3)
total = total / (1024 ** 3)
return free, total-free, total
class Blip2Base(BaseModel):
# @classmethod
# def init_tokenizer(cls):
# tokenizer = BertTokenizer.from_pretrained('./bert_pretrained/')
# tokenizer.add_special_tokens({"bos_token": "[DEC]"})
# return tokenizer
@classmethod
def init_Qformer(cls, model_name, num_query_token, plm_width, cross_attention_freq=2):
assert model_name == 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'
print("bert load microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
encoder_config = BertConfig.from_pretrained(model_name)
encoder_config.encoder_width = plm_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel.from_pretrained(model_name, config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
tokenizer = BertTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
return tokenizer, Qformer, query_tokens
def init_protein_encoder(self, plm_name, load_4bit=False, device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
"""
Create a protein encoder + tokenizer + LayerNorm.
Supported Encoders:
1. ESM2 (HuggingFace transformers): plm_name starts with 'facebook/esm2'
- Uses EsmTokenizer and EsmModel from transformers
- Examples: 'facebook/esm2_t30_150M_UR50D', 'facebook/esm2_t33_650M_UR50D'
2. ESM-C (official ESM package): plm_name starts with 'esmc_'
- Uses ESMC from esm.models.esmc
- Examples: 'esmc_300m', 'esmc_600m'
Args:
plm_name (str): Model name/identifier
load_4bit (bool): Whether to use 4-bit quantization (ESM2 only)
device (str): Target device
Returns:
tuple: (plm_tokenizer, plm_module, ln_layer)
- plm_tokenizer: Tokenizer function or object
- plm_module: The encoder model
- ln_layer: LayerNorm layer for encoder output
"""
# ---------- Case A: ESM-2 (HF transformers) ----------
if str(plm_name).startswith("facebook/esm2"):
plm_tokenizer = EsmTokenizer.from_pretrained(plm_name)
if not load_4bit:
plm = EsmModel.from_pretrained(
plm_name,
add_pooling_layer=False,
torch_dtype=torch.bfloat16,
).to(device)
else:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# Automatic device selection for 4-bit quantization
# Use CUDA_VISIBLE_DEVICES or default to first available device
import os
visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
device_id = int(visible_devices.split(',')[0])
device_map = {"": device_id}
plm = EsmModel.from_pretrained(
plm_name,
add_pooling_layer=False,
quantization_config=quant_config,
load_in_4bit=True,
load_in_8bit=False,
device_map=device_map,
torch_dtype=torch.bfloat16,
)
plm.num_features = plm.config.hidden_size
ln_layer = nn.LayerNorm(plm.num_features)
return plm_tokenizer, plm, ln_layer
# ---------- Case B: ESM-C (official esm package) ----------
elif str(plm_name).startswith("esmc_"):
esmc = ESMC.from_pretrained(plm_name).to(device)
esmc.eval()
# tokenizer shim: return python lists (no tensors here)
def esmc_tokenizer(batch_seqs, *args, **kwargs):
"""
ESM-C tokenizer returns python lists; we intentionally avoid tensors here.
Collate function will handle truncation/padding/tensorization.
"""
if isinstance(batch_seqs, str):
batch_seqs = [batch_seqs]
toks = esmc.tokenizer(batch_seqs) # no return_tensors
# unify to HF-like mapping keys
return {"input_ids": toks["input_ids"]}
class ESMCWrapper(nn.Module):
"""Expose HF-like forward that returns .last_hidden_state [B, L, D]."""
def __init__(self, model):
super().__init__()
self.model = model
# probe hidden size (fallback if missing)
dim = getattr(getattr(model, "config", None), "hidden_size", None)
if dim is None:
with torch.no_grad():
probe = model.tokenizer(["M"])["input_ids"]
# make tensor on device for probing embeddings shape
probe_ids = torch.tensor(probe, dtype=torch.long, device=device)
if probe_ids.dim() == 1:
probe_ids = probe_ids.unsqueeze(0)
emb = self._forward_embeddings(probe_ids)
dim = emb.shape[-1]
self.num_features = dim
def _forward_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Use logits(..., return_embeddings=True) if available; otherwise fallback to __call__.
Returns per-residue embeddings [B, L, D].
"""
try:
out = self.model.logits(
input_ids,
LogitsConfig(sequence=True, return_embeddings=True),
)
emb = out.embeddings
except Exception:
out = self.model(input_ids)
emb = out.embeddings
if emb.dim() == 2:
emb = emb.unsqueeze(0)
return emb
def forward(self, input_ids, attention_mask=None, **kwargs):
emb = self._forward_embeddings(input_ids) # [B, L, D]
class _Out:
pass
ret = _Out()
ret.last_hidden_state = emb
return ret
plm = ESMCWrapper(esmc).to(device)
ln_layer = nn.LayerNorm(plm.num_features)
return esmc_tokenizer, plm, ln_layer
else:
raise ValueError(f"Unknown PLM name: {plm_name}")
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self