| | import torch |
| |
|
| | |
| | from builtin_architecture import make_model |
| | import os |
| | import sys |
| | import time |
| | from dataset import dataset, get_train_dataset, get_dataloader |
| | import torch.nn.functional as F |
| | from tqdm import tqdm, trange |
| | import heapq |
| |
|
| | EXPERIMENT_DIRECTORY = "runs/code-decoder-v23-mega" |
| |
|
| | device = "mps" if torch.backends.mps.is_available() else "cpu" |
| |
|
| | device = "cpu" |
| |
|
| |
|
| | def evaluate_topk(model, start_sequence, amt=10, k=20, temperature=1.0, device="cpu"): |
| | generated_sequence = start_sequence.clone().to(device) |
| |
|
| | model.eval() |
| | with torch.no_grad(): |
| | for _ in trange(amt, leave=False, dynamic_ncols=True, desc="topk"): |
| | seq = generated_sequence |
| | results = model(seq, transpose=True) |
| | results = results.transpose(0, 1) |
| |
|
| | logits = results.reshape(-1, results.size(-1))[-1] |
| |
|
| | logits = logits / temperature |
| |
|
| | top_k_values, top_k_indices = torch.topk(logits, k) |
| | top_k_probs = F.softmax(top_k_values, dim=-1) |
| |
|
| | sampled_index = torch.multinomial(top_k_probs, 1).item() |
| | next_token = top_k_indices[sampled_index].unsqueeze(0) |
| |
|
| | generated_sequence = torch.cat( |
| | (generated_sequence, next_token.unsqueeze(0)), dim=1 |
| | ) |
| |
|
| | return generated_sequence |
| |
|
| |
|
| | def evaluate_topp(model, start_sequence, amt=10, p=0.9, temperature=1.0, device="cpu"): |
| | generated_sequence = start_sequence.clone().to(device) |
| |
|
| | model.eval() |
| | with torch.no_grad(): |
| | for _ in trange(amt, leave=False, dynamic_ncols=True, desc="topp"): |
| | seq = generated_sequence |
| | results = model(seq, transpose=True) |
| | results = results.transpose(0, 1) |
| |
|
| | logits = results.reshape(-1, results.size(-1))[-1] |
| | logits = logits / temperature |
| |
|
| | probs = F.softmax(logits, dim=-1) |
| |
|
| | sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
| | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| |
|
| | cutoff_idx = torch.where(cumulative_probs > p)[0][0] + 1 |
| | top_p_probs = sorted_probs[:cutoff_idx] |
| | top_p_indices = sorted_indices[:cutoff_idx] |
| |
|
| | |
| | top_p_probs /= top_p_probs.sum() |
| |
|
| | |
| | sampled_index = torch.multinomial(top_p_probs, 1).item() |
| | next_token = top_p_indices[sampled_index].unsqueeze(0) |
| |
|
| | generated_sequence = torch.cat( |
| | (generated_sequence, next_token.unsqueeze(0)), dim=1 |
| | ) |
| |
|
| | return generated_sequence |
| |
|
| |
|
| | def evaluate_beam(model, start_sequence, k=2, amt=10, temperature=0.8, device="cpu"): |
| | generated_sequence = start_sequence.clone().to(device) |
| |
|
| | model.eval() |
| |
|
| | |
| | current_beams = generated_sequence.expand(k, -1) |
| | current_beam_scores = torch.zeros(k, device=device) |
| |
|
| | with torch.no_grad(): |
| | for _ in trange(amt, leave=False, dynamic_ncols=True, desc="beam"): |
| | all_candidates = [] |
| |
|
| | |
| | for i in range(k): |
| | seq = current_beams[i].unsqueeze(0) |
| | results = model(seq, transpose=True) |
| | results = results.transpose(0, 1) |
| |
|
| | logits = results[:, -1, :] / temperature |
| | topk_values, topk_indices = torch.topk(logits, k) |
| |
|
| | |
| | for j in range(k): |
| | candidate = torch.cat((seq, topk_indices[:, j].unsqueeze(0)), dim=1) |
| | score = current_beam_scores[i] + topk_values[:, j] |
| | all_candidates.append((candidate, score)) |
| |
|
| | |
| | all_candidates.sort(key=lambda x: x[1], reverse=True) |
| | top_candidates = all_candidates[:k] |
| |
|
| | current_beams = torch.cat([candidate for candidate, _ in top_candidates]) |
| | current_beam_scores = torch.tensor( |
| | [score.item() for _, score in top_candidates], device=device |
| | ) |
| |
|
| | return current_beams[0] |
| |
|
| |
|
| | def evaluate( |
| | model, |
| | start_sequence, |
| | amt=10, |
| | ): |
| | generated_sequence = start_sequence.clone() |
| | generated_sequence = generated_sequence.to(device) |
| |
|
| | model.eval() |
| | with torch.no_grad(): |
| | for _ in trange(amt, leave=False): |
| | seq = generated_sequence |
| | results = model(seq, transpose=True) |
| | results = results.transpose(0, 1) |
| |
|
| | next_token = torch.argmax(results.reshape(-1, results.size(-1)), dim=1)[ |
| | -1 |
| | ].unsqueeze(0) |
| |
|
| | generated_sequence = torch.cat( |
| | (generated_sequence, next_token.unsqueeze(0)), dim=1 |
| | ) |
| |
|
| | return generated_sequence |
| |
|
| |
|
| | def tester_exactly_like_trainingmanager_please_please_work(model, rawbatch): |
| | labels = rawbatch[:, 1:].contiguous() |
| | batch = rawbatch[:, :-1].contiguous() |
| | results = model(batch, transpose=True) |
| | results = results.transpose(0, 1) |
| | print( |
| | torch.sum( |
| | torch.argmax(results.reshape(-1, results.size(-1)), dim=1) |
| | == labels.reshape(-1) |
| | ) |
| | / len(labels.reshape(-1)) |
| | ) |
| | return torch.argmax(results.reshape(-1, results.size(-1)), dim=1), labels.reshape( |
| | -1 |
| | ) |
| |
|
| |
|
| | def tester_exactly_like_trainingmanager_only_last_please_work(model, rawbatch): |
| | labels = rawbatch[:, 1:].contiguous() |
| | batch = rawbatch[:, :-1].contiguous() |
| |
|
| | batch = batch[-1].unsqueeze(0) |
| | labels = labels[-1].unsqueeze(0) |
| |
|
| | results = model(batch, transpose=True) |
| | results = results.transpose(0, 1) |
| | print( |
| | torch.sum( |
| | torch.argmax(results.reshape(-1, results.size(-1)), dim=1) |
| | == labels.reshape(-1) |
| | ) |
| | / len(labels.reshape(-1)) |
| | ) |
| | return torch.argmax(results.reshape(-1, results.size(-1)), dim=1), labels.reshape( |
| | -1 |
| | ) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | return torch.argmax(results.reshape(-1, results.size(-1)), dim=1)[-1] |
| |
|
| |
|
| | def compute_entropy(logits): |
| |
|
| | probs = F.softmax(logits, dim=-1) |
| | entropy = -(probs * probs.log()).sum(dim=-1) |
| | return entropy.mean().item() |
| |
|
| |
|
| | def main(): |
| | |
| | net = make_model() |
| | net.to(device) |
| | print(os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "latest.pt")) |
| | net.load_state_dict( |
| | torch.load( |
| | os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "latest.pt"), weights_only=True |
| | ) |
| | ) |
| |
|
| | for name, param in net.named_parameters(): |
| | if torch.isnan(param).any(): |
| | print(f"NaN found in {name}") |
| | for name, param in net.named_parameters(): |
| | if param.grad is not None and torch.isnan(param.grad).any(): |
| | print(f"NaN found in gradients of {name}") |
| | loader = get_dataloader(get_train_dataset()) |
| | torch.random.manual_seed( |
| | sum([ord(i) for i in input("seed? ")]) |
| | ) |
| | for data in loader: |
| | batch, attn_mask = data |
| |
|
| | print( |
| | tester_exactly_like_trainingmanager_please_please_work(net, rawbatch=batch) |
| | ) |
| | print("pretty please") |
| |
|
| | print( |
| | tester_exactly_like_trainingmanager_only_last_please_work( |
| | net, rawbatch=batch |
| | ) |
| | ) |
| | print("please please please") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | labels = batch[:, 1:].contiguous() |
| | batch = batch[:, :-1].contiguous() |
| |
|
| | batch = batch[0] |
| | labels = labels[0] |
| |
|
| | batch = batch[:100] |
| | labels = labels[:100] |
| | print("Getting first 100 tokens for batch and labels") |
| |
|
| | |
| |
|
| | |
| | print(batch) |
| | print(dataset.manager.decode(batch)) |
| | print("batch ^ labels v") |
| | print(dataset.manager.decode(labels)) |
| | print("that's inp I guess ^^") |
| | with torch.no_grad(): |
| | logits = net(batch.unsqueeze(0)) |
| | entropy = compute_entropy( |
| | logits[:, -1, :] |
| | ) |
| |
|
| | print(f"Entropy of last token: {entropy:.4f}") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | print("USING BEAM") |
| | result = evaluate_beam(net, batch.unsqueeze(0), amt=100, k=3) |
| |
|
| | result = dataset.manager.decode(result) |
| | batch_str = dataset.manager.decode(batch) |
| |
|
| | result = f"<data>\n{batch_str}</data>\n{result[len(batch_str):]}" |
| |
|
| | print(result) |
| |
|
| | |
| |
|
| | break |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|