| """ |
| Evaluation and qualitative error analysis helpers for the IMDB Transformer model. |
| |
| This module is separate from `c1.py` and focuses only on: |
| - Loading a previously trained model from disk. |
| - Evaluating it on an IMDB split. |
| - Inspecting misclassified examples for qualitative error analysis. |
| """ |
|
|
| from typing import Dict, List, Tuple |
|
|
| import argparse |
| import os |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from c1 import ( |
| IMDBDataset, |
| TransformerClassifier, |
| preprocess_data, |
| evaluate_model, |
| load_imdb_texts, |
| ) |
|
|
| |
| SAVE_DIR = os.path.join(".", "saved_model") |
| MODEL_PATH = os.path.join(SAVE_DIR, "transformer_imdb.pt") |
|
|
|
|
| def analyze_misclassifications_on_texts( |
| model: torch.nn.Module, |
| texts: List[str], |
| labels: List[int], |
| vocab: Dict[str, int], |
| max_len: int, |
| device: torch.device, |
| num_examples: int = 5, |
| ) -> None: |
| """ |
| Inspect concrete examples where the model makes mistakes to understand |
| *why* it fails and how to improve it. |
| |
| How to read the output (practical guidance): |
| - Start with the true vs. predicted label: |
| - For each misclassified review, ask whether the ground-truth label |
| actually matches the human-intuitive sentiment. Occasional noisy |
| labels are common in IMDB-style datasets. |
| - Look at the confidence vector: |
| - Very confident but wrong predictions often indicate *systematic bias* |
| (e.g., the model over-trusts certain keywords like "great", "worst"). |
| - Low-confidence errors may simply reflect inherently ambiguous reviews. |
| - Scan the text content: |
| - Check for **rare or domain-specific words** (brand names, slang, |
| technical jargon) that might not appear often enough in training. |
| - Look for **negation patterns** ("not good", "hardly bad", "no longer |
| terrible") where bag-of-words style cues can mislead attention. |
| - Notice **mixed sentiment** or **topic vs. opinion** separation |
| (e.g., long plot summary plus a brief opinion at the end). |
| - Pay attention to **sarcasm and irony**, which are notoriously hard |
| for models relying mostly on local lexical cues. |
| - Compare several misclassified examples: |
| - If you see many errors with long reviews, consider increasing MAX_LEN |
| or using a deeper model. |
| - If errors cluster around subtle, low-intensity sentiment, you may need |
| more expressive capacity (higher d_model / more layers) or additional |
| training data. |
| |
| Based on these observations you can propose targeted improvements, such as: |
| - Expanding the vocabulary or switching to subword tokenization. |
| - Adjusting hyperparameters (sequence length, model size). |
| - Incorporating pre-trained language models for richer semantics. |
| """ |
| model.eval() |
| sequences = preprocess_data(texts, vocab, max_len) |
| dataset = IMDBDataset(sequences, labels) |
| loader = DataLoader(dataset, batch_size=64, shuffle=False) |
|
|
| printed = 0 |
| with torch.no_grad(): |
| for batch_idx, (batch_seq, batch_lab) in enumerate(loader): |
| batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device) |
| logits = model(batch_seq) |
| probs = F.softmax(logits, dim=1) |
| preds = torch.argmax(probs, dim=1) |
|
|
| start = batch_idx * loader.batch_size |
| end = start + batch_seq.size(0) |
| batch_texts = texts[start:end] |
|
|
| for text, true_y, pred_y, prob_vec in zip( |
| batch_texts, |
| batch_lab.cpu().numpy(), |
| preds.cpu().numpy(), |
| probs.cpu().numpy(), |
| ): |
| if true_y != pred_y: |
| printed += 1 |
| print("=" * 80) |
| print(f"Misclassified example #{printed}") |
| print(f"True label : {true_y} (0=neg, 1=pos)") |
| print(f"Predicted label: {pred_y}") |
| print(f"Model confidence (class 0, class 1): {prob_vec}") |
|
|
| if printed >= num_examples: |
| print("=" * 80) |
| print( |
| f"Displayed the first {num_examples} misclassified " |
| "examples on this split." |
| ) |
| return |
|
|
| if printed == 0: |
| print("No misclassified examples found on this split (perfect accuracy).") |
|
|
|
|
| def load_trained_model_from_checkpoint( |
| checkpoint_path: str = MODEL_PATH, |
| device: torch.device | None = None, |
| ) -> Tuple[torch.nn.Module, Dict[str, int], Dict]: |
| """ |
| Load a previously trained Transformer model, along with its vocabulary |
| and configuration, from the checkpoint saved by `c1.py`. |
| |
| Returns: |
| model: Loaded TransformerClassifier on the requested device. |
| vocab: Token-to-index mapping used during training. |
| config: Hyperparameter/config dictionary saved in the checkpoint. |
| """ |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| ckpt = torch.load(checkpoint_path, map_location=device) |
| vocab: Dict[str, int] = ckpt["vocab"] |
| config: Dict = ckpt["config"] |
|
|
| model = TransformerClassifier( |
| vocab_size=len(vocab), |
| d_model=config["d_model"], |
| num_heads=config["num_heads"], |
| num_layers=config["num_layers"], |
| d_ff=config["d_ff"], |
| max_len=config["max_len"], |
| ).to(device) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
|
|
| return model, vocab, config |
|
|
|
|
| def evaluate_and_analyze_saved_model( |
| split: str = "test", |
| checkpoint_path: str | None = None, |
| model_size: str = "medium", |
| num_examples: int = 5, |
| device: torch.device | None = None, |
| ) -> None: |
| """ |
| High-level helper that: |
| 1) Loads the trained model/vocab/config from disk. |
| 2) Evaluates it on the requested IMDB split. |
| 3) Runs qualitative error analysis on that split. |
| """ |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| if checkpoint_path is None: |
| checkpoint_path = os.path.join(SAVE_DIR, f"transformer_imdb_{model_size}.pt") |
|
|
| print(f"Loading trained model from: {checkpoint_path}") |
| model, vocab, config = load_trained_model_from_checkpoint( |
| checkpoint_path=checkpoint_path, |
| device=device, |
| ) |
|
|
| print(f"Evaluating on IMDB '{split}' split...") |
| texts, labels = load_imdb_texts(split=split) |
| sequences = preprocess_data(texts, vocab, config["max_len"]) |
| dataset = IMDBDataset(sequences, labels) |
| loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False) |
|
|
| metrics = evaluate_model(model, loader, device) |
| print("Evaluation metrics:", metrics) |
|
|
| print("\nRunning qualitative error analysis...") |
| analyze_misclassifications_on_texts( |
| model=model, |
| texts=texts, |
| labels=labels, |
| vocab=vocab, |
| max_len=config["max_len"], |
| device=device, |
| num_examples=num_examples, |
| ) |
|
|
|
|
| def main(): |
| """ |
| Command-line interface for evaluation and analysis utilities. |
| |
| Example: |
| # Evaluate medium model on IMDB test split and show 5 errors |
| python c1_analysis.py --split test --model_size medium --num_examples 5 |
| """ |
| parser = argparse.ArgumentParser(description="IMDB Transformer evaluation and analysis utilities") |
| parser.add_argument( |
| "--split", |
| type=str, |
| default="test", |
| help="IMDB split to evaluate on (e.g., 'test', 'train').", |
| ) |
| parser.add_argument( |
| "--checkpoint", |
| type=str, |
| default=None, |
| help=( |
| "Optional explicit checkpoint path. If provided, this overrides " |
| "--model_size." |
| ), |
| ) |
| parser.add_argument( |
| "--model_size", |
| type=str, |
| choices=["small", "medium", "large"], |
| default="medium", |
| help=( |
| "Model size to load from saved checkpoints. Used when --checkpoint " |
| "is not provided." |
| ), |
| ) |
| parser.add_argument( |
| "--num_examples", |
| type=int, |
| default=5, |
| help="Number of misclassified examples to print in error analysis.", |
| ) |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| evaluate_and_analyze_saved_model( |
| split=args.split, |
| checkpoint_path=args.checkpoint, |
| model_size=args.model_size, |
| num_examples=args.num_examples, |
| device=device, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|