Alon Albalak
major update: data saved to hf, user sessions maintain separation, fixed scoring bug
36d5e94
"""Statistics calculation and visualization functionality"""
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Dict, Optional, Tuple
class StatisticsCalculator:
"""Handles statistical calculations and visualization generation"""
def get_score_statistics(self, scores):
"""Calculate mean and standard deviation of scores."""
if not scores:
return None, None
return np.mean(scores), np.std(scores)
def create_violin_plot(self, prompt_results, user_score, user_tokens):
"""Create horizontal violin plots stacked vertically by token count."""
token_counts = [1, 2, 3, 4, 5]
# Filter to only include token counts that have data
token_data = []
for token_count in token_counts:
token_scores = [r["cosine_distance"] for r in prompt_results
if r["num_user_tokens"] == token_count]
if token_scores:
token_data.append((token_count, token_scores))
if not token_data:
fig, ax = plt.subplots(figsize=(10, 4))
ax.text(0.5, 0.5, 'No data available for visualization',
ha='center', va='center', transform=ax.transAxes, fontsize=14)
ax.set_title('Score Distribution by Token Count', fontsize=14, fontweight='bold')
return fig
# Create subplots - one for each token count
fig, axes = plt.subplots(len(token_data), 1, figsize=(10, 1 * len(token_data)),
sharex=True)
# Handle single subplot case
if len(token_data) == 1:
axes = [axes]
for i, (token_count, scores) in enumerate(token_data):
ax = axes[i]
# Create horizontal violin plot
parts = ax.violinplot([scores], positions=[0], vert=False,
showmeans=True, showextrema=True)
# Color based on whether this is user's token count - reuse existing gradient colors
color = '#667eea' if token_count == user_tokens else '#764ba2'
for pc in parts['bodies']:
pc.set_facecolor(color)
pc.set_alpha(0.7)
pc.set_edgecolor('black')
pc.set_linewidth(1)
# Highlight user's score if this is their token count
if token_count == user_tokens:
ax.scatter(user_score, 0, color='red', s=150, zorder=5,
marker='*', label=f'Your Score: {user_score:.3f}')
ax.legend(loc='upper right')
# Styling for each subplot
ax.set_ylabel(f'{token_count} token{"s" if token_count != 1 else ""}\n(n={len(scores)})',
fontsize=11, fontweight='bold')
ax.set_yticks([])
ax.grid(True, alpha=0.3, axis='x')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
# Set consistent y-limits for visual consistency
ax.set_ylim(-0.4, 0.4)
# Set common x-label only on bottom subplot
axes[-1].set_xlabel('Creativity Score (Cosine Distance)', fontsize=12)
# Overall title
fig.suptitle('Score Distribution by Token Count', fontsize=14, fontweight='bold', y=0.98)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
return fig
def calculate_session_ranking_stats(self, session_results, data_manager, scorer):
"""Calculate comprehensive ranking statistics for the session."""
all_results = data_manager.get_results()
ranking_stats = {
"best_rank": None,
"best_percentile": None,
"average_percentile": 0.0,
"total_ranked_attempts": 0,
"ranking_trend": "stable", # up, down, stable
"recent_percentiles": []
}
if not session_results:
return ranking_stats
percentiles = []
ranks = []
# Calculate rankings for each session attempt
for result in session_results:
# Get all results for this specific prompt
prompt_results = data_manager.filter_results_by_partial_response(
all_results, result["prompt"], result["llm_partial_response"]
)
if len(prompt_results) >= 2: # Need at least 2 results to rank
rank, percentile = scorer.calculate_rank_and_percentile(
result["cosine_distance"], prompt_results, result["num_user_tokens"]
)
if rank and percentile is not None:
percentiles.append(percentile)
ranks.append(rank)
ranking_stats["recent_percentiles"].append({
"percentile": percentile,
"rank": rank,
"total": len([r for r in prompt_results if r["num_user_tokens"] == result["num_user_tokens"]]),
"timestamp": result["timestamp"]
})
if percentiles:
ranking_stats["total_ranked_attempts"] = len(percentiles)
ranking_stats["average_percentile"] = sum(percentiles) / len(percentiles)
ranking_stats["best_percentile"] = max(percentiles)
ranking_stats["best_rank"] = min(ranks) if ranks else None
# Determine trend (compare first half vs second half)
if len(percentiles) >= 4:
mid_point = len(percentiles) // 2
first_half_avg = sum(percentiles[:mid_point]) / mid_point
second_half_avg = sum(percentiles[mid_point:]) / (len(percentiles) - mid_point)
if second_half_avg > first_half_avg + 10:
ranking_stats["ranking_trend"] = "up"
elif second_half_avg < first_half_avg - 10:
ranking_stats["ranking_trend"] = "down"
return ranking_stats