Monimoy commited on
Commit
4cebbcb
·
verified ·
1 Parent(s): 8dc6d91

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +123 -0
  2. requirements.txt +5 -0
  3. smollm_checkpoint.pth +3 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import tiktoken
5
+ from transformer import GPT, GPTConfig # Import your model class
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from transformers import AutoTokenizer
10
+ from tqdm import tqdm
11
+ import os
12
+
13
+ # Load the model from Hugging Face Hub
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+
16
+ # Define the SmolLM2-135M model (a simplified version of a Transformer)
17
+ class SmolLM(nn.Module):
18
+ def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len):
19
+ super(SmolLM, self).__init__()
20
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
21
+ self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
22
+ self.layers = nn.ModuleList([
23
+ nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
24
+ for _ in range(num_layers)
25
+ ])
26
+ self.fc_out = nn.Linear(embed_dim, vocab_size)
27
+
28
+ def forward(self, x):
29
+ seq_len = x.size(1)
30
+ x = self.embedding(x) + self.pos_embedding[:, :seq_len, :]
31
+ for layer in self.layers:
32
+ x = layer(x)
33
+ return self.fc_out(x)
34
+
35
+
36
+ def load_model():
37
+ checkpoint_path = 'smollm_checkpoint.pth'
38
+ embed_dim = 512
39
+ num_heads = 8
40
+ num_layers = 4
41
+ max_seq_len = 128
42
+ vocab_size = 50257
43
+
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ model = SmolLM(vocab_size, embed_dim, num_heads, num_layers, max_seq_len).to(device)
46
+
47
+ model.load_state_dict(torch.load(checkpoint_path))
48
+ #checkpoint = torch.load(checkpoint_path, map_location=device)
49
+ #config = checkpoint['config']
50
+ #model = GPT(config)
51
+ #model.load_state_dict(checkpoint['model_state_dict'])
52
+ model.to(device)
53
+ model.eval() # Set to evaluation mode
54
+
55
+ # Disable gradient computation
56
+ for param in model.parameters():
57
+ param.requires_grad = False
58
+
59
+ return model
60
+
61
+ model = load_model()
62
+
63
+
64
+ # Force model to stay in eval mode
65
+ model.train(False)
66
+
67
+ def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
68
+ enc = tiktoken.get_encoding('gpt2')
69
+ tokens = enc.encode(prompt)
70
+ tokens = torch.tensor(tokens, dtype=torch.long)
71
+ tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
72
+ tokens = tokens.to(device)
73
+
74
+ with torch.no_grad():
75
+ for _ in range(max_length):
76
+ if tokens.size(1) >= 1024: # GPT context length
77
+ break
78
+
79
+ logits = model(tokens)[0]
80
+ logits = logits[:, -1, :]
81
+ #logits = logits[:, -1, :] / temperature
82
+ probs = F.softmax(logits, dim=-1)
83
+
84
+ # Top-k sampling
85
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
86
+ ix = torch.multinomial(topk_probs, 1)
87
+ next_token = torch.gather(topk_indices, -1, ix)
88
+
89
+ tokens = torch.cat((tokens, next_token), dim=1)
90
+
91
+ # Remove special token check entirely
92
+ # Just generate for the specified length or until context limit
93
+
94
+ generated_texts = []
95
+ for i in range(num_samples):
96
+ text = enc.decode(tokens[i].tolist())
97
+ generated_texts.append(text)
98
+
99
+ return '\n\n---\n\n'.join(generated_texts)
100
+
101
+ # Create Gradio interface
102
+ iface = gr.Interface(
103
+ fn=generate_text,
104
+ inputs=[
105
+ gr.Textbox(label="Prompt", value="Good night, good night! Parting is such sweet sorrow"),
106
+ gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
107
+ gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
108
+ ],
109
+ outputs=gr.Textbox(label="Generated Text"),
110
+ title="Shakesphere Text Generator",
111
+ description="Enter text for Shakesphere way of text and continue the same",
112
+ examples=[
113
+ ["There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy.", 100, 1],
114
+ ["Love all, trust a few, do wrong to none.", 60, 2],
115
+ ["It's not enough to speak, but to speak true", 50, 3],
116
+ ["To be, or not to be: that is the question.", 100, 1],
117
+ ["If you can look into the seeds of time, and say which grain will grow and which will not, speak then to me", 100, 1],
118
+ ["Love sought is good, but given unsought is better.", 100, 1],
119
+ ]
120
+ )
121
+
122
+ if __name__ == "__main__":
123
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ tiktoken
4
+ transformers
5
+ huggingface_hub
smollm_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2b086a3dc93275f59df03e5132454955c4b3231f50e32508817c4ea9d502bb8
3
+ size 256773738