import os from typing import Any, Dict, Optional from torch import nn from transformers import AutoModel, AutoTokenizer, AutoConfig class PseudoMoETransformer(nn.Module): def __init__( self, model_name_or_path: str = None, max_seq_length: Optional[int] = None, config_args: Optional[Dict[str, Any]] = None, model_args: Optional[Dict[str, Any]] = None, tokenizer_args: Optional[Dict[str, Any]] = None, trust_remote_code: bool = True, default_domain: str = "general", **kwargs, ): super().__init__() if model_name_or_path is None: return model_args = model_args or {} config_kwargs = config_args or {} tokenizer_args = tokenizer_args or {} self.config = AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, **config_kwargs ) self.auto_model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, **model_args, ) self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, **tokenizer_args ) self.max_seq_length = max_seq_length or self.config.max_position_embeddings self._model_name_or_path = model_name_or_path self.default_domain = default_domain @staticmethod def load( model_name_or_path: str, subfolder: str = "", trust_remote_code: bool = True, **kwargs, ) -> "PseudoMoETransformer": load_path = os.path.join(model_name_or_path, subfolder) if subfolder else model_name_or_path model_kwargs = kwargs.get("model_kwargs", {}) config_kwargs = kwargs.get("config_kwargs", {}) tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) return PseudoMoETransformer( model_name_or_path=load_path, model_args=model_kwargs, config_args=config_kwargs, tokenizer_args=tokenizer_kwargs, trust_remote_code=trust_remote_code, ) def tokenize(self, texts: list, **kwargs) -> dict: return self.tokenizer( texts, padding=True, truncation=True, max_length=self.max_seq_length, return_tensors="pt", ) def forward(self, features: dict, domain: Optional[str] = None, truncate_dim: Optional[int] = None, **kwargs) -> dict: if domain is None: if self.default_domain is None: raise ValueError( "Task must be specified before encoding data. You can set it either during " "loading the model (e.g., model_kwargs={'default_domain': 'general'}) or " "pass it as an argument to the encode method (e.g., model.encode(texts, domain='general'))." ) domain = self.default_domain else: if domain not in self.config.domain_names: raise ValueError( f"Invalid domain: {domain}. Must be one of {self.config.domain_names}." ) self.set_domain(domain) input_ids = features["input_ids"] attention_mask = features.get("attention_mask") outputs = self.auto_model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) token_embeddings = outputs.last_hidden_state if truncate_dim is not None: if not isinstance(truncate_dim, int) or truncate_dim <= 0: raise ValueError(f"truncate_dim must be a positive integer, got: {truncate_dim}") if truncate_dim < token_embeddings.shape[-1]: token_embeddings = token_embeddings[..., :truncate_dim] features["token_embeddings"] = token_embeddings if attention_mask is not None: features["attention_mask"] = attention_mask return features def get_word_embedding_dimension(self) -> int: return self.config.proj_dim def get_sentence_embedding_dimension(self) -> int: return self.config.proj_dim def get_max_seq_length(self) -> int: return self.max_seq_length def save(self, output_path: str, safe_serialization: bool = True): self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization) self.tokenizer.save_pretrained(output_path) def set_domain(self, domain: str): self.auto_model.set_domain(domain) @property def domain(self) -> str: return self.auto_model.domain