lucweber commited on
Commit
1965f5e
·
verified ·
1 Parent(s): 58f73dc

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +20 -14
model.py CHANGED
@@ -82,29 +82,35 @@ class CausalLMForRegression(nn.Module):
82
  return tokenizer
83
 
84
  @classmethod
85
- def from_pretrained(cls, output_dir):
86
- from_local = os.path.exists(output_dir)
87
- loading_kwargs = {"use_safetensors": False} if from_local else {}
88
-
89
- model = AutoModelForCausalLM.from_pretrained(output_dir, **loading_kwargs)
90
 
91
- # Explicitly enable `output_hidden_states` after loading
92
- model.config.output_hidden_states = True
 
 
 
93
 
94
  # Create an uninitialized instance of CausalLMForRegression
95
  instance = cls.__new__(cls)
96
  nn.Module.__init__(instance)
 
 
 
 
 
97
  instance._keys_to_ignore_on_save = []
98
- instance.model = model
99
 
100
  # Load the regression head separately
101
- instance.regression_head = nn.Linear(model.config.hidden_size, 1)
102
- try:
103
- regression_head_path = os.path.join(output_dir, "regression_head.bin")
104
- state = torch.load(regression_head_path, map_location="cpu")
 
105
  instance.regression_head.load_state_dict(state)
106
- except FileNotFoundError:
107
- print(f"No regression head found. Initializing with random weights!")
108
  return instance
109
 
110
  @torch.no_grad()
 
82
  return tokenizer
83
 
84
  @classmethod
85
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
86
+
87
+ kwargs.setdefault("output_hidden_states", True)
 
 
88
 
89
+ base_model = AutoModelForCausalLM.from_pretrained(
90
+ pretrained_model_name_or_path,
91
+ *model_args,
92
+ **kwargs
93
+ )
94
 
95
  # Create an uninitialized instance of CausalLMForRegression
96
  instance = cls.__new__(cls)
97
  nn.Module.__init__(instance)
98
+
99
+ instance.model = base_model
100
+ instance.regression_head = nn.Linear(
101
+ base_model.config.hidden_size, 1
102
+ )
103
  instance._keys_to_ignore_on_save = []
 
104
 
105
  # Load the regression head separately
106
+ head_path = os.path.join(
107
+ pretrained_model_name_or_path, "regression_head.bin"
108
+ )
109
+ if os.path.exists(head_path):
110
+ state = torch.load(head_path, map_location="cpu")
111
  instance.regression_head.load_state_dict(state)
112
+ else:
113
+ print("No regression head found initialising randomly.")
114
  return instance
115
 
116
  @torch.no_grad()