jgs-430 commited on
Commit
b4ce828
·
1 Parent(s): c60a23b

updated my_model.py

Browse files
Files changed (1) hide show
  1. my_model.py +5 -15
my_model.py CHANGED
@@ -6,32 +6,22 @@ class StoryPointIncrementModel(nn.Module):
6
  """
7
  A custom model wrapper designed to load and use the weights of a fine-tuned
8
  Transformer model for regression (story point prediction).
9
-
10
- The missing/unexpected keys error indicates the checkpoint contains the full
11
- Transformer structure. We redefine the model to match that structure.
12
  """
13
 
14
- def __init__(self, model_name="prajjwal1/bert-tiny", num_labels=1):
 
15
  super().__init__()
 
16
  # Load the configuration of a small BERT-like model as a base template.
17
- # The actual weights from model.safetensors will be loaded into this structure.
18
- config = AutoConfig.from_pretrained(model_name)
19
 
20
  # We load the base encoder (up to the pooler)
21
  self.encoder = AutoModel.from_config(config)
22
 
23
- # The unexpected keys suggest the saved model structure includes a pooler layer.
24
- # We define a custom regressor head that will be matched by `load_state_dict`
25
- # (or at least provide a place for the final linear layer if it was saved
26
- # under a different name than the original checkpoint).
27
- # We will manually map the final linear layer if necessary.
28
-
29
  # A simple linear layer for regression (predicting a single story point value)
30
  self.regressor = nn.Linear(config.hidden_size, num_labels)
31
 
32
- # A custom property to track if the loading was successful
33
- self.loaded_safetensors_keys = False
34
-
35
  def forward(self, input_ids, attention_mask):
36
  # Pass the tokenized inputs through the Transformer encoder
37
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
 
6
  """
7
  A custom model wrapper designed to load and use the weights of a fine-tuned
8
  Transformer model for regression (story point prediction).
 
 
 
9
  """
10
 
11
+ # CRITICAL FIX: Add cache_dir argument to __init__ and set a default to None
12
+ def __init__(self, model_name="prajjwal1/bert-tiny", num_labels=1, cache_dir=None):
13
  super().__init__()
14
+
15
  # Load the configuration of a small BERT-like model as a base template.
16
+ # PASS cache_dir to from_pretrained to prevent permission errors
17
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
18
 
19
  # We load the base encoder (up to the pooler)
20
  self.encoder = AutoModel.from_config(config)
21
 
 
 
 
 
 
 
22
  # A simple linear layer for regression (predicting a single story point value)
23
  self.regressor = nn.Linear(config.hidden_size, num_labels)
24
 
 
 
 
25
  def forward(self, input_ids, attention_mask):
26
  # Pass the tokenized inputs through the Transformer encoder
27
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)