File size: 8,884 Bytes
7c15d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torch.nn as nn

from lavis.models.base_model import BaseModel
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
from transformers import BertTokenizer, BitsAndBytesConfig
from transformers import EsmTokenizer, EsmModel
try:
    from esm.models.esmc import ESMC
except Exception:
    # some older installs expose it under esm.models.esmc.esmc
    try:
        from esm.models import esmc as _esmc_mod
        ESMC = _esmc_mod.ESMC
    except Exception as e:
        raise ImportError(
            "Cannot import ESMC. Make sure `pip install esm` succeeded "
            "and esm>=2.x is installed. Original error: %r" % e
        )

from esm.sdk.api import LogitsConfig

def get_gpu_memory(device=0):
    # t = torch.cuda.get_device_properties(device).total_memory
    # r = torch.cuda.memory_reserved(device)
    # a = torch.cuda.memory_allocated(device)
    # f = r-a  # free inside reserved
    free, total = torch.cuda.mem_get_info(device)
    free = free / (1024 ** 3)
    total = total / (1024 ** 3)
    return free, total-free, total


class Blip2Base(BaseModel):
    # @classmethod
    # def init_tokenizer(cls):
    #     tokenizer = BertTokenizer.from_pretrained('./bert_pretrained/')
    #     tokenizer.add_special_tokens({"bos_token": "[DEC]"})
    #     return tokenizer

    @classmethod
    def init_Qformer(cls, model_name, num_query_token, plm_width, cross_attention_freq=2):
        assert model_name == 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'
        print("bert load microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
        
        encoder_config = BertConfig.from_pretrained(model_name)
        encoder_config.encoder_width = plm_width
        # insert cross-attention layer every other block
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        
        Qformer = BertLMHeadModel.from_pretrained(model_name, config=encoder_config)
        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)

        tokenizer = BertTokenizer.from_pretrained(model_name)
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
        return tokenizer, Qformer, query_tokens
    

    def init_protein_encoder(self, plm_name, load_4bit=False, device=None):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        """
        Create a protein encoder + tokenizer + LayerNorm.
        
        Supported Encoders:
          1. ESM2 (HuggingFace transformers): plm_name starts with 'facebook/esm2'
             - Uses EsmTokenizer and EsmModel from transformers
             - Examples: 'facebook/esm2_t30_150M_UR50D', 'facebook/esm2_t33_650M_UR50D'
             
          2. ESM-C (official ESM package): plm_name starts with 'esmc_'
             - Uses ESMC from esm.models.esmc
             - Examples: 'esmc_300m', 'esmc_600m'
        
        Args:
            plm_name (str): Model name/identifier
            load_4bit (bool): Whether to use 4-bit quantization (ESM2 only)
            device (str): Target device
            
        Returns:
            tuple: (plm_tokenizer, plm_module, ln_layer)
                - plm_tokenizer: Tokenizer function or object
                - plm_module: The encoder model
                - ln_layer: LayerNorm layer for encoder output
        """
        # ---------- Case A: ESM-2 (HF transformers) ----------
        if str(plm_name).startswith("facebook/esm2"):
            plm_tokenizer = EsmTokenizer.from_pretrained(plm_name)

            if not load_4bit:
                plm = EsmModel.from_pretrained(
                    plm_name,
                    add_pooling_layer=False,
                    torch_dtype=torch.bfloat16,
                ).to(device)
            else:
                quant_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    load_in_8bit=False,
                    llm_int8_threshold=6.0,
                    llm_int8_has_fp16_weight=False,
                    bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                )
                # Automatic device selection for 4-bit quantization
                # Use CUDA_VISIBLE_DEVICES or default to first available device
                import os
                visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
                device_id = int(visible_devices.split(',')[0])
                device_map = {"": device_id}

                plm = EsmModel.from_pretrained(
                    plm_name,
                    add_pooling_layer=False,
                    quantization_config=quant_config,
                    load_in_4bit=True,
                    load_in_8bit=False,
                    device_map=device_map,
                    torch_dtype=torch.bfloat16,
                )

            plm.num_features = plm.config.hidden_size
            ln_layer = nn.LayerNorm(plm.num_features)
            return plm_tokenizer, plm, ln_layer

        # ---------- Case B: ESM-C (official esm package) ----------
        elif str(plm_name).startswith("esmc_"):
            esmc = ESMC.from_pretrained(plm_name).to(device)
            esmc.eval()

            # tokenizer shim: return python lists (no tensors here)
            def esmc_tokenizer(batch_seqs, *args, **kwargs):
                """
                ESM-C tokenizer returns python lists; we intentionally avoid tensors here.
                Collate function will handle truncation/padding/tensorization.
                """
                if isinstance(batch_seqs, str):
                    batch_seqs = [batch_seqs]
                toks = esmc.tokenizer(batch_seqs)  # no return_tensors
                # unify to HF-like mapping keys
                return {"input_ids": toks["input_ids"]}

            class ESMCWrapper(nn.Module):
                """Expose HF-like forward that returns .last_hidden_state [B, L, D]."""
                def __init__(self, model):
                    super().__init__()
                    self.model = model
                    # probe hidden size (fallback if missing)
                    dim = getattr(getattr(model, "config", None), "hidden_size", None)
                    if dim is None:
                        with torch.no_grad():
                            probe = model.tokenizer(["M"])["input_ids"]
                            # make tensor on device for probing embeddings shape
                            probe_ids = torch.tensor(probe, dtype=torch.long, device=device)
                            if probe_ids.dim() == 1:
                                probe_ids = probe_ids.unsqueeze(0)
                            emb = self._forward_embeddings(probe_ids)
                            dim = emb.shape[-1]
                    self.num_features = dim

                def _forward_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
                    """
                    Use logits(..., return_embeddings=True) if available; otherwise fallback to __call__.
                    Returns per-residue embeddings [B, L, D].
                    """
                    try:
                        out = self.model.logits(
                            input_ids,
                            LogitsConfig(sequence=True, return_embeddings=True),
                        )
                        emb = out.embeddings
                    except Exception:
                        out = self.model(input_ids)
                        emb = out.embeddings
                    if emb.dim() == 2:
                        emb = emb.unsqueeze(0)
                    return emb

                def forward(self, input_ids, attention_mask=None, **kwargs):
                    emb = self._forward_embeddings(input_ids)  # [B, L, D]
                    class _Out:
                        pass
                    ret = _Out()
                    ret.last_hidden_state = emb
                    return ret

            plm = ESMCWrapper(esmc).to(device)
            ln_layer = nn.LayerNorm(plm.num_features)
            return esmc_tokenizer, plm, ln_layer

        else:
            raise ValueError(f"Unknown PLM name: {plm_name}")


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self