ProtT3_model / model /blip2.py
yuccaaa's picture
Add files using upload-large-folder tool
4d12519 verified
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torch.nn as nn
from lavis.models.base_model import BaseModel
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
from transformers import BertTokenizer, BitsAndBytesConfig
from transformers import EsmTokenizer, EsmModel
import os
from pathlib import Path # 添加到文件顶部
def get_gpu_memory(device=0):
# t = torch.cuda.get_device_properties(device).total_memory
# r = torch.cuda.memory_reserved(device)
# a = torch.cuda.memory_allocated(device)
# f = r-a # free inside reserved
free, total = torch.cuda.mem_get_info(device)
free = free / (1024 ** 3)
total = total / (1024 ** 3)
return free, total-free, total
class Blip2Base(BaseModel):
# @classmethod
# def init_tokenizer(cls):
# tokenizer = BertTokenizer.from_pretrained('./bert_pretrained/')
# tokenizer.add_special_tokens({"bos_token": "[DEC]"})
# return tokenizer
@classmethod
def init_Qformer(cls, model_name, num_query_token, plm_width, cross_attention_freq=2):
# assert model_name == 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'
# print("bert load microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
print(f"Loading Qformer from: {model_name}")
# 修改2:添加本地路径检查逻辑
if not model_name.startswith('microsoft/') and Path(model_name).exists():
print("Loading from local path...")
else:
print("Loading from Hugging Face Hub...")
encoder_config = BertConfig.from_pretrained(model_name)
encoder_config.encoder_width = plm_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel.from_pretrained(model_name, config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
tokenizer = BertTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
return tokenizer, Qformer, query_tokens
def init_protein_encoder(self, plm_name, load_4bit=False):
# assert plm_name.startswith('facebook/esm2')
# plm_tokenizer = EsmTokenizer.from_pretrained(plm_name)
# 检查是否为本地路径(判断是否存在文件夹或文件)
if os.path.isdir(plm_name) or os.path.exists(os.path.join(plm_name, "config.json")):
print(f"Loading local PLM from {plm_name}")
plm_tokenizer = EsmTokenizer.from_pretrained(plm_name)
else:
# 保留远程加载逻辑(可选)
print(f"Loading remote PLM from {plm_name}")
plm_tokenizer = EsmTokenizer.from_pretrained(plm_name)
if not load_4bit:
plm = EsmModel.from_pretrained(plm_name, add_pooling_layer=False, torch_dtype=torch.bfloat16)
else:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
)
## give a device map that assign all layers to device 0
outputs = get_gpu_memory(6)
used_memory = outputs[1]
if used_memory > 1:
device_map = {"": 7}
else:
device_map = {"": 6}
plm = EsmModel.from_pretrained(
plm_name,
add_pooling_layer=False,
quantization_config=quant_config,
load_in_4bit=True,
load_in_8bit=False,
device_map=device_map,
torch_dtype=torch.bfloat16,
)
plm.num_features = plm.config.hidden_size
ln_layer = nn.LayerNorm(plm.num_features)
return plm_tokenizer, plm, ln_layer
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
# class LayerNorm(nn.LayerNorm):
# """Subclass torch's LayerNorm to handle fp16."""
# def forward(self, x: torch.Tensor):
# orig_type = x.dtype
# ret = super().forward(x.type(torch.float32))
# return ret.type(orig_type)