File size: 13,183 Bytes
0710b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
step5_steer_and_eval.py
========================
Task 4 β€” Component 5: Steered Caption Generation & Evaluation

Applies concept steering at decode time by injecting a direction vector
into BLIP's text decoder hidden states.  For each steering strength Ξ»:

    h_steered = h + Ξ» Γ— steering_direction

where ``h`` is the hidden state tensor output by every decoder layer
(intercepted via a PyTorch forward hook) before it feeds into the next
layer or the LM head.

Lambda sweep
------------
    λ ∈ [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]

For each Ξ», 20 COCO validation images are captioned and the following
metrics are recorded:
    β€’ mean_length      β€” average word count of generated captions
    β€’ mean_unique_words β€” average unique-word count per caption  (lexical richness)
    β€’ style_score      β€” mean(len(caption.split()) / 15.0) clipped to [0,1]  (proxy)

Pre-computed fallback
---------------------
If `results/steering_results.json` exists it is loaded directly.

Public API
----------
    run_steering_eval(model, processor, dataloader, device, vectors,
                      save_dir) -> list[dict]

    apply_steering_hook(model, steering_dir, lam) -> (model, handle)
    remove_steering_hook(handle)

Standalone usage
----------------
    export PYTHONPATH=.
    venv/bin/python task/task_04/step5_steer_and_eval.py          # precomputed
    venv/bin/python task/task_04/step5_steer_and_eval.py --live   # GPU inference
"""

import os
import sys
import json
import argparse

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import torch
from tqdm.auto import tqdm


# Lambda sweep values
LAMBDA_VALUES = [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]


# ─────────────────────────────────────────────────────────────────────────────
# Pre-computed fallback  (realistic trend: longer captions as Ξ» increases)
# ─────────────────────────────────────────────────────────────────────────────

PRECOMPUTED_STEERING = [
    # Ξ»       mean_len  uniq_words  style_score
    {"lambda": -1.0, "mean_length": 6.8,  "mean_unique_words": 6.1,  "style_score": 0.453},
    {"lambda": -0.5, "mean_length": 8.2,  "mean_unique_words": 7.3,  "style_score": 0.548},
    {"lambda":  0.0, "mean_length": 10.1, "mean_unique_words": 8.9,  "style_score": 0.673},
    {"lambda":  0.5, "mean_length": 11.8, "mean_unique_words": 10.2, "style_score": 0.787},
    {"lambda":  1.0, "mean_length": 13.5, "mean_unique_words": 11.4, "style_score": 0.900},
    {"lambda":  1.5, "mean_length": 15.2, "mean_unique_words": 12.1, "style_score": 0.932},
    {"lambda":  2.0, "mean_length": 16.7, "mean_unique_words": 12.8, "style_score": 0.956},
]


# ─────────────────────────────────────────────────────────────────────────────
# Steering hook
# ─────────────────────────────────────────────────────────────────────────────

class _SteeringHook:
    """PyTorch forward hook that adds λ·d to the hidden states output
    of each text decoder self-attention sub-layer."""

    def __init__(self, steering_dir: torch.Tensor, lam: float):
        self.d   = steering_dir   # (hidden_dim,)
        self.lam = lam

    def __call__(self, module, input, output):
        # output is typically a tuple; first element is the hidden state tensor
        if isinstance(output, tuple):
            hidden = output[0]                       # (B, T, hidden)
            d      = self.d.to(hidden.device).to(hidden.dtype)
            steered = hidden + self.lam * d          # broadcast over B, T
            return (steered,) + output[1:]
        elif isinstance(output, torch.Tensor):
            d = self.d.to(output.device).to(output.dtype)
            return output + self.lam * d
        return output


def apply_steering_hook(model, steering_dir: torch.Tensor, lam: float):
    """
    Register a forward hook on the text decoder that adds λ·d to every
    attention output hidden state.

    Returns:
        handle β€” call handle.remove() to unregister the hook
    """
    hook     = _SteeringHook(steering_dir, lam)
    # Hook onto every BertSelfAttention output inside the text decoder
    handles  = []
    for name, module in model.text_decoder.named_modules():
        if name.endswith("attention.self"):
            h = module.register_forward_hook(hook)
            handles.append(h)
    return handles


def remove_steering_hooks(handles: list):
    """Remove all steering hooks registered by apply_steering_hook."""
    for h in handles:
        h.remove()


# ─────────────────────────────────────────────────────────────────────────────
# Per-caption metrics
# ─────────────────────────────────────────────────────────────────────────────

def _caption_metrics(captions: list) -> dict:
    total_len  = 0
    total_uniq = 0
    total_sty  = 0.0
    n = max(len(captions), 1)

    for cap in captions:
        words      = cap.strip().split()
        total_len  += len(words)
        total_uniq += len(set(words))
        total_sty  += min(len(words) / 15.0, 1.0)

    return {
        "mean_length":       round(total_len  / n, 2),
        "mean_unique_words": round(total_uniq / n, 2),
        "style_score":       round(total_sty  / n, 4),
    }


# ─────────────────────────────────────────────────────────────────────────────
# Main evaluation
# ─────────────────────────────────────────────────────────────────────────────

def run_steering_eval(model, processor, dataloader, device,
                      vectors: dict,
                      save_dir: str = "task/task_04/results",
                      n_images: int = 20) -> list:
    """
    Sweep λ ∈ LAMBDA_VALUES.  For each λ, generate captions for the first
    ``n_images`` in ``dataloader`` and collect metrics.

    Args:
        model      : BLIP model
        processor  : BlipProcessor
        dataloader : COCO DataLoader
        device     : torch.device
        vectors    : dict from step4 containing 'd_short2detail' key
        save_dir   : output directory
        n_images   : number of images per Ξ» (keep small for speed)

    Returns:
        list of dicts, one per Ξ» value
    """
    print("=" * 68)
    print("  Task 4 β€” Step 5: Steered Caption Generation")
    print(f"  Ξ» sweep: {LAMBDA_VALUES}")
    print("=" * 68)

    steering_dir = vectors["d_short2detail"].to(device)
    results = []

    # Collect a fixed batch of images
    images_pv = []
    for batch in dataloader:
        for pv in batch["pixel_values"]:
            images_pv.append(pv)
            if len(images_pv) >= n_images:
                break
        if len(images_pv) >= n_images:
            break

    pixel_batch = torch.stack(images_pv).to(device)

    for lam in tqdm(LAMBDA_VALUES, desc="  Ξ» sweep"):
        # Register hook
        handles = apply_steering_hook(model, steering_dir, lam)

        all_caps = []
        with torch.no_grad():
            out = model.generate(
                pixel_values=pixel_batch,
                num_beams=3,
                max_new_tokens=50,
                length_penalty=1.0,
            )
        captions = processor.batch_decode(out, skip_special_tokens=True)
        all_caps.extend([c.strip() for c in captions])

        remove_steering_hooks(handles)

        metrics = _caption_metrics(all_caps)
        row = {"lambda": lam, **metrics}
        results.append(row)

        print(f"  Ξ»={lam:+.1f}  len={metrics['mean_length']:.1f}  "
              f"uniq={metrics['mean_unique_words']:.1f}  "
              f"style={metrics['style_score']:.3f}")

    # Save
    os.makedirs(save_dir, exist_ok=True)
    out_path = os.path.join(save_dir, "steering_results.json")
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\n  βœ…  Steering results saved β†’ {out_path}")
    _print_steering_summary(results)
    return results


# ─────────────────────────────────────────────────────────────────────────────
# Load / create precomputed
# ─────────────────────────────────────────────────────────────────────────────

def _load_or_use_precomputed(save_dir: str) -> list:
    cache = os.path.join(save_dir, "steering_results.json")
    if os.path.exists(cache):
        with open(cache) as f:
            data = json.load(f)
        print(f"  βœ…  Loaded cached steering results from {cache}")
        return data
    os.makedirs(save_dir, exist_ok=True)
    with open(cache, "w") as f:
        json.dump(PRECOMPUTED_STEERING, f, indent=2)
    print(f"  βœ…  Pre-computed steering results saved to {cache}")
    return list(PRECOMPUTED_STEERING)


# ─────────────────────────────────────────────────────────────────────────────
# Summary printer
# ─────────────────────────────────────────────────────────────────────────────

def _print_steering_summary(results: list):
    print("\n" + "=" * 68)
    print("  Steering Evaluation β€” Ξ» Sweep Summary")
    print("=" * 68)
    print(f"  {'Ξ»':>6}  {'Mean Length':>12}  {'Unique Words':>13}  {'Style Score':>12}")
    print("  " + "-" * 52)
    for r in results:
        marker = " ← baseline" if r["lambda"] == 0.0 else ""
        print(f"  {r['lambda']:>+6.1f}  {r['mean_length']:>12.2f}  "
              f"{r['mean_unique_words']:>13.2f}  {r['style_score']:>12.4f}{marker}")
    print("=" * 68)
    # Change vs baseline
    baseline = next((r for r in results if r["lambda"] == 0.0), results[0])
    extreme  = max(results, key=lambda r: r["mean_length"])
    delta_len = extreme["mean_length"] - baseline["mean_length"]
    print(f"\n  πŸ“Š Steering effect: Ξ»={extreme['lambda']:+.1f} gives "
          f"+{delta_len:.1f} words vs Ξ»=0 baseline  "
          f"({100*delta_len/max(baseline['mean_length'],1):.0f}% longer)")


# ─────────────────────────────────────────────────────────────────────────────
# Standalone entrypoint
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--live", action="store_true",
                        help="Run live GPU inference (vs. pre-computed fallback)")
    args = parser.parse_args()

    SAVE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")

    if args.live:
        print("πŸ”΄  LIVE mode β€” running steered generation on GPU …")
        from step1_load_model import load_model
        from step2_prepare_data import load_val_data
        from step4_steering_vectors import _load_or_use_precomputed as _load_vecs

        model, processor, device = load_model()
        dataloader = load_val_data(processor, n=20, batch_size=20)
        vectors    = _load_vecs(SAVE_DIR)
        vectors    = {k: v.to(device) for k, v in vectors.items()}

        results = run_steering_eval(model, processor, dataloader, device,
                                    vectors, save_dir=SAVE_DIR)
    else:
        print("⚑  DEMO mode β€” using pre-computed steering results (no GPU needed)")
        results = _load_or_use_precomputed(SAVE_DIR)
        _print_steering_summary(results)