Update custom_transformer.py
Browse files- custom_transformer.py +2 -79
custom_transformer.py
CHANGED
|
@@ -33,89 +33,12 @@ class CGETransformer(Transformer):
|
|
| 33 |
config_keys: list[str] = ["max_seq_length", "do_lower_case"]
|
| 34 |
save_in_root: bool = True
|
| 35 |
|
| 36 |
-
|
| 37 |
-
# self,
|
| 38 |
-
# model_name_or_path: str,
|
| 39 |
-
# max_seq_length: int | None = None,
|
| 40 |
-
# model_args: dict[str, Any] | None = None,
|
| 41 |
-
# tokenizer_args: dict[str, Any] | None = None,
|
| 42 |
-
# config_args: dict[str, Any] | None = None,
|
| 43 |
-
# cache_dir: str | None = None,
|
| 44 |
-
# do_lower_case: bool = False,
|
| 45 |
-
# tokenizer_name_or_path: str | None = None,
|
| 46 |
-
# backend: str = "torch",
|
| 47 |
-
# **kwargs
|
| 48 |
-
# ) -> None:
|
| 49 |
-
# super().__init__(model_name_or_path, **kwargs)
|
| 50 |
-
# self.do_lower_case = do_lower_case
|
| 51 |
-
# self.backend = backend
|
| 52 |
-
# if model_args is None:
|
| 53 |
-
# model_args = {}
|
| 54 |
-
# if tokenizer_args is None:
|
| 55 |
-
# tokenizer_args = {}
|
| 56 |
-
# if config_args is None:
|
| 57 |
-
# config_args = {}
|
| 58 |
-
|
| 59 |
-
# config, is_peft_model = self._load_config(model_name_or_path, cache_dir, backend, config_args)
|
| 60 |
-
# self._load_model(model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args)
|
| 61 |
-
|
| 62 |
-
# # Get the signature of the auto_model's forward method to pass only the expected arguments from `features`,
|
| 63 |
-
# # plus some common values like "input_ids", "attention_mask", etc.
|
| 64 |
-
# model_forward_params = list(inspect.signature(self.auto_model.forward).parameters)
|
| 65 |
-
# self.model_forward_params = set(model_forward_params) | {
|
| 66 |
-
# "input_ids",
|
| 67 |
-
# "attention_mask",
|
| 68 |
-
# "token_type_ids",
|
| 69 |
-
# "inputs_embeds",
|
| 70 |
-
# }
|
| 71 |
-
|
| 72 |
-
# if max_seq_length is not None and "model_max_length" not in tokenizer_args:
|
| 73 |
-
# tokenizer_args["model_max_length"] = max_seq_length
|
| 74 |
-
# self.tokenizer = AutoTokenizer.from_pretrained(
|
| 75 |
-
# tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
|
| 76 |
-
# cache_dir=cache_dir,
|
| 77 |
-
# **tokenizer_args,
|
| 78 |
-
# )
|
| 79 |
-
|
| 80 |
-
# # No max_seq_length set. Try to infer from model
|
| 81 |
-
# if max_seq_length is None:
|
| 82 |
-
# if (
|
| 83 |
-
# hasattr(self.auto_model, "config")
|
| 84 |
-
# and hasattr(self.auto_model.config, "max_position_embeddings")
|
| 85 |
-
# and hasattr(self.tokenizer, "model_max_length")
|
| 86 |
-
# ):
|
| 87 |
-
# max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length)
|
| 88 |
-
|
| 89 |
-
# self.max_seq_length = max_seq_length
|
| 90 |
-
|
| 91 |
-
# if tokenizer_name_or_path is not None:
|
| 92 |
-
# self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
|
| 93 |
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
|
| 94 |
-
"""
|
| 95 |
-
Forward pass through the transformer model.
|
| 96 |
-
|
| 97 |
-
This method processes the input features through the underlying transformers model
|
| 98 |
-
and returns the token embeddings along with any other relevant outputs.
|
| 99 |
-
|
| 100 |
-
Notes:
|
| 101 |
-
- Only passes arguments that are expected by the underlying transformer model
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
features (dict[str, torch.Tensor]): Input features dictionary containing at least
|
| 105 |
-
'input_ids' and 'attention_mask'. May also contain other tensors required by
|
| 106 |
-
the underlying transformer model.
|
| 107 |
-
**kwargs: Additional keyword arguments to pass to the underlying transformer model.
|
| 108 |
-
|
| 109 |
-
Returns:
|
| 110 |
-
dict[str, torch.Tensor]: Updated features dictionary containing the input features, plus:
|
| 111 |
-
- 'token_embeddings': Token-level embeddings from the transformer model
|
| 112 |
-
- 'attention_mask': Possibly modified attention mask if using PeftModel with prompt learning
|
| 113 |
-
- 'all_layer_embeddings': If the model outputs hidden states, contains embeddings from all layers
|
| 114 |
-
"""
|
| 115 |
trans_features = {key: value for key, value in features.items() if key in self.model_forward_params}
|
| 116 |
|
| 117 |
outputs = self.auto_model(**trans_features, **kwargs, return_dict=True)
|
| 118 |
-
|
| 119 |
sentence_embedding = outputs["sentence_embedding"]
|
| 120 |
features["sentence_embedding"] = sentence_embedding
|
| 121 |
|
|
|
|
| 33 |
config_keys: list[str] = ["max_seq_length", "do_lower_case"]
|
| 34 |
save_in_root: bool = True
|
| 35 |
|
| 36 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
trans_features = {key: value for key, value in features.items() if key in self.model_forward_params}
|
| 39 |
|
| 40 |
outputs = self.auto_model(**trans_features, **kwargs, return_dict=True)
|
| 41 |
+
|
| 42 |
sentence_embedding = outputs["sentence_embedding"]
|
| 43 |
features["sentence_embedding"] = sentence_embedding
|
| 44 |
|