File size: 3,308 Bytes
b84d85a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import os
import sys
import argparse
import logging
from pathlib import Path
import torch
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from codsworth.config import CodsworthConfig
from codsworth.model import CodsworthTransformer
from codsworth.tokenizer import Tokenizer
from codsworth.utils import setup_logging, load_checkpoint, get_device
from codsworth.train.dataset import CodsworthDataset, CodsworthDataLoader
from codsworth.eval.evaluator import Evaluator
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate Codsworth model")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
parser.add_argument("--tokenizer", type=str, required=True, help="Path to tokenizer")
parser.add_argument("--eval_files", type=str, nargs="+", default=["data/val/*.txt"])
parser.add_argument("--output", type=str, default=None, help="Output file for results")
parser.add_argument("--context_length", type=int, default=2048)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--max_batches", type=int, default=None, help="Max batches to evaluate")
parser.add_argument("--device", type=str, default=None, help="Device to use")
parser.add_argument("--dtype", type=str, default="bf16", help="Data type (bf16/fp32)")
parser.add_argument("--log_level", type=str, default="INFO")
return parser.parse_args()
def main():
args = parse_args()
logger = setup_logging(log_level=args.log_level)
device = get_device() if args.device is None else torch.device(args.device)
logger.info(f"Using device: {device}")
tokenizer = Tokenizer.load(args.tokenizer)
logger.info(f"Loaded tokenizer from {args.tokenizer}")
dtype_map = {"bf16": torch.bfloat16, "fp32": torch.float32, "fp16": torch.float16}
dtype = dtype_map.get(args.dtype, torch.bfloat16)
config = CodsworthConfig(context_length=args.context_length)
model = CodsworthTransformer(config)
checkpoint = load_checkpoint(model, args.checkpoint, device=device)
logger.info(f"Loaded checkpoint from {args.checkpoint}")
model = model.to(device=device, dtype=dtype)
model.eval()
eval_dataset = CodsworthDataset(
file_paths=args.eval_files,
tokenizer=tokenizer,
context_length=args.context_length,
shuffle=False,
)
eval_loader = CodsworthDataLoader(
eval_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
)
logger.info(f"Eval dataset size: {len(eval_dataset)}")
evaluator = Evaluator(model, tokenizer, config, device)
logger.info("Starting evaluation...")
results = evaluator.evaluate(eval_loader, max_batches=args.max_batches)
logger.info("\n=== Evaluation Results ===")
logger.info(f"Loss: {results['loss']:.4f}")
logger.info(f"Perplexity: {results['perplexity']:.4f}")
logger.info(f"Tokens evaluated: {results['total_tokens']:,}")
if args.output:
import json
with open(args.output, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {args.output}")
if __name__ == "__main__":
main() |