mathminakshi commited on
Commit
2a44c04
·
verified ·
1 Parent(s): e9f0135

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -17,8 +17,14 @@ def get_model():
17
  def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
18
  """Generate text based on the prompt."""
19
  # Encode the prompt
20
- enc = tiktoken.get_encoding('gpt2')
21
- input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
 
 
 
 
 
 
22
 
23
  # Get cached model
24
  model = get_model()
@@ -31,7 +37,7 @@ def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
31
  progress_bar.progress(i / max_tokens)
32
 
33
  # Get predictions
34
- outputs = model(input_ids)
35
  logits = outputs.logits[:, -1, :] / temperature
36
 
37
  # Apply top-k filtering
 
17
  def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
18
  """Generate text based on the prompt."""
19
  # Encode the prompt
20
+ device = next(model.parameters()).device
21
+
22
+ # Tokenize prompt with special token handling
23
+ enc = tiktoken.get_encoding("gpt2")
24
+ input_ids = torch.tensor(enc.encode(prompt, allowed_special={'<|endoftext|>'})).unsqueeze(0).to(device)
25
+
26
+ # Get end token id
27
+ end_token = enc.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]
28
 
29
  # Get cached model
30
  model = get_model()
 
37
  progress_bar.progress(i / max_tokens)
38
 
39
  # Get predictions
40
+ logits,_ = model(input_ids)
41
  logits = outputs.logits[:, -1, :] / temperature
42
 
43
  # Apply top-k filtering