Saachi-S123 commited on
Commit
5e28d36
·
verified ·
1 Parent(s): 7ecd2ad

created code

Browse files
Files changed (1) hide show
  1. app.py +16 -0
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MiniTransformer(nn.Module):
5
+ def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim):
6
+ super().__init__()
7
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
8
+ self.attention = nn.MultiheadAttention(embed_dim, num_heads)
9
+ self.linear1 = nn.Linear(embed_dim, hidden_dim)
10
+ self.linear2 = nn.Linear(hidden_dim, vocab_size)
11
+
12
+ def forward(self, x):
13
+ x = self.embedding(x)
14
+ attn_output, _ = self.attention(x, x, x)
15
+ x = self.linear1(attn_output).relu()
16
+ return self.linear2(x)