File size: 4,845 Bytes
0ede4e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Gradio App for Sentence Completion
Main entry point for Hugging Face Spaces
"""

import gradio as gr
import torch
from inference import load_model, generate_text, get_device


# Global model variable
model = None
device = None


def initialize_model(model_path=None, pretrained_model='gpt2'):
    """Initialize the model on startup"""
    global model, device
    try:
        model, device = load_model(model_path=model_path, pretrained_model=pretrained_model)
        return f"Model loaded successfully on device: {device}"
    except Exception as e:
        return f"Error loading model: {str(e)}"


def complete_sentence(prompt, max_tokens, top_k, temperature):
    """Generate sentence completion based on prompt"""
    global model, device
    
    if model is None:
        return "Error: Model not loaded. Please restart the app."
    
    if not prompt.strip():
        return "Please enter a prompt to complete."
    
    try:
        # Ensure device is current
        if device != get_device():
            device = get_device()
            model = model.to(device)
        
        # Generate completion
        generated_text = generate_text(
            prompt=prompt,
            model=model,
            max_tokens=max_tokens,
            top_k=top_k,
            temperature=temperature,
            device=device
        )
        
        return generated_text
    except Exception as e:
        return f"Error generating text: {str(e)}"


def create_interface():
    """Create and return the Gradio interface"""
    
    # Initialize model on startup
    # Try to load from common checkpoint paths
    checkpoint_paths = [
        './model/model.pth',
        'model.pt',
        'checkpoint.pth',
        'checkpoint.pt',
        'gpt_model.pth',
    ]
    
    model_path = None
    for path in checkpoint_paths:
        import os
        if os.path.exists(path):
            model_path = path
            break
    
    status = initialize_model(model_path=model_path, pretrained_model='gpt2')
    print(status)
    
    # Create Gradio interface
    with gr.Blocks(title="Sentence Completion with GPT") as demo:
        gr.Markdown(
            """
            # Sentence Completion with GPT
            
            Enter a prompt and the model will complete the sentence for you.
            Adjust the parameters to control the generation behavior.
            """
        )
        
        with gr.Row():
            with gr.Column(scale=2):
                prompt_input = gr.Textbox(
                    label="Prompt",
                    placeholder="Enter your prompt here...",
                    lines=3,
                    value="The future of artificial intelligence is"
                )
                
                with gr.Row():
                    max_tokens_slider = gr.Slider(
                        minimum=10,
                        maximum=200,
                        value=50,
                        step=10,
                        label="Max Tokens"
                    )
                    
                    top_k_slider = gr.Slider(
                        minimum=1,
                        maximum=100,
                        value=50,
                        step=1,
                        label="Top-K"
                    )
                    
                    temperature_slider = gr.Slider(
                        minimum=0.1,
                        maximum=2.0,
                        value=1.0,
                        step=0.1,
                        label="Temperature"
                    )
                
                generate_btn = gr.Button("Generate", variant="primary")
            
            with gr.Column(scale=2):
                output_text = gr.Textbox(
                    label="Generated Text",
                    lines=10,
                    interactive=False
                )
        
        gr.Markdown(
            """
            ### Parameters:
            - **Max Tokens**: Maximum number of tokens to generate
            - **Top-K**: Sample from top K most likely tokens (lower = more focused)
            - **Temperature**: Controls randomness (lower = more deterministic, higher = more creative)
            """
        )
        
        # Set up the generate function
        generate_btn.click(
            fn=complete_sentence,
            inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider],
            outputs=output_text
        )
        
        # Also generate on Enter key press
        prompt_input.submit(
            fn=complete_sentence,
            inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider],
            outputs=output_text
        )
    
    return demo


if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=False)