| | import torch |
| |
|
| | |
| | from builtin_architecture import make_model |
| | import os |
| | import sys |
| | import time |
| | from dataset import dataset, get_train_dataset |
| | import torch.nn.functional as F |
| |
|
| | EXPERIMENT_DIRECTORY = "runs/code-decoder-v10-vanilla-smaller-batchfirst" |
| |
|
| | device = "mps" if torch.backends.mps.is_available() else "cpu" |
| |
|
| | device = "cpu" |
| |
|
| | |
| | net = make_model() |
| | net.to(device) |
| |
|
| | net.load_state_dict( |
| | torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "best.pt"), weights_only=True) |
| | ) |
| |
|
| |
|
| | for name, param in net.named_parameters(): |
| | if torch.isnan(param).any(): |
| | print(f"NaN found in {name}") |
| | for name, param in net.named_parameters(): |
| | if param.grad is not None and torch.isnan(param.grad).any(): |
| | print(f"NaN found in gradients of {name}") |
| |
|
| |
|
| | pad_token_id = 0 |
| | sep_token_id = None |
| |
|
| | input_text = input("Prompt: ") |
| | max_length = 100 |
| |
|
| |
|
| | input_ids = torch.tensor(dataset.manager.encode(input_text), dtype=int) |
| | print(input_ids.shape) |
| | attention_mask = dataset.manager.attention_mask(input_ids.squeeze(0)).to(device) |
| |
|
| |
|
| | generated_text = dataset.manager.decode(input_ids) |
| |
|
| | print(generated_text) |
| | generated_text = "" |
| | input_ids = torch.randint(199, (1, 1), dtype=torch.long).to(device) |
| |
|
| | net.eval() |
| | temp = 1.0 |
| |
|
| | for _ in range(max_length): |
| | with torch.no_grad(): |
| | output = net(input_ids) |
| | logits = F.log_softmax(output[-1], dim=-1) |
| | word_weights = logits.div(temp).cpu() |
| |
|
| | |
| | top_k = 10 |
| | vocab_size = word_weights.size(0) |
| | top_k = min(top_k, vocab_size) |
| |
|
| | top_probs, top_indices = torch.topk(word_weights, k=top_k) |
| |
|
| | |
| | if top_probs.size(0) == 1: |
| | word_idx = top_indices[0] |
| | else: |
| | sampled_idx = torch.multinomial(top_probs, 1).item() |
| | word_idx = top_indices[sampled_idx] |
| |
|
| | |
| | print(word_idx) |
| | predicted_token = dataset.manager.decode(word_idx.item()) |
| | print(predicted_token, end=" ") |
| | generated_text += predicted_token |
| |
|
| | print("Word Weights:", word_weights) |
| | print("Top Probabilities:", top_probs) |
| | print("Top Indices:", top_indices) |
| |
|
| | |
| | word_tensor = torch.tensor([[word_idx]], dtype=torch.long).to(device) |
| | input_ids = torch.cat([input_ids, word_tensor], dim=1) |
| |
|
| | print("\nGenerated text:", generated_text) |
| | with open("output.txt", "w+") as f: |
| | f.write(generated_text) |
| |
|