File size: 2,251 Bytes
cd40de0
 
 
 
 
 
baa7748
cd40de0
 
 
93fa830
b6894c7
 
93fa830
cd40de0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1835a4f
94f40ab
cd40de0
 
 
 
a749d20
 
cd40de0
 
 
 
 
b6894c7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import os
from typing import Union, List, Dict, Tuple

import torch
from sentence_transformers import models
from transformers import AutoModel


class EmbeddingModel(models.Transformer):
    def __init__(self, *args, **kwargs):
        self.model_name_or_path = "lamarr-llm-development/elbedding"
        kwargs.pop("model_name_or_path", None)
        super().__init__(*args, **kwargs)

    def tokenize(
        self,
        texts: Union[List[str], List[Dict], List[Tuple[str, str]]],
        padding: Union[str, bool] = True,
    ) -> Dict[str, torch.Tensor]:
        """Tokenizes a text and maps tokens to token-ids"""
        output = {}
        if isinstance(texts[0], str):
            texts = [x + self.tokenizer.eos_token for x in texts]
            to_tokenize = [texts]
        elif isinstance(texts[0], dict):
            to_tokenize = []
            output["text_keys"] = []
            for lookup in texts:
                text_key, text = next(iter(lookup.items()))
                to_tokenize.append(text)
                output["text_keys"].append(text_key)
            to_tokenize = [to_tokenize]
        else:
            batch1, batch2 = [], []
            for text_tuple in texts:
                batch1.append(text_tuple[0])
                batch2.append(text_tuple[1])
            to_tokenize = [batch1, batch2]

        output.update(
            self.tokenizer(
                *to_tokenize,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                max_length=512,
            )
        )

        # this is specific to OpenGPT-X model
        output.pop("token_type_ids", None)

        return output       

    def get_config_dict(self) -> dict[str, str]:
        return {"model_name_or_path": self.model_name_or_path}

    def save(self, save_dir: str, **kwargs) -> None:
        self.auto_model.save_pretrained(save_dir, safe_serialization=True)
        self.tokenizer.save_pretrained(save_dir)

        with open(os.path.join(save_dir, "sentence_bert_config.json"), "w+") as f:
            json.dump(self.get_config_dict(), f, indent=4)

    @staticmethod
    def load(**kwargs) -> "EmbeddingModel":
        return EmbeddingModel(**kwargs)