# utils/visualization.py (updated for maximum contrast) import matplotlib.pyplot as plt import matplotlib.colors as mcolors import base64 from io import BytesIO import numpy as np def create_visualization(text, explanation, tokenizer, explainer_type): """Create HTML visualization of token attributions with maximum contrast""" try: # Tokenize the text tokens = tokenizer.tokenize(text) # Handle different explanation formats token_values = {} if explainer_type == "LIME" and explanation: # LIME returns list of (feature, weight) tuples for feature, weight in explanation: # Extract individual tokens from LIME features feature_tokens = feature.split() for token in feature_tokens: # Clean token (remove punctuation, etc.) clean_token = token.strip('.,!?;:"()[]{}').lower() if clean_token: token_values[clean_token] = weight / len(feature_tokens) if feature_tokens else weight elif explainer_type in ["SHAP", "Captum"] and explanation: # SHAP and Captum return list of dicts with 'token' and 'value' for item in explanation: if isinstance(item, dict) and 'token' in item and 'value' in item: token = item['token'].lower() value = item['value'] token_values[token] = value # If no explanation data, create a neutral visualization if not token_values: html_output = '''
Explanation data not available. Showing tokenized text.
''' for token in tokens: html_output += f'{token.replace("##", "")} ' html_output += '
' return html_output # Normalize scores for coloring values = list(token_values.values()) max_abs_value = max(abs(min(values)), abs(max(values))) if values else 1 if max_abs_value > 0: normalized_values = {k: v / max_abs_value for k, v in token_values.items()} else: normalized_values = {k: 0 for k in token_values.keys()} # Create HTML html_output = '''
''' # Map tokens to values with high contrast colors for token in tokens: clean_token = token.replace('##', '').lower() if clean_token in normalized_values: value = token_values[clean_token] norm_value = normalized_values[clean_token] # Determine color based on value with maximum contrast if value < 0: # Negative values: deep red colors intensity = min(1.0, 0.6 + 0.4 * abs(norm_value)) if intensity > 0.8: color = "#cc0000" # Very dark red elif intensity > 0.6: color = "#ff4444" # Dark red elif intensity > 0.4: color = "#ff8888" # Medium red else: color = "#ffcccc" # Light red text_color = "white" if intensity > 0.5 else "black" else: # Positive values: deep blue colors intensity = min(1.0, 0.6 + 0.4 * norm_value) if intensity > 0.8: color = "#0000cc" # Very dark blue elif intensity > 0.6: color = "#4444ff" # Dark blue elif intensity > 0.4: color = "#8888ff" # Medium blue else: color = "#ccccff" # Light blue text_color = "white" if intensity > 0.5 else "black" html_output += f'{token.replace("##", "")} ' else: html_output += f'{token.replace("##", "")} ' html_output += '
' # Add color legend html_output += '''
Strong negative Weak negative Strong positive Weak positive
''' return html_output except Exception as e: print(f"Visualization error: {e}") return f'
Error creating visualization: {str(e)}
' def create_attribution_plot(explanation, method_name): """Create matplotlib visualization of token attributions""" try: if not explanation: return "

No explanation data available

" # Handle different explanation formats if method_name == "LIME": # LIME: list of (feature, weight) tuples features = [item[0] for item in explanation][:15] # Show top 15 features scores = [item[1] for item in explanation][:15] title = f'Top Feature Attributions ({method_name})' else: # SHAP/Captum: list of dicts with 'token' and 'value' tokens = [item['token'] for item in explanation if isinstance(item, dict) and 'token' in item][:15] scores = [item['value'] for item in explanation if isinstance(item, dict) and 'value' in item][:15] features = tokens title = f'Top Token Attributions ({method_name})' if not features or not scores: return "

No valid explanation data available for plotting

" # Create plot with better colors fig, ax = plt.subplots(figsize=(12, 6)) # Create colors based on values - using high contrast colors colors = ['#ff6b6b' if score < 0 else '#4ecdc4' for score in scores] # Create horizontal bar chart y_pos = np.arange(len(features)) bars = ax.barh(y_pos, scores, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5) # Customize plot ax.set_yticks(y_pos) ax.set_yticklabels(features, fontsize=10) ax.set_xlabel('Attribution Score', fontsize=12, fontweight='bold') ax.set_title(title, fontsize=14, fontweight='bold') ax.axvline(x=0, color='black', linestyle='-', alpha=0.5, linewidth=1) # Add grid for better readability ax.grid(True, alpha=0.3, axis='x') # Add value labels on bars for i, (bar, score) in enumerate(zip(bars, scores)): width = bar.get_width() label_x_pos = width + (0.01 * max(scores) if width >= 0 else 0.01 * min(scores)) ax.text(label_x_pos, bar.get_y() + bar.get_height()/2, f'{score:.4f}', ha='left' if width >= 0 else 'right', va='center', fontsize=9, fontweight='bold') # Set background color ax.set_facecolor('#f8f9fa') fig.patch.set_facecolor('#f8f9fa') plt.tight_layout() # Convert to HTML buf = BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor()) buf.seek(0) img_str = base64.b64encode(buf.read()).decode('utf-8') plt.close(fig) return f'' except Exception as e: print(f"Plot error: {e}") return f'
Error creating plot: {str(e)}
' # utils/visualization.py (add this function) def create_confidence_chart(probabilities, class_names=None): """Create a bar chart showing class probabilities""" try: if class_names is None: class_names = [f"Class {i}" for i in range(len(probabilities))] fig, ax = plt.subplots(figsize=(10, 6)) # Create bar chart bars = ax.bar(range(len(probabilities)), probabilities, color=['#ff6b6b', '#4ecdc4', '#45b7af', '#556270'][:len(probabilities)], alpha=0.8, edgecolor='black', linewidth=1) # Customize chart ax.set_xlabel('Classes', fontsize=12, fontweight='bold') ax.set_ylabel('Probability', fontsize=12, fontweight='bold') ax.set_title('Class Probability Distribution', fontsize=14, fontweight='bold') ax.set_xticks(range(len(probabilities))) ax.set_xticklabels(class_names, rotation=45, ha='right') ax.set_ylim(0, 1) # Add value labels on bars for i, (bar, prob) in enumerate(zip(bars, probabilities)): ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{prob:.3f}', ha='center', va='bottom', fontweight='bold') # Add grid ax.grid(True, alpha=0.3, axis='y') # Set background color ax.set_facecolor('#f8f9fa') fig.patch.set_facecolor('#f8f9fa') plt.tight_layout() # Convert to HTML buf = BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor()) buf.seek(0) img_str = base64.b64encode(buf.read()).decode('utf-8') plt.close(fig) return f'' except Exception as e: print(f"Confidence chart error: {e}") return f'
Error creating confidence chart: {str(e)}
'