File size: 3,603 Bytes
6141866
 
 
fc6aa6c
6141866
 
 
 
 
 
 
77c1734
e9f0135
6141866
 
 
77c1734
6141866
1a1e2b6
628e455
 
 
 
 
2a44c04
 
 
 
 
 
6141866
 
 
 
 
 
 
 
 
2a44c04
ead2b35
6141866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import torch
import tiktoken
from model import GPT, GPTConfig
from transformers import GPT2LMHeadModel

@st.cache_resource
def get_model():
    """Load the trained GPT model."""
    model = GPT(GPTConfig())
    # Load from the Hugging Face Hub instead of local file
    model_path = 'mathminakshi/custom_gpt2'
    model.load_state_dict(torch.hub.load_state_dict_from_url(f'https://huggingface.co/{model_path}/resolve/main/model.pth', map_location='cpu'))
    model.eval()
    return model

def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
    """Generate text based on the prompt."""
       
    
    
    # Get cached model
    model = get_model()
    device = next(model.parameters()).device
    # Tokenize prompt with special token handling
    enc = tiktoken.get_encoding("gpt2")
    input_ids = torch.tensor(enc.encode(prompt, allowed_special={'<|endoftext|>'})).unsqueeze(0).to(device)
    
    # Get end token id
    end_token = enc.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]
    
    with torch.no_grad():
        output_sequence = []
        progress_bar = st.progress(0)
        
        for i in range(max_tokens):
            progress_bar.progress(i / max_tokens)
            
            # Get predictions
            logits,_ = model(input_ids)
            logits = logits[:, -1, :] / temperature
            
            # Apply top-k filtering
            if top_k > 0:
                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                logits[indices_to_remove] = float('-inf')
            
            # Sample from the filtered distribution
            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to output
            output_sequence.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            # Stop if we generate an EOS token
            if next_token.item() == 50256:
                break
    
    progress_bar.progress(1.0)
    generated_text = enc.decode(output_sequence)
    return prompt + generated_text

def main():
    st.title("GPT Text Generator")
    st.write("Enter a prompt to generate text using GPT-2.")
    
    # Sidebar for parameters
    st.sidebar.header("Generation Parameters")
    max_tokens = st.sidebar.slider(
        "Max Tokens", 
        min_value=1, 
        max_value=1000, 
        value=100,
        help="Maximum number of tokens to generate"
    )
    
    temperature = st.sidebar.slider(
        "Temperature", 
        min_value=0.1, 
        max_value=2.0, 
        value=0.8,
        help="Higher values make the output more random"
    )
    
    top_k = st.sidebar.slider(
        "Top-K", 
        min_value=1, 
        max_value=100, 
        value=40,
        help="Limits the number of tokens to choose from"
    )
    
    prompt = st.text_area(
        "Enter your prompt:",
        height=100,
        placeholder="Once upon a time..."
    )
    
    if st.button("Generate"):
        if prompt:
            with st.spinner("Generating text..."):
                generated_text = generate_text(
                    prompt=prompt,
                    max_tokens=max_tokens,
                    temperature=temperature,
                    top_k=top_k
                )
                st.write("### Generated Text:")
                st.write(generated_text)
        else:
            st.warning("Please enter a prompt first!")

if __name__ == "__main__":
    main()