| import torch |
| from ..hellaswag import render_example, iterate_examples, get_most_likely_row |
| import torch.distributed as dist |
| from torch.distributed import init_process_group, destroy_process_group |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import os |
| from ..ModelGPT2 import GPT,log_file |
|
|
| ddp = int(os.environ.get('RANK', -1)) != -1 |
| if ddp: |
| assert torch.cuda.is_available() |
| init_process_group(backend='nccl') |
| ddp_rank = int(os.environ['RANK']) |
| ddp_local_rank = int(os.environ['LOCAL_RANK']) |
| ddp_world_size = int(os.environ['WORLD_SIZE']) |
| device = f"cuda:{ddp_local_rank}" |
| torch.cuda.set_device(device) |
| master_process = ddp_rank == 0 |
| else: |
| ddp_rank = 0 |
| ddp_local_rank = 0 |
| ddp_world_size = 1 |
| master_process = True |
| |
| device = 'cpu' |
| if torch.cuda.is_available(): |
| device = 'cuda' |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| device = "mps" |
| print(f"Using device: {device}") |
| |
|
|
| device_type = "cuda" if device.startswith("cuda") else "cpu" |
|
|
| torch.manual_seed(1337) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(1337) |
|
|
|
|
| |
| checkpoint_path = '../log/model_final.pt' |
| if master_process: |
| print(f"Loading checkpoint from {checkpoint_path}") |
|
|
| checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
| |
| model_config = checkpoint['config'] |
| model_config.vocab_size = 50304 |
| model = GPT(model_config) |
| |
| model.load_state_dict(checkpoint['model']) |
| model = DDP(model, device_ids=[ddp_local_rank]) |
| model.to(device) |
|
|
|
|
| def evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process): |
| |
| num_correct_norm = 0 |
| num_total = 0 |
|
|
| for i, example in enumerate(iterate_examples("val")): |
| |
| if ddp: |
| if i % ddp_world_size != ddp_rank: |
| continue |
| |
| _, tokens, mask, label = render_example(example) |
| tokens = tokens.to(device) |
| mask = mask.to(device) |
| |
| with torch.no_grad(): |
| with torch.autocast(device_type=device_type, dtype=torch.bfloat16): |
| logits, loss = model(tokens) |
| pred_norm = get_most_likely_row(tokens, mask, logits) |
| num_total += 1 |
| num_correct_norm += int(pred_norm == label) |
| |
| if ddp: |
| num_total = torch.tensor(num_total, dtype=torch.long, device=device) |
| num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device) |
| dist.all_reduce(num_total, op=dist.ReduceOp.SUM) |
| dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM) |
| num_total = num_total.item() |
| num_correct_norm = num_correct_norm.item() |
| acc_norm = num_correct_norm / num_total |
| if master_process: |
| print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}") |
| with open(log_file, "a") as f: |
| f.write(f"Final Hellaswag accuracy: {acc_norm:.4f}\n") |
|
|
| evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process) |
| if ddp: |
| destroy_process_group() |