Update logging
Browse files- bad_gpt.py +3 -1
- main.py +7 -4
bad_gpt.py
CHANGED
|
@@ -57,7 +57,9 @@ class BadGPTModel(nn.Module):
|
|
| 57 |
# generate new tokens in the next sentence
|
| 58 |
def generate(self, idx: torch.Tensor, max_new_tokens: int):
|
| 59 |
for _ in range(max_new_tokens):
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
# Crop out the last block_size tokens
|
| 62 |
cropped_idx = idx[:, -self.block_size:]
|
| 63 |
logits = self(cropped_idx)
|
|
|
|
| 57 |
# generate new tokens in the next sentence
|
| 58 |
def generate(self, idx: torch.Tensor, max_new_tokens: int):
|
| 59 |
for _ in range(max_new_tokens):
|
| 60 |
+
# Log progress so I don't go insane
|
| 61 |
+
if _ % 16 == 0:
|
| 62 |
+
logger.debug(f'Iteration {_} of {max_new_tokens}')
|
| 63 |
# Crop out the last block_size tokens
|
| 64 |
cropped_idx = idx[:, -self.block_size:]
|
| 65 |
logits = self(cropped_idx)
|
main.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
from bad_gpt import BadGPT
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# HYPERPARAMETERS #
|
| 6 |
### Impacts performance ###
|
|
@@ -19,6 +21,7 @@ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
| 19 |
|
| 20 |
|
| 21 |
if __name__ == '__main__':
|
|
|
|
| 22 |
bad_gpt = BadGPT(
|
| 23 |
device=DEVICE,
|
| 24 |
batch_size=BATCH_SIZE,
|
|
@@ -32,7 +35,7 @@ if __name__ == '__main__':
|
|
| 32 |
lr=LEARNING_RATE
|
| 33 |
)
|
| 34 |
|
| 35 |
-
|
| 36 |
resp = bad_gpt.generate(
|
| 37 |
'JULIET:\nRomeo, Romeo, wherefore art thou Romeo?', 256)
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 1 |
from bad_gpt import BadGPT
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
logging.basicConfig()
|
| 5 |
+
logger = logging.getLogger('bad_gpt').getChild(__name__)
|
| 6 |
|
| 7 |
# HYPERPARAMETERS #
|
| 8 |
### Impacts performance ###
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
if __name__ == '__main__':
|
| 24 |
+
logging.getLogger('bad_gpt').setLevel(logging.DEBUG)
|
| 25 |
bad_gpt = BadGPT(
|
| 26 |
device=DEVICE,
|
| 27 |
batch_size=BATCH_SIZE,
|
|
|
|
| 35 |
lr=LEARNING_RATE
|
| 36 |
)
|
| 37 |
|
| 38 |
+
logger.info("Generating response...")
|
| 39 |
resp = bad_gpt.generate(
|
| 40 |
'JULIET:\nRomeo, Romeo, wherefore art thou Romeo?', 256)
|
| 41 |
+
logger.info("Response:\n" + resp)
|