Sh2425 commited on
Commit
e65adfe
·
verified ·
1 Parent(s): 4e34a33

Update modeling_dolphy.py

Browse files
Files changed (1) hide show
  1. modeling_dolphy.py +8 -32
modeling_dolphy.py CHANGED
@@ -1,40 +1,16 @@
1
  from transformers import PreTrainedModel
2
- from transformers.modeling_outputs import CausalLMOutputWithPast
3
- from torch import nn
4
-
5
- class DolphyBlock(nn.Module):
6
- def __init__(self, config):
7
- super().__init__()
8
- self.attn = nn.Linear(config.hidden_size, config.hidden_size) # placeholder
9
- self.mlp = nn.Linear(config.hidden_size, config.hidden_size) # placeholder
10
-
11
- def forward(self, x):
12
- x = self.attn(x)
13
- x = self.mlp(x)
14
- return x
15
-
16
- class DolphyModel(nn.Module):
17
- def __init__(self, config):
18
- super().__init__()
19
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
20
- self.layers = nn.ModuleList([DolphyBlock(config) for _ in range(config.num_hidden_layers)])
21
- self.norm = nn.LayerNorm(config.hidden_size)
22
-
23
- def forward(self, input_ids):
24
- x = self.embed_tokens(input_ids)
25
- for layer in self.layers:
26
- x = layer(x)
27
- return self.norm(x)
28
 
29
  class Dolphy1ForCausalLM(PreTrainedModel):
30
- _auto_class = True
31
 
32
  def __init__(self, config):
33
  super().__init__(config)
34
- self.model = DolphyModel(config)
35
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
36
 
37
  def forward(self, input_ids, attention_mask=None, **kwargs):
38
- hidden_states = self.model(input_ids)
39
- logits = self.lm_head(hidden_states)
40
- return CausalLMOutputWithPast(logits=logits)
 
1
  from transformers import PreTrainedModel
2
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model, Qwen3Config
3
+ import torch.nn as nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class Dolphy1ForCausalLM(PreTrainedModel):
6
+ config_class = Qwen3Config
7
 
8
  def __init__(self, config):
9
  super().__init__(config)
10
+ self.model = Qwen3Model(config)
11
+
12
+ # If your router was saved as part of the model, this will load it automatically.
13
+ # No need to redefine or reattach anything here.
14
 
15
  def forward(self, input_ids, attention_mask=None, **kwargs):
16
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)