syko818121 commited on
Commit
628d356
·
verified ·
1 Parent(s): 2dc5b66

Upgrade: Scaled up to ~15M parameters and 256 block size.

Browse files
Files changed (1) hide show
  1. modeling_syko.py +15 -8
modeling_syko.py CHANGED
@@ -11,10 +11,10 @@ class SykoConfig(PretrainedConfig):
11
  def __init__(
12
  self,
13
  vocab_size=4096,
14
- n_embd=256,
15
- n_layer=6,
16
- n_head=8,
17
- block_size=64,
18
  dropout=0.2,
19
  **kwargs
20
  ):
@@ -45,6 +45,7 @@ class Head(nn.Module):
45
  k = self.key(x)
46
  q = self.query(x)
47
  wei = q @ k.transpose(-2, -1) * (C ** -0.5)
 
48
  wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
49
  wei = F.softmax(wei, dim=-1)
50
  wei = self.dropout(wei)
@@ -103,7 +104,6 @@ class SykoForCausalLM(PreTrainedModel):
103
  self.n_layer = config.n_layer
104
  self.dropout = config.dropout
105
 
106
- # Embedding katmanının adı 'token_embedding_table'
107
  self.token_embedding_table = nn.Embedding(self.vocab_size, self.n_embd)
108
  self.position_embedding_table = nn.Embedding(self.block_size, self.n_embd)
109
  self.blocks = nn.Sequential(*[Block(self.n_embd, self.n_head, self.block_size, self.dropout) for _ in range(self.n_layer)])
@@ -112,13 +112,11 @@ class SykoForCausalLM(PreTrainedModel):
112
 
113
  self.apply(self._init_weights)
114
 
115
- # --- YENİ EKLENEN KISIM: HF BU FONKSİYONLARI ARIYOR ---
116
  def get_input_embeddings(self):
117
  return self.token_embedding_table
118
 
119
  def set_input_embeddings(self, new_embeddings):
120
  self.token_embedding_table = new_embeddings
121
- # -----------------------------------------------------
122
 
123
  def _init_weights(self, module):
124
  if isinstance(module, nn.Linear):
@@ -133,16 +131,25 @@ class SykoForCausalLM(PreTrainedModel):
133
  B, T = idx.shape
134
  device = idx.device
135
 
 
 
 
 
 
136
  pos_emb = self.position_embedding_table(torch.arange(T, device=device))
137
  tok_emb = self.token_embedding_table(idx)
138
  x = tok_emb + pos_emb
139
 
140
  x = self.blocks(x)
141
- x = self.ln_f(x)
142
  logits = self.lm_head(x)
143
 
144
  loss = None
145
  if labels is not None:
 
 
 
 
146
  B, T, C = logits.shape
147
  logits_reshaped = logits.view(B*T, C)
148
  labels_reshaped = labels.view(B*T)
 
11
  def __init__(
12
  self,
13
  vocab_size=4096,
14
+ n_embd=384, # ARTIRILDI (Eskisi 256)
15
+ n_layer=8, # ARTIRILDI (Eskisi 6)
16
+ n_head=6, # AYARLANDI (384 / 64 = 6)
17
+ block_size=256, # ARTIRILDI (Eskisi 64) -> Daha uzun hafıza
18
  dropout=0.2,
19
  **kwargs
20
  ):
 
45
  k = self.key(x)
46
  q = self.query(x)
47
  wei = q @ k.transpose(-2, -1) * (C ** -0.5)
48
+ # Maskeleme dinamik olmalı (gelen T kadar)
49
  wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
50
  wei = F.softmax(wei, dim=-1)
51
  wei = self.dropout(wei)
 
104
  self.n_layer = config.n_layer
105
  self.dropout = config.dropout
106
 
 
107
  self.token_embedding_table = nn.Embedding(self.vocab_size, self.n_embd)
108
  self.position_embedding_table = nn.Embedding(self.block_size, self.n_embd)
109
  self.blocks = nn.Sequential(*[Block(self.n_embd, self.n_head, self.block_size, self.dropout) for _ in range(self.n_layer)])
 
112
 
113
  self.apply(self._init_weights)
114
 
 
115
  def get_input_embeddings(self):
116
  return self.token_embedding_table
117
 
118
  def set_input_embeddings(self, new_embeddings):
119
  self.token_embedding_table = new_embeddings
 
120
 
121
  def _init_weights(self, module):
122
  if isinstance(module, nn.Linear):
 
131
  B, T = idx.shape
132
  device = idx.device
133
 
134
+ # Eğer context (T), block_size'dan büyükse kırp (Safety check)
135
+ if T > self.block_size:
136
+ idx = idx[:, -self.block_size:]
137
+ T = self.block_size
138
+
139
  pos_emb = self.position_embedding_table(torch.arange(T, device=device))
140
  tok_emb = self.token_embedding_table(idx)
141
  x = tok_emb + pos_emb
142
 
143
  x = self.blocks(x)
144
+ x = self.ln1_f(x) if hasattr(self, 'ln1_f') else self.ln_f(x)
145
  logits = self.lm_head(x)
146
 
147
  loss = None
148
  if labels is not None:
149
+ # Labels da kırpılmalı eğer idx kırpıldıysa
150
+ if labels.shape[1] > T:
151
+ labels = labels[:, -T:]
152
+
153
  B, T, C = logits.shape
154
  logits_reshaped = logits.view(B*T, C)
155
  labels_reshaped = labels.view(B*T)