Dolphy-1.2-Base / modeling_dolphy.py
Sh2425's picture
Update modeling_dolphy.py
6ccae00 verified
raw
history blame
1.43 kB
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch import nn
class DolphyBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = nn.Linear(config.hidden_size, config.hidden_size) # placeholder
self.mlp = nn.Linear(config.hidden_size, config.hidden_size) # placeholder
def forward(self, x):
x = self.attn(x)
x = self.mlp(x)
return x
class DolphyModel(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([DolphyBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = nn.LayerNorm(config.hidden_size)
def forward(self, input_ids):
x = self.embed_tokens(input_ids)
for layer in self.layers:
x = layer(x)
return self.norm(x)
class Dolphy1ForCausalLM(PreTrainedModel):
_auto_class = True
def __init__(self, config):
super().__init__(config)
self.model = DolphyModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None, **kwargs):
hidden_states = self.model(input_ids)
logits = self.lm_head(hidden_states)
return CausalLMOutputWithPast(logits=logits)