Chaitanya Sagar Gurujula
commited on
Commit
·
0e19c73
1
Parent(s):
8275a34
fixed generate method
Browse files- src/model.py +4 -1
src/model.py
CHANGED
|
@@ -197,7 +197,10 @@ class GPT(nn.Module):
|
|
| 197 |
def generate(self, input_ids, max_length=50,eos_token_id=None):
|
| 198 |
generated_tokens = []
|
| 199 |
current_ids = input_ids
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
| 201 |
for _ in range(max_length):
|
| 202 |
# Forward pass to get logits
|
| 203 |
logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)
|
|
|
|
| 197 |
def generate(self, input_ids, max_length=50,eos_token_id=None):
|
| 198 |
generated_tokens = []
|
| 199 |
current_ids = input_ids
|
| 200 |
+
|
| 201 |
+
# 🔥 Infer device from input_ids
|
| 202 |
+
device = input_ids.device
|
| 203 |
+
|
| 204 |
for _ in range(max_length):
|
| 205 |
# Forward pass to get logits
|
| 206 |
logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)
|