pranavupadhyaya52 commited on
Commit
a7dfca3
·
verified ·
1 Parent(s): b749cfb

Upload RockyForEmbeddings

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. model.safetensors +1 -1
  3. modeling_rocky.py +158 -0
config.json CHANGED
@@ -1,10 +1,14 @@
1
  {
 
 
 
2
  "auto_map": {
3
  "AutoConfig": "configuration_rocky.RockyConfig",
4
  "AutoModel": "modeling_rocky.RockyForEmbeddings"
5
  },
6
  "depth": 12,
7
  "dim": 768,
 
8
  "ffn_dim": 2048,
9
  "heads": 12,
10
  "max_seq_len": 1024,
 
1
  {
2
+ "architectures": [
3
+ "RockyForEmbeddings"
4
+ ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_rocky.RockyConfig",
7
  "AutoModel": "modeling_rocky.RockyForEmbeddings"
8
  },
9
  "depth": 12,
10
  "dim": 768,
11
+ "dtype": "float32",
12
  "ffn_dim": 2048,
13
  "heads": 12,
14
  "max_seq_len": 1024,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3dbf7eb6ec4738cbb8fc6505111067e1f663b42dac277857064b84899b51bd8b
3
  size 363597664
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aad1693ebd30a454f69bd9f9b5406516afd3a9493fc8695d04d9483422b24dda
3
  size 363597664
modeling_rocky.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel
5
+ from configuration_rocky import RockyConfig
6
+
7
+ class RMSNorm(nn.Module):
8
+ def __init__(self, dim, eps=1e-6):
9
+ super().__init__()
10
+ self.eps = eps
11
+ self.scale = nn.Parameter(torch.ones(dim))
12
+
13
+ def forward(self, x):
14
+ norm = x.pow(2).mean(-1, keepdim=True)
15
+ return self.scale * x * torch.rsqrt(norm + self.eps)
16
+
17
+ class GELU(nn.Module):
18
+ def __init__(self, dim, hidden_dim):
19
+ super().__init__()
20
+ self.net = nn.Sequential(
21
+ nn.Linear(dim, hidden_dim, bias=False),
22
+ nn.GELU(),
23
+ nn.Linear(hidden_dim, dim, bias=False),
24
+ )
25
+
26
+ def forward(self, x):
27
+ return self.net(x)
28
+
29
+ class RotaryEmbedding(nn.Module):
30
+ def __init__(self, dim):
31
+ super().__init__()
32
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
33
+ self.register_buffer("inv_freq", inv_freq)
34
+
35
+ def get_embed(self, seq_len, device):
36
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
37
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
38
+ return torch.cat((freqs, freqs), dim=-1)
39
+
40
+ def rotate_half(x):
41
+ x1 = x[..., :x.shape[-1] // 2]
42
+ x2 = x[..., x.shape[-1] // 2:]
43
+ return torch.cat((-x2, x1), dim=-1)
44
+
45
+ def apply_rope(q, k, freqs_tensor):
46
+ cos = torch.cos(freqs_tensor)[None, None, :, :]
47
+ sin = torch.sin(freqs_tensor)[None, None, :, :]
48
+ q = (q * cos) + (rotate_half(q) * sin)
49
+ k = (k * cos) + (rotate_half(k) * sin)
50
+ return q, k
51
+
52
+ class Attention(nn.Module):
53
+ def __init__(self, dim, heads=8):
54
+ super().__init__()
55
+ self.heads = heads
56
+ self.head_dim = dim // heads
57
+
58
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
59
+ self.out = nn.Linear(dim, dim, bias=False)
60
+ self.rope = RotaryEmbedding(self.head_dim)
61
+ self.temperature = nn.Parameter(torch.tensor(15.0))
62
+
63
+ def forward(self, x, mask=None):
64
+ B, T, C = x.shape
65
+ qkv = self.qkv(x)
66
+ qkv = qkv.view(B, T, 3, self.heads, self.head_dim)
67
+ q, k, v = qkv.unbind(dim=2)
68
+
69
+ q = q.transpose(1, 2)
70
+ k = k.transpose(1, 2)
71
+ v = v.transpose(1, 2)
72
+
73
+ rope_emb = self.rope.get_embed(T, x.device)
74
+ q, k = apply_rope(q, k, rope_emb)
75
+
76
+ q = F.normalize(q, dim=-1)
77
+ k = F.normalize(k, dim=-1)
78
+
79
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
80
+
81
+ if mask is not None:
82
+ mask = mask[:, None, None, :]
83
+ attn = attn.masked_fill(mask == 0, -1e9)
84
+
85
+ attn = attn - attn.max(dim=-1, keepdim=True).values
86
+ attn = torch.softmax(attn, dim=-1)
87
+
88
+ out = attn @ v
89
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
90
+ return self.out(out)
91
+
92
+ class TransformerBlock(nn.Module):
93
+ def __init__(self, dim, heads, ffn_dim, dropout=0.0):
94
+ super().__init__()
95
+ self.norm1 = RMSNorm(dim)
96
+ self.attn = Attention(dim, heads)
97
+ self.norm2 = RMSNorm(dim)
98
+ self.ffn = GELU(dim, ffn_dim)
99
+ self.dropout = nn.Dropout(dropout)
100
+
101
+ def forward(self, x, mask=None):
102
+ x = x + self.dropout(self.attn(self.norm1(x), mask))
103
+ x = x + self.dropout(self.ffn(self.norm2(x)))
104
+ return x
105
+
106
+ class ProjectionHead(nn.Module):
107
+ def __init__(self, dim, proj_dim=512):
108
+ super().__init__()
109
+ self.net = nn.Sequential(
110
+ nn.Linear(dim, dim, bias=False),
111
+ nn.GELU(),
112
+ nn.Linear(dim, proj_dim, bias=False),
113
+ )
114
+
115
+ def forward(self, x):
116
+ return F.normalize(self.net(x), dim=-1)
117
+
118
+ class RockyForEmbeddings(PreTrainedModel):
119
+ config_class = RockyConfig
120
+
121
+ def __init__(self, config):
122
+ super().__init__(config)
123
+ self.config = config
124
+
125
+ self.token_emb = nn.Embedding(config.vocab_size, config.dim)
126
+
127
+ self.layers = nn.ModuleList([
128
+ TransformerBlock(config.dim, config.heads, config.ffn_dim)
129
+ for _ in range(config.depth)
130
+ ])
131
+
132
+ self.norm = RMSNorm(config.dim)
133
+ self.projection = ProjectionHead(config.dim, config.proj_dim)
134
+
135
+ self.post_init()
136
+
137
+ def forward(self, input_ids, attention_mask=None, return_raw=False):
138
+ if attention_mask is not None:
139
+ attention_mask = attention_mask.long()
140
+
141
+ x = self.token_emb(input_ids)
142
+
143
+ for layer in self.layers:
144
+ x = layer(x, attention_mask)
145
+
146
+ x = self.norm(x)
147
+
148
+ if attention_mask is not None:
149
+ mask = attention_mask.unsqueeze(-1)
150
+ x = x * mask
151
+ pooled = x.sum(dim=1) / mask.sum(dim=1).clamp(min=1e-6)
152
+ else:
153
+ pooled = x.mean(dim=1)
154
+
155
+ if return_raw:
156
+ return pooled
157
+
158
+ return self.projection(pooled)