| """Script to evaluate a model on a validation set. Based on scripts/generate.py from open_lm repo. | |
| """ | |
| import argparse | |
| import json | |
| import re | |
| import torch | |
| from open_lm.evaluate import evaluate_loop | |
| from open_lm.data import get_data | |
| from open_lm.model import create_model | |
| from open_lm.distributed import init_distributed_device | |
| from open_lm.params import parse_args | |
| from scripts.generate_without_hf import Generator, GenerationArgs | |
| def generate_model_jsonl(params): | |
| params_to_width_depth_dict = {5: (96, 3), | |
| 7: (128, 4), | |
| 9: (160, 5), | |
| 15: (224, 6), | |
| 22: (288, 8), | |
| 28: (320, 9), | |
| 37: (384, 10), | |
| 57: (480, 12), | |
| 84: (576, 14), | |
| 108: (640, 15), | |
| 149: (704, 18), | |
| 220: (832, 21), | |
| 347: (1024, 23), | |
| 455: (1120, 26), | |
| 611: (1312, 26), | |
| 901: (1504, 30) | |
| } | |
| width, depth = params_to_width_depth_dict[params] | |
| filepath = f"layers={depth}_hidden-dim={width}.json" | |
| data = { | |
| "hidden_dim": width, | |
| "n_layers": depth, | |
| "n_heads": 4, | |
| "seq_len": 2048, | |
| "vocab_size": 50432, | |
| "post_embed_norm": False, | |
| "weight_tying": False, | |
| "qk_norm": True | |
| } | |
| with open(filepath, 'w') as file: | |
| file.write(json.dumps(data) + '\n') | |
| return filepath | |
| class ModelArgs: | |
| def __init__(self, params, val_data, val_data_key): | |
| default_params = vars(parse_args("")) | |
| for k, v in default_params.items(): | |
| setattr(self, k, v) | |
| self.model = generate_model_jsonl(params) | |
| self.val_data = [val_data] | |
| self.val_data_key = [val_data_key] | |
| self.per_gpu_val_batch_size = 16 | |
| self.vocab_size = 50432 | |
| self.seq_len = 2048 | |
| self.wandb = False | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint", default="path/to/checkpoint") | |
| parser.add_argument("--val-data", default="", help="Path to validation data. If empty, generate text.") | |
| parser.add_argument("--val-data-key", default="json.gz") | |
| parser.add_argument("--input-text", default="", type=str, help="Input text to generate from. If empty, evaluate on validation data.") | |
| parser.add_argument("--max-gen-len", default=200, type=int) | |
| parser.add_argument("--temperature", default=0.8, type=float) | |
| parser.add_argument("--top-p", default=0.95, type=float) | |
| args = parser.parse_args() | |
| params = int(re.search(r"params=(\d+)", args.checkpoint).group(1)) | |
| checkpoint = torch.load(args.checkpoint) | |
| state_dict = checkpoint["state_dict"] | |
| state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} | |
| model_args = ModelArgs(params=params, val_data=args.val_data, val_data_key=args.val_data_key) | |
| device = init_distributed_device(model_args) | |
| model_args.device = device | |
| model = create_model(model_args) | |
| model.load_state_dict(state_dict) | |
| model.eval().cuda() | |
| if args.val_data != "": | |
| data = get_data( | |
| model_args, | |
| skip_train=True, | |
| ) | |
| metrics = evaluate_loop(model, data["val_list"], 0, model_args, None) | |
| print(metrics) | |
| elif args.input_text != "": | |
| model = model.half() | |
| generator = Generator(model) | |
| input_text = [ | |
| args.input_text, | |
| ] | |
| output = generator.generate( | |
| input_text, | |
| GenerationArgs(args.max_gen_len, args.temperature, args.top_p), | |
| ) | |
| print("".join(output)) | |
| else: | |
| print("Please provide either --val-data or --input-text") | |
| if __name__ == "__main__": | |
| main() | |