|
|
import math |
|
|
import argparse |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from model.gpt_model import GPTModel |
|
|
from data.dataset import TextDataset |
|
|
from data import utils |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Evaluate a trained OpenGPT model on a validation set.") |
|
|
parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint (.pt file).") |
|
|
parser.add_argument("--config", type=str, required=True, help="Path to the model config file (YAML/JSON).") |
|
|
parser.add_argument("--tokenizer", type=str, required=True, help="Path to the trained tokenizer (.json or directory).") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
config = utils.load_config(args.config) |
|
|
model_conf = config.get("model", {}) |
|
|
data_conf = config.get("data", {}) |
|
|
vocab_size = model_conf["vocab_size"] |
|
|
max_pos = model_conf.get("max_position_embeddings", 512) |
|
|
hidden_dim = model_conf.get("embedding_dim", 768) |
|
|
n_layers = model_conf.get("n_layers", 12) |
|
|
n_heads = model_conf.get("n_heads", 12) |
|
|
dropout = model_conf.get("dropout", 0.0) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = GPTModel(vocab_size=vocab_size, max_position_embeddings=max_pos, |
|
|
n_layers=n_layers, n_heads=n_heads, hidden_dim=hidden_dim, |
|
|
dropout=dropout).to(device) |
|
|
model.eval() |
|
|
utils.load_checkpoint(model, optimizer=None, filepath=args.model, device=device) |
|
|
|
|
|
|
|
|
valid_path = data_conf.get("valid_path", data_conf.get("train_path")) |
|
|
block_size = data_conf.get("block_size", 128) |
|
|
dataset = TextDataset(valid_path, args.tokenizer, block_size) |
|
|
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) |
|
|
|
|
|
total_loss = 0.0 |
|
|
total_tokens = 0 |
|
|
with torch.no_grad(): |
|
|
for inputs, targets in loader: |
|
|
inputs = inputs.to(device) |
|
|
targets = targets.to(device) |
|
|
outputs = model(inputs) |
|
|
|
|
|
loss = F.cross_entropy(outputs.view(-1, vocab_size), targets.view(-1), reduction='sum') |
|
|
total_loss += loss.item() |
|
|
total_tokens += targets.numel() |
|
|
avg_nll = total_loss / total_tokens |
|
|
perplexity = math.exp(avg_nll) |
|
|
print(f"Validation Perplexity: {perplexity:.4f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|