FlameF0X commited on
Commit
72510fa
·
verified ·
1 Parent(s): 1b4b2d0

Create modeling_i3.py

Browse files
Files changed (1) hide show
  1. modeling_i3.py +34 -0
modeling_i3.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel
5
+ from .configuration_i3 import I3Config
6
+
7
+ # import or paste your existing i3Model and i3Block classes here
8
+ # (or import them if you split them out into another module)
9
+ from .i3_architecture import i3Model # optional if you modularize
10
+
11
+ class I3ForCausalLM(PreTrainedModel):
12
+ config_class = I3Config
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.model = i3Model(
17
+ vocab_size=config.vocab_size,
18
+ d_model=config.d_model,
19
+ n_layers=config.n_layers,
20
+ n_heads=config.n_heads,
21
+ max_seq_len=config.max_seq_len,
22
+ rank=config.rank,
23
+ d_state=config.d_state,
24
+ )
25
+
26
+ self.post_init()
27
+
28
+ def forward(self, input_ids, labels=None):
29
+ logits, loss = self.model(input_ids, labels)
30
+ return {"loss": loss, "logits": logits}
31
+
32
+ @torch.no_grad()
33
+ def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
34
+ return self.model.generate(input_ids, max_new_tokens, temperature, top_k)