lucweber commited on
Commit
814735f
·
verified ·
1 Parent(s): 6166137

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +30 -23
model.py CHANGED
@@ -1,15 +1,14 @@
1
  import os
2
  from typing import Optional
3
  from transformers import Qwen3ForCausalLM, AutoTokenizer, AutoConfig
 
4
  import torch
5
  import torch.nn as nn
6
 
7
- from transformers.models.qwen3 import Qwen3Config
8
-
9
 
10
  # Define a custom model that wraps a causal LM and adds a regression head
11
  class CausalLMForRegression(nn.Module):
12
- config_class = Qwen3Config
13
  base_model_prefix = "model"
14
 
15
  def __init__(self, model_name):
@@ -88,35 +87,43 @@ class CausalLMForRegression(nn.Module):
88
 
89
  @classmethod
90
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
91
- config = kwargs.pop("config", None)
92
- if config is None:
93
- from transformers import AutoConfig
94
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
95
- config.output_hidden_states = True
96
-
97
- from transformers import Qwen3ForCausalLM
98
- base_model = Qwen3ForCausalLM.from_pretrained(
99
  pretrained_model_name_or_path,
100
  *model_args,
101
- config=config,
102
- **kwargs
 
103
  )
104
 
105
- instance = cls.__new__(cls)
106
- nn.Module.__init__(instance)
107
- instance.model = base_model
108
- instance.regression_head = nn.Linear(config.hidden_size, 1)
109
- instance._keys_to_ignore_on_save = []
 
 
 
 
 
 
 
 
 
 
110
 
111
- print(pretrained_model_name_or_path)
112
- head_path = os.path.join(pretrained_model_name_or_path, "regression_head.bin")
113
  if os.path.exists(head_path):
114
- instance.regression_head.load_state_dict(
115
  torch.load(head_path, map_location="cpu")
116
  )
117
  else:
118
- print("No regression head found – initialising randomly.")
119
- return instance
 
120
 
121
  @torch.no_grad()
122
  def generate(self, *args, **kwargs):
 
1
  import os
2
  from typing import Optional
3
  from transformers import Qwen3ForCausalLM, AutoTokenizer, AutoConfig
4
+ from huggingface_hub import hf_hub_download
5
  import torch
6
  import torch.nn as nn
7
 
 
 
8
 
9
  # Define a custom model that wraps a causal LM and adds a regression head
10
  class CausalLMForRegression(nn.Module):
11
+ config_class = Qwen3ForCausalLM.config_class
12
  base_model_prefix = "model"
13
 
14
  def __init__(self, model_name):
 
87
 
88
  @classmethod
89
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
90
+ cfg = kwargs.pop("config", None)
91
+ if cfg is None:
92
+ cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
93
+ cfg.output_hidden_states = True
94
+
95
+ backbone = Qwen3ForCausalLM.from_pretrained(
 
 
96
  pretrained_model_name_or_path,
97
  *model_args,
98
+ config=cfg,
99
+ trust_remote_code=False,
100
+ **kwargs
101
  )
102
 
103
+ if os.path.isdir(pretrained_model_name_or_path):
104
+ head_path = os.path.join(pretrained_model_name_or_path,
105
+ "regression_head.bin")
106
+ else:
107
+ head_path = hf_hub_download(
108
+ repo_id=pretrained_model_name_or_path,
109
+ filename="regression_head.bin",
110
+ repo_type="model"
111
+ )
112
+
113
+ inst = cls.__new__(cls)
114
+ nn.Module.__init__(inst)
115
+ inst.model = backbone
116
+ inst.regression_head = nn.Linear(cfg.hidden_size, 1)
117
+ inst._keys_to_ignore_on_save = []
118
 
 
 
119
  if os.path.exists(head_path):
120
+ inst.regression_head.load_state_dict(
121
  torch.load(head_path, map_location="cpu")
122
  )
123
  else:
124
+ print("'regression_head.bin' not found – initialising randomly.")
125
+
126
+ return inst
127
 
128
  @torch.no_grad()
129
  def generate(self, *args, **kwargs):