File size: 4,848 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import torch
import torch.nn as nn
from typing import Optional, Union, List, Dict
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
AutoModelForMaskedLM
)
from .base_tokenizer import BaseSequenceTokenizer
presets = {
'GLM2-150': 'tattabio/gLM2_150M',
'GLM2-650': 'tattabio/gLM2_650M',
'GLM2-GAIA': 'tattabio/gLM2_650M_embed'
}
class GLMTokenizerWrapper(BaseSequenceTokenizer):
def __init__(self, tokenizer: AutoTokenizer):
super().__init__(tokenizer)
self.plus_token = "<+>"
if self.plus_token not in self.tokenizer.vocab:
print(f"Warning: Token '{self.plus_token}' not found in GLM tokenizer vocabulary.")
def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]:
if isinstance(sequences, str):
sequences = [sequences]
kwargs.setdefault('return_tensors', 'pt')
kwargs.setdefault('padding', 'longest')
kwargs.setdefault('add_special_tokens', True)
modified_sequences = [self.plus_token + seq for seq in sequences]
tokenized = self.tokenizer(modified_sequences, **kwargs)
return tokenized
class gLM2ForEmbedding(nn.Module):
def __init__(self, model_path: str, dtype: torch.dtype = None):
super().__init__()
self.glm2 = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = False,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
assert not output_attentions or not output_hidden_states, (
"output_attentions=True and output_hidden_states=True are not supported by gLM2ForEmbedding."
)
out = self.glm2(
input_ids=input_ids,
attention_mask=attention_mask
)
return out.last_hidden_state
class gLM2GAIAForEmbedding(nn.Module):
def __init__(self, model_path: str, dtype: torch.dtype = None):
super().__init__()
self.glm2_embed = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True)
self.glm2 = self.glm2_embed.glm2
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = False,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
assert not output_attentions or not output_hidden_states, (
"output_attentions=True and output_hidden_states=True are not supported by gLM2ForEmbedding."
)
out = self.glm2(
input_ids=input_ids,
attention_mask=attention_mask
)
return out.last_hidden_state
def get_glm2_tokenizer(preset: str, model_path: str = None):
return GLMTokenizerWrapper(AutoTokenizer.from_pretrained(model_path or presets[preset], trust_remote_code=True))
def build_glm2_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs):
model_path = model_path or presets[preset]
if masked_lm:
model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval()
else:
if preset == "GLM2-GAIA":
model = gLM2GAIAForEmbedding(model_path, dtype=dtype).eval()
else:
model = gLM2ForEmbedding(model_path, dtype=dtype).eval()
tokenizer = get_glm2_tokenizer(preset)
return model, tokenizer
def get_glm2_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype: torch.dtype = None, model_path: str = None):
model_path = model_path or presets[preset]
if hybrid:
model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval()
else:
if tokenwise:
model = AutoModelForTokenClassification.from_pretrained(
model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True
).eval()
else:
model = AutoModelForSequenceClassification.from_pretrained(
model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True
).eval()
tokenizer = get_glm2_tokenizer(preset)
return model, tokenizer
if __name__ == '__main__':
# py -m src.protify.base_models.glm
model, tokenizer = build_glm2_model('GLM2-650')
print(model)
print(tokenizer)
print(tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL'))
|