idah4 commited on
Commit
00205bb
·
verified ·
1 Parent(s): 06e5f66

Upload ByteETM-Korean (HF inference compatible)

Browse files
Files changed (1) hide show
  1. modeling_byteetm.py +28 -0
modeling_byteetm.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import torch.nn as nn, torch.nn.functional as F, torch
3
+
4
+ class ByteETMConfig(PretrainedConfig):
5
+ model_type = "byteetm"
6
+ def __init__(self, vocab_size=258, n_embd=512, n_head=8, n_layer=6, block_size=256, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.vocab_size = vocab_size
9
+ self.n_embd = n_embd
10
+ self.n_head = n_head
11
+ self.n_layer = n_layer
12
+ self.block_size = block_size
13
+
14
+ class HFByteETM(PreTrainedModel):
15
+ config_class = ByteETMConfig
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+ from .model import ByteETM # 네가 정의한 실제 모델
19
+ self.model = ByteETM(
20
+ vocab_size=config.vocab_size,
21
+ n_embd=config.n_embd,
22
+ n_head=config.n_head,
23
+ n_layer=config.n_layer,
24
+ block_size=config.block_size
25
+ )
26
+ def forward(self, input_ids, **kwargs):
27
+ logits, _ = self.model(input_ids)
28
+ return {"logits": logits}