File size: 2,251 Bytes
cd40de0 baa7748 cd40de0 93fa830 b6894c7 93fa830 cd40de0 1835a4f 94f40ab cd40de0 a749d20 cd40de0 b6894c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import json
import os
from typing import Union, List, Dict, Tuple
import torch
from sentence_transformers import models
from transformers import AutoModel
class EmbeddingModel(models.Transformer):
def __init__(self, *args, **kwargs):
self.model_name_or_path = "lamarr-llm-development/elbedding"
kwargs.pop("model_name_or_path", None)
super().__init__(*args, **kwargs)
def tokenize(
self,
texts: Union[List[str], List[Dict], List[Tuple[str, str]]],
padding: Union[str, bool] = True,
) -> Dict[str, torch.Tensor]:
"""Tokenizes a text and maps tokens to token-ids"""
output = {}
if isinstance(texts[0], str):
texts = [x + self.tokenizer.eos_token for x in texts]
to_tokenize = [texts]
elif isinstance(texts[0], dict):
to_tokenize = []
output["text_keys"] = []
for lookup in texts:
text_key, text = next(iter(lookup.items()))
to_tokenize.append(text)
output["text_keys"].append(text_key)
to_tokenize = [to_tokenize]
else:
batch1, batch2 = [], []
for text_tuple in texts:
batch1.append(text_tuple[0])
batch2.append(text_tuple[1])
to_tokenize = [batch1, batch2]
output.update(
self.tokenizer(
*to_tokenize,
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=512,
)
)
# this is specific to OpenGPT-X model
output.pop("token_type_ids", None)
return output
def get_config_dict(self) -> dict[str, str]:
return {"model_name_or_path": self.model_name_or_path}
def save(self, save_dir: str, **kwargs) -> None:
self.auto_model.save_pretrained(save_dir, safe_serialization=True)
self.tokenizer.save_pretrained(save_dir)
with open(os.path.join(save_dir, "sentence_bert_config.json"), "w+") as f:
json.dump(self.get_config_dict(), f, indent=4)
@staticmethod
def load(**kwargs) -> "EmbeddingModel":
return EmbeddingModel(**kwargs)
|