File size: 5,458 Bytes
46b55ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from .config import ASSETS_DIR
# -------------------------
# Visualization Utilities
# -------------------------
def save_figure(filename):
"""
Saves the current matplotlib figure with consistent professional settings.
Saves to the global assets directory with 300 DPI resolution.
"""
path = os.path.join(ASSETS_DIR, filename)
# Ensure layout does not clip labels
plt.tight_layout()
# High resolution for documentation and reports
plt.savefig(path, dpi=300, bbox_inches='tight')
print(f" Artifact Saved: {path}")
plt.show()
plt.close()
# -------------------------
# Training Diagnostics
# -------------------------
def plot_learning_curves(metrics, version="v1"):
"""
Standardized learning curve plotter for loss and validation metrics.
"""
sns.set_style("whitegrid")
plt.figure(figsize=(12, 6))
# Training Loss
sns.lineplot(x=metrics["train_steps"], y=metrics["train_loss"],
label='Training Loss', color='#4E79A7', linewidth=2.5)
# Validation Loss (if available)
if metrics["eval_loss"]:
sns.lineplot(x=metrics["eval_steps"], y=metrics["eval_loss"],
label='Validation Loss', color='#E15759', linewidth=2.5, marker='o')
plt.title(f'Learning Curve: LyricLoop {version.upper()}', fontsize=16, fontweight='bold', pad=15)
plt.xlabel('Training Steps')
plt.ylabel('Loss')
plt.legend(frameon=True, fancybox=True, framealpha=0.9)
save_figure(f"eval_loss_curve_{version}.png")
# -------------------------
# Confidence & Interpretability
# -------------------------
def plot_token_heatmap(token_conf_pairs, title="Confidence Heatmap", filename="heatmap.png"):
"""Draws a text heatmap where background color represents model confidence."""
fig = plt.figure(figsize=(10, 4))
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
x, y = 0.02, 0.85
line_height = 0.12
confidences = [p[1] for p in token_conf_pairs]
avg_conf = np.mean(confidences) if confidences else 0
ax.text(0.02, 0.95, f"{title} (Avg: {avg_conf:.2%})",
fontsize=12, fontweight='bold', transform=ax.transAxes)
for t, score in token_conf_pairs:
# Professional Color Scale: Green (High), Orange (Medium), Red (Low)
if score > 0.7: bg = '#aaffaa'
elif score > 0.3: bg = '#ffeeba'
else: bg = '#ffcccc'
clean_text = t.replace('\n', '↵ ')
text_w = len(clean_text) * 0.015
if x + text_w > 0.95:
x = 0.02
y -= line_height
ax.text(x, y, clean_text, bbox=dict(facecolor=bg, edgecolor='none', pad=2, alpha=0.8),
fontfamily='monospace', fontsize=10, transform=ax.transAxes)
x += text_w + 0.005
save_figure(filename)
return avg_conf
def plot_confidence_summary(genres, scores, title="Confidence Summary", filename="conf_summary.png"):
"""Standardized bar chart for comparing confidence across genres."""
plt.figure(figsize=(11, 6))
x = np.arange(len(genres))
width = 0.35
palette = ['#A0A0A0', '#4E79A7', '#E15759'] # grey, blue, red
if isinstance(scores, list):
scores_dict = {"Model Output": scores}
width = 0.5
else:
scores_dict = scores
active_scores = {k: v for k, v in scores_dict.items() if len(v) == len(genres)}
for i, (label, values) in enumerate(active_scores.items()):
offset = (i - (len(active_scores)-1)/2) * width if len(active_scores) > 1 else 0
bars = plt.bar(x + offset, values, width, label=label,
color=palette[i % 3], edgecolor='black', alpha=0.8)
for bar in bars:
h = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., h + 0.02, f'{h:.2f}',
ha='center', va='bottom', fontweight='bold', fontsize=9)
plt.title(title, fontsize=16, fontweight='bold')
plt.ylabel('Average Confidence Score')
plt.xticks(x, genres)
plt.ylim(0, 1.1)
if len(active_scores) > 1:
plt.legend(loc='lower right')
plt.grid(axis='y', linestyle='--', alpha=0.3)
save_figure(filename)
# -------------------------
# Performance Comparison
# -------------------------
def plot_perplexity(genres, scores_dict, title="Model Perplexity", filename="perplexity.png", use_log=False):
"""Global plotter for perplexity scores with support for log-scaling."""
plt.figure(figsize=(10, 6))
if use_log: plt.yscale('log')
x = np.arange(len(genres))
comp_colors = ['#A0A0A0', '#4E79A7'] # grey for Baseline, blue for Fine-Tuned
if len(scores_dict) == 1:
label = list(scores_dict.keys())[0]
values = list(scores_dict.values())[0]
bars = plt.bar(genres, values, color='#A0A0A0', edgecolor='black', alpha=0.8)
else:
width = 0.35
for i, (label, values) in enumerate(scores_dict.items()):
offset = (i - (len(scores_dict)-1)/2) * width
bars = plt.bar(x + offset, values, width, label=label, color=comp_colors[i % 2], edgecolor='black')
plt.title(title, fontsize=14, fontweight='bold')
plt.ylabel('Perplexity (Lower is Better)', fontsize=12)
plt.xticks(x, genres)
plt.grid(axis='y', linestyle='--', alpha=0.5)
if len(scores_dict) > 1: plt.legend()
save_figure(filename) |