mathminakshi commited on
Commit
1a1e2b6
·
verified ·
1 Parent(s): 2a44c04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -16,9 +16,7 @@ def get_model():
16
 
17
  def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
18
  """Generate text based on the prompt."""
19
- # Encode the prompt
20
- device = next(model.parameters()).device
21
-
22
  # Tokenize prompt with special token handling
23
  enc = tiktoken.get_encoding("gpt2")
24
  input_ids = torch.tensor(enc.encode(prompt, allowed_special={'<|endoftext|>'})).unsqueeze(0).to(device)
@@ -28,6 +26,7 @@ def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
28
 
29
  # Get cached model
30
  model = get_model()
 
31
 
32
  with torch.no_grad():
33
  output_sequence = []
 
16
 
17
  def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
18
  """Generate text based on the prompt."""
19
+
 
 
20
  # Tokenize prompt with special token handling
21
  enc = tiktoken.get_encoding("gpt2")
22
  input_ids = torch.tensor(enc.encode(prompt, allowed_special={'<|endoftext|>'})).unsqueeze(0).to(device)
 
26
 
27
  # Get cached model
28
  model = get_model()
29
+ device = next(model.parameters()).device
30
 
31
  with torch.no_grad():
32
  output_sequence = []