Clemylia commited on
Commit
c1e244c
·
verified ·
1 Parent(s): 7008b84

Update modeling_sora.py

Browse files
Files changed (1) hide show
  1. modeling_sora.py +10 -15
modeling_sora.py CHANGED
@@ -11,20 +11,17 @@ class SoraForSLM(PreTrainedModel, GenerationMixin):
11
  super().__init__(config)
12
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
13
  self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
14
-
15
- # Utilisation de TransformerEncoderLayer, mais nous allons appliquer un masque causal manuellement
16
  self.layers = nn.ModuleList([
17
  nn.TransformerEncoderLayer(
18
  d_model=config.hidden_size,
19
  nhead=config.num_heads,
20
  dim_feedforward=config.hidden_size * 4,
21
  batch_first=True,
22
- activation="gelu",
23
- norm_first=True
24
  ) for _ in range(config.num_layers)
25
  ])
26
-
27
- self.ln_f = nn.LayerNorm(config.hidden_size)
28
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
29
  self.post_init()
30
 
@@ -35,18 +32,17 @@ class SoraForSLM(PreTrainedModel, GenerationMixin):
35
  return {"input_ids": input_ids}
36
 
37
  def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
 
38
  seq_length = input_ids.size(1)
39
- # Création du masque causal (triangulaire)
40
- causal_mask = torch.triu(torch.ones(seq_length, seq_length, device=input_ids.device), diagonal=1).bool()
41
-
42
  positions = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
 
 
43
  x = self.embeddings(input_ids) + self.position_embeddings(positions)
44
 
 
45
  for layer in self.layers:
46
- # Application du masque pour éviter de voir le futur
47
- x = layer(x, src_mask=causal_mask)
48
-
49
- x = self.ln_f(x)
50
  logits = self.lm_head(x)
51
 
52
  loss = None
@@ -57,5 +53,4 @@ class SoraForSLM(PreTrainedModel, GenerationMixin):
57
  loss_fct = nn.CrossEntropyLoss()
58
  loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
59
 
60
- return CausalLMOutput(loss=loss, logits=logits)
61
-
 
11
  super().__init__(config)
12
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
13
  self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
14
+
 
15
  self.layers = nn.ModuleList([
16
  nn.TransformerEncoderLayer(
17
  d_model=config.hidden_size,
18
  nhead=config.num_heads,
19
  dim_feedforward=config.hidden_size * 4,
20
  batch_first=True,
21
+ activation="gelu"
 
22
  ) for _ in range(config.num_layers)
23
  ])
24
+
 
25
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
26
  self.post_init()
27
 
 
32
  return {"input_ids": input_ids}
33
 
34
  def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
35
+ # Calcul des positions
36
  seq_length = input_ids.size(1)
 
 
 
37
  positions = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
38
+
39
+ # Embeddings
40
  x = self.embeddings(input_ids) + self.position_embeddings(positions)
41
 
42
+ # Passage dans les couches (sans masque pour éviter tout conflit)
43
  for layer in self.layers:
44
+ x = layer(x)
45
+
 
 
46
  logits = self.lm_head(x)
47
 
48
  loss = None
 
53
  loss_fct = nn.CrossEntropyLoss()
54
  loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
55
 
56
+ return CausalLMOutput(loss=loss, logits=logits)