Commit ·
f4beaf9
1
Parent(s): b0d3660
add model
Browse files- config.json +1 -1
- rita_modeling.py +42 -9
config.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "
|
| 3 |
"architectures": [
|
| 4 |
"RITAModel"
|
| 5 |
],
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "Seledorn/RITA_l",
|
| 3 |
"architectures": [
|
| 4 |
"RITAModel"
|
| 5 |
],
|
rita_modeling.py
CHANGED
|
@@ -13,6 +13,7 @@ from transformers.modeling_outputs import (
|
|
| 13 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 14 |
CausalLMOutputWithCrossAttentions,
|
| 15 |
CausalLMOutputWithPast,
|
|
|
|
| 16 |
)
|
| 17 |
|
| 18 |
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -222,18 +223,50 @@ class RITAModel(PreTrainedModel):
|
|
| 222 |
self.final_norm = nn.LayerNorm(config.d_model)
|
| 223 |
self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
|
| 224 |
|
| 225 |
-
def forward(
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
for layer in self.layers:
|
| 230 |
-
x = layer(x, attn_mask=
|
| 231 |
x = self.final_norm(x) # N x L x D
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
#Some common HF functions.
|
| 239 |
def get_input_embeddings(self):
|
|
|
|
| 13 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 14 |
CausalLMOutputWithCrossAttentions,
|
| 15 |
CausalLMOutputWithPast,
|
| 16 |
+
CausalLMOutput,
|
| 17 |
)
|
| 18 |
|
| 19 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 223 |
self.final_norm = nn.LayerNorm(config.d_model)
|
| 224 |
self.projector = nn.Linear(config.d_model, config.vocab_size, bias = False)
|
| 225 |
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
input_ids=None,
|
| 229 |
+
past_key_values=None, # NOT USED
|
| 230 |
+
attention_mask=None,
|
| 231 |
+
token_type_ids=None, # NOT USED
|
| 232 |
+
position_ids=None, # NOT USED
|
| 233 |
+
head_mask=None, # NOT USED
|
| 234 |
+
inputs_embeds=None,
|
| 235 |
+
encoder_hidden_states=None, # NOT USED
|
| 236 |
+
encoder_attention_mask=None, # NOT USED
|
| 237 |
+
labels=None,
|
| 238 |
+
use_cache=None, # NOT USED
|
| 239 |
+
output_attentions=None, # NOT USED
|
| 240 |
+
output_hidden_states=None, # NOT USED
|
| 241 |
+
return_dict=None # NOT USED
|
| 242 |
+
) -> torch.FloatTensor:
|
| 243 |
+
|
| 244 |
+
if inputs_embeds == None:
|
| 245 |
+
x = self.embedding(input_ids) # N x L x D
|
| 246 |
+
else:
|
| 247 |
+
x = inputs_embeds
|
| 248 |
+
if attention_mask == None:
|
| 249 |
+
attention_mask = (torch.triu(torch.ones(input_ids.size(1), input_ids.size(1))) == 0).transpose(0, 1).contiguous().to(input_ids.device)
|
| 250 |
for layer in self.layers:
|
| 251 |
+
x = layer(x, attn_mask=attention_mask)
|
| 252 |
x = self.final_norm(x) # N x L x D
|
| 253 |
|
| 254 |
+
logits = self.projector(x)
|
| 255 |
+
loss = None
|
| 256 |
+
if labels is not None:
|
| 257 |
+
# Shift so that tokens < n predict n
|
| 258 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 259 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 260 |
+
# Flatten the tokens
|
| 261 |
+
loss_fct = CrossEntropyLoss()
|
| 262 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 263 |
+
|
| 264 |
+
return CausalLMOutput(
|
| 265 |
+
loss=loss,
|
| 266 |
+
logits=logits,
|
| 267 |
+
hidden_states=x,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
|
| 271 |
#Some common HF functions.
|
| 272 |
def get_input_embeddings(self):
|