File size: 3,763 Bytes
4cebbcb
 
 
 
 
 
 
 
 
 
 
 
ab1e906
 
4cebbcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526bed7
4cebbcb
 
 
 
 
 
 
94b45bd
4cebbcb
94b45bd
526bed7
16c62b1
526bed7
 
 
 
4cebbcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab1e906
 
 
 
 
 
 
4cebbcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
import torch.nn.functional as F
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import os

# Load the model from Hugging Face Hub
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Define the SmolLM2-135M model (a simplified version of a Transformer)
class SmolLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len):
        super(SmolLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_embedding[:, :seq_len, :]
        for layer in self.layers:
            x = layer(x)
        return self.fc_out(x)


def load_model():
    checkpoint_path = 'final_checkpoint.pth'
    embed_dim = 512
    num_heads = 8
    num_layers = 4
    max_seq_len = 128
    vocab_size = 50257
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #model = SmolLM(vocab_size, embed_dim, num_heads, num_layers, max_seq_len).to(device)
    
    #model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
    #checkpoint = torch.load(checkpoint_path, map_location=device)
    #config = checkpoint['config']
    #model = SmolLM(vocab_size, embed_dim, num_heads, num_layers, max_seq_len)
    #model.load_state_dict(checkpoint)
    model = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model = model.to(device)
    model.eval()  # Set to evaluation mode
    
    # Disable gradient computation
    for param in model.parameters():
        param.requires_grad = False
        
    return model

model = load_model()


# Force model to stay in eval mode
model.train(False)

def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    outputs = model(input_ids)
    predictions = torch.argmax(outputs, dim=-1)
    decoded = tokenizer.decode(predictions[0], skip_special_tokens=True)
    return decoded
  

# Create Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", value="Good night, good night! Parting is such sweet sorrow"),
        gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
        gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="Shakesphere Text Generator",
    description="Enter text for Shakesphere way of text and continue the same",
    examples=[
        ["There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy.", 100, 1],
        ["Love all, trust a few, do wrong to none.", 60, 2],
        ["It's not enough to speak, but to speak true", 50, 3], 
        ["To be, or not to be: that is the question.", 100, 1],        
        ["If you can look into the seeds of time, and say which grain will grow and which will not, speak then to me", 100, 1],
        ["Love sought is good, but given unsought is better.", 100, 1],
    ]
)

if __name__ == "__main__":
    iface.launch()