Spaces:
Sleeping
Sleeping
| """Visualization module for emotion analysis results.""" | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import Dict, Optional | |
| from datetime import datetime | |
| class EmotionVisualizer: | |
| """Generate visualizations for emotion analysis results.""" | |
| def __init__(self, output_dir: Path = None): | |
| """ | |
| Initialize the visualizer. | |
| Args: | |
| output_dir: Directory to save visualization files | |
| """ | |
| self.output_dir = output_dir or Path("visualizations") | |
| self.output_dir.mkdir(exist_ok=True) | |
| # Set style | |
| plt.style.use('default') | |
| sns.set_palette("husl") | |
| # Emotion colors for consistency | |
| self.emotion_colors = { | |
| 'joy': '#FFD700', # Gold | |
| 'sadness': '#4169E1', # Royal Blue | |
| 'anger': '#DC143C', # Crimson | |
| 'fear': '#9370DB', # Medium Purple | |
| 'love': '#FF69B4', # Hot Pink | |
| 'surprise': '#FF8C00', # Dark Orange | |
| } | |
| def create_emotion_bar_chart(self, | |
| probabilities: Dict[str, float], | |
| text: str = "", | |
| save_path: Optional[Path] = None, | |
| show_chart: bool = True, | |
| primary_emotion: Optional[str] = None) -> Path: | |
| """ | |
| Create a bar chart showing emotion probabilities. | |
| Args: | |
| probabilities: Dict of emotion -> probability | |
| text: Input text (for title) | |
| save_path: Where to save the chart | |
| show_chart: Whether to display the chart | |
| primary_emotion: Explicit primary emotion to override argmax | |
| Returns: | |
| Path to saved chart image | |
| """ | |
| # Prepare data | |
| emotions = list(probabilities.keys()) | |
| probs = list(probabilities.values()) | |
| colors = [self.emotion_colors.get(emo, '#808080') for emo in emotions] | |
| # Create figure | |
| plt.figure(figsize=(12, 8)) | |
| # Create bars | |
| bars = plt.bar(emotions, probs, color=colors, alpha=0.8, edgecolor='black', linewidth=1) | |
| # Customize chart | |
| plt.title(f'Emotion Analysis Results\n"{text[:60]}{"..." if len(text) > 60 else ""}"', | |
| fontsize=16, fontweight='bold', pad=20) | |
| plt.xlabel('Emotions', fontsize=14, fontweight='bold') | |
| plt.ylabel('Probability', fontsize=14, fontweight='bold') | |
| # Add percentage labels on bars | |
| for bar, prob in zip(bars, probs): | |
| height = bar.get_height() | |
| plt.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
| f'{prob:.1%}', ha='center', va='bottom', | |
| fontweight='bold', fontsize=12) | |
| # Customize appearance | |
| plt.ylim(0, max(probs) * 1.2) | |
| plt.xticks(rotation=45, fontsize=12) | |
| plt.yticks(fontsize=12) | |
| plt.grid(axis='y', alpha=0.3, linestyle='--') | |
| # Add emotion indicators | |
| if primary_emotion is None: | |
| max_emotion = max(probabilities, key=probabilities.get) | |
| else: | |
| max_emotion = primary_emotion | |
| max_prob = probabilities[max_emotion] | |
| plt.text(0.02, 0.98, f'Primary: {max_emotion.upper()} ({max_prob:.1%})', | |
| transform=plt.gca().transAxes, fontsize=14, fontweight='bold', | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7), | |
| verticalalignment='top') | |
| plt.tight_layout() | |
| # Save chart | |
| if save_path is None: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path = self.output_dir / f"emotion_analysis_{timestamp}.png" | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight', | |
| facecolor='white', edgecolor='none') | |
| if show_chart: | |
| plt.show() | |
| else: | |
| plt.close() | |
| return save_path | |
| def create_comparison_chart(self, | |
| results_list: list, | |
| titles: list = None, | |
| save_path: Optional[Path] = None, | |
| show_chart: bool = True) -> Path: | |
| """ | |
| Create a comparison chart for multiple predictions. | |
| Args: | |
| results_list: List of prediction results | |
| titles: List of titles for each prediction | |
| save_path: Where to save the chart | |
| show_chart: Whether to display the chart | |
| Returns: | |
| Path to saved chart image | |
| """ | |
| if not results_list: | |
| raise ValueError("No results provided for comparison") | |
| # Prepare data | |
| emotions = list(results_list[0]['probabilities'].keys()) | |
| n_samples = len(results_list) | |
| if titles is None: | |
| titles = [f"Sample {i+1}" for i in range(n_samples)] | |
| # Create figure with subplots | |
| fig, axes = plt.subplots(1, n_samples, figsize=(6*n_samples, 6)) | |
| if n_samples == 1: | |
| axes = [axes] | |
| fig.suptitle('Emotion Analysis Comparison', fontsize=18, fontweight='bold') | |
| for idx, (result, title, ax) in enumerate(zip(results_list, titles, axes)): | |
| probs = list(result['probabilities'].values()) | |
| colors = [self.emotion_colors.get(emo, '#808080') for emo in emotions] | |
| bars = ax.bar(emotions, probs, color=colors, alpha=0.8, | |
| edgecolor='black', linewidth=1) | |
| # Add labels | |
| for bar, prob in zip(bars, probs): | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
| f'{prob:.1%}', ha='center', va='bottom', | |
| fontweight='bold', fontsize=10) | |
| ax.set_title(f'{title}\n🎯 {result["emotion"].upper()} ({result["confidence"]:.1%})') | |
| ax.set_ylim(0, 1) | |
| ax.tick_params(axis='x', rotation=45) | |
| ax.grid(axis='y', alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| # Save chart | |
| if save_path is None: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path = self.output_dir / f"emotion_comparison_{timestamp}.png" | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight', | |
| facecolor='white', edgecolor='none') | |
| if show_chart: | |
| plt.show() | |
| else: | |
| plt.close() | |
| return save_path | |
| def create_detailed_analysis_chart(self, | |
| result: dict, | |
| text: str = "", | |
| save_path: Optional[Path] = None, | |
| show_chart: bool = True) -> Path: | |
| """ | |
| Create a simplified detailed analysis chart with only the bar chart. | |
| Args: | |
| result: Prediction result dictionary | |
| text: Input text | |
| save_path: Where to save the chart | |
| show_chart: Whether to display the chart | |
| Returns: | |
| Path to saved chart image | |
| """ | |
| fig = plt.figure(figsize=(14, 8)) | |
| # Single bar chart with enhanced styling | |
| emotions = list(result['probabilities'].keys()) | |
| probs = list(result['probabilities'].values()) | |
| colors = [self.emotion_colors.get(emo, '#808080') for emo in emotions] | |
| # Verify probabilities sum to 100% | |
| total_prob = sum(probs) | |
| if abs(total_prob - 1.0) > 0.001: | |
| print(f"⚠️ Warning: Probabilities sum to {total_prob:.4f}, not 1.0") | |
| bars = plt.bar(emotions, probs, color=colors, alpha=0.8, | |
| edgecolor='black', linewidth=2, width=0.6) | |
| # Add percentage labels on bars | |
| for bar, prob in zip(bars, probs): | |
| height = bar.get_height() | |
| plt.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
| f'{prob:.1%}', ha='center', va='bottom', | |
| fontweight='bold', fontsize=14) | |
| # Enhanced styling | |
| plt.title(f'Emotion Analysis Results\n"{text[:80]}{"..." if len(text) > 80 else ""}"', | |
| fontsize=18, fontweight='bold', pad=20) | |
| plt.xlabel('Emotions', fontsize=16, fontweight='bold') | |
| plt.ylabel('Probability', fontsize=16, fontweight='bold') | |
| # Set y-axis to show 0-100% | |
| plt.ylim(0, 1.0) | |
| plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) | |
| # Customize appearance | |
| plt.xticks(rotation=0, fontsize=13, fontweight='bold') | |
| plt.yticks(fontsize=12) | |
| plt.grid(axis='y', alpha=0.3, linestyle='--') | |
| # Add primary emotion indicator | |
| max_emotion = max(result['probabilities'], key=result['probabilities'].get) | |
| max_prob = result['probabilities'][max_emotion] | |
| plt.text(0.02, 0.98, f'Primary Emotion: {max_emotion.upper()} ({max_prob:.1%})', | |
| transform=plt.gca().transAxes, fontsize=16, fontweight='bold', | |
| bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8), | |
| verticalalignment='top') | |
| # Add verification text | |
| plt.text(0.98, 0.02, f'Total: {sum(probs):.1%}', | |
| transform=plt.gca().transAxes, fontsize=12, | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7), | |
| horizontalalignment='right', verticalalignment='bottom') | |
| plt.tight_layout() | |
| # Save chart | |
| if save_path is None: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path = self.output_dir / f"detailed_analysis_{timestamp}.png" | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight', | |
| facecolor='white', edgecolor='none') | |
| if show_chart: | |
| plt.show() | |
| else: | |
| plt.close() | |
| return save_path | |
| def create_quick_chart(probabilities: dict, text: str = "", show: bool = True) -> Path: | |
| """Quick function to create a simple emotion bar chart.""" | |
| visualizer = EmotionVisualizer() | |
| return visualizer.create_emotion_bar_chart(probabilities, text, show_chart=show) | |
| if __name__ == "__main__": | |
| # Demo visualization | |
| sample_probs = { | |
| 'joy': 0.7, | |
| 'sadness': 0.1, | |
| 'anger': 0.08, | |
| 'fear': 0.06, | |
| 'love': 0.04, | |
| 'surprise': 0.02 | |
| } | |
| visualizer = EmotionVisualizer() | |
| chart_path = visualizer.create_emotion_bar_chart( | |
| sample_probs, | |
| "I'm feeling so happy today, everything is wonderful!" | |
| ) | |
| print(f"📊 Chart saved to: {chart_path}") |