eumora-api / backend /src /visualize.py
VivDubs's picture
refactor: move backend files into backend/ directory
9eb5faa
Raw
History Blame Contribute Delete
10.7 kB
"""Visualization module for emotion analysis results."""
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
from typing import Dict, Optional
from datetime import datetime
class EmotionVisualizer:
"""Generate visualizations for emotion analysis results."""
def __init__(self, output_dir: Path = None):
"""
Initialize the visualizer.
Args:
output_dir: Directory to save visualization files
"""
self.output_dir = output_dir or Path("visualizations")
self.output_dir.mkdir(exist_ok=True)
# Set style
plt.style.use('default')
sns.set_palette("husl")
# Emotion colors for consistency
self.emotion_colors = {
'joy': '#FFD700', # Gold
'sadness': '#4169E1', # Royal Blue
'anger': '#DC143C', # Crimson
'fear': '#9370DB', # Medium Purple
'love': '#FF69B4', # Hot Pink
'surprise': '#FF8C00', # Dark Orange
}
def create_emotion_bar_chart(self,
probabilities: Dict[str, float],
text: str = "",
save_path: Optional[Path] = None,
show_chart: bool = True,
primary_emotion: Optional[str] = None) -> Path:
"""
Create a bar chart showing emotion probabilities.
Args:
probabilities: Dict of emotion -> probability
text: Input text (for title)
save_path: Where to save the chart
show_chart: Whether to display the chart
primary_emotion: Explicit primary emotion to override argmax
Returns:
Path to saved chart image
"""
# Prepare data
emotions = list(probabilities.keys())
probs = list(probabilities.values())
colors = [self.emotion_colors.get(emo, '#808080') for emo in emotions]
# Create figure
plt.figure(figsize=(12, 8))
# Create bars
bars = plt.bar(emotions, probs, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
# Customize chart
plt.title(f'Emotion Analysis Results\n"{text[:60]}{"..." if len(text) > 60 else ""}"',
fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Emotions', fontsize=14, fontweight='bold')
plt.ylabel('Probability', fontsize=14, fontweight='bold')
# Add percentage labels on bars
for bar, prob in zip(bars, probs):
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
f'{prob:.1%}', ha='center', va='bottom',
fontweight='bold', fontsize=12)
# Customize appearance
plt.ylim(0, max(probs) * 1.2)
plt.xticks(rotation=45, fontsize=12)
plt.yticks(fontsize=12)
plt.grid(axis='y', alpha=0.3, linestyle='--')
# Add emotion indicators
if primary_emotion is None:
max_emotion = max(probabilities, key=probabilities.get)
else:
max_emotion = primary_emotion
max_prob = probabilities[max_emotion]
plt.text(0.02, 0.98, f'Primary: {max_emotion.upper()} ({max_prob:.1%})',
transform=plt.gca().transAxes, fontsize=14, fontweight='bold',
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7),
verticalalignment='top')
plt.tight_layout()
# Save chart
if save_path is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = self.output_dir / f"emotion_analysis_{timestamp}.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight',
facecolor='white', edgecolor='none')
if show_chart:
plt.show()
else:
plt.close()
return save_path
def create_comparison_chart(self,
results_list: list,
titles: list = None,
save_path: Optional[Path] = None,
show_chart: bool = True) -> Path:
"""
Create a comparison chart for multiple predictions.
Args:
results_list: List of prediction results
titles: List of titles for each prediction
save_path: Where to save the chart
show_chart: Whether to display the chart
Returns:
Path to saved chart image
"""
if not results_list:
raise ValueError("No results provided for comparison")
# Prepare data
emotions = list(results_list[0]['probabilities'].keys())
n_samples = len(results_list)
if titles is None:
titles = [f"Sample {i+1}" for i in range(n_samples)]
# Create figure with subplots
fig, axes = plt.subplots(1, n_samples, figsize=(6*n_samples, 6))
if n_samples == 1:
axes = [axes]
fig.suptitle('Emotion Analysis Comparison', fontsize=18, fontweight='bold')
for idx, (result, title, ax) in enumerate(zip(results_list, titles, axes)):
probs = list(result['probabilities'].values())
colors = [self.emotion_colors.get(emo, '#808080') for emo in emotions]
bars = ax.bar(emotions, probs, color=colors, alpha=0.8,
edgecolor='black', linewidth=1)
# Add labels
for bar, prob in zip(bars, probs):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
f'{prob:.1%}', ha='center', va='bottom',
fontweight='bold', fontsize=10)
ax.set_title(f'{title}\n🎯 {result["emotion"].upper()} ({result["confidence"]:.1%})')
ax.set_ylim(0, 1)
ax.tick_params(axis='x', rotation=45)
ax.grid(axis='y', alpha=0.3, linestyle='--')
plt.tight_layout()
# Save chart
if save_path is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = self.output_dir / f"emotion_comparison_{timestamp}.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight',
facecolor='white', edgecolor='none')
if show_chart:
plt.show()
else:
plt.close()
return save_path
def create_detailed_analysis_chart(self,
result: dict,
text: str = "",
save_path: Optional[Path] = None,
show_chart: bool = True) -> Path:
"""
Create a simplified detailed analysis chart with only the bar chart.
Args:
result: Prediction result dictionary
text: Input text
save_path: Where to save the chart
show_chart: Whether to display the chart
Returns:
Path to saved chart image
"""
fig = plt.figure(figsize=(14, 8))
# Single bar chart with enhanced styling
emotions = list(result['probabilities'].keys())
probs = list(result['probabilities'].values())
colors = [self.emotion_colors.get(emo, '#808080') for emo in emotions]
# Verify probabilities sum to 100%
total_prob = sum(probs)
if abs(total_prob - 1.0) > 0.001:
print(f"⚠️ Warning: Probabilities sum to {total_prob:.4f}, not 1.0")
bars = plt.bar(emotions, probs, color=colors, alpha=0.8,
edgecolor='black', linewidth=2, width=0.6)
# Add percentage labels on bars
for bar, prob in zip(bars, probs):
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
f'{prob:.1%}', ha='center', va='bottom',
fontweight='bold', fontsize=14)
# Enhanced styling
plt.title(f'Emotion Analysis Results\n"{text[:80]}{"..." if len(text) > 80 else ""}"',
fontsize=18, fontweight='bold', pad=20)
plt.xlabel('Emotions', fontsize=16, fontweight='bold')
plt.ylabel('Probability', fontsize=16, fontweight='bold')
# Set y-axis to show 0-100%
plt.ylim(0, 1.0)
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
# Customize appearance
plt.xticks(rotation=0, fontsize=13, fontweight='bold')
plt.yticks(fontsize=12)
plt.grid(axis='y', alpha=0.3, linestyle='--')
# Add primary emotion indicator
max_emotion = max(result['probabilities'], key=result['probabilities'].get)
max_prob = result['probabilities'][max_emotion]
plt.text(0.02, 0.98, f'Primary Emotion: {max_emotion.upper()} ({max_prob:.1%})',
transform=plt.gca().transAxes, fontsize=16, fontweight='bold',
bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8),
verticalalignment='top')
# Add verification text
plt.text(0.98, 0.02, f'Total: {sum(probs):.1%}',
transform=plt.gca().transAxes, fontsize=12,
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7),
horizontalalignment='right', verticalalignment='bottom')
plt.tight_layout()
# Save chart
if save_path is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = self.output_dir / f"detailed_analysis_{timestamp}.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight',
facecolor='white', edgecolor='none')
if show_chart:
plt.show()
else:
plt.close()
return save_path
def create_quick_chart(probabilities: dict, text: str = "", show: bool = True) -> Path:
"""Quick function to create a simple emotion bar chart."""
visualizer = EmotionVisualizer()
return visualizer.create_emotion_bar_chart(probabilities, text, show_chart=show)
if __name__ == "__main__":
# Demo visualization
sample_probs = {
'joy': 0.7,
'sadness': 0.1,
'anger': 0.08,
'fear': 0.06,
'love': 0.04,
'surprise': 0.02
}
visualizer = EmotionVisualizer()
chart_path = visualizer.create_emotion_bar_chart(
sample_probs,
"I'm feeling so happy today, everything is wonderful!"
)
print(f"📊 Chart saved to: {chart_path}")