| | import gradio as gr |
| | import torch |
| | import torch.nn.functional as F |
| | import tiktoken |
| | from huggingface_hub import hf_hub_download |
| | from transformer import GPT, GPTConfig |
| |
|
| | |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | def load_model_from_hf(): |
| | |
| | model_id = "sudhakar272/transformer_model" |
| | checkpoint_path = hf_hub_download(repo_id=model_id, filename="transformer_model.pt") |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | config = checkpoint['config'] |
| | model = GPT(config) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.to(device) |
| | model.eval() |
| | |
| | |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| | |
| | return model |
| |
|
| | model = load_model_from_hf() |
| |
|
| |
|
| | |
| | model.train(False) |
| |
|
| | def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8): |
| | enc = tiktoken.get_encoding('gpt2') |
| | tokens = enc.encode(prompt) |
| | tokens = torch.tensor(tokens, dtype=torch.long) |
| | tokens = tokens.unsqueeze(0).repeat(num_samples, 1) |
| | tokens = tokens.to(device) |
| | |
| | with torch.no_grad(): |
| | for _ in range(max_length): |
| | if tokens.size(1) >= 1024: |
| | break |
| | |
| | logits = model(tokens)[0] |
| | logits = logits[:, -1, :] |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | |
| | |
| | topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) |
| | ix = torch.multinomial(topk_probs, 1) |
| | next_token = torch.gather(topk_indices, -1, ix) |
| | |
| | tokens = torch.cat((tokens, next_token), dim=1) |
| | |
| | |
| | |
| | |
| | generated_texts = [] |
| | for i in range(num_samples): |
| | text = enc.decode(tokens[i].tolist()) |
| | generated_texts.append(text) |
| | |
| | return '\n\n---\n\n'.join(generated_texts) |
| |
|
| | |
| | iface = gr.Interface( |
| | fn=generate_text, |
| | inputs=[ |
| | gr.Textbox(label="Prompt", value="Good night, good night! Parting is such sweet sorrow"), |
| | gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"), |
| | gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"), |
| | ], |
| | outputs=gr.Textbox(label="Generated Text"), |
| | title="Shakesphere Text Generator", |
| | description="Enter text for Shakesphere way of text and continue the same", |
| | examples=[ |
| | ["To be, or not to be: that is the question.", 100, 1], |
| | ["Love all, trust a few, do wrong to none.", 60, 2], |
| | ["It's not enough to speak, but to speak true", 50, 3], |
| | ["There are more things in heaven and earth, Horatio, than are dreamt of in your philosophy.", 100, 1], |
| | ["If you can look into the seeds of time, and say which grain will grow and which will not, speak then to me", 100, 1], |
| | ["Love sought is good, but given unsought is better.", 100, 1], |
| | ] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | iface.launch() |