Spaces:
Sleeping
Sleeping
| """ | |
| step6_visualize.py | |
| =================== | |
| Task 4 β Component 6: Generate publication-quality figures. | |
| Produces three figures from the Task 4 results: | |
| Figure 1: diversity_histogram.png | |
| Histogram of per-image diversity scores (200 images). | |
| "Diverse" (>0.75) and "Repetitive" (<0.40) zones are shaded. | |
| Figure 2: diverse_vs_repetitive.png | |
| 3-row grid: actual image thumbnail (or coloured placeholder) on the | |
| left, 5 generated captions + diversity score badge on the right. | |
| Two sides: top-3 most diverse (left half) vs top-3 most repetitive | |
| (right half). Thumbnails loaded from results/images/img_{id}.jpg. | |
| Figure 3: steering_lambda_sweep.png | |
| Dual-axis line chart: Ξ» on x-axis, mean caption length (left y-axis) | |
| and mean unique word count (right y-axis). Ξ»=0 baseline annotated. | |
| Public API | |
| ---------- | |
| plot_diversity_histogram(records, save_dir) -> str (path) | |
| plot_diverse_vs_repetitive(records, save_dir) -> str | |
| plot_steering_lambda_sweep(results, save_dir) -> str | |
| visualize_all(records, steering_results, save_dir) -> dict[str, str] | |
| Standalone usage | |
| ---------------- | |
| export PYTHONPATH=. | |
| venv/bin/python task/task_04/step6_visualize.py | |
| """ | |
| import os | |
| import sys | |
| import json | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mticker | |
| from matplotlib.gridspec import GridSpec | |
| from matplotlib.lines import Line2D | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Colour palette | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| C_DIVERSE = "#4C72B0" # blue | |
| C_MEDIUM = "#55A868" # green | |
| C_REPETITIVE = "#C44E52" # red | |
| C_LAMBDA = "#DD8452" # orange | |
| C_UNIQ = "#8172B2" # purple | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 1 β Diversity histogram | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_diversity_histogram(records: list, | |
| save_dir: str = "task/task_04/results") -> str: | |
| """ | |
| Histogram of per-image diversity scores. Zones: | |
| "Repetitive" < 0.40 -> shaded red | |
| "Diverse" > 0.75 -> shaded blue | |
| """ | |
| os.makedirs(save_dir, exist_ok=True) | |
| scores = [r["diversity_score"] for r in records] | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| # Shaded zones | |
| ax.axvspan(0.0, 0.40, alpha=0.12, color=C_REPETITIVE, label="Repetitive zone (<0.40)") | |
| ax.axvspan(0.75, 1.00, alpha=0.12, color=C_DIVERSE, label="Diverse zone (>0.75)") | |
| # Histogram | |
| n_bins = 25 | |
| ax.hist(scores, bins=n_bins, color=C_MEDIUM, edgecolor="white", | |
| linewidth=0.6, alpha=0.9, label="Image count") | |
| # Threshold lines | |
| ax.axvline(0.40, color=C_REPETITIVE, linewidth=1.8, linestyle="--") | |
| ax.axvline(0.75, color=C_DIVERSE, linewidth=1.8, linestyle="--") | |
| # Mean line | |
| mean_score = np.mean(scores) | |
| ax.axvline(mean_score, color="#333", linewidth=1.4, linestyle=":", | |
| label=f"Mean = {mean_score:.3f}") | |
| ax.set_xlabel("Diversity Score (unique n-grams / total n-grams)", fontsize=12) | |
| ax.set_ylabel("Number of Images", fontsize=12) | |
| ax.set_title("Caption Diversity Distribution Across COCO Images\n" | |
| "(5 nucleus-sampled captions per image, p=0.9)", | |
| fontsize=13, fontweight="bold", pad=10) | |
| ax.legend(fontsize=10) | |
| ax.yaxis.set_minor_locator(mticker.AutoMinorLocator()) | |
| ax.grid(axis="y", linestyle="--", alpha=0.4) | |
| ax.set_xlim(0, 1) | |
| # Annotations | |
| n_rep = sum(1 for s in scores if s < 0.40) | |
| n_div = sum(1 for s in scores if s > 0.75) | |
| ax.text(0.20, ax.get_ylim()[1] * 0.85, f"{n_rep} images\n(repetitive)", | |
| ha="center", va="top", color=C_REPETITIVE, fontsize=9) | |
| ax.text(0.875, ax.get_ylim()[1] * 0.85, f"{n_div} images\n(diverse)", | |
| ha="center", va="top", color=C_DIVERSE, fontsize=9) | |
| fig.tight_layout() | |
| path = os.path.join(save_dir, "diversity_histogram.png") | |
| fig.savefig(path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" OK Saved: {path}") | |
| return path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 2 β Diverse vs Repetitive (image thumbnail + captions) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_thumb_array(img_id: int, save_dir: str): | |
| """Load thumbnail JPEG from results/images/. Returns numpy array or None.""" | |
| from PIL import Image as PILImage | |
| path = os.path.join(save_dir, "images", f"img_{img_id}.jpg") | |
| if os.path.exists(path): | |
| try: | |
| return np.array(PILImage.open(path).convert("RGB")) | |
| except Exception: | |
| pass | |
| return None | |
| def plot_diverse_vs_repetitive(records: list, | |
| save_dir: str = "task/task_04/results") -> str: | |
| """ | |
| 3-row image-caption grid. | |
| Each row: [thumbnail | 5 captions] for one image. | |
| Left half = top-3 diverse, right half = top-3 repetitive. | |
| Thumbnails come from results/images/img_{id}.jpg (generated by step3). | |
| Falls back to a coloured score-labelled placeholder if file missing. | |
| """ | |
| os.makedirs(save_dir, exist_ok=True) | |
| def _get_top_unique(recs, reverse=True, n=3): | |
| sorted_recs = sorted(recs, key=lambda r: r["diversity_score"], reverse=reverse) | |
| unique_recs = [] | |
| seen = set() | |
| for r in sorted_recs: | |
| cap_hash = tuple(r["captions"]) | |
| if cap_hash not in seen: | |
| seen.add(cap_hash) | |
| unique_recs.append(r) | |
| if len(unique_recs) == n: | |
| break | |
| return unique_recs | |
| diverse = _get_top_unique(records, reverse=True, n=3) | |
| repetitive = _get_top_unique(records, reverse=False, n=3) | |
| N = 3 | |
| fig = plt.figure(figsize=(17, 12), facecolor="#F4F4F4") | |
| # 4 columns: [img_div | cap_div | img_rep | cap_rep] | |
| gs = GridSpec(N, 4, figure=fig, | |
| hspace=0.60, wspace=0.10, | |
| left=0.03, right=0.97, | |
| top=0.90, bottom=0.03, | |
| width_ratios=[1, 2.2, 1, 2.2]) | |
| def _render_row(row, rec, img_col, cap_col, badge_color): | |
| arr = _load_thumb_array(rec["image_id"], save_dir) | |
| score = rec["diversity_score"] | |
| cat = rec.get("category", "") | |
| # Image cell | |
| ax_img = fig.add_subplot(gs[row, img_col]) | |
| ax_img.axis("off") | |
| if arr is not None: | |
| ax_img.imshow(arr, aspect="auto", interpolation="bilinear") | |
| else: | |
| ax_img.set_facecolor(badge_color) | |
| ax_img.text(0.5, 0.55, f"Image #{rec['image_id']}", | |
| ha="center", va="center", color="white", | |
| fontsize=9, fontweight="bold", | |
| transform=ax_img.transAxes) | |
| ax_img.text(0.5, 0.35, f"{score:.3f}", | |
| ha="center", va="center", color="white", | |
| fontsize=9, transform=ax_img.transAxes) | |
| ax_img.set_title(f"Score: {score:.3f} [{cat}]", | |
| fontsize=8, color=badge_color, | |
| pad=3, fontweight="bold") | |
| # Caption cell | |
| ax_cap = fig.add_subplot(gs[row, cap_col]) | |
| ax_cap.set_facecolor("#FAFAFA") | |
| ax_cap.axis("off") | |
| y = 0.97 | |
| for ci, cap in enumerate(rec["captions"][:5], 1): | |
| words, line, lines = cap.split(), "", [] | |
| for w in words: | |
| if len(line) + len(w) + 1 > 52: | |
| lines.append(line) | |
| line = w | |
| else: | |
| line = (line + " " + w).strip() | |
| if line: | |
| lines.append(line) | |
| first = True | |
| for ln in lines: | |
| prefix = f"{ci}. " if first else " " | |
| ax_cap.text(0.03, y, prefix + ln, | |
| transform=ax_cap.transAxes, | |
| fontsize=7.8, color="#333", va="top", | |
| fontfamily="monospace") | |
| y -= 0.13 | |
| first = False | |
| y -= 0.04 | |
| for row, rec in enumerate(diverse): | |
| _render_row(row, rec, img_col=0, cap_col=1, badge_color=C_DIVERSE) | |
| for row, rec in enumerate(repetitive): | |
| _render_row(row, rec, img_col=2, cap_col=3, badge_color=C_REPETITIVE) | |
| # Column headers | |
| fig.text(0.28, 0.945, "Top-3 Most DIVERSE Images", | |
| ha="center", va="bottom", fontsize=13, fontweight="bold", | |
| color=C_DIVERSE) | |
| fig.text(0.75, 0.945, "Top-3 Most REPETITIVE Images", | |
| ha="center", va="bottom", fontsize=13, fontweight="bold", | |
| color=C_REPETITIVE) | |
| # Centre divider | |
| sep = Line2D([0.505, 0.505], [0.02, 0.94], | |
| transform=fig.transFigure, | |
| color="#BBBBBB", linewidth=1.5, linestyle="--") | |
| fig.add_artist(sep) | |
| fig.suptitle("Caption Style Extremes β COCO Validation Set\n" | |
| "(5 nucleus-sampled captions per image, top_p=0.9)", | |
| fontsize=13, fontweight="bold", y=0.995) | |
| path = os.path.join(save_dir, "diverse_vs_repetitive.png") | |
| fig.savefig(path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" OK Saved: {path}") | |
| return path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Figure 3 β Lambda sweep chart | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_steering_lambda_sweep(steering_results: list, | |
| save_dir: str = "task/task_04/results") -> str: | |
| """ | |
| Dual-axis line chart: lambda (x) vs mean caption length (left y) and | |
| mean unique word count (right y). lambda=0 baseline marked. | |
| """ | |
| os.makedirs(save_dir, exist_ok=True) | |
| lambdas = [r["lambda"] for r in steering_results] | |
| lengths = [r["mean_length"] for r in steering_results] | |
| uniq = [r["mean_unique_words"] for r in steering_results] | |
| fig, ax1 = plt.subplots(figsize=(9, 5.5)) | |
| # Left axis: mean length | |
| ax1.plot(lambdas, lengths, "-o", color=C_LAMBDA, linewidth=2, | |
| markersize=7, label="Mean Caption Length (words)") | |
| ax1.set_xlabel("Steering Strength (lambda)", fontsize=12) | |
| ax1.set_ylabel("Mean Caption Length (words)", color=C_LAMBDA, fontsize=12) | |
| ax1.tick_params(axis="y", labelcolor=C_LAMBDA) | |
| # Right axis: unique words | |
| ax2 = ax1.twinx() | |
| ax2.plot(lambdas, uniq, "-s", color=C_UNIQ, linewidth=2, | |
| markersize=7, label="Mean Unique Words") | |
| ax2.set_ylabel("Mean Unique Word Count", color=C_UNIQ, fontsize=12) | |
| ax2.tick_params(axis="y", labelcolor=C_UNIQ) | |
| # Baseline | |
| ax1.axvline(0.0, color="#999", linewidth=1.4, linestyle="--") | |
| ax1.text(0.02, max(lengths) * 0.97, "lambda=0\nbaseline", | |
| color="#777", fontsize=8.5) | |
| # Double-headed arrow annotation | |
| ax1.annotate("", xy=(lambdas[-1], lengths[-1] + 0.4), | |
| xytext=(lambdas[0], lengths[0] + 0.4), | |
| arrowprops=dict(arrowstyle="<->", color="#555", lw=1.2)) | |
| mid_x = (lambdas[0] + lambdas[-1]) / 2 | |
| ax1.text(mid_x, lengths[-1] + 0.6, "steering effect on length", | |
| ha="center", fontsize=8.5, color="#555") | |
| # Combined legend | |
| h1, l1 = ax1.get_legend_handles_labels() | |
| h2, l2 = ax2.get_legend_handles_labels() | |
| ax1.legend(h1 + h2, l1 + l2, fontsize=9, loc="upper left") | |
| ax1.set_title("Concept Steering Effect: lambda x d_short2detail\n" | |
| "(BLIP decoder hidden-state injection, beam=3)", | |
| fontsize=12, fontweight="bold", pad=10) | |
| ax1.grid(linestyle="--", alpha=0.35) | |
| fig.tight_layout() | |
| path = os.path.join(save_dir, "steering_lambda_sweep.png") | |
| fig.savefig(path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" OK Saved: {path}") | |
| return path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Master | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def visualize_all(records: list, steering_results: list, | |
| save_dir: str = "task/task_04/results") -> dict: | |
| """ | |
| Generate all three figures. | |
| Returns: | |
| dict with keys 'histogram', 'extremes', 'lambda_sweep' -> absolute paths | |
| """ | |
| print("=" * 62) | |
| print(" Task 4 -- Step 6: Generate Visualizations") | |
| print("=" * 62) | |
| paths = { | |
| "histogram": plot_diversity_histogram(records, save_dir), | |
| "extremes": plot_diverse_vs_repetitive(records, save_dir), | |
| "lambda_sweep": plot_steering_lambda_sweep(steering_results, save_dir), | |
| } | |
| print(f"\n 3 figures saved to: {save_dir}") | |
| return paths | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Standalone entrypoint | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| SAVE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") | |
| div_cache = os.path.join(SAVE_DIR, "diversity_results.json") | |
| if os.path.exists(div_cache): | |
| with open(div_cache) as f: | |
| records = json.load(f) | |
| print(f" Loaded diversity results from {div_cache}") | |
| else: | |
| from step3_diversity_analysis import _make_precomputed | |
| records = _make_precomputed() | |
| steer_cache = os.path.join(SAVE_DIR, "steering_results.json") | |
| if os.path.exists(steer_cache): | |
| with open(steer_cache) as f: | |
| steering_results = json.load(f) | |
| print(f" Loaded steering results from {steer_cache}") | |
| else: | |
| from step5_steer_and_eval import PRECOMPUTED_STEERING | |
| steering_results = PRECOMPUTED_STEERING | |
| paths = visualize_all(records, steering_results, SAVE_DIR) | |
| print("\n All figures generated.") | |
| for name, p in paths.items(): | |
| print(f" {name:14}: {p}") | |