Spaces:
Sleeping
Sleeping
| """ | |
| Gradio App for Sentence Completion | |
| Main entry point for Hugging Face Spaces | |
| """ | |
| import gradio as gr | |
| import torch | |
| from inference import load_model, generate_text, get_device | |
| # Global model variable | |
| model = None | |
| device = None | |
| def initialize_model(model_path=None, pretrained_model='gpt2'): | |
| """Initialize the model on startup""" | |
| global model, device | |
| try: | |
| model, device = load_model(model_path=model_path, pretrained_model=pretrained_model) | |
| return f"Model loaded successfully on device: {device}" | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| def complete_sentence(prompt, max_tokens, top_k, temperature): | |
| """Generate sentence completion based on prompt""" | |
| global model, device | |
| if model is None: | |
| return "Error: Model not loaded. Please restart the app." | |
| if not prompt.strip(): | |
| return "Please enter a prompt to complete." | |
| try: | |
| # Ensure device is current | |
| if device != get_device(): | |
| device = get_device() | |
| model = model.to(device) | |
| # Generate completion | |
| generated_text = generate_text( | |
| prompt=prompt, | |
| model=model, | |
| max_tokens=max_tokens, | |
| top_k=top_k, | |
| temperature=temperature, | |
| device=device | |
| ) | |
| return generated_text | |
| except Exception as e: | |
| return f"Error generating text: {str(e)}" | |
| def create_interface(): | |
| """Create and return the Gradio interface""" | |
| # Initialize model on startup | |
| # Try to load from common checkpoint paths | |
| checkpoint_paths = [ | |
| './model/model.pth', | |
| 'model.pt', | |
| 'checkpoint.pth', | |
| 'checkpoint.pt', | |
| 'gpt_model.pth', | |
| ] | |
| model_path = None | |
| for path in checkpoint_paths: | |
| import os | |
| if os.path.exists(path): | |
| model_path = path | |
| break | |
| status = initialize_model(model_path=model_path, pretrained_model='gpt2') | |
| print(status) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Sentence Completion with GPT") as demo: | |
| gr.Markdown( | |
| """ | |
| # Sentence Completion with GPT | |
| Enter a prompt and the model will complete the sentence for you. | |
| Adjust the parameters to control the generation behavior. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3, | |
| value="The future of artificial intelligence is" | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=50, | |
| step=10, | |
| label="Max Tokens" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top-K" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| interactive=False | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Parameters: | |
| - **Max Tokens**: Maximum number of tokens to generate | |
| - **Top-K**: Sample from top K most likely tokens (lower = more focused) | |
| - **Temperature**: Controls randomness (lower = more deterministic, higher = more creative) | |
| """ | |
| ) | |
| # Set up the generate function | |
| generate_btn.click( | |
| fn=complete_sentence, | |
| inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider], | |
| outputs=output_text | |
| ) | |
| # Also generate on Enter key press | |
| prompt_input.submit( | |
| fn=complete_sentence, | |
| inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider], | |
| outputs=output_text | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=False) | |