File size: 11,145 Bytes
2850cb6
0bf6b56
 
 
 
 
 
442107b
2850cb6
86ef0cd
 
 
 
 
442107b
86ef0cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e700f5c
86ef0cd
 
 
 
442107b
 
0bf6b56
442107b
0bf6b56
442107b
0bf6b56
86ef0cd
 
 
 
 
 
 
2850cb6
86ef0cd
 
0bf6b56
86ef0cd
 
 
 
2850cb6
86ef0cd
2850cb6
 
 
 
 
 
 
 
 
 
 
86ef0cd
2850cb6
 
 
 
 
 
 
 
 
 
 
86ef0cd
2850cb6
0bf6b56
e700f5c
86ef0cd
 
 
e700f5c
 
 
2850cb6
 
 
 
e700f5c
 
 
86ef0cd
0bf6b56
86ef0cd
 
 
0bf6b56
 
2850cb6
86ef0cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e700f5c
86ef0cd
 
2850cb6
 
86ef0cd
 
 
e700f5c
86ef0cd
 
 
e700f5c
 
 
 
 
 
 
86ef0cd
 
 
 
 
 
e700f5c
 
 
 
 
 
86ef0cd
 
 
 
 
e700f5c
86ef0cd
 
 
 
e700f5c
0bf6b56
86ef0cd
 
fb007f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# utils/visualization.py (updated for maximum contrast)
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import base64
from io import BytesIO
import numpy as np

def create_visualization(text, explanation, tokenizer, explainer_type):
    """Create HTML visualization of token attributions with maximum contrast"""
    try:
        # Tokenize the text
        tokens = tokenizer.tokenize(text)
        
        # Handle different explanation formats
        token_values = {}
        if explainer_type == "LIME" and explanation:
            # LIME returns list of (feature, weight) tuples
            for feature, weight in explanation:
                # Extract individual tokens from LIME features
                feature_tokens = feature.split()
                for token in feature_tokens:
                    # Clean token (remove punctuation, etc.)
                    clean_token = token.strip('.,!?;:"()[]{}').lower()
                    if clean_token:
                        token_values[clean_token] = weight / len(feature_tokens) if feature_tokens else weight
        
        elif explainer_type in ["SHAP", "Captum"] and explanation:
            # SHAP and Captum return list of dicts with 'token' and 'value'
            for item in explanation:
                if isinstance(item, dict) and 'token' in item and 'value' in item:
                    token = item['token'].lower()
                    value = item['value']
                    token_values[token] = value
        
        # If no explanation data, create a neutral visualization
        if not token_values:
            html_output = '''
            <div style="font-family: monospace; line-height: 2; padding: 15px; 
                        border-radius: 5px; background-color: #f9f9f9; 
                        border: 1px solid #ddd; margin: 10px 0; color: #666;">
            <i>Explanation data not available. Showing tokenized text.</i><br>
            '''
            for token in tokens:
                html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block; background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 4px;">{token.replace("##", "")}</span> '
            html_output += '</div>'
            return html_output
        
        # Normalize scores for coloring
        values = list(token_values.values())
        max_abs_value = max(abs(min(values)), abs(max(values))) if values else 1
        if max_abs_value > 0:
            normalized_values = {k: v / max_abs_value for k, v in token_values.items()}
        else:
            normalized_values = {k: 0 for k in token_values.keys()}
        
        # Create HTML
        html_output = '''
        <div style="font-family: monospace; line-height: 2; padding: 15px; 
                    border-radius: 5px; background-color: #f9f9f9; 
                    border: 1px solid #ddd; margin: 10px 0;">
        '''
        
        # Map tokens to values with high contrast colors
        for token in tokens:
            clean_token = token.replace('##', '').lower()
            
            if clean_token in normalized_values:
                value = token_values[clean_token]
                norm_value = normalized_values[clean_token]
                
                # Determine color based on value with maximum contrast
                if value < 0:
                    # Negative values: deep red colors
                    intensity = min(1.0, 0.6 + 0.4 * abs(norm_value))
                    if intensity > 0.8:
                        color = "#cc0000"  # Very dark red
                    elif intensity > 0.6:
                        color = "#ff4444"  # Dark red
                    elif intensity > 0.4:
                        color = "#ff8888"  # Medium red
                    else:
                        color = "#ffcccc"  # Light red
                    text_color = "white" if intensity > 0.5 else "black"
                else:
                    # Positive values: deep blue colors
                    intensity = min(1.0, 0.6 + 0.4 * norm_value)
                    if intensity > 0.8:
                        color = "#0000cc"  # Very dark blue
                    elif intensity > 0.6:
                        color = "#4444ff"  # Dark blue
                    elif intensity > 0.4:
                        color = "#8888ff"  # Medium blue
                    else:
                        color = "#ccccff"  # Light blue
                    text_color = "white" if intensity > 0.5 else "black"
                
                html_output += f'<span style="background-color: {color}; color: {text_color}; border: 2px solid {color}; margin: 2px; padding: 4px 6px; border-radius: 4px; display: inline-block; font-weight: bold;">{token.replace("##", "")}</span> '
            else:
                html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block; background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 4px;">{token.replace("##", "")}</span> '
        
        html_output += '</div>'
        
        # Add color legend
        html_output += '''
        <div style="margin-top: 10px; font-size: 12px; color: #666;">
        <span style="background-color: #cc0000; color: white; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Strong negative</span>
        <span style="background-color: #ff8888; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Weak negative</span>
        <span style="background-color: #0000cc; color: white; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Strong positive</span>
        <span style="background-color: #8888ff; padding: 2px 6px; border-radius: 3px;">Weak positive</span>
        </div>
        '''
        
        return html_output
    
    except Exception as e:
        print(f"Visualization error: {e}")
        return f'<div style="color: red; padding: 10px;">Error creating visualization: {str(e)}</div>'

def create_attribution_plot(explanation, method_name):
    """Create matplotlib visualization of token attributions"""
    try:
        if not explanation:
            return "<p>No explanation data available</p>"
        
        # Handle different explanation formats
        if method_name == "LIME":
            # LIME: list of (feature, weight) tuples
            features = [item[0] for item in explanation][:15]  # Show top 15 features
            scores = [item[1] for item in explanation][:15]
            title = f'Top Feature Attributions ({method_name})'
        else:
            # SHAP/Captum: list of dicts with 'token' and 'value'
            tokens = [item['token'] for item in explanation if isinstance(item, dict) and 'token' in item][:15]
            scores = [item['value'] for item in explanation if isinstance(item, dict) and 'value' in item][:15]
            features = tokens
            title = f'Top Token Attributions ({method_name})'
        
        if not features or not scores:
            return "<p>No valid explanation data available for plotting</p>"
        
        # Create plot with better colors
        fig, ax = plt.subplots(figsize=(12, 6))
        
        # Create colors based on values - using high contrast colors
        colors = ['#ff6b6b' if score < 0 else '#4ecdc4' for score in scores]
        
        # Create horizontal bar chart
        y_pos = np.arange(len(features))
        bars = ax.barh(y_pos, scores, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
        
        # Customize plot
        ax.set_yticks(y_pos)
        ax.set_yticklabels(features, fontsize=10)
        ax.set_xlabel('Attribution Score', fontsize=12, fontweight='bold')
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.axvline(x=0, color='black', linestyle='-', alpha=0.5, linewidth=1)
        
        # Add grid for better readability
        ax.grid(True, alpha=0.3, axis='x')
        
        # Add value labels on bars
        for i, (bar, score) in enumerate(zip(bars, scores)):
            width = bar.get_width()
            label_x_pos = width + (0.01 * max(scores) if width >= 0 else 0.01 * min(scores))
            ax.text(label_x_pos, bar.get_y() + bar.get_height()/2, 
                    f'{score:.4f}', ha='left' if width >= 0 else 'right', va='center',
                    fontsize=9, fontweight='bold')
        
        # Set background color
        ax.set_facecolor('#f8f9fa')
        fig.patch.set_facecolor('#f8f9fa')
        
        plt.tight_layout()
        
        # Convert to HTML
        buf = BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor())
        buf.seek(0)
        img_str = base64.b64encode(buf.read()).decode('utf-8')
        plt.close(fig)
        
        return f'<img src="data:image/png;base64,{img_str}" style="max-width: 100%; border: 1px solid #ddd; border-radius: 5px;">'
    
    except Exception as e:
        print(f"Plot error: {e}")
        return f'<div style="color: red; padding: 10px;">Error creating plot: {str(e)}</div>'
    

# utils/visualization.py (add this function)
def create_confidence_chart(probabilities, class_names=None):
    """Create a bar chart showing class probabilities"""
    try:
        if class_names is None:
            class_names = [f"Class {i}" for i in range(len(probabilities))]
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Create bar chart
        bars = ax.bar(range(len(probabilities)), probabilities, 
                     color=['#ff6b6b', '#4ecdc4', '#45b7af', '#556270'][:len(probabilities)],
                     alpha=0.8, edgecolor='black', linewidth=1)
        
        # Customize chart
        ax.set_xlabel('Classes', fontsize=12, fontweight='bold')
        ax.set_ylabel('Probability', fontsize=12, fontweight='bold')
        ax.set_title('Class Probability Distribution', fontsize=14, fontweight='bold')
        ax.set_xticks(range(len(probabilities)))
        ax.set_xticklabels(class_names, rotation=45, ha='right')
        ax.set_ylim(0, 1)
        
        # Add value labels on bars
        for i, (bar, prob) in enumerate(zip(bars, probabilities)):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{prob:.3f}', ha='center', va='bottom', fontweight='bold')
        
        # Add grid
        ax.grid(True, alpha=0.3, axis='y')
        
        # Set background color
        ax.set_facecolor('#f8f9fa')
        fig.patch.set_facecolor('#f8f9fa')
        
        plt.tight_layout()
        
        # Convert to HTML
        buf = BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor())
        buf.seek(0)
        img_str = base64.b64encode(buf.read()).decode('utf-8')
        plt.close(fig)
        
        return f'<img src="data:image/png;base64,{img_str}" style="max-width: 100%; border: 1px solid #ddd; border-radius: 5px;">'
    
    except Exception as e:
        print(f"Confidence chart error: {e}")
        return f'<div style="color: red; padding: 10px;">Error creating confidence chart: {str(e)}</div>'