| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import torch |
|
|
|
|
| def plot_training_loss( |
| losses: list[float], |
| output_path: str = "training_loss.png", |
| title: str = "YAML-BERT Training Loss", |
| ) -> None: |
| """Plot training loss curve over epochs.""" |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| ax.plot(range(1, len(losses) + 1), losses, marker="o", linewidth=2) |
| ax.set_xlabel("Epoch") |
| ax.set_ylabel("Loss") |
| ax.set_title(title) |
| ax.grid(True, alpha=0.3) |
| fig.tight_layout() |
| fig.savefig(output_path, dpi=150) |
| plt.close(fig) |
| print(f"Training loss plot saved: {output_path}") |
|
|
|
|
| def plot_accuracy( |
| results: dict[str, float], |
| output_path: str = "accuracy.png", |
| title: str = "YAML-BERT Key Prediction Accuracy", |
| ) -> None: |
| """Plot top-1 and top-5 prediction accuracy as a bar chart.""" |
| labels: list[str] = ["Top-1", "Top-5"] |
| values: list[float] = [results["top1_accuracy"], results["top5_accuracy"]] |
|
|
| fig, ax = plt.subplots(figsize=(6, 5)) |
| bars = ax.bar(labels, values, color=["steelblue", "coral"], width=0.5) |
| ax.set_ylabel("Accuracy") |
| ax.set_title(title) |
| ax.set_ylim(0, 1) |
|
|
| for bar, val in zip(bars, values): |
| ax.text( |
| bar.get_x() + bar.get_width() / 2, |
| bar.get_height() + 0.02, |
| f"{val:.1%}", |
| ha="center", |
| fontsize=14, |
| fontweight="bold", |
| ) |
|
|
| total: float = results.get("total_masked", 0) |
| ax.text( |
| 0.5, -0.1, |
| f"Total masked positions: {int(total)}", |
| ha="center", |
| transform=ax.transAxes, |
| fontsize=10, |
| color="gray", |
| ) |
|
|
| fig.tight_layout() |
| fig.savefig(output_path, dpi=150) |
| plt.close(fig) |
| print(f"Accuracy plot saved: {output_path}") |
|
|
|
|
| def plot_embedding_similarity( |
| results: list[dict[str, Any]], |
| output_path: str = "embedding_similarity.png", |
| title: str = "Tree Position Embedding Similarity", |
| ) -> None: |
| """Plot cosine similarity between same-key embeddings at different tree positions.""" |
| labels: list[str] = [] |
| similarities: list[float] = [] |
|
|
| for r in results: |
| pa: dict[str, Any] = r["position_a"] |
| pb: dict[str, Any] = r["position_b"] |
| label: str = ( |
| f"{r['key']}\n" |
| f"d={pa['depth']},p={pa['parent_key']}\n" |
| f"vs d={pb['depth']},p={pb['parent_key']}" |
| ) |
| labels.append(label) |
| similarities.append(r["cosine_similarity"]) |
|
|
| fig, ax = plt.subplots(figsize=(max(8, len(results) * 3), 6)) |
| bars = ax.bar(range(len(results)), similarities, color="steelblue") |
| ax.set_xticks(range(len(results))) |
| ax.set_xticklabels(labels, fontsize=9) |
| ax.set_ylabel("Cosine Similarity") |
| ax.set_title(title) |
| ax.set_ylim(-1, 1) |
| ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5) |
|
|
| for bar, sim in zip(bars, similarities): |
| ax.text( |
| bar.get_x() + bar.get_width() / 2, |
| bar.get_height() + 0.02, |
| f"{sim:.3f}", |
| ha="center", |
| fontsize=10, |
| ) |
|
|
| fig.tight_layout() |
| fig.savefig(output_path, dpi=150) |
| plt.close(fig) |
| print(f"Embedding similarity plot saved: {output_path}") |
|
|
|
|
| def plot_attention_patterns( |
| attention_weights: torch.Tensor, |
| token_labels: list[str], |
| output_path: str = "attention_patterns.png", |
| title: str = "Attention Patterns", |
| ) -> None: |
| """Plot attention heatmaps for each head. |
| |
| Args: |
| attention_weights: (num_heads, seq_len, seq_len) |
| token_labels: labels for each position in the sequence |
| """ |
| num_heads: int = attention_weights.shape[0] |
| fig, axes = plt.subplots(1, num_heads, figsize=(6 * num_heads, 5)) |
|
|
| if num_heads == 1: |
| axes = [axes] |
|
|
| for head_idx, ax in enumerate(axes): |
| weights: torch.Tensor = attention_weights[head_idx].cpu() |
| im = ax.imshow(weights.numpy(), cmap="Blues", vmin=0, vmax=1) |
| ax.set_title(f"Head {head_idx}") |
| ax.set_xticks(range(len(token_labels))) |
| ax.set_yticks(range(len(token_labels))) |
| ax.set_xticklabels(token_labels, rotation=45, ha="right", fontsize=7) |
| ax.set_yticklabels(token_labels, fontsize=7) |
|
|
| fig.suptitle(title) |
| fig.tight_layout() |
| fig.savefig(output_path, dpi=150) |
| plt.close(fig) |
| print(f"Attention pattern plot saved: {output_path}") |
|
|