Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import pipeline | |
| from gpt import GPTLanguageModel | |
| model1 = GPTLanguageModel() | |
| model1.load_state_dict(torch.load("gpt_model.pth", map_location=torch.device('cpu')), strict=False) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model1 = model1.to(device) | |
| # Character mappings from the training script | |
| with open('input.txt', 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| chars = sorted(list(set(text))) | |
| vocab_size = len(chars) | |
| stoi = { ch:i for i,ch in enumerate(chars) } | |
| itos = { i:ch for i,ch in enumerate(chars) } | |
| # Define encode and decode functions | |
| encode = lambda s: [stoi[c] for c in s] | |
| decode = lambda l: ''.join([itos[i] for i in l]) | |
| # Define a text generation function that accepts context | |
| def generate_text(context_text, max_length=50): | |
| context = torch.tensor([encode(context_text)], dtype=torch.long, device=device) | |
| generated_ids = model1.generate(context, max_new_tokens=max_length) | |
| generated_text = decode(generated_ids[0].tolist()) | |
| return generated_text | |
| # Create a Gradio interface with context as input | |
| iface = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter context here..."), | |
| gr.Slider(minimum=10, maximum=500, label="Max Length") | |
| ], | |
| outputs="text", | |
| title="Text Generation with Context", | |
| description="Provide some context and generate text based on it." | |
| ) | |
| iface.launch() | |