File size: 4,725 Bytes
64253c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from typing import Any, Dict, Optional
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig


class PseudoMoETransformer(nn.Module):
    def __init__(
        self,
        model_name_or_path: str = None,
        max_seq_length: Optional[int] = None,
        config_args: Optional[Dict[str, Any]] = None,
        model_args: Optional[Dict[str, Any]] = None,
        tokenizer_args: Optional[Dict[str, Any]] = None,
        trust_remote_code: bool = True,
        default_domain: str = "general",
        **kwargs,
    ):
        super().__init__()

        if model_name_or_path is None:
            return  

        model_args = model_args or {}
        config_kwargs = config_args or {}
        tokenizer_args = tokenizer_args or {}

        self.config = AutoConfig.from_pretrained(
            model_name_or_path, trust_remote_code=trust_remote_code, **config_kwargs
        )
        self.auto_model = AutoModel.from_pretrained(
            model_name_or_path,
            trust_remote_code=trust_remote_code,
            **model_args,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, **tokenizer_args
        )
        self.max_seq_length = max_seq_length or self.config.max_position_embeddings
        self._model_name_or_path = model_name_or_path
        self.default_domain = default_domain

    @staticmethod
    def load(
        model_name_or_path: str,
        subfolder: str = "",
        trust_remote_code: bool = True,
        **kwargs,
    ) -> "PseudoMoETransformer":
        load_path = os.path.join(model_name_or_path, subfolder) if subfolder else model_name_or_path

        model_kwargs = kwargs.get("model_kwargs", {})
        config_kwargs = kwargs.get("config_kwargs", {})
        tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})

        return PseudoMoETransformer(
            model_name_or_path=load_path,
            model_args=model_kwargs,
            config_args=config_kwargs,
            tokenizer_args=tokenizer_kwargs,
            trust_remote_code=trust_remote_code,
        )

    def tokenize(self, texts: list, **kwargs) -> dict:
        return self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.max_seq_length,
            return_tensors="pt",
        )

    def forward(self, 
                features: dict,
                domain: Optional[str] = None,
                truncate_dim: Optional[int] = None,
                **kwargs) -> dict:
        if domain is None:
            if self.default_domain is None:
                raise ValueError(
                    "Task must be specified before encoding data. You can set it either during "
                    "loading the model (e.g., model_kwargs={'default_domain': 'general'}) or "
                    "pass it as an argument to the encode method (e.g., model.encode(texts, domain='general'))."
                )
            domain = self.default_domain
        else:
            if domain not in self.config.domain_names:
                raise ValueError(
                    f"Invalid domain: {domain}. Must be one of {self.config.domain_names}."
                )
        self.set_domain(domain)
        
        input_ids = features["input_ids"]
        attention_mask = features.get("attention_mask")

        outputs = self.auto_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )

        token_embeddings = outputs.last_hidden_state
        if truncate_dim is not None:
            if not isinstance(truncate_dim, int) or truncate_dim <= 0:
                raise ValueError(f"truncate_dim must be a positive integer, got: {truncate_dim}")
            if truncate_dim < token_embeddings.shape[-1]:
                token_embeddings = token_embeddings[..., :truncate_dim]

        features["token_embeddings"] = token_embeddings
        if attention_mask is not None:
            features["attention_mask"] = attention_mask

        return features

    def get_word_embedding_dimension(self) -> int:
        return self.config.proj_dim

    def get_sentence_embedding_dimension(self) -> int:
        return self.config.proj_dim

    def get_max_seq_length(self) -> int:
        return self.max_seq_length

    def save(self, output_path: str, safe_serialization: bool = True):
        self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
        self.tokenizer.save_pretrained(output_path)

    def set_domain(self, domain: str):
        self.auto_model.set_domain(domain)

    @property
    def domain(self) -> str:
        return self.auto_model.domain