geevec-embeddings-1.0-lite / pseudo_moe_st_module.py
geevec's picture
Upload folder using huggingface_hub
64253c3 verified
import os
from typing import Any, Dict, Optional
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
class PseudoMoETransformer(nn.Module):
def __init__(
self,
model_name_or_path: str = None,
max_seq_length: Optional[int] = None,
config_args: Optional[Dict[str, Any]] = None,
model_args: Optional[Dict[str, Any]] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
trust_remote_code: bool = True,
default_domain: str = "general",
**kwargs,
):
super().__init__()
if model_name_or_path is None:
return
model_args = model_args or {}
config_kwargs = config_args or {}
tokenizer_args = tokenizer_args or {}
self.config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code, **config_kwargs
)
self.auto_model = AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
**model_args,
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, **tokenizer_args
)
self.max_seq_length = max_seq_length or self.config.max_position_embeddings
self._model_name_or_path = model_name_or_path
self.default_domain = default_domain
@staticmethod
def load(
model_name_or_path: str,
subfolder: str = "",
trust_remote_code: bool = True,
**kwargs,
) -> "PseudoMoETransformer":
load_path = os.path.join(model_name_or_path, subfolder) if subfolder else model_name_or_path
model_kwargs = kwargs.get("model_kwargs", {})
config_kwargs = kwargs.get("config_kwargs", {})
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
return PseudoMoETransformer(
model_name_or_path=load_path,
model_args=model_kwargs,
config_args=config_kwargs,
tokenizer_args=tokenizer_kwargs,
trust_remote_code=trust_remote_code,
)
def tokenize(self, texts: list, **kwargs) -> dict:
return self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt",
)
def forward(self,
features: dict,
domain: Optional[str] = None,
truncate_dim: Optional[int] = None,
**kwargs) -> dict:
if domain is None:
if self.default_domain is None:
raise ValueError(
"Task must be specified before encoding data. You can set it either during "
"loading the model (e.g., model_kwargs={'default_domain': 'general'}) or "
"pass it as an argument to the encode method (e.g., model.encode(texts, domain='general'))."
)
domain = self.default_domain
else:
if domain not in self.config.domain_names:
raise ValueError(
f"Invalid domain: {domain}. Must be one of {self.config.domain_names}."
)
self.set_domain(domain)
input_ids = features["input_ids"]
attention_mask = features.get("attention_mask")
outputs = self.auto_model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
token_embeddings = outputs.last_hidden_state
if truncate_dim is not None:
if not isinstance(truncate_dim, int) or truncate_dim <= 0:
raise ValueError(f"truncate_dim must be a positive integer, got: {truncate_dim}")
if truncate_dim < token_embeddings.shape[-1]:
token_embeddings = token_embeddings[..., :truncate_dim]
features["token_embeddings"] = token_embeddings
if attention_mask is not None:
features["attention_mask"] = attention_mask
return features
def get_word_embedding_dimension(self) -> int:
return self.config.proj_dim
def get_sentence_embedding_dimension(self) -> int:
return self.config.proj_dim
def get_max_seq_length(self) -> int:
return self.max_seq_length
def save(self, output_path: str, safe_serialization: bool = True):
self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
self.tokenizer.save_pretrained(output_path)
def set_domain(self, domain: str):
self.auto_model.set_domain(domain)
@property
def domain(self) -> str:
return self.auto_model.domain