""" Evaluation and qualitative error analysis helpers for the IMDB Transformer model. This module is separate from `c1.py` and focuses only on: - Loading a previously trained model from disk. - Evaluating it on an IMDB split. - Inspecting misclassified examples for qualitative error analysis. """ from typing import Dict, List, Tuple import argparse import os import torch import torch.nn.functional as F from torch.utils.data import DataLoader from c1 import ( IMDBDataset, TransformerClassifier, preprocess_data, evaluate_model, load_imdb_texts, ) # Keep output/checkpoint paths relative to the current working directory. SAVE_DIR = os.path.join(".", "saved_model") MODEL_PATH = os.path.join(SAVE_DIR, "transformer_imdb.pt") def analyze_misclassifications_on_texts( model: torch.nn.Module, texts: List[str], labels: List[int], vocab: Dict[str, int], max_len: int, device: torch.device, num_examples: int = 5, ) -> None: """ Inspect concrete examples where the model makes mistakes to understand *why* it fails and how to improve it. How to read the output (practical guidance): - Start with the true vs. predicted label: - For each misclassified review, ask whether the ground-truth label actually matches the human-intuitive sentiment. Occasional noisy labels are common in IMDB-style datasets. - Look at the confidence vector: - Very confident but wrong predictions often indicate *systematic bias* (e.g., the model over-trusts certain keywords like "great", "worst"). - Low-confidence errors may simply reflect inherently ambiguous reviews. - Scan the text content: - Check for **rare or domain-specific words** (brand names, slang, technical jargon) that might not appear often enough in training. - Look for **negation patterns** ("not good", "hardly bad", "no longer terrible") where bag-of-words style cues can mislead attention. - Notice **mixed sentiment** or **topic vs. opinion** separation (e.g., long plot summary plus a brief opinion at the end). - Pay attention to **sarcasm and irony**, which are notoriously hard for models relying mostly on local lexical cues. - Compare several misclassified examples: - If you see many errors with long reviews, consider increasing MAX_LEN or using a deeper model. - If errors cluster around subtle, low-intensity sentiment, you may need more expressive capacity (higher d_model / more layers) or additional training data. Based on these observations you can propose targeted improvements, such as: - Expanding the vocabulary or switching to subword tokenization. - Adjusting hyperparameters (sequence length, model size). - Incorporating pre-trained language models for richer semantics. """ model.eval() sequences = preprocess_data(texts, vocab, max_len) dataset = IMDBDataset(sequences, labels) loader = DataLoader(dataset, batch_size=64, shuffle=False) printed = 0 with torch.no_grad(): for batch_idx, (batch_seq, batch_lab) in enumerate(loader): batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device) logits = model(batch_seq) probs = F.softmax(logits, dim=1) preds = torch.argmax(probs, dim=1) start = batch_idx * loader.batch_size end = start + batch_seq.size(0) batch_texts = texts[start:end] for text, true_y, pred_y, prob_vec in zip( batch_texts, batch_lab.cpu().numpy(), preds.cpu().numpy(), probs.cpu().numpy(), ): if true_y != pred_y: printed += 1 print("=" * 80) print(f"Misclassified example #{printed}") print(f"True label : {true_y} (0=neg, 1=pos)") print(f"Predicted label: {pred_y}") print(f"Model confidence (class 0, class 1): {prob_vec}") if printed >= num_examples: print("=" * 80) print( f"Displayed the first {num_examples} misclassified " "examples on this split." ) return if printed == 0: print("No misclassified examples found on this split (perfect accuracy).") def load_trained_model_from_checkpoint( checkpoint_path: str = MODEL_PATH, device: torch.device | None = None, ) -> Tuple[torch.nn.Module, Dict[str, int], Dict]: """ Load a previously trained Transformer model, along with its vocabulary and configuration, from the checkpoint saved by `c1.py`. Returns: model: Loaded TransformerClassifier on the requested device. vocab: Token-to-index mapping used during training. config: Hyperparameter/config dictionary saved in the checkpoint. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = torch.load(checkpoint_path, map_location=device) vocab: Dict[str, int] = ckpt["vocab"] config: Dict = ckpt["config"] model = TransformerClassifier( vocab_size=len(vocab), d_model=config["d_model"], num_heads=config["num_heads"], num_layers=config["num_layers"], d_ff=config["d_ff"], max_len=config["max_len"], ).to(device) model.load_state_dict(ckpt["model_state_dict"]) model.eval() return model, vocab, config def evaluate_and_analyze_saved_model( split: str = "test", checkpoint_path: str | None = None, model_size: str = "medium", num_examples: int = 5, device: torch.device | None = None, ) -> None: """ High-level helper that: 1) Loads the trained model/vocab/config from disk. 2) Evaluates it on the requested IMDB split. 3) Runs qualitative error analysis on that split. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if checkpoint_path is None: checkpoint_path = os.path.join(SAVE_DIR, f"transformer_imdb_{model_size}.pt") print(f"Loading trained model from: {checkpoint_path}") model, vocab, config = load_trained_model_from_checkpoint( checkpoint_path=checkpoint_path, device=device, ) print(f"Evaluating on IMDB '{split}' split...") texts, labels = load_imdb_texts(split=split) sequences = preprocess_data(texts, vocab, config["max_len"]) dataset = IMDBDataset(sequences, labels) loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False) metrics = evaluate_model(model, loader, device) print("Evaluation metrics:", metrics) print("\nRunning qualitative error analysis...") analyze_misclassifications_on_texts( model=model, texts=texts, labels=labels, vocab=vocab, max_len=config["max_len"], device=device, num_examples=num_examples, ) def main(): """ Command-line interface for evaluation and analysis utilities. Example: # Evaluate medium model on IMDB test split and show 5 errors python c1_analysis.py --split test --model_size medium --num_examples 5 """ parser = argparse.ArgumentParser(description="IMDB Transformer evaluation and analysis utilities") parser.add_argument( "--split", type=str, default="test", help="IMDB split to evaluate on (e.g., 'test', 'train').", ) parser.add_argument( "--checkpoint", type=str, default=None, help=( "Optional explicit checkpoint path. If provided, this overrides " "--model_size." ), ) parser.add_argument( "--model_size", type=str, choices=["small", "medium", "large"], default="medium", help=( "Model size to load from saved checkpoints. Used when --checkpoint " "is not provided." ), ) parser.add_argument( "--num_examples", type=int, default=5, help="Number of misclassified examples to print in error analysis.", ) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") evaluate_and_analyze_saved_model( split=args.split, checkpoint_path=args.checkpoint, model_size=args.model_size, num_examples=args.num_examples, device=device, ) if __name__ == "__main__": main()