yaml-bert / yaml_bert /visualize.py
vimalk78's picture
Initial app: Gradio missing-field suggester (v6.1 model)
222a479 verified
Raw
History Blame Contribute Delete
4.49 kB
from __future__ import annotations
from typing import Any
import matplotlib
matplotlib.use("Agg") # non-interactive backend
import matplotlib.pyplot as plt
import torch
def plot_training_loss(
losses: list[float],
output_path: str = "training_loss.png",
title: str = "YAML-BERT Training Loss",
) -> None:
"""Plot training loss curve over epochs."""
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(range(1, len(losses) + 1), losses, marker="o", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title(title)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
print(f"Training loss plot saved: {output_path}")
def plot_accuracy(
results: dict[str, float],
output_path: str = "accuracy.png",
title: str = "YAML-BERT Key Prediction Accuracy",
) -> None:
"""Plot top-1 and top-5 prediction accuracy as a bar chart."""
labels: list[str] = ["Top-1", "Top-5"]
values: list[float] = [results["top1_accuracy"], results["top5_accuracy"]]
fig, ax = plt.subplots(figsize=(6, 5))
bars = ax.bar(labels, values, color=["steelblue", "coral"], width=0.5)
ax.set_ylabel("Accuracy")
ax.set_title(title)
ax.set_ylim(0, 1)
for bar, val in zip(bars, values):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.02,
f"{val:.1%}",
ha="center",
fontsize=14,
fontweight="bold",
)
total: float = results.get("total_masked", 0)
ax.text(
0.5, -0.1,
f"Total masked positions: {int(total)}",
ha="center",
transform=ax.transAxes,
fontsize=10,
color="gray",
)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
print(f"Accuracy plot saved: {output_path}")
def plot_embedding_similarity(
results: list[dict[str, Any]],
output_path: str = "embedding_similarity.png",
title: str = "Tree Position Embedding Similarity",
) -> None:
"""Plot cosine similarity between same-key embeddings at different tree positions."""
labels: list[str] = []
similarities: list[float] = []
for r in results:
pa: dict[str, Any] = r["position_a"]
pb: dict[str, Any] = r["position_b"]
label: str = (
f"{r['key']}\n"
f"d={pa['depth']},p={pa['parent_key']}\n"
f"vs d={pb['depth']},p={pb['parent_key']}"
)
labels.append(label)
similarities.append(r["cosine_similarity"])
fig, ax = plt.subplots(figsize=(max(8, len(results) * 3), 6))
bars = ax.bar(range(len(results)), similarities, color="steelblue")
ax.set_xticks(range(len(results)))
ax.set_xticklabels(labels, fontsize=9)
ax.set_ylabel("Cosine Similarity")
ax.set_title(title)
ax.set_ylim(-1, 1)
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
for bar, sim in zip(bars, similarities):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.02,
f"{sim:.3f}",
ha="center",
fontsize=10,
)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
print(f"Embedding similarity plot saved: {output_path}")
def plot_attention_patterns(
attention_weights: torch.Tensor,
token_labels: list[str],
output_path: str = "attention_patterns.png",
title: str = "Attention Patterns",
) -> None:
"""Plot attention heatmaps for each head.
Args:
attention_weights: (num_heads, seq_len, seq_len)
token_labels: labels for each position in the sequence
"""
num_heads: int = attention_weights.shape[0]
fig, axes = plt.subplots(1, num_heads, figsize=(6 * num_heads, 5))
if num_heads == 1:
axes = [axes]
for head_idx, ax in enumerate(axes):
weights: torch.Tensor = attention_weights[head_idx].cpu()
im = ax.imshow(weights.numpy(), cmap="Blues", vmin=0, vmax=1)
ax.set_title(f"Head {head_idx}")
ax.set_xticks(range(len(token_labels)))
ax.set_yticks(range(len(token_labels)))
ax.set_xticklabels(token_labels, rotation=45, ha="right", fontsize=7)
ax.set_yticklabels(token_labels, fontsize=7)
fig.suptitle(title)
fig.tight_layout()
fig.savefig(output_path, dpi=150)
plt.close(fig)
print(f"Attention pattern plot saved: {output_path}")