shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
"""
Evaluation and qualitative error analysis helpers for the IMDB Transformer model.
This module is separate from `c1.py` and focuses only on:
- Loading a previously trained model from disk.
- Evaluating it on an IMDB split.
- Inspecting misclassified examples for qualitative error analysis.
"""
from typing import Dict, List, Tuple
import argparse
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from c1 import (
IMDBDataset,
TransformerClassifier,
preprocess_data,
evaluate_model,
load_imdb_texts,
)
# Keep output/checkpoint paths relative to the current working directory.
SAVE_DIR = os.path.join(".", "saved_model")
MODEL_PATH = os.path.join(SAVE_DIR, "transformer_imdb.pt")
def analyze_misclassifications_on_texts(
model: torch.nn.Module,
texts: List[str],
labels: List[int],
vocab: Dict[str, int],
max_len: int,
device: torch.device,
num_examples: int = 5,
) -> None:
"""
Inspect concrete examples where the model makes mistakes to understand
*why* it fails and how to improve it.
How to read the output (practical guidance):
- Start with the true vs. predicted label:
- For each misclassified review, ask whether the ground-truth label
actually matches the human-intuitive sentiment. Occasional noisy
labels are common in IMDB-style datasets.
- Look at the confidence vector:
- Very confident but wrong predictions often indicate *systematic bias*
(e.g., the model over-trusts certain keywords like "great", "worst").
- Low-confidence errors may simply reflect inherently ambiguous reviews.
- Scan the text content:
- Check for **rare or domain-specific words** (brand names, slang,
technical jargon) that might not appear often enough in training.
- Look for **negation patterns** ("not good", "hardly bad", "no longer
terrible") where bag-of-words style cues can mislead attention.
- Notice **mixed sentiment** or **topic vs. opinion** separation
(e.g., long plot summary plus a brief opinion at the end).
- Pay attention to **sarcasm and irony**, which are notoriously hard
for models relying mostly on local lexical cues.
- Compare several misclassified examples:
- If you see many errors with long reviews, consider increasing MAX_LEN
or using a deeper model.
- If errors cluster around subtle, low-intensity sentiment, you may need
more expressive capacity (higher d_model / more layers) or additional
training data.
Based on these observations you can propose targeted improvements, such as:
- Expanding the vocabulary or switching to subword tokenization.
- Adjusting hyperparameters (sequence length, model size).
- Incorporating pre-trained language models for richer semantics.
"""
model.eval()
sequences = preprocess_data(texts, vocab, max_len)
dataset = IMDBDataset(sequences, labels)
loader = DataLoader(dataset, batch_size=64, shuffle=False)
printed = 0
with torch.no_grad():
for batch_idx, (batch_seq, batch_lab) in enumerate(loader):
batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device)
logits = model(batch_seq)
probs = F.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
start = batch_idx * loader.batch_size
end = start + batch_seq.size(0)
batch_texts = texts[start:end]
for text, true_y, pred_y, prob_vec in zip(
batch_texts,
batch_lab.cpu().numpy(),
preds.cpu().numpy(),
probs.cpu().numpy(),
):
if true_y != pred_y:
printed += 1
print("=" * 80)
print(f"Misclassified example #{printed}")
print(f"True label : {true_y} (0=neg, 1=pos)")
print(f"Predicted label: {pred_y}")
print(f"Model confidence (class 0, class 1): {prob_vec}")
if printed >= num_examples:
print("=" * 80)
print(
f"Displayed the first {num_examples} misclassified "
"examples on this split."
)
return
if printed == 0:
print("No misclassified examples found on this split (perfect accuracy).")
def load_trained_model_from_checkpoint(
checkpoint_path: str = MODEL_PATH,
device: torch.device | None = None,
) -> Tuple[torch.nn.Module, Dict[str, int], Dict]:
"""
Load a previously trained Transformer model, along with its vocabulary
and configuration, from the checkpoint saved by `c1.py`.
Returns:
model: Loaded TransformerClassifier on the requested device.
vocab: Token-to-index mapping used during training.
config: Hyperparameter/config dictionary saved in the checkpoint.
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(checkpoint_path, map_location=device)
vocab: Dict[str, int] = ckpt["vocab"]
config: Dict = ckpt["config"]
model = TransformerClassifier(
vocab_size=len(vocab),
d_model=config["d_model"],
num_heads=config["num_heads"],
num_layers=config["num_layers"],
d_ff=config["d_ff"],
max_len=config["max_len"],
).to(device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
return model, vocab, config
def evaluate_and_analyze_saved_model(
split: str = "test",
checkpoint_path: str | None = None,
model_size: str = "medium",
num_examples: int = 5,
device: torch.device | None = None,
) -> None:
"""
High-level helper that:
1) Loads the trained model/vocab/config from disk.
2) Evaluates it on the requested IMDB split.
3) Runs qualitative error analysis on that split.
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if checkpoint_path is None:
checkpoint_path = os.path.join(SAVE_DIR, f"transformer_imdb_{model_size}.pt")
print(f"Loading trained model from: {checkpoint_path}")
model, vocab, config = load_trained_model_from_checkpoint(
checkpoint_path=checkpoint_path,
device=device,
)
print(f"Evaluating on IMDB '{split}' split...")
texts, labels = load_imdb_texts(split=split)
sequences = preprocess_data(texts, vocab, config["max_len"])
dataset = IMDBDataset(sequences, labels)
loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False)
metrics = evaluate_model(model, loader, device)
print("Evaluation metrics:", metrics)
print("\nRunning qualitative error analysis...")
analyze_misclassifications_on_texts(
model=model,
texts=texts,
labels=labels,
vocab=vocab,
max_len=config["max_len"],
device=device,
num_examples=num_examples,
)
def main():
"""
Command-line interface for evaluation and analysis utilities.
Example:
# Evaluate medium model on IMDB test split and show 5 errors
python c1_analysis.py --split test --model_size medium --num_examples 5
"""
parser = argparse.ArgumentParser(description="IMDB Transformer evaluation and analysis utilities")
parser.add_argument(
"--split",
type=str,
default="test",
help="IMDB split to evaluate on (e.g., 'test', 'train').",
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help=(
"Optional explicit checkpoint path. If provided, this overrides "
"--model_size."
),
)
parser.add_argument(
"--model_size",
type=str,
choices=["small", "medium", "large"],
default="medium",
help=(
"Model size to load from saved checkpoints. Used when --checkpoint "
"is not provided."
),
)
parser.add_argument(
"--num_examples",
type=int,
default=5,
help="Number of misclassified examples to print in error analysis.",
)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evaluate_and_analyze_saved_model(
split=args.split,
checkpoint_path=args.checkpoint,
model_size=args.model_size,
num_examples=args.num_examples,
device=device,
)
if __name__ == "__main__":
main()