raul3820 commited on
Commit
46fac09
·
1 Parent(s): d279f37

Fix head_mask documentation errors in model classes

Browse files

Added missing head_mask parameter documentation to:
- BertHashModel.forward
- BertHashForMaskedLM.forward
- BertHashForSequenceClassification.forward

This resolves transformer loading warnings about undocumented head_mask parameter in docstrings.

Files changed (2) hide show
  1. modeling_bert_hash.py +22 -0
  2. test.py +0 -68
modeling_bert_hash.py CHANGED
@@ -232,6 +232,14 @@ class BertHashModel(BertPreTrainedModel):
232
  return_dict: Optional[bool] = None,
233
  cache_position: Optional[torch.Tensor] = None,
234
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
 
 
 
 
 
 
 
 
235
  output_attentions = (
236
  output_attentions
237
  if output_attentions is not None
@@ -432,6 +440,13 @@ class BertHashForMaskedLM(BertPreTrainedModel):
432
  Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
433
  config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
434
  loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
 
 
 
 
 
 
 
435
  """
436
 
437
  return_dict = (
@@ -553,6 +568,13 @@ class BertHashForSequenceClassification(BertPreTrainedModel):
553
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
554
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
555
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
 
 
 
 
 
 
 
556
  """
557
  return_dict = (
558
  return_dict if return_dict is not None else self.config.use_return_dict
 
232
  return_dict: Optional[bool] = None,
233
  cache_position: Optional[torch.Tensor] = None,
234
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
235
+ r"""
236
+ head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
237
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
238
+
239
+ - 1 indicates the head is **not masked**,
240
+ - 0 indicates the head is **masked**.
241
+
242
+ """
243
  output_attentions = (
244
  output_attentions
245
  if output_attentions is not None
 
440
  Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
441
  config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
442
  loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
443
+
444
+ head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
445
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
446
+
447
+ - 1 indicates the head is **not masked**,
448
+ - 0 indicates the head is **masked**.
449
+
450
  """
451
 
452
  return_dict = (
 
568
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
569
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
570
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
571
+
572
+ head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
573
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
574
+
575
+ - 1 indicates the head is **not masked**,
576
+ - 0 indicates the head is **masked**.
577
+
578
  """
579
  return_dict = (
580
  return_dict if return_dict is not None else self.config.use_return_dict
test.py DELETED
@@ -1,68 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModel
2
- import torch
3
- import os
4
- import sys
5
- import io
6
- import tempfile
7
- import shutil
8
-
9
-
10
- # Mean Pooling - Take attention mask into account for correct averaging
11
- def meanpooling(output, mask):
12
- embeddings = output[
13
- 0
14
- ] # First element of model_output contains all token embeddings
15
- mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
16
- return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
17
-
18
-
19
- # Sentences we want sentence embeddings for
20
- sentences = ["This is an example sentence", "Each sentence is converted"]
21
-
22
- # Load model from local repository (current directory)
23
- local_model_path = os.getcwd() # Current directory contains the model files
24
-
25
- print(f"Loading model from local path: {local_model_path}")
26
- # Suppress all output during model loading (including progress bars to stdout and stderr)
27
- # Save original file descriptors
28
- orig_stdout = os.dup(1)
29
- orig_stderr = os.dup(2)
30
- null_fd = os.open(os.devnull, os.O_WRONLY | os.O_CREAT | os.O_TRUNC)
31
- # Redirect stdout and stderr to null
32
- os.dup2(null_fd, 1)
33
- os.dup2(null_fd, 2)
34
- try:
35
- tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
36
- model = AutoModel.from_pretrained(local_model_path, trust_remote_code=True)
37
- finally:
38
- # Restore stdout and stderr
39
- os.dup2(orig_stdout, 1)
40
- os.dup2(orig_stderr, 2)
41
- os.close(null_fd)
42
- os.close(orig_stdout)
43
- os.close(orig_stderr)
44
-
45
- print(f"Model loaded successfully!")
46
-
47
- # Set model to evaluation mode
48
- model.eval()
49
-
50
- # Tokenize sentences
51
- inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
52
-
53
- # Add token_type_ids for transformers 5.x compatibility
54
- if "token_type_ids" not in inputs or inputs["token_type_ids"] is None:
55
- batch_size = inputs["input_ids"].size(0)
56
- seq_length = inputs["input_ids"].size(1)
57
- inputs["token_type_ids"] = torch.zeros(batch_size, seq_length, dtype=torch.long)
58
-
59
- # Compute token embeddings
60
- with torch.no_grad():
61
- output = model(**inputs)
62
-
63
- # Perform pooling. In this case, mean pooling.
64
- embeddings = meanpooling(output, inputs["attention_mask"])
65
-
66
- print("Sentence embeddings:")
67
- print(embeddings)
68
- print(f"\nEmbeddings shape: {embeddings.shape}")