# utils/visualization.py (updated for maximum contrast)
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import base64
from io import BytesIO
import numpy as np
def create_visualization(text, explanation, tokenizer, explainer_type):
"""Create HTML visualization of token attributions with maximum contrast"""
try:
# Tokenize the text
tokens = tokenizer.tokenize(text)
# Handle different explanation formats
token_values = {}
if explainer_type == "LIME" and explanation:
# LIME returns list of (feature, weight) tuples
for feature, weight in explanation:
# Extract individual tokens from LIME features
feature_tokens = feature.split()
for token in feature_tokens:
# Clean token (remove punctuation, etc.)
clean_token = token.strip('.,!?;:"()[]{}').lower()
if clean_token:
token_values[clean_token] = weight / len(feature_tokens) if feature_tokens else weight
elif explainer_type in ["SHAP", "Captum"] and explanation:
# SHAP and Captum return list of dicts with 'token' and 'value'
for item in explanation:
if isinstance(item, dict) and 'token' in item and 'value' in item:
token = item['token'].lower()
value = item['value']
token_values[token] = value
# If no explanation data, create a neutral visualization
if not token_values:
html_output = '''
Explanation data not available. Showing tokenized text.
'''
for token in tokens:
html_output += f'{token.replace("##", "")} '
html_output += '
'
return html_output
# Normalize scores for coloring
values = list(token_values.values())
max_abs_value = max(abs(min(values)), abs(max(values))) if values else 1
if max_abs_value > 0:
normalized_values = {k: v / max_abs_value for k, v in token_values.items()}
else:
normalized_values = {k: 0 for k in token_values.keys()}
# Create HTML
html_output = '''
'''
# Map tokens to values with high contrast colors
for token in tokens:
clean_token = token.replace('##', '').lower()
if clean_token in normalized_values:
value = token_values[clean_token]
norm_value = normalized_values[clean_token]
# Determine color based on value with maximum contrast
if value < 0:
# Negative values: deep red colors
intensity = min(1.0, 0.6 + 0.4 * abs(norm_value))
if intensity > 0.8:
color = "#cc0000" # Very dark red
elif intensity > 0.6:
color = "#ff4444" # Dark red
elif intensity > 0.4:
color = "#ff8888" # Medium red
else:
color = "#ffcccc" # Light red
text_color = "white" if intensity > 0.5 else "black"
else:
# Positive values: deep blue colors
intensity = min(1.0, 0.6 + 0.4 * norm_value)
if intensity > 0.8:
color = "#0000cc" # Very dark blue
elif intensity > 0.6:
color = "#4444ff" # Dark blue
elif intensity > 0.4:
color = "#8888ff" # Medium blue
else:
color = "#ccccff" # Light blue
text_color = "white" if intensity > 0.5 else "black"
html_output += f'{token.replace("##", "")} '
else:
html_output += f'{token.replace("##", "")} '
html_output += '
'
# Add color legend
html_output += '''
Strong negative
Weak negative
Strong positive
Weak positive
'''
return html_output
except Exception as e:
print(f"Visualization error: {e}")
return f'Error creating visualization: {str(e)}
'
def create_attribution_plot(explanation, method_name):
"""Create matplotlib visualization of token attributions"""
try:
if not explanation:
return "No explanation data available
"
# Handle different explanation formats
if method_name == "LIME":
# LIME: list of (feature, weight) tuples
features = [item[0] for item in explanation][:15] # Show top 15 features
scores = [item[1] for item in explanation][:15]
title = f'Top Feature Attributions ({method_name})'
else:
# SHAP/Captum: list of dicts with 'token' and 'value'
tokens = [item['token'] for item in explanation if isinstance(item, dict) and 'token' in item][:15]
scores = [item['value'] for item in explanation if isinstance(item, dict) and 'value' in item][:15]
features = tokens
title = f'Top Token Attributions ({method_name})'
if not features or not scores:
return "No valid explanation data available for plotting
"
# Create plot with better colors
fig, ax = plt.subplots(figsize=(12, 6))
# Create colors based on values - using high contrast colors
colors = ['#ff6b6b' if score < 0 else '#4ecdc4' for score in scores]
# Create horizontal bar chart
y_pos = np.arange(len(features))
bars = ax.barh(y_pos, scores, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
# Customize plot
ax.set_yticks(y_pos)
ax.set_yticklabels(features, fontsize=10)
ax.set_xlabel('Attribution Score', fontsize=12, fontweight='bold')
ax.set_title(title, fontsize=14, fontweight='bold')
ax.axvline(x=0, color='black', linestyle='-', alpha=0.5, linewidth=1)
# Add grid for better readability
ax.grid(True, alpha=0.3, axis='x')
# Add value labels on bars
for i, (bar, score) in enumerate(zip(bars, scores)):
width = bar.get_width()
label_x_pos = width + (0.01 * max(scores) if width >= 0 else 0.01 * min(scores))
ax.text(label_x_pos, bar.get_y() + bar.get_height()/2,
f'{score:.4f}', ha='left' if width >= 0 else 'right', va='center',
fontsize=9, fontweight='bold')
# Set background color
ax.set_facecolor('#f8f9fa')
fig.patch.set_facecolor('#f8f9fa')
plt.tight_layout()
# Convert to HTML
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor())
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig)
return f'
'
except Exception as e:
print(f"Plot error: {e}")
return f'Error creating plot: {str(e)}
'
# utils/visualization.py (add this function)
def create_confidence_chart(probabilities, class_names=None):
"""Create a bar chart showing class probabilities"""
try:
if class_names is None:
class_names = [f"Class {i}" for i in range(len(probabilities))]
fig, ax = plt.subplots(figsize=(10, 6))
# Create bar chart
bars = ax.bar(range(len(probabilities)), probabilities,
color=['#ff6b6b', '#4ecdc4', '#45b7af', '#556270'][:len(probabilities)],
alpha=0.8, edgecolor='black', linewidth=1)
# Customize chart
ax.set_xlabel('Classes', fontsize=12, fontweight='bold')
ax.set_ylabel('Probability', fontsize=12, fontweight='bold')
ax.set_title('Class Probability Distribution', fontsize=14, fontweight='bold')
ax.set_xticks(range(len(probabilities)))
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.set_ylim(0, 1)
# Add value labels on bars
for i, (bar, prob) in enumerate(zip(bars, probabilities)):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{prob:.3f}', ha='center', va='bottom', fontweight='bold')
# Add grid
ax.grid(True, alpha=0.3, axis='y')
# Set background color
ax.set_facecolor('#f8f9fa')
fig.patch.set_facecolor('#f8f9fa')
plt.tight_layout()
# Convert to HTML
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor())
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig)
return f'
'
except Exception as e:
print(f"Confidence chart error: {e}")
return f'Error creating confidence chart: {str(e)}
'