|
|
import os |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from .config import ASSETS_DIR |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_figure(filename): |
|
|
""" |
|
|
Saves the current matplotlib figure with consistent professional settings. |
|
|
Saves to the global assets directory with 300 DPI resolution. |
|
|
""" |
|
|
path = os.path.join(ASSETS_DIR, filename) |
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
plt.savefig(path, dpi=300, bbox_inches='tight') |
|
|
print(f" Artifact Saved: {path}") |
|
|
|
|
|
plt.show() |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_learning_curves(metrics, version="v1"): |
|
|
""" |
|
|
Standardized learning curve plotter for loss and validation metrics. |
|
|
""" |
|
|
sns.set_style("whitegrid") |
|
|
plt.figure(figsize=(12, 6)) |
|
|
|
|
|
|
|
|
sns.lineplot(x=metrics["train_steps"], y=metrics["train_loss"], |
|
|
label='Training Loss', color='#4E79A7', linewidth=2.5) |
|
|
|
|
|
|
|
|
if metrics["eval_loss"]: |
|
|
sns.lineplot(x=metrics["eval_steps"], y=metrics["eval_loss"], |
|
|
label='Validation Loss', color='#E15759', linewidth=2.5, marker='o') |
|
|
|
|
|
plt.title(f'Learning Curve: LyricLoop {version.upper()}', fontsize=16, fontweight='bold', pad=15) |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Loss') |
|
|
plt.legend(frameon=True, fancybox=True, framealpha=0.9) |
|
|
|
|
|
save_figure(f"eval_loss_curve_{version}.png") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_token_heatmap(token_conf_pairs, title="Confidence Heatmap", filename="heatmap.png"): |
|
|
"""Draws a text heatmap where background color represents model confidence.""" |
|
|
fig = plt.figure(figsize=(10, 4)) |
|
|
ax = fig.add_axes([0, 0, 1, 1]) |
|
|
ax.axis('off') |
|
|
|
|
|
x, y = 0.02, 0.85 |
|
|
line_height = 0.12 |
|
|
confidences = [p[1] for p in token_conf_pairs] |
|
|
avg_conf = np.mean(confidences) if confidences else 0 |
|
|
|
|
|
ax.text(0.02, 0.95, f"{title} (Avg: {avg_conf:.2%})", |
|
|
fontsize=12, fontweight='bold', transform=ax.transAxes) |
|
|
|
|
|
for t, score in token_conf_pairs: |
|
|
|
|
|
if score > 0.7: bg = '#aaffaa' |
|
|
elif score > 0.3: bg = '#ffeeba' |
|
|
else: bg = '#ffcccc' |
|
|
|
|
|
clean_text = t.replace('\n', '↵ ') |
|
|
text_w = len(clean_text) * 0.015 |
|
|
|
|
|
if x + text_w > 0.95: |
|
|
x = 0.02 |
|
|
y -= line_height |
|
|
|
|
|
ax.text(x, y, clean_text, bbox=dict(facecolor=bg, edgecolor='none', pad=2, alpha=0.8), |
|
|
fontfamily='monospace', fontsize=10, transform=ax.transAxes) |
|
|
x += text_w + 0.005 |
|
|
|
|
|
save_figure(filename) |
|
|
return avg_conf |
|
|
|
|
|
def plot_confidence_summary(genres, scores, title="Confidence Summary", filename="conf_summary.png"): |
|
|
"""Standardized bar chart for comparing confidence across genres.""" |
|
|
plt.figure(figsize=(11, 6)) |
|
|
x = np.arange(len(genres)) |
|
|
width = 0.35 |
|
|
palette = ['#A0A0A0', '#4E79A7', '#E15759'] |
|
|
|
|
|
if isinstance(scores, list): |
|
|
scores_dict = {"Model Output": scores} |
|
|
width = 0.5 |
|
|
else: |
|
|
scores_dict = scores |
|
|
|
|
|
active_scores = {k: v for k, v in scores_dict.items() if len(v) == len(genres)} |
|
|
|
|
|
for i, (label, values) in enumerate(active_scores.items()): |
|
|
offset = (i - (len(active_scores)-1)/2) * width if len(active_scores) > 1 else 0 |
|
|
bars = plt.bar(x + offset, values, width, label=label, |
|
|
color=palette[i % 3], edgecolor='black', alpha=0.8) |
|
|
|
|
|
for bar in bars: |
|
|
h = bar.get_height() |
|
|
plt.text(bar.get_x() + bar.get_width()/2., h + 0.02, f'{h:.2f}', |
|
|
ha='center', va='bottom', fontweight='bold', fontsize=9) |
|
|
|
|
|
plt.title(title, fontsize=16, fontweight='bold') |
|
|
plt.ylabel('Average Confidence Score') |
|
|
plt.xticks(x, genres) |
|
|
plt.ylim(0, 1.1) |
|
|
if len(active_scores) > 1: |
|
|
plt.legend(loc='lower right') |
|
|
plt.grid(axis='y', linestyle='--', alpha=0.3) |
|
|
save_figure(filename) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_perplexity(genres, scores_dict, title="Model Perplexity", filename="perplexity.png", use_log=False): |
|
|
"""Global plotter for perplexity scores with support for log-scaling.""" |
|
|
plt.figure(figsize=(10, 6)) |
|
|
if use_log: plt.yscale('log') |
|
|
|
|
|
x = np.arange(len(genres)) |
|
|
comp_colors = ['#A0A0A0', '#4E79A7'] |
|
|
|
|
|
if len(scores_dict) == 1: |
|
|
label = list(scores_dict.keys())[0] |
|
|
values = list(scores_dict.values())[0] |
|
|
bars = plt.bar(genres, values, color='#A0A0A0', edgecolor='black', alpha=0.8) |
|
|
else: |
|
|
width = 0.35 |
|
|
for i, (label, values) in enumerate(scores_dict.items()): |
|
|
offset = (i - (len(scores_dict)-1)/2) * width |
|
|
bars = plt.bar(x + offset, values, width, label=label, color=comp_colors[i % 2], edgecolor='black') |
|
|
|
|
|
plt.title(title, fontsize=14, fontweight='bold') |
|
|
plt.ylabel('Perplexity (Lower is Better)', fontsize=12) |
|
|
plt.xticks(x, genres) |
|
|
plt.grid(axis='y', linestyle='--', alpha=0.5) |
|
|
if len(scores_dict) > 1: plt.legend() |
|
|
|
|
|
save_figure(filename) |