MaxiiMin commited on
Commit
2b372b6
·
verified ·
1 Parent(s): c0cb7df

Update modeling_challenger.py

Browse files
Files changed (1) hide show
  1. modeling_challenger.py +13 -107
modeling_challenger.py CHANGED
@@ -2,10 +2,10 @@
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from transformers import PreTrainedModel
 
6
  from .configuration_challenger import ChallengerConfig
7
 
8
-
9
  class RMSNorm(nn.Module):
10
  def __init__(self, dim: int, eps: float = 1e-8):
11
  super().__init__()
@@ -19,95 +19,15 @@ class RMSNorm(nn.Module):
19
  output = self._norm(x.float())
20
  return (output * self.weight.float()).type_as(x)
21
 
22
- def _to_fp8(x, dtype=torch.float8_e4m3fn):
23
- finfo = torch.finfo(dtype)
24
- scale = finfo.max / x.abs().max().clamp(min=1e-12)
25
- x_f8 = (x * scale).clamp(finfo.min, finfo.max).to(dtype)
26
- return x_f8, scale.reciprocal().float() # inverse for _scaled_mm
27
-
28
- class _FP8Matmul(torch.autograd.Function):
29
- @staticmethod
30
- def forward(ctx, x, w, out_dtype=torch.bfloat16):
31
- x_f8, x_inv = _to_fp8(x)
32
- w_f8, w_inv = _to_fp8(w)
33
-
34
- y = torch._scaled_mm( # row‑major A × col‑major B
35
- x_f8, w_f8.t(),
36
- out_dtype=out_dtype,
37
- scale_a=x_inv, scale_b=w_inv,
38
- use_fast_accum=True,
39
- )
40
- ctx.save_for_backward(x_f8, w_f8, x_inv, w_inv)
41
- ctx.out_dtype = out_dtype
42
- return y
43
-
44
- @staticmethod
45
- def backward(ctx, grad_out):
46
- x_f8, w_f8, x_inv, w_inv = ctx.saved_tensors
47
- g_f8, g_inv = _to_fp8(grad_out, dtype=torch.float8_e5m2)
48
-
49
- # ---- dx = grad_out @ w ------------------------------------------
50
- # A = g_f8 (row‑major, (N, out))
51
- # B = w_f8.T.contiguous().T (col‑major, (out, in))
52
- dx = torch._scaled_mm(
53
- g_f8,
54
- w_f8.t().contiguous().t(),
55
- out_dtype=ctx.out_dtype,
56
- scale_a=g_inv, scale_b=w_inv,
57
- use_fast_accum=False,
58
- )
59
-
60
- # ---- dw = x.T @ grad_out ----------------------------------------
61
- # A = x_f8.T.contiguous() (row‑major, (in, N))
62
- # B = g_f8.T.contiguous().T (col‑major, (N, out))
63
- dw = torch._scaled_mm(
64
- x_f8.t().contiguous(),
65
- g_f8.t().contiguous().t(),
66
- out_dtype=torch.float32,
67
- scale_a=x_inv, scale_b=g_inv,
68
- use_fast_accum=False,
69
- ).t() # bring back to (out, in)
70
-
71
- return dx, dw, None # no grad for out_dtype
72
-
73
- # Convenience alias, identical signature to torch.mm
74
- fp8_mm = _FP8Matmul.apply
75
-
76
- # ---- drop‑in Linear ----------------------------------------------------------
77
- class FP8Linear(torch.nn.Module):
78
- """Same signature as nn.Linear but weight‑stationary FP8 matmul."""
79
- def __init__(self, in_features, out_features, bias=False):
80
- super().__init__()
81
- self.weight = torch.nn.Parameter(torch.empty(out_features, in_features))
82
- torch.nn.init.trunc_normal_(self.weight, std=0.02)
83
- if bias:
84
- self.bias = torch.nn.Parameter(torch.zeros(out_features))
85
- else:
86
- self.register_parameter("bias", None)
87
-
88
- def forward(self, x):
89
- """
90
- Accepts x of shape (..., in_features) – any leading dims.
91
- Flattens to 2‑D, does the FP8 matmul, then restores the shape.
92
- """
93
- orig_shape = x.shape[:-1] # e.g. (B, T)
94
- x2d = x.view(-1, x.shape[-1]) # (N, in_features)
95
- y2d = fp8_mm(x2d, self.weight) # (N, out_features)
96
- if self.bias is not None:
97
- y2d = y2d + self.bias
98
- y = y2d.view(*orig_shape, self.weight.size(0))
99
- return y
100
-
101
  class CausalSelfAttention(nn.Module):
102
 
103
  def __init__(self, config):
104
  super().__init__()
105
  assert config.n_embd % config.n_head == 0
106
  # key, query, value projections for all heads, but in a batch
107
- self.c_attn = FP8Linear(config.n_embd, 3 * config.n_embd)
108
  # output projection
109
- self.c_proj = FP8Linear(config.n_embd, config.n_embd)
110
- self.c_proj.NANOGPT_SCALE_INIT = 1
111
  # regularization
112
  self.n_head = config.n_head
113
  self.n_embd = config.n_embd
@@ -132,10 +52,9 @@ class MLP(nn.Module):
132
 
133
  def __init__(self, config):
134
  super().__init__()
135
- self.c_fc = FP8Linear(config.n_embd, 8 * config.n_embd)
136
  self.gelu = nn.SiLU()
137
- self.c_proj = FP8Linear(4 * config.n_embd, config.n_embd)
138
- self.c_proj.NANOGPT_SCALE_INIT = 1
139
 
140
  def forward(self, x):
141
  x, y = self.c_fc(x).split(x.size(-1) * 4, dim=2)
@@ -166,22 +85,7 @@ class GPT(nn.Module):
166
  h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
167
  ln_f = RMSNorm(config.n_embd),
168
  ))
169
- self.lm_head = FP8Linear(config.n_embd, config.vocab_size, bias=False)
170
- self.lm_head.NANOGPT_SCALE_INIT = 1
171
-
172
- # init params
173
- self.apply(self._init_weights)
174
-
175
- def _init_weights(self, module):
176
- if isinstance(module, FP8Linear):
177
- std = 0.02
178
- if hasattr(module, 'NANOGPT_SCALE_INIT'):
179
- std *= (2 * self.config.n_layer) ** -0.5
180
- torch.nn.init.normal_(module.weight, mean=0.0, std=std)
181
- if module.bias is not None:
182
- torch.nn.init.zeros_(module.bias)
183
- elif isinstance(module, nn.Embedding):
184
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
185
 
186
  def forward(self, idx, targets=None):
187
  x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
@@ -196,8 +100,9 @@ class GPT(nn.Module):
196
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
197
  return logits, loss
198
 
199
- class ChallengerModel(PreTrainedModel):
200
  config_class = ChallengerConfig
 
201
 
202
  def __init__(self, config):
203
  super().__init__(config)
@@ -205,6 +110,7 @@ class ChallengerModel(PreTrainedModel):
205
 
206
  def forward(self, input_ids, labels=None):
207
  logits, loss = self.model(input_ids, labels)
208
- if labels is not None:
209
- return {"loss": loss, "logits": logits}
210
- return {"logits": logits}
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ from transformers import PreTrainedModel, GenerationMixin
6
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
7
  from .configuration_challenger import ChallengerConfig
8
 
 
9
  class RMSNorm(nn.Module):
10
  def __init__(self, dim: int, eps: float = 1e-8):
11
  super().__init__()
 
19
  output = self._norm(x.float())
20
  return (output * self.weight.float()).type_as(x)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class CausalSelfAttention(nn.Module):
23
 
24
  def __init__(self, config):
25
  super().__init__()
26
  assert config.n_embd % config.n_head == 0
27
  # key, query, value projections for all heads, but in a batch
28
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
29
  # output projection
30
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
 
31
  # regularization
32
  self.n_head = config.n_head
33
  self.n_embd = config.n_embd
 
52
 
53
  def __init__(self, config):
54
  super().__init__()
55
+ self.c_fc = nn.Linear(config.n_embd, 8 * config.n_embd, bias=False)
56
  self.gelu = nn.SiLU()
57
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
 
58
 
59
  def forward(self, x):
60
  x, y = self.c_fc(x).split(x.size(-1) * 4, dim=2)
 
85
  h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
86
  ln_f = RMSNorm(config.n_embd),
87
  ))
88
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def forward(self, idx, targets=None):
91
  x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
 
100
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
101
  return logits, loss
102
 
103
+ class ChallengerForCausalLM(PreTrainedModel, GenerationMixin):
104
  config_class = ChallengerConfig
105
+ _keys_to_ignore_on_load_unexpected = [r"past_key_values"]
106
 
107
  def __init__(self, config):
108
  super().__init__(config)
 
110
 
111
  def forward(self, input_ids, labels=None):
112
  logits, loss = self.model(input_ids, labels)
113
+ return CausalLMOutputWithCrossAttentions(
114
+ loss=loss,
115
+ logits=logits
116
+ )