lucweber commited on
Commit
e346e0a
·
verified ·
1 Parent(s): 1965f5e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +5 -0
model.py CHANGED
@@ -4,9 +4,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import torch.nn as nn
6
 
 
 
7
 
8
  # Define a custom model that wraps a causal LM and adds a regression head
9
  class CausalLMForRegression(nn.Module):
 
 
 
10
  def __init__(self, model_name):
11
  super().__init__()
12
  # Load the causal LM with hidden states enabled
 
4
  import torch
5
  import torch.nn as nn
6
 
7
+ from transformers.models.qwen3 import Qwen3Config
8
+
9
 
10
  # Define a custom model that wraps a causal LM and adds a regression head
11
  class CausalLMForRegression(nn.Module):
12
+ config_class = Qwen3Config
13
+ base_model_prefix = "model"
14
+
15
  def __init__(self, model_name):
16
  super().__init__()
17
  # Load the causal LM with hidden states enabled