mathminakshi commited on
Commit
628e455
·
verified ·
1 Parent(s): 1a1e2b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -17,6 +17,11 @@ def get_model():
17
  def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
18
  """Generate text based on the prompt."""
19
 
 
 
 
 
 
20
  # Tokenize prompt with special token handling
21
  enc = tiktoken.get_encoding("gpt2")
22
  input_ids = torch.tensor(enc.encode(prompt, allowed_special={'<|endoftext|>'})).unsqueeze(0).to(device)
@@ -24,10 +29,6 @@ def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
24
  # Get end token id
25
  end_token = enc.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]
26
 
27
- # Get cached model
28
- model = get_model()
29
- device = next(model.parameters()).device
30
-
31
  with torch.no_grad():
32
  output_sequence = []
33
  progress_bar = st.progress(0)
 
17
  def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
18
  """Generate text based on the prompt."""
19
 
20
+
21
+
22
+ # Get cached model
23
+ model = get_model()
24
+ device = next(model.parameters()).device
25
  # Tokenize prompt with special token handling
26
  enc = tiktoken.get_encoding("gpt2")
27
  input_ids = torch.tensor(enc.encode(prompt, allowed_special={'<|endoftext|>'})).unsqueeze(0).to(device)
 
29
  # Get end token id
30
  end_token = enc.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]
31
 
 
 
 
 
32
  with torch.no_grad():
33
  output_sequence = []
34
  progress_bar = st.progress(0)