lucweber commited on
Commit
22e477f
·
verified ·
1 Parent(s): f84cc1f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +13 -6
model.py CHANGED
@@ -88,26 +88,33 @@ class CausalLMForRegression(nn.Module):
88
 
89
  @classmethod
90
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
91
- # make sure hidden states are returned
92
- kwargs.setdefault("output_hidden_states", True)
93
 
94
  base_model = Qwen3ForCausalLM.from_pretrained(
95
- pretrained_model_name_or_path, *model_args, **kwargs
 
 
 
 
96
  )
97
 
98
  instance = cls.__new__(cls)
99
  nn.Module.__init__(instance)
 
100
  instance.model = base_model
101
- instance.regression_head = nn.Linear(base_model.config.hidden_size, 1)
 
102
 
103
- head_path = os.path.join(pretrained_model_name_or_path, "regression_head.bin")
 
104
  if os.path.exists(head_path):
105
  instance.regression_head.load_state_dict(
106
  torch.load(head_path, map_location="cpu")
107
  )
108
  else:
109
  print("No regression head found – initialising randomly.")
110
- instance._keys_to_ignore_on_save = []
111
  return instance
112
 
113
  @torch.no_grad()
 
88
 
89
  @classmethod
90
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
91
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
92
+ config.output_hidden_states = True
93
 
94
  base_model = Qwen3ForCausalLM.from_pretrained(
95
+ pretrained_model_name_or_path,
96
+ *model_args,
97
+ config=config,
98
+ **{k: v for k, v in kwargs.items()
99
+ if k not in config.to_dict()}
100
  )
101
 
102
  instance = cls.__new__(cls)
103
  nn.Module.__init__(instance)
104
+
105
  instance.model = base_model
106
+ instance.regression_head = nn.Linear(config.hidden_size, 1)
107
+ instance._keys_to_ignore_on_save = []
108
 
109
+ head_path = os.path.join(pretrained_model_name_or_path,
110
+ "regression_head.bin")
111
  if os.path.exists(head_path):
112
  instance.regression_head.load_state_dict(
113
  torch.load(head_path, map_location="cpu")
114
  )
115
  else:
116
  print("No regression head found – initialising randomly.")
117
+
118
  return instance
119
 
120
  @torch.no_grad()