elbedding / embedding_model.py
vijusudhi's picture
Update embedding_model.py
a749d20 verified
import json
import os
from typing import Union, List, Dict, Tuple
import torch
from sentence_transformers import models
from transformers import AutoModel
class EmbeddingModel(models.Transformer):
def __init__(self, *args, **kwargs):
self.model_name_or_path = "lamarr-llm-development/elbedding"
kwargs.pop("model_name_or_path", None)
super().__init__(*args, **kwargs)
def tokenize(
self,
texts: Union[List[str], List[Dict], List[Tuple[str, str]]],
padding: Union[str, bool] = True,
) -> Dict[str, torch.Tensor]:
"""Tokenizes a text and maps tokens to token-ids"""
output = {}
if isinstance(texts[0], str):
texts = [x + self.tokenizer.eos_token for x in texts]
to_tokenize = [texts]
elif isinstance(texts[0], dict):
to_tokenize = []
output["text_keys"] = []
for lookup in texts:
text_key, text = next(iter(lookup.items()))
to_tokenize.append(text)
output["text_keys"].append(text_key)
to_tokenize = [to_tokenize]
else:
batch1, batch2 = [], []
for text_tuple in texts:
batch1.append(text_tuple[0])
batch2.append(text_tuple[1])
to_tokenize = [batch1, batch2]
output.update(
self.tokenizer(
*to_tokenize,
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=512,
)
)
# this is specific to OpenGPT-X model
output.pop("token_type_ids", None)
return output
def get_config_dict(self) -> dict[str, str]:
return {"model_name_or_path": self.model_name_or_path}
def save(self, save_dir: str, **kwargs) -> None:
self.auto_model.save_pretrained(save_dir, safe_serialization=True)
self.tokenizer.save_pretrained(save_dir)
with open(os.path.join(save_dir, "sentence_bert_config.json"), "w+") as f:
json.dump(self.get_config_dict(), f, indent=4)
@staticmethod
def load(**kwargs) -> "EmbeddingModel":
return EmbeddingModel(**kwargs)