dyra1222 commited on
Commit
e700f5c
·
1 Parent(s): 286850b

font color change

Browse files
Files changed (2) hide show
  1. app.py +96 -22
  2. utils/visualization.py +50 -24
app.py CHANGED
@@ -109,38 +109,53 @@ def predict_and_explain(text, model_choice, explainer_choice):
109
  return result, visualization_html, plot_html, explanation
110
 
111
  # Create Gradio interface
 
 
 
 
112
  with gr.Blocks(
113
  title="Explainability Sandbox for Transformers",
114
  css="""
115
  footer {visibility: hidden}
116
  .gradio-container {max-width: 1200px !important}
 
 
 
 
117
  """
118
  ) as demo:
119
  gr.Markdown("""
120
- # 🔍 Explainability Sandbox for Transformers
121
- *Explore how transformer models make decisions with various explanation methods.*
 
 
122
  """)
123
 
124
  with gr.Row():
125
  with gr.Column(scale=1):
126
- gr.Markdown("### Input Settings")
 
 
127
  text_input = gr.Textbox(
128
  label="Input Text",
129
  lines=5,
130
  placeholder="Enter text to analyze...",
131
- value="The movie was fantastic with great acting and an engaging plot."
 
132
  )
133
  model_choice = gr.Dropdown(
134
  choices=list(MODELS.keys()),
135
  label="Model",
136
- value="BERT Base (English)"
 
137
  )
138
  explainer_choice = gr.Radio(
139
  choices=["LIME", "SHAP", "Captum"],
140
  label="Explanation Method",
141
- value="LIME"
 
142
  )
143
- analyze_btn = gr.Button("Analyze", variant="primary")
144
 
145
  gr.Markdown("""
146
  ---
@@ -157,20 +172,34 @@ with gr.Blocks(
157
  """)
158
 
159
  with gr.Column(scale=2):
160
- gr.Markdown("### Results")
161
- output_text = gr.Textbox(label="Prediction Result")
 
 
 
 
 
162
 
163
- gr.Markdown("#### Token Attributions")
 
 
 
164
  output_vis = gr.HTML(label="Visualization")
165
 
166
- gr.Markdown("#### Attribution Plot")
 
 
167
  output_plot = gr.HTML()
168
 
169
- gr.Markdown("#### Explanation Data")
 
 
170
  explanation_output = gr.JSON(label="Detailed Data")
171
 
172
  # Examples
173
- gr.Markdown("### Examples")
 
 
174
  gr.Examples(
175
  examples=[
176
  ["This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.", "BERT Base (English)", "LIME"],
@@ -180,13 +209,14 @@ with gr.Blocks(
180
  inputs=[text_input, model_choice, explainer_choice],
181
  outputs=[output_text, output_vis, output_plot, explanation_output],
182
  fn=predict_and_explain,
183
- cache_examples=False
 
184
  )
185
 
186
  # Footer with model card info
187
  gr.Markdown("---")
188
  gr.Markdown("""
189
- ### 📖 Model Card & Ethical Considerations
190
  For details about model capabilities, limitations, and ethical considerations,
191
  please see our [Model Card](https://huggingface.co/docs/hub/model-cards).
192
 
@@ -196,12 +226,56 @@ with gr.Blocks(
196
  - Be cautious about over-interpreting individual token attributions
197
  - Models may reflect biases present in training data
198
  """)
199
-
200
- analyze_btn.click(
201
- fn=predict_and_explain,
202
- inputs=[text_input, model_choice, explainer_choice],
203
- outputs=[output_text, output_vis, output_plot, explanation_output]
204
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  if __name__ == "__main__":
207
- demo.launch(share=False) # Set to False for Hugging Face Spaces
 
109
  return result, visualization_html, plot_html, explanation
110
 
111
  # Create Gradio interface
112
+ # app.py (updated with better styling)
113
+ # ... [previous imports and code] ...
114
+
115
+ # Create Gradio interface with better styling
116
  with gr.Blocks(
117
  title="Explainability Sandbox for Transformers",
118
  css="""
119
  footer {visibility: hidden}
120
  .gradio-container {max-width: 1200px !important}
121
+ .gradio-button {background: linear-gradient(45deg, #4ecdc4, #556270) !important; color: white !important;}
122
+ .gradio-button:hover {background: linear-gradient(45deg, #45b7af, #44505d) !important;}
123
+ .gradio-radio-item {padding: 8px 12px !important; border-radius: 5px !important;}
124
+ .gradio-radio-item.selected {background: #4ecdc4 !important; color: white !important;}
125
  """
126
  ) as demo:
127
  gr.Markdown("""
128
+ <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white;">
129
+ <h1 style="margin: 0; font-size: 2.5em;">🔍 Explainability Sandbox for Transformers</h1>
130
+ <p style="margin: 10px 0 0 0; font-size: 1.2em; opacity: 0.9;">Explore how transformer models make decisions with various explanation methods</p>
131
+ </div>
132
  """)
133
 
134
  with gr.Row():
135
  with gr.Column(scale=1):
136
+ gr.Markdown("""
137
+ ### ⚙️ Input Settings
138
+ """)
139
  text_input = gr.Textbox(
140
  label="Input Text",
141
  lines=5,
142
  placeholder="Enter text to analyze...",
143
+ value="The movie was fantastic with great acting and an engaging plot.",
144
+ elem_classes=["input-box"]
145
  )
146
  model_choice = gr.Dropdown(
147
  choices=list(MODELS.keys()),
148
  label="Model",
149
+ value="BERT Base (English)",
150
+ elem_classes=["dropdown"]
151
  )
152
  explainer_choice = gr.Radio(
153
  choices=["LIME", "SHAP", "Captum"],
154
  label="Explanation Method",
155
+ value="LIME",
156
+ elem_classes=["radio-group"]
157
  )
158
+ analyze_btn = gr.Button("Analyze Text", variant="primary", elem_classes=["primary-button"])
159
 
160
  gr.Markdown("""
161
  ---
 
172
  """)
173
 
174
  with gr.Column(scale=2):
175
+ gr.Markdown("""
176
+ ### 📊 Results
177
+ """)
178
+ output_text = gr.Textbox(
179
+ label="Prediction Result",
180
+ elem_classes=["result-box"]
181
+ )
182
 
183
+ gr.Markdown("""
184
+ #### 🎨 Token Attributions
185
+ *Darker colors indicate stronger influence*
186
+ """)
187
  output_vis = gr.HTML(label="Visualization")
188
 
189
+ gr.Markdown("""
190
+ #### 📈 Attribution Plot
191
+ """)
192
  output_plot = gr.HTML()
193
 
194
+ gr.Markdown("""
195
+ #### 🔍 Explanation Data
196
+ """)
197
  explanation_output = gr.JSON(label="Detailed Data")
198
 
199
  # Examples
200
+ gr.Markdown("""
201
+ ### 🚀 Quick Examples
202
+ """)
203
  gr.Examples(
204
  examples=[
205
  ["This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.", "BERT Base (English)", "LIME"],
 
209
  inputs=[text_input, model_choice, explainer_choice],
210
  outputs=[output_text, output_vis, output_plot, explanation_output],
211
  fn=predict_and_explain,
212
+ cache_examples=False,
213
+ label="Click any example to try it out:"
214
  )
215
 
216
  # Footer with model card info
217
  gr.Markdown("---")
218
  gr.Markdown("""
219
+ ### 📋 Model Card & Ethical Considerations
220
  For details about model capabilities, limitations, and ethical considerations,
221
  please see our [Model Card](https://huggingface.co/docs/hub/model-cards).
222
 
 
226
  - Be cautious about over-interpreting individual token attributions
227
  - Models may reflect biases present in training data
228
  """)
229
+
230
+ # Add custom CSS for better styling
231
+ custom_css = """
232
+ .input-box textarea {
233
+ border-radius: 8px !important;
234
+ border: 2px solid #e1e5e9 !important;
235
+ padding: 12px !important;
236
+ }
237
+
238
+ .dropdown select {
239
+ border-radius: 8px !important;
240
+ border: 2px solid #e1e5e9 !important;
241
+ padding: 10px !important;
242
+ }
243
+
244
+ .radio-group .gr-radio-item {
245
+ border: 2px solid #e1e5e9 !important;
246
+ border-radius: 8px !important;
247
+ margin: 5px !important;
248
+ padding: 10px 15px !important;
249
+ }
250
+
251
+ .radio-group .gr-radio-item.selected {
252
+ background: #4ecdc4 !important;
253
+ color: white !important;
254
+ border-color: #3bb5ad !important;
255
+ }
256
+
257
+ .primary-button {
258
+ border-radius: 8px !important;
259
+ padding: 12px 24px !important;
260
+ font-weight: bold !important;
261
+ font-size: 16px !important;
262
+ }
263
+
264
+ .result-box input {
265
+ border-radius: 8px !important;
266
+ border: 2px solid #4ecdc4 !important;
267
+ background-color: #f8f9fa !important;
268
+ font-weight: bold !important;
269
+ padding: 12px !important;
270
+ }
271
+
272
+ .gr-box {
273
+ border-radius: 10px !important;
274
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important;
275
+ }
276
+ """
277
+
278
+ demo.css = custom_css
279
 
280
  if __name__ == "__main__":
281
+ demo.launch(share=False)
utils/visualization.py CHANGED
@@ -1,4 +1,4 @@
1
- # utils/visualization.py (updated)
2
  import matplotlib.pyplot as plt
3
  import matplotlib.colors as mcolors
4
  import base64
@@ -6,7 +6,7 @@ from io import BytesIO
6
  import numpy as np
7
 
8
  def create_visualization(text, explanation, tokenizer, explainer_type):
9
- """Create HTML visualization of token attributions"""
10
  try:
11
  # Tokenize the text
12
  tokens = tokenizer.tokenize(text)
@@ -41,7 +41,7 @@ def create_visualization(text, explanation, tokenizer, explainer_type):
41
  <i>Explanation data not available. Showing tokenized text.</i><br>
42
  '''
43
  for token in tokens:
44
- html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block;">{token.replace("##", "")}</span> '
45
  html_output += '</div>'
46
  return html_output
47
 
@@ -68,22 +68,40 @@ def create_visualization(text, explanation, tokenizer, explainer_type):
68
  value = token_values[clean_token]
69
  norm_value = normalized_values[clean_token]
70
 
71
- # Determine color based on value (red for negative, blue for positive)
72
  if value < 0:
73
- intensity = min(0.9, abs(norm_value))
74
- color = f"rgba(255, 87, 87, {intensity})"
75
- border = "1px solid rgba(255, 0, 0, 0.3)"
 
 
 
 
 
76
  else:
77
- intensity = min(0.9, norm_value)
78
- color = f"rgba(92, 167, 255, {intensity})"
79
- border = "1px solid rgba(0, 0, 255, 0.3)"
 
 
 
 
 
80
 
81
- html_output += f'<span style="background-color: {color}; border: {border}; margin: 2px; padding: 4px 6px; border-radius: 4px; display: inline-block;">{token.replace("##", "")}</span> '
82
  else:
83
- html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block;">{token.replace("##", "")}</span> '
84
 
85
  html_output += '</div>'
86
 
 
 
 
 
 
 
 
 
87
  return html_output
88
 
89
  except Exception as e:
@@ -91,7 +109,7 @@ def create_visualization(text, explanation, tokenizer, explainer_type):
91
  return f'<div style="color: red; padding: 10px;">Error creating visualization: {str(e)}</div>'
92
 
93
  def create_attribution_plot(explanation, method_name):
94
- """Create matplotlib visualization of token attributions"""
95
  try:
96
  if not explanation:
97
  return "<p>No explanation data available</p>"
@@ -112,40 +130,48 @@ def create_attribution_plot(explanation, method_name):
112
  if not features or not scores:
113
  return "<p>No valid explanation data available for plotting</p>"
114
 
115
- # Create plot
116
  fig, ax = plt.subplots(figsize=(12, 6))
117
 
118
- # Create colors based on values
119
- colors = ['red' if score < 0 else 'blue' for score in scores]
120
 
121
  # Create horizontal bar chart
122
  y_pos = np.arange(len(features))
123
- bars = ax.barh(y_pos, scores, color=colors, alpha=0.7)
124
 
125
  # Customize plot
126
  ax.set_yticks(y_pos)
127
- ax.set_yticklabels(features)
128
- ax.set_xlabel('Attribution Score')
129
- ax.set_title(title)
130
- ax.axvline(x=0, color='black', linestyle='-', alpha=0.3)
 
 
 
131
 
132
  # Add value labels on bars
133
  for i, (bar, score) in enumerate(zip(bars, scores)):
134
  width = bar.get_width()
135
  label_x_pos = width + (0.01 * max(scores) if width >= 0 else 0.01 * min(scores))
136
  ax.text(label_x_pos, bar.get_y() + bar.get_height()/2,
137
- f'{score:.4f}', ha='left' if width >= 0 else 'right', va='center')
 
 
 
 
 
138
 
139
  plt.tight_layout()
140
 
141
  # Convert to HTML
142
  buf = BytesIO()
143
- plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
144
  buf.seek(0)
145
  img_str = base64.b64encode(buf.read()).decode('utf-8')
146
  plt.close(fig)
147
 
148
- return f'<img src="data:image/png;base64,{img_str}" style="max-width: 100%;">'
149
 
150
  except Exception as e:
151
  print(f"Plot error: {e}")
 
1
+ # utils/visualization.py (updated with better colors)
2
  import matplotlib.pyplot as plt
3
  import matplotlib.colors as mcolors
4
  import base64
 
6
  import numpy as np
7
 
8
  def create_visualization(text, explanation, tokenizer, explainer_type):
9
+ """Create HTML visualization of token attributions with better color contrast"""
10
  try:
11
  # Tokenize the text
12
  tokens = tokenizer.tokenize(text)
 
41
  <i>Explanation data not available. Showing tokenized text.</i><br>
42
  '''
43
  for token in tokens:
44
+ html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block; background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 4px;">{token.replace("##", "")}</span> '
45
  html_output += '</div>'
46
  return html_output
47
 
 
68
  value = token_values[clean_token]
69
  norm_value = normalized_values[clean_token]
70
 
71
+ # Determine color based on value with better contrast
72
  if value < 0:
73
+ # Negative values: red scale with good contrast
74
+ intensity = min(0.95, 0.3 + 0.7 * abs(norm_value)) # Ensure minimum darkness
75
+ red = int(255 * intensity)
76
+ green = int(200 * (1 - intensity))
77
+ blue = int(200 * (1 - intensity))
78
+ color = f"rgb({red}, {green}, {blue})"
79
+ text_color = "white" if intensity > 0.6 else "black"
80
+ border = f"2px solid rgb({min(255, red+30)}, {max(0, green-30)}, {max(0, blue-30)})"
81
  else:
82
+ # Positive values: blue scale with good contrast
83
+ intensity = min(0.95, 0.3 + 0.7 * norm_value) # Ensure minimum darkness
84
+ red = int(200 * (1 - intensity))
85
+ green = int(200 * (1 - intensity))
86
+ blue = int(255 * intensity)
87
+ color = f"rgb({red}, {green}, {blue})"
88
+ text_color = "white" if intensity > 0.6 else "black"
89
+ border = f"2px solid rgb({max(0, red-30)}, {max(0, green-30)}, {min(255, blue+30)})"
90
 
91
+ html_output += f'<span style="background-color: {color}; color: {text_color}; border: {border}; margin: 2px; padding: 4px 6px; border-radius: 4px; display: inline-block; font-weight: bold;">{token.replace("##", "")}</span> '
92
  else:
93
+ html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block; background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 4px;">{token.replace("##", "")}</span> '
94
 
95
  html_output += '</div>'
96
 
97
+ # Add color legend
98
+ html_output += '''
99
+ <div style="margin-top: 10px; font-size: 12px; color: #666;">
100
+ <span style="background-color: rgb(240, 150, 150); padding: 2px 6px; border: 1px solid #d88; border-radius: 3px; margin-right: 10px;">Negative impact</span>
101
+ <span style="background-color: rgb(150, 150, 240); padding: 2px 6px; border: 1px solid #88d; border-radius: 3px;">Positive impact</span>
102
+ </div>
103
+ '''
104
+
105
  return html_output
106
 
107
  except Exception as e:
 
109
  return f'<div style="color: red; padding: 10px;">Error creating visualization: {str(e)}</div>'
110
 
111
  def create_attribution_plot(explanation, method_name):
112
+ """Create matplotlib visualization of token attributions with better colors"""
113
  try:
114
  if not explanation:
115
  return "<p>No explanation data available</p>"
 
130
  if not features or not scores:
131
  return "<p>No valid explanation data available for plotting</p>"
132
 
133
+ # Create plot with better colors
134
  fig, ax = plt.subplots(figsize=(12, 6))
135
 
136
+ # Create colors based on values - using darker, more saturated colors
137
+ colors = ['#ff6b6b' if score < 0 else '#4ecdc4' for score in scores] # Coral red and teal
138
 
139
  # Create horizontal bar chart
140
  y_pos = np.arange(len(features))
141
+ bars = ax.barh(y_pos, scores, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
142
 
143
  # Customize plot
144
  ax.set_yticks(y_pos)
145
+ ax.set_yticklabels(features, fontsize=10)
146
+ ax.set_xlabel('Attribution Score', fontsize=12, fontweight='bold')
147
+ ax.set_title(title, fontsize=14, fontweight='bold')
148
+ ax.axvline(x=0, color='black', linestyle='-', alpha=0.5, linewidth=1)
149
+
150
+ # Add grid for better readability
151
+ ax.grid(True, alpha=0.3, axis='x')
152
 
153
  # Add value labels on bars
154
  for i, (bar, score) in enumerate(zip(bars, scores)):
155
  width = bar.get_width()
156
  label_x_pos = width + (0.01 * max(scores) if width >= 0 else 0.01 * min(scores))
157
  ax.text(label_x_pos, bar.get_y() + bar.get_height()/2,
158
+ f'{score:.4f}', ha='left' if width >= 0 else 'right', va='center',
159
+ fontsize=9, fontweight='bold')
160
+
161
+ # Set background color
162
+ ax.set_facecolor('#f8f9fa')
163
+ fig.patch.set_facecolor('#f8f9fa')
164
 
165
  plt.tight_layout()
166
 
167
  # Convert to HTML
168
  buf = BytesIO()
169
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor())
170
  buf.seek(0)
171
  img_str = base64.b64encode(buf.read()).decode('utf-8')
172
  plt.close(fig)
173
 
174
+ return f'<img src="data:image/png;base64,{img_str}" style="max-width: 100%; border: 1px solid #ddd; border-radius: 5px;">'
175
 
176
  except Exception as e:
177
  print(f"Plot error: {e}")