Clem27-Assistants commited on
Commit
567c20a
·
verified ·
1 Parent(s): e3f390d

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_sora.py +21 -0
  2. modeling_sora.py +56 -0
configuration_sora.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class SoraConfig(PretrainedConfig):
4
+ model_type = "sora_slm"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=2628,
9
+ hidden_size=512,
10
+ num_layers=8,
11
+ num_heads=8,
12
+ max_position_embeddings=512,
13
+ **kwargs
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.vocab_size = vocab_size
17
+ self.hidden_size = hidden_size
18
+ self.num_layers = num_layers
19
+ self.num_heads = num_heads
20
+ self.max_position_embeddings = max_position_embeddings
21
+
modeling_sora.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, GenerationMixin
4
+ from transformers.modeling_outputs import CausalLMOutput
5
+ from .configuration_sora import SoraConfig
6
+
7
+ class SoraForSLM(PreTrainedModel, GenerationMixin):
8
+ config_class = SoraConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
13
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
14
+
15
+ self.layers = nn.ModuleList([
16
+ nn.TransformerEncoderLayer(
17
+ d_model=config.hidden_size,
18
+ nhead=config.num_heads,
19
+ dim_feedforward=config.hidden_size * 4,
20
+ batch_first=True,
21
+ activation="gelu"
22
+ ) for _ in range(config.num_layers)
23
+ ])
24
+
25
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
26
+ self.post_init()
27
+
28
+ def get_input_embeddings(self):
29
+ return self.embeddings
30
+
31
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
32
+ return {"input_ids": input_ids}
33
+
34
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
35
+ # Calcul des positions
36
+ seq_length = input_ids.size(1)
37
+ positions = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
38
+
39
+ # Embeddings
40
+ x = self.embeddings(input_ids) + self.position_embeddings(positions)
41
+
42
+ # Passage dans les couches (sans masque pour éviter tout conflit)
43
+ for layer in self.layers:
44
+ x = layer(x)
45
+
46
+ logits = self.lm_head(x)
47
+
48
+ loss = None
49
+ if labels is not None:
50
+ # Shift pour l'entraînement causal
51
+ shift_logits = logits[..., :-1, :].contiguous()
52
+ shift_labels = input_ids[..., 1:].contiguous()
53
+ loss_fct = nn.CrossEntropyLoss()
54
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
55
+
56
+ return CausalLMOutput(loss=loss, logits=logits)