Ronakparmar commited on
Commit
3d7ebe2
·
verified ·
1 Parent(s): 43bd89d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -50,32 +50,33 @@ class GPT(nn.Module):
50
  return logits
51
 
52
  def generate(self, input_ids, max_new_tokens, temperature, top_k):
53
- # Implement the text generation logic
54
- output_ids = input_ids
55
- for _ in range(max_new_tokens):
56
- logits = self.forward(output_ids[:, -1:])
57
- logits = logits / temperature
58
- probs = F.softmax(logits, dim=-1)
59
-
60
- # Ensure probs is 2D
61
- if probs.dim() == 3:
62
- probs = probs.squeeze(0) # Remove the batch dimension if it exists
63
-
64
- top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
65
-
66
- # Ensure top_k_probs is 2D
67
- if top_k_probs.dim() == 1:
68
- top_k_probs = top_k_probs.unsqueeze(0)
69
-
70
- next_token = torch.multinomial(top_k_probs, num_samples=1)
71
- next_token = top_k_indices.gather(-1, next_token)
72
-
73
- # Ensure next_token is 2D
74
- if next_token.dim() == 1:
75
- next_token = next_token.unsqueeze(0)
76
-
77
- output_ids = torch.cat([output_ids, next_token], dim=1)
78
- return output_ids
 
79
  # Initialize global variables
80
  model = None
81
  tokenizer = None
 
50
  return logits
51
 
52
  def generate(self, input_ids, max_new_tokens, temperature, top_k):
53
+ # Implement the text generation logic
54
+ output_ids = input_ids
55
+ for _ in range(max_new_tokens):
56
+ logits = self.forward(output_ids[:, -1:])
57
+ logits = logits / temperature
58
+ probs = F.softmax(logits, dim=-1)
59
+
60
+ # Ensure probs is 2D
61
+ if probs.dim() == 3:
62
+ probs = probs.squeeze(0) # Remove the batch dimension if it exists
63
+
64
+ top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
65
+
66
+ # Ensure top_k_probs is 2D
67
+ if top_k_probs.dim() == 1:
68
+ top_k_probs = top_k_probs.unsqueeze(0)
69
+
70
+ next_token = torch.multinomial(top_k_probs, num_samples=1)
71
+ next_token = top_k_indices.gather(-1, next_token)
72
+
73
+ # Ensure next_token is 2D
74
+ if next_token.dim() == 1:
75
+ next_token = next_token.unsqueeze(0)
76
+
77
+ output_ids = torch.cat([output_ids, next_token], dim=1)
78
+ return output_ids
79
+
80
  # Initialize global variables
81
  model = None
82
  tokenizer = None