FirstLLM / app.py
Krishnakanth1993's picture
Initial commit
0ede4e9
"""
Gradio App for Sentence Completion
Main entry point for Hugging Face Spaces
"""
import gradio as gr
import torch
from inference import load_model, generate_text, get_device
# Global model variable
model = None
device = None
def initialize_model(model_path=None, pretrained_model='gpt2'):
"""Initialize the model on startup"""
global model, device
try:
model, device = load_model(model_path=model_path, pretrained_model=pretrained_model)
return f"Model loaded successfully on device: {device}"
except Exception as e:
return f"Error loading model: {str(e)}"
def complete_sentence(prompt, max_tokens, top_k, temperature):
"""Generate sentence completion based on prompt"""
global model, device
if model is None:
return "Error: Model not loaded. Please restart the app."
if not prompt.strip():
return "Please enter a prompt to complete."
try:
# Ensure device is current
if device != get_device():
device = get_device()
model = model.to(device)
# Generate completion
generated_text = generate_text(
prompt=prompt,
model=model,
max_tokens=max_tokens,
top_k=top_k,
temperature=temperature,
device=device
)
return generated_text
except Exception as e:
return f"Error generating text: {str(e)}"
def create_interface():
"""Create and return the Gradio interface"""
# Initialize model on startup
# Try to load from common checkpoint paths
checkpoint_paths = [
'./model/model.pth',
'model.pt',
'checkpoint.pth',
'checkpoint.pt',
'gpt_model.pth',
]
model_path = None
for path in checkpoint_paths:
import os
if os.path.exists(path):
model_path = path
break
status = initialize_model(model_path=model_path, pretrained_model='gpt2')
print(status)
# Create Gradio interface
with gr.Blocks(title="Sentence Completion with GPT") as demo:
gr.Markdown(
"""
# Sentence Completion with GPT
Enter a prompt and the model will complete the sentence for you.
Adjust the parameters to control the generation behavior.
"""
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3,
value="The future of artificial intelligence is"
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10,
maximum=200,
value=50,
step=10,
label="Max Tokens"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-K"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
output_text = gr.Textbox(
label="Generated Text",
lines=10,
interactive=False
)
gr.Markdown(
"""
### Parameters:
- **Max Tokens**: Maximum number of tokens to generate
- **Top-K**: Sample from top K most likely tokens (lower = more focused)
- **Temperature**: Controls randomness (lower = more deterministic, higher = more creative)
"""
)
# Set up the generate function
generate_btn.click(
fn=complete_sentence,
inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider],
outputs=output_text
)
# Also generate on Enter key press
prompt_input.submit(
fn=complete_sentence,
inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider],
outputs=output_text
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=False)