File size: 3,308 Bytes
b84d85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import sys
import argparse
import logging
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from codsworth.config import CodsworthConfig
from codsworth.model import CodsworthTransformer
from codsworth.tokenizer import Tokenizer
from codsworth.utils import setup_logging, load_checkpoint, get_device
from codsworth.train.dataset import CodsworthDataset, CodsworthDataLoader
from codsworth.eval.evaluator import Evaluator


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate Codsworth model")
    
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
    parser.add_argument("--tokenizer", type=str, required=True, help="Path to tokenizer")
    parser.add_argument("--eval_files", type=str, nargs="+", default=["data/val/*.txt"])
    parser.add_argument("--output", type=str, default=None, help="Output file for results")
    
    parser.add_argument("--context_length", type=int, default=2048)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--max_batches", type=int, default=None, help="Max batches to evaluate")
    
    parser.add_argument("--device", type=str, default=None, help="Device to use")
    parser.add_argument("--dtype", type=str, default="bf16", help="Data type (bf16/fp32)")
    
    parser.add_argument("--log_level", type=str, default="INFO")
    
    return parser.parse_args()


def main():
    args = parse_args()
    
    logger = setup_logging(log_level=args.log_level)
    
    device = get_device() if args.device is None else torch.device(args.device)
    logger.info(f"Using device: {device}")
    
    tokenizer = Tokenizer.load(args.tokenizer)
    logger.info(f"Loaded tokenizer from {args.tokenizer}")
    
    dtype_map = {"bf16": torch.bfloat16, "fp32": torch.float32, "fp16": torch.float16}
    dtype = dtype_map.get(args.dtype, torch.bfloat16)
    
    config = CodsworthConfig(context_length=args.context_length)
    model = CodsworthTransformer(config)
    
    checkpoint = load_checkpoint(model, args.checkpoint, device=device)
    logger.info(f"Loaded checkpoint from {args.checkpoint}")
    
    model = model.to(device=device, dtype=dtype)
    model.eval()
    
    eval_dataset = CodsworthDataset(
        file_paths=args.eval_files,
        tokenizer=tokenizer,
        context_length=args.context_length,
        shuffle=False,
    )
    
    eval_loader = CodsworthDataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
    )
    
    logger.info(f"Eval dataset size: {len(eval_dataset)}")
    
    evaluator = Evaluator(model, tokenizer, config, device)
    
    logger.info("Starting evaluation...")
    results = evaluator.evaluate(eval_loader, max_batches=args.max_batches)
    
    logger.info("\n=== Evaluation Results ===")
    logger.info(f"Loss: {results['loss']:.4f}")
    logger.info(f"Perplexity: {results['perplexity']:.4f}")
    logger.info(f"Tokens evaluated: {results['total_tokens']:,}")
    
    if args.output:
        import json
        with open(args.output, "w") as f:
            json.dump(results, f, indent=2)
        logger.info(f"Results saved to {args.output}")


if __name__ == "__main__":
    main()