|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from torch import nn |
|
|
|
|
|
class DolphyBlock(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.attn = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.mlp = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.attn(x) |
|
|
x = self.mlp(x) |
|
|
return x |
|
|
|
|
|
class DolphyModel(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.layers = nn.ModuleList([DolphyBlock(config) for _ in range(config.num_hidden_layers)]) |
|
|
self.norm = nn.LayerNorm(config.hidden_size) |
|
|
|
|
|
def forward(self, input_ids): |
|
|
x = self.embed_tokens(input_ids) |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
return self.norm(x) |
|
|
|
|
|
class Dolphy1ForCausalLM(PreTrainedModel): |
|
|
_auto_class = True |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = DolphyModel(config) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, **kwargs): |
|
|
hidden_states = self.model(input_ids) |
|
|
logits = self.lm_head(hidden_states) |
|
|
return CausalLMOutputWithPast(logits=logits) |