mathminakshi commited on
Commit
6141866
·
verified ·
1 Parent(s): 43ea451

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import tiktoken
4
+ from src.model import GPT, GPTConfig
5
+ from transformers import GPT2LMHeadModel
6
+
7
+ @st.cache_resource
8
+ def get_model():
9
+ """Load the trained GPT model."""
10
+ model = GPT(GPTConfig())
11
+ # Load from the Hugging Face Hub instead of local file
12
+ model_path = 'YOUR_USERNAME/YOUR_MODEL_REPO/final_best_model.pth'
13
+ model.load_state_dict(torch.hub.load_state_dict_from_url(f'https://huggingface.co/{model_path}/resolve/main/final_best_model.pth', map_location='cpu')['model_state_dict'])
14
+ model.eval()
15
+ return model
16
+
17
+ def generate_text(prompt, max_tokens=500, temperature=0.8, top_k=40):
18
+ """Generate text based on the prompt."""
19
+ # Encode the prompt
20
+ enc = tiktoken.get_encoding('gpt2')
21
+ input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
22
+
23
+ # Get cached model
24
+ model = get_model()
25
+
26
+ with torch.no_grad():
27
+ output_sequence = []
28
+ progress_bar = st.progress(0)
29
+
30
+ for i in range(max_tokens):
31
+ progress_bar.progress(i / max_tokens)
32
+
33
+ # Get predictions
34
+ outputs = model(input_ids)
35
+ logits = outputs.logits[:, -1, :] / temperature
36
+
37
+ # Apply top-k filtering
38
+ if top_k > 0:
39
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
40
+ logits[indices_to_remove] = float('-inf')
41
+
42
+ # Sample from the filtered distribution
43
+ probs = torch.nn.functional.softmax(logits, dim=-1)
44
+ next_token = torch.multinomial(probs, num_samples=1)
45
+
46
+ # Append to output
47
+ output_sequence.append(next_token.item())
48
+ input_ids = torch.cat([input_ids, next_token], dim=1)
49
+
50
+ # Stop if we generate an EOS token
51
+ if next_token.item() == 50256:
52
+ break
53
+
54
+ progress_bar.progress(1.0)
55
+ generated_text = enc.decode(output_sequence)
56
+ return prompt + generated_text
57
+
58
+ def main():
59
+ st.title("GPT Text Generator")
60
+ st.write("Enter a prompt to generate text using GPT-2.")
61
+
62
+ # Sidebar for parameters
63
+ st.sidebar.header("Generation Parameters")
64
+ max_tokens = st.sidebar.slider(
65
+ "Max Tokens",
66
+ min_value=1,
67
+ max_value=1000,
68
+ value=100,
69
+ help="Maximum number of tokens to generate"
70
+ )
71
+
72
+ temperature = st.sidebar.slider(
73
+ "Temperature",
74
+ min_value=0.1,
75
+ max_value=2.0,
76
+ value=0.8,
77
+ help="Higher values make the output more random"
78
+ )
79
+
80
+ top_k = st.sidebar.slider(
81
+ "Top-K",
82
+ min_value=1,
83
+ max_value=100,
84
+ value=40,
85
+ help="Limits the number of tokens to choose from"
86
+ )
87
+
88
+ prompt = st.text_area(
89
+ "Enter your prompt:",
90
+ height=100,
91
+ placeholder="Once upon a time..."
92
+ )
93
+
94
+ if st.button("Generate"):
95
+ if prompt:
96
+ with st.spinner("Generating text..."):
97
+ generated_text = generate_text(
98
+ prompt=prompt,
99
+ max_tokens=max_tokens,
100
+ temperature=temperature,
101
+ top_k=top_k
102
+ )
103
+ st.write("### Generated Text:")
104
+ st.write(generated_text)
105
+ else:
106
+ st.warning("Please enter a prompt first!")
107
+
108
+ if __name__ == "__main__":
109
+ main()