Elvis-t9 commited on
Commit
9e91b68
·
verified ·
1 Parent(s): 0a89990

Update custom_transformer.py

Browse files
Files changed (1) hide show
  1. custom_transformer.py +2 -79
custom_transformer.py CHANGED
@@ -33,89 +33,12 @@ class CGETransformer(Transformer):
33
  config_keys: list[str] = ["max_seq_length", "do_lower_case"]
34
  save_in_root: bool = True
35
 
36
- # def __init__(
37
- # self,
38
- # model_name_or_path: str,
39
- # max_seq_length: int | None = None,
40
- # model_args: dict[str, Any] | None = None,
41
- # tokenizer_args: dict[str, Any] | None = None,
42
- # config_args: dict[str, Any] | None = None,
43
- # cache_dir: str | None = None,
44
- # do_lower_case: bool = False,
45
- # tokenizer_name_or_path: str | None = None,
46
- # backend: str = "torch",
47
- # **kwargs
48
- # ) -> None:
49
- # super().__init__(model_name_or_path, **kwargs)
50
- # self.do_lower_case = do_lower_case
51
- # self.backend = backend
52
- # if model_args is None:
53
- # model_args = {}
54
- # if tokenizer_args is None:
55
- # tokenizer_args = {}
56
- # if config_args is None:
57
- # config_args = {}
58
-
59
- # config, is_peft_model = self._load_config(model_name_or_path, cache_dir, backend, config_args)
60
- # self._load_model(model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args)
61
-
62
- # # Get the signature of the auto_model's forward method to pass only the expected arguments from `features`,
63
- # # plus some common values like "input_ids", "attention_mask", etc.
64
- # model_forward_params = list(inspect.signature(self.auto_model.forward).parameters)
65
- # self.model_forward_params = set(model_forward_params) | {
66
- # "input_ids",
67
- # "attention_mask",
68
- # "token_type_ids",
69
- # "inputs_embeds",
70
- # }
71
-
72
- # if max_seq_length is not None and "model_max_length" not in tokenizer_args:
73
- # tokenizer_args["model_max_length"] = max_seq_length
74
- # self.tokenizer = AutoTokenizer.from_pretrained(
75
- # tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
76
- # cache_dir=cache_dir,
77
- # **tokenizer_args,
78
- # )
79
-
80
- # # No max_seq_length set. Try to infer from model
81
- # if max_seq_length is None:
82
- # if (
83
- # hasattr(self.auto_model, "config")
84
- # and hasattr(self.auto_model.config, "max_position_embeddings")
85
- # and hasattr(self.tokenizer, "model_max_length")
86
- # ):
87
- # max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length)
88
-
89
- # self.max_seq_length = max_seq_length
90
-
91
- # if tokenizer_name_or_path is not None:
92
- # self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
93
  def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
94
- """
95
- Forward pass through the transformer model.
96
-
97
- This method processes the input features through the underlying transformers model
98
- and returns the token embeddings along with any other relevant outputs.
99
-
100
- Notes:
101
- - Only passes arguments that are expected by the underlying transformer model
102
-
103
- Args:
104
- features (dict[str, torch.Tensor]): Input features dictionary containing at least
105
- 'input_ids' and 'attention_mask'. May also contain other tensors required by
106
- the underlying transformer model.
107
- **kwargs: Additional keyword arguments to pass to the underlying transformer model.
108
-
109
- Returns:
110
- dict[str, torch.Tensor]: Updated features dictionary containing the input features, plus:
111
- - 'token_embeddings': Token-level embeddings from the transformer model
112
- - 'attention_mask': Possibly modified attention mask if using PeftModel with prompt learning
113
- - 'all_layer_embeddings': If the model outputs hidden states, contains embeddings from all layers
114
- """
115
  trans_features = {key: value for key, value in features.items() if key in self.model_forward_params}
116
 
117
  outputs = self.auto_model(**trans_features, **kwargs, return_dict=True)
118
-
119
  sentence_embedding = outputs["sentence_embedding"]
120
  features["sentence_embedding"] = sentence_embedding
121
 
 
33
  config_keys: list[str] = ["max_seq_length", "do_lower_case"]
34
  save_in_root: bool = True
35
 
36
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  trans_features = {key: value for key, value in features.items() if key in self.model_forward_params}
39
 
40
  outputs = self.auto_model(**trans_features, **kwargs, return_dict=True)
41
+
42
  sentence_embedding = outputs["sentence_embedding"]
43
  features["sentence_embedding"] = sentence_embedding
44