import json import os from typing import Union, List, Dict, Tuple import torch from sentence_transformers import models from transformers import AutoModel class EmbeddingModel(models.Transformer): def __init__(self, *args, **kwargs): self.model_name_or_path = "lamarr-llm-development/elbedding" kwargs.pop("model_name_or_path", None) super().__init__(*args, **kwargs) def tokenize( self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True, ) -> Dict[str, torch.Tensor]: """Tokenizes a text and maps tokens to token-ids""" output = {} if isinstance(texts[0], str): texts = [x + self.tokenizer.eos_token for x in texts] to_tokenize = [texts] elif isinstance(texts[0], dict): to_tokenize = [] output["text_keys"] = [] for lookup in texts: text_key, text = next(iter(lookup.items())) to_tokenize.append(text) output["text_keys"].append(text_key) to_tokenize = [to_tokenize] else: batch1, batch2 = [], [] for text_tuple in texts: batch1.append(text_tuple[0]) batch2.append(text_tuple[1]) to_tokenize = [batch1, batch2] output.update( self.tokenizer( *to_tokenize, padding="max_length", truncation=True, return_tensors="pt", max_length=512, ) ) # this is specific to OpenGPT-X model output.pop("token_type_ids", None) return output def get_config_dict(self) -> dict[str, str]: return {"model_name_or_path": self.model_name_or_path} def save(self, save_dir: str, **kwargs) -> None: self.auto_model.save_pretrained(save_dir, safe_serialization=True) self.tokenizer.save_pretrained(save_dir) with open(os.path.join(save_dir, "sentence_bert_config.json"), "w+") as f: json.dump(self.get_config_dict(), f, indent=4) @staticmethod def load(**kwargs) -> "EmbeddingModel": return EmbeddingModel(**kwargs)