Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from train_get2_8_init import GPT, GPTConfig, generate_text, TrainingConfig | |
| from huggingface_hub import hf_hub_download | |
| from torch.serialization import add_safe_globals | |
| # Add GPTConfig to safe globals | |
| add_safe_globals([GPTConfig]) | |
| def load_trained_model(): | |
| config = TrainingConfig() | |
| model_config = GPTConfig( | |
| block_size=config.block_size, | |
| n_layer=config.n_layer, | |
| n_head=config.n_head, | |
| n_embd=config.n_embd, | |
| dropout=config.dropout | |
| ) | |
| model = GPT(model_config) | |
| model_path = hf_hub_download( | |
| repo_id="padmanabhbosamia/Short_Shakesphere", | |
| filename="best_model_compressed.pt", | |
| token=os.getenv('HF_TOKEN') | |
| ) | |
| checkpoint = torch.load(model_path, map_location=config.device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(config.device) | |
| model.eval() | |
| return model | |
| def create_gradio_interface(): | |
| model = load_trained_model() | |
| def predict(prompt, max_length, temperature=0.7): | |
| return generate_text(model, prompt, max_length, temperature) | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Textbox( | |
| lines=3, | |
| label="Enter your prompt", | |
| placeholder="Start typing here..." | |
| ), | |
| gr.Slider( | |
| minimum=10, | |
| maximum=500, | |
| value=100, | |
| step=10, | |
| label="Maximum Length" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature (Higher = more creative)" | |
| ) | |
| ], | |
| outputs=gr.Textbox(lines=5, label="Generated Text"), | |
| title="Custom GPT Text Generator (124M) based on Shakespeare", | |
| description="A GPT-style language model trained on custom data by Shakespeare with 124M parameters" | |
| ) | |
| return interface | |
| # For Hugging Face Spaces | |
| if __name__ == "__main__": | |
| interface = create_gradio_interface() | |
| interface.launch() |