File size: 4,409 Bytes
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Per-sample inspection — print N validation predictions vs. ground truth.

Usage:
    python -m scripts.inspect_predictions \\
        --config configs/base.yaml \\
        --weights models/v1.0.0/model.h5 \\
        --tokenizer-dir models/v1.0.0 \\
        --n-samples 25

This script answers the diagnostic question: *when the model is wrong, **how**
is it wrong?* Each row shows the image filename, the predicted caption, one
reference caption, sentence-level BLEU-4 / ROUGE-L, the prediction length,
the longest repeated-token run, and a set of failure flags
(``empty`` / ``very_short`` / ``repetitive`` / ``under_length``).

Output also lands as ``diagnostics.jsonl`` so the same data can be loaded
into pandas / DuckDB for ad-hoc grouping.
"""

from __future__ import annotations

import random
from pathlib import Path
from typing import cast

import click

from captioning.config import load_config
from captioning.data import load_coco_annotations, make_image_level_splits
from captioning.evaluation import (
    diagnose_many,
    format_diagnostic_row,
    write_diagnostics_jsonl,
)
from captioning.inference import CaptionPredictor
from captioning.inference.predictor import DecodeStrategy
from captioning.preprocessing import preprocess_caption
from captioning.utils import configure_logging, get_logger, set_global_seed

log = get_logger(__name__)


@click.command()
@click.option(
    "--config", "config_path", required=True, type=click.Path(exists=True, path_type=Path)
)
@click.option("--weights", required=True, type=click.Path(exists=True, path_type=Path))
@click.option("--tokenizer-dir", required=True, type=click.Path(exists=True, path_type=Path))
@click.option("--n-samples", type=int, default=25, help="Number of random val samples to inspect.")
@click.option(
    "--decode-strategy",
    type=click.Choice(["greedy", "beam"]),
    default=None,
    help="Override decode strategy for this inspection.",
)
@click.option("--beam-width", type=int, default=None)
@click.option(
    "--output",
    "output_path",
    type=click.Path(path_type=Path),
    default=None,
    help="Optional path to write diagnostics.jsonl. Defaults to printing only.",
)
@click.option(
    "--seed",
    type=int,
    default=None,
    help="RNG seed for sample selection (defaults to config.train.seed).",
)
def main(
    config_path: Path,
    weights: Path,
    tokenizer_dir: Path,
    n_samples: int,
    decode_strategy: str | None,
    beam_width: int | None,
    output_path: Path | None,
    seed: int | None,
) -> None:
    """Sample N val images, run inference, and print prediction diagnostics."""
    configure_logging()
    config = load_config(config_path)
    set_global_seed(config.train.seed)

    df = load_coco_annotations(
        base_path=config.data.base_path,
        annotations_filename=config.data.annotations_filename,
        images_subdir=config.data.images_subdir,
        sample_size=config.data.sample_size,
        seed=config.train.seed,
        caption_preprocessor=preprocess_caption,
    )
    _, _, val_imgs, val_caps = make_image_level_splits(
        df, train_fraction=config.data.train_val_split, seed=config.train.seed
    )
    refs_by_image: dict[str, list[str]] = {}
    for img, cap in zip(val_imgs, val_caps, strict=True):
        refs_by_image.setdefault(img, []).append(cap)

    rng = random.Random(seed if seed is not None else config.train.seed)
    picks = rng.sample(sorted(refs_by_image.keys()), k=min(n_samples, len(refs_by_image)))

    effective_strategy = decode_strategy or config.serve.decode_strategy
    effective_beam_width = beam_width if beam_width is not None else config.serve.beam_width
    predictor = CaptionPredictor.from_artifacts(
        weights_path=weights,
        tokenizer_dir=tokenizer_dir,
        config=config,
        decode_strategy=cast("DecodeStrategy", effective_strategy),
        beam_width=effective_beam_width,
    )
    predictor.warmup()

    predictions = [predictor.predict_path(p) for p in picks]
    references = [refs_by_image[p] for p in picks]
    diagnostics = diagnose_many(picks, predictions, references)

    for d in diagnostics:
        click.echo(format_diagnostic_row(d))
        click.echo("-" * 80)

    if output_path is not None:
        write_diagnostics_jsonl(diagnostics, output_path)
        click.echo(f"Wrote diagnostics: {output_path}")


if __name__ == "__main__":
    main()