File size: 3,395 Bytes
6141866
 
 
fc6aa6c
6141866
 
 
 
 
 
 
77c1734
 
6141866
 
 
77c1734
6141866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
import torch
import tiktoken
from model import GPT, GPTConfig
from transformers import GPT2LMHeadModel

@st.cache_resource
def get_model():
    """Load the trained GPT model."""
    model = GPT(GPTConfig())
    # Load from the Hugging Face Hub instead of local file
    model_path = 'mathminakshi/custom_gpt2'
    model.load_state_dict(torch.hub.load_state_dict_from_url(f'https://huggingface.co/{model_path}/resolve/main/best_model.pth', map_location='cpu')['model_state_dict'])
    model.eval()
    return model

def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
    """Generate text based on the prompt."""
    # Encode the prompt
    enc = tiktoken.get_encoding('gpt2')
    input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
    
    # Get cached model
    model = get_model()
    
    with torch.no_grad():
        output_sequence = []
        progress_bar = st.progress(0)
        
        for i in range(max_tokens):
            progress_bar.progress(i / max_tokens)
            
            # Get predictions
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :] / temperature
            
            # Apply top-k filtering
            if top_k > 0:
                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                logits[indices_to_remove] = float('-inf')
            
            # Sample from the filtered distribution
            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to output
            output_sequence.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            # Stop if we generate an EOS token
            if next_token.item() == 50256:
                break
    
    progress_bar.progress(1.0)
    generated_text = enc.decode(output_sequence)
    return prompt + generated_text

def main():
    st.title("GPT Text Generator")
    st.write("Enter a prompt to generate text using GPT-2.")
    
    # Sidebar for parameters
    st.sidebar.header("Generation Parameters")
    max_tokens = st.sidebar.slider(
        "Max Tokens", 
        min_value=1, 
        max_value=1000, 
        value=100,
        help="Maximum number of tokens to generate"
    )
    
    temperature = st.sidebar.slider(
        "Temperature", 
        min_value=0.1, 
        max_value=2.0, 
        value=0.8,
        help="Higher values make the output more random"
    )
    
    top_k = st.sidebar.slider(
        "Top-K", 
        min_value=1, 
        max_value=100, 
        value=40,
        help="Limits the number of tokens to choose from"
    )
    
    prompt = st.text_area(
        "Enter your prompt:",
        height=100,
        placeholder="Once upon a time..."
    )
    
    if st.button("Generate"):
        if prompt:
            with st.spinner("Generating text..."):
                generated_text = generate_text(
                    prompt=prompt,
                    max_tokens=max_tokens,
                    temperature=temperature,
                    top_k=top_k
                )
                st.write("### Generated Text:")
                st.write(generated_text)
        else:
            st.warning("Please enter a prompt first!")

if __name__ == "__main__":
    main()