File size: 4,487 Bytes
222a479 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | from __future__ import annotations
from typing import Any
import matplotlib
matplotlib.use("Agg") # non-interactive backend
import matplotlib.pyplot as plt
import torch
def plot_training_loss(
losses: list[float],
output_path: str = "training_loss.png",
title: str = "YAML-BERT Training Loss",
) -> None:
"""Plot training loss curve over epochs."""
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(range(1, len(losses) + 1), losses, marker="o", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title(title)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
print(f"Training loss plot saved: {output_path}")
def plot_accuracy(
results: dict[str, float],
output_path: str = "accuracy.png",
title: str = "YAML-BERT Key Prediction Accuracy",
) -> None:
"""Plot top-1 and top-5 prediction accuracy as a bar chart."""
labels: list[str] = ["Top-1", "Top-5"]
values: list[float] = [results["top1_accuracy"], results["top5_accuracy"]]
fig, ax = plt.subplots(figsize=(6, 5))
bars = ax.bar(labels, values, color=["steelblue", "coral"], width=0.5)
ax.set_ylabel("Accuracy")
ax.set_title(title)
ax.set_ylim(0, 1)
for bar, val in zip(bars, values):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.02,
f"{val:.1%}",
ha="center",
fontsize=14,
fontweight="bold",
)
total: float = results.get("total_masked", 0)
ax.text(
0.5, -0.1,
f"Total masked positions: {int(total)}",
ha="center",
transform=ax.transAxes,
fontsize=10,
color="gray",
)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
print(f"Accuracy plot saved: {output_path}")
def plot_embedding_similarity(
results: list[dict[str, Any]],
output_path: str = "embedding_similarity.png",
title: str = "Tree Position Embedding Similarity",
) -> None:
"""Plot cosine similarity between same-key embeddings at different tree positions."""
labels: list[str] = []
similarities: list[float] = []
for r in results:
pa: dict[str, Any] = r["position_a"]
pb: dict[str, Any] = r["position_b"]
label: str = (
f"{r['key']}\n"
f"d={pa['depth']},p={pa['parent_key']}\n"
f"vs d={pb['depth']},p={pb['parent_key']}"
)
labels.append(label)
similarities.append(r["cosine_similarity"])
fig, ax = plt.subplots(figsize=(max(8, len(results) * 3), 6))
bars = ax.bar(range(len(results)), similarities, color="steelblue")
ax.set_xticks(range(len(results)))
ax.set_xticklabels(labels, fontsize=9)
ax.set_ylabel("Cosine Similarity")
ax.set_title(title)
ax.set_ylim(-1, 1)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
for bar, sim in zip(bars, similarities):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.02,
f"{sim:.3f}",
ha="center",
fontsize=10,
)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
print(f"Embedding similarity plot saved: {output_path}")
def plot_attention_patterns(
attention_weights: torch.Tensor,
token_labels: list[str],
output_path: str = "attention_patterns.png",
title: str = "Attention Patterns",
) -> None:
"""Plot attention heatmaps for each head.
Args:
attention_weights: (num_heads, seq_len, seq_len)
token_labels: labels for each position in the sequence
"""
num_heads: int = attention_weights.shape[0]
fig, axes = plt.subplots(1, num_heads, figsize=(6 * num_heads, 5))
if num_heads == 1:
axes = [axes]
for head_idx, ax in enumerate(axes):
weights: torch.Tensor = attention_weights[head_idx].cpu()
im = ax.imshow(weights.numpy(), cmap="Blues", vmin=0, vmax=1)
ax.set_title(f"Head {head_idx}")
ax.set_xticks(range(len(token_labels)))
ax.set_yticks(range(len(token_labels)))
ax.set_xticklabels(token_labels, rotation=45, ha="right", fontsize=7)
ax.set_yticklabels(token_labels, fontsize=7)
fig.suptitle(title)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
print(f"Attention pattern plot saved: {output_path}")
|