yagizdevre commited on
Commit
a78f7b3
·
1 Parent(s): 1f55a56
__pycache__/attn.cpython-312.pyc CHANGED
Binary files a/__pycache__/attn.cpython-312.pyc and b/__pycache__/attn.cpython-312.pyc differ
 
__pycache__/mlp.cpython-312.pyc CHANGED
Binary files a/__pycache__/mlp.cpython-312.pyc and b/__pycache__/mlp.cpython-312.pyc differ
 
mlp.py CHANGED
@@ -1,11 +1,11 @@
1
  import torch.nn as nn
2
  from torch.nn import functional as F
3
-
4
  class MLP(nn.Module):
5
  def __init__(self, config, dtype=None):
6
  # https://arxiv.org/pdf/2002.05202
7
  super().__init__()
8
- torch_dtype = getattr(torch, config.torch_dtype, torch.float32) # Use config dtype
9
  dtype = dtype if dtype is not None else torch_dtype
10
  self.hidden_size = config.n_embd
11
  self.intermediate_size = config.n_embd * config.mlp_scale
 
1
  import torch.nn as nn
2
  from torch.nn import functional as F
3
+ import torch
4
  class MLP(nn.Module):
5
  def __init__(self, config, dtype=None):
6
  # https://arxiv.org/pdf/2002.05202
7
  super().__init__()
8
+ torch_dtype = config.torch_dtype
9
  dtype = dtype if dtype is not None else torch_dtype
10
  self.hidden_size = config.n_embd
11
  self.intermediate_size = config.n_embd * config.mlp_scale