OliverPerrin commited on
Commit
fd09961
·
1 Parent(s): 185b05e

Added gradio demo interface

Browse files
Files changed (1) hide show
  1. scripts/demo_gradio.py +425 -0
scripts/demo_gradio.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Demo interface for LexiMind NLP pipeline.
3
+ Showcases summarization, emotion detection, and topic prediction.
4
+ """
5
+ import json
6
+ import sys
7
+ from io import StringIO
8
+ from pathlib import Path
9
+ from typing import Iterable, Sequence
10
+ import gradio as gr
11
+ from gradio.themes import Soft
12
+ import matplotlib.pyplot as plt
13
+ import pandas as pd
14
+ import seaborn as sns
15
+ import torch
16
+ from matplotlib.figure import Figure
17
+
18
+ # Add project root to the path, going up two folder levels from this file
19
+ project_root = Path(__file__).parent.parent
20
+ sys.path.insert(0, str(project_root))
21
+
22
+ from src.inference.factory import create_inference_pipeline
23
+ from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
24
+ from src.utils.logging import configure_logging, get_logger
25
+
26
+ configure_logging()
27
+ logger = get_logger(__name__)
28
+
29
+ _pipeline: InferencePipeline | None = None # Global pipeline instance
30
+ _label_metadata = None # Cached label metadata
31
+
32
+
33
+ def get_pipeline() -> InferencePipeline:
34
+ """Lazy Loading and Caching the inference pipeline"""
35
+ global _pipeline, _label_metadata
36
+ if _pipeline is None:
37
+ try:
38
+ logger.info("Loading inference pipeline...")
39
+ pipeline, label_metadata = create_inference_pipeline(
40
+ tokenizer_dir="data/tokenization",
41
+ checkpoint_path="checkpoints/best.pt",
42
+ labels_path="data/labels.json",
43
+ )
44
+ _pipeline = pipeline
45
+ _label_metadata = label_metadata
46
+ logger.info("Pipeline loaded successfully")
47
+ except Exception as e:
48
+ logger.error(f"Failed to load pipeline: {e}")
49
+ raise RuntimeError("Could not initialize inference pipeline. Check logs for details.")
50
+ return _pipeline
51
+
52
+ def count_tokens(text: str) -> str:
53
+ """Count tokens in the input text."""
54
+ if not text:
55
+ return "Tokens: 0"
56
+ try:
57
+ pipeline = get_pipeline()
58
+ token_count = len(pipeline.tokenizer.encode(text))
59
+ return f"Tokens: {token_count}"
60
+ except Exception as e:
61
+ logger.error(f"Token counting error: {e}")
62
+ return "Token count unavailable"
63
+
64
+ def map_compression_to_length(compression: int, max_model_length: int = 512):
65
+ """
66
+ Map Compression slider (20-80%) to max summary length.
67
+ Higher compression = shorter summary output.
68
+ """
69
+ # Invert, 20% compression = 80% of max length
70
+ ratio = (100 - compression) / 100
71
+ return int(ratio * max_model_length)
72
+
73
+ def predict(text: str, compression: int):
74
+ """
75
+ Main predcition function for the Gradio interface.
76
+ Args:
77
+ text: Text to process
78
+ compression: Compression percentage (20-80)
79
+ Returns:
80
+ Tuple of (summary_html, emotion_plot, topic_output, attention_fig, download_data)
81
+ """
82
+ if not text or not text.strip():
83
+ return ("Please enter some text to analyze.",
84
+ None,
85
+ "No topic prediction available",
86
+ None,
87
+ None)
88
+ try:
89
+ pipeline = get_pipeline()
90
+ max_len = map_compression_to_length(compression)
91
+ logger.info(f"Generating summary with max length of {max_len}")
92
+
93
+ # Get the predictions
94
+ summary = pipeline.summarize([text], max_length=max_len)[0]
95
+ emotions = pipeline.predict_emotions([text])[0]
96
+ topic = pipeline.predict_topics([text])[0]
97
+
98
+ summary_html = format_summary(text, summary)
99
+ emotion_plot = create_emotion_plot(emotions)
100
+ topic_output = format_topic(topic)
101
+ attention_fig = create_attention_heatmap(text, summary, pipeline)
102
+ download_data = prepare_download(text, summary, emotions, topic)
103
+
104
+ return summary_html, emotion_plot, topic_output, attention_fig, download_data
105
+
106
+ except Exception as e:
107
+ logger.error(f"Prediction error: {e}", exc_info=True)
108
+ error_msg = "Prediction failed. Check logs for details."
109
+ return error_msg, None, "Error", None, None
110
+
111
+ def format_summary(original: str, summary:str) ->str:
112
+ """Format original and summary text for display"""
113
+ html = f"""
114
+ <div style="padding: 10px; border-radius: 5px;">
115
+ <h3>Original Text</h3>
116
+ <p style="background-color: #f0f0f0; padding: 10px; border-radius: 3px;">
117
+ {original}
118
+ </p>
119
+ <h3>Summary</h3>
120
+ <p style="background-color: #e6f3ff; padding: 10px; border-radius: 3px;">
121
+ {summary}
122
+ </p>
123
+ </div>
124
+ """
125
+ return html
126
+
127
+ def create_emotion_plot(emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]) -> Figure | None:
128
+ """
129
+ Create bar plot for emotion predictions.
130
+ Args:
131
+ emotions: Dict with 'labels' and 'scores' keys
132
+ """
133
+ if isinstance(emotions, EmotionPrediction):
134
+ labels = emotions.labels
135
+ scores = emotions.scores
136
+ else:
137
+ labels = list(emotions.get("labels", []))
138
+ scores = list(emotions.get("scores", []))
139
+
140
+ if not labels or not scores:
141
+ return None
142
+
143
+ df = pd.DataFrame({
144
+ "Emotion": labels,
145
+ "Probability": scores,
146
+ })
147
+ fig, ax = plt.subplots(figsize=(8, 5))
148
+ colors = sns.color_palette("Set2", len(labels))
149
+ bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
150
+ ax.set_xlabel("Probability", fontsize=12)
151
+ ax.set_ylabel("Emotion", fontsize=12)
152
+ ax.set_title("Emotion Detection Results", fontsize=14, fontweight="bold")
153
+ ax.set_xlim(0, 1)
154
+ for bar in bars:
155
+ width = bar.get_width()
156
+ ax.text(
157
+ width,
158
+ bar.get_y() + bar.get_height() / 2,
159
+ f"{width:.2%}",
160
+ ha="left",
161
+ va="center",
162
+ fontsize=10,
163
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
164
+ )
165
+ plt.tight_layout()
166
+ return fig
167
+
168
+ def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
169
+ """
170
+ Format topic prediction output.
171
+
172
+ Args:
173
+ topic: Dict with 'label' and 'score' keys
174
+ """
175
+ if isinstance(topic, TopicPrediction):
176
+ label = topic.label
177
+ score = topic.confidence
178
+ else:
179
+ label = str(topic.get("label", "Unknown"))
180
+ score = float(topic.get("score", 0.0))
181
+ output = f"""
182
+ ### Predicted Topic
183
+
184
+ **{label}**
185
+
186
+ Confidence: {score:.2%}
187
+ """
188
+ return output
189
+
190
+ def _clean_tokens(tokens: Iterable[str]) -> list[str]:
191
+ cleaned: list[str] = []
192
+ for token in tokens:
193
+ item = token.replace("Ġ", " ").replace("▁", " ")
194
+ cleaned.append(item.strip() if item.strip() else token)
195
+ return cleaned
196
+
197
+
198
+ def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None:
199
+ """Generate a seaborn heatmap of decoder cross-attention averaged over heads."""
200
+ if not summary:
201
+ return None
202
+ try:
203
+ batch = pipeline.preprocessor.batch_encode([text])
204
+ batch = pipeline._batch_to_device(batch)
205
+ src_ids = batch.input_ids
206
+ src_mask = batch.attention_mask
207
+ encoder_mask = None
208
+ if src_mask is not None:
209
+ encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)
210
+
211
+ with torch.inference_mode():
212
+ memory = pipeline.model.encoder(src_ids, mask=encoder_mask)
213
+ target_enc = pipeline.tokenizer.batch_encode([summary])
214
+ target_ids = target_enc["input_ids"].to(pipeline.device)
215
+ target_mask = target_enc["attention_mask"].to(pipeline.device)
216
+ target_len = int(target_mask.sum().item())
217
+ decoder_inputs = pipeline.tokenizer.prepare_decoder_inputs(target_ids)
218
+ decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device)
219
+ target_ids = target_ids[:, :target_len]
220
+ memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None
221
+ _, attn_list = pipeline.model.decoder(
222
+ decoder_inputs,
223
+ memory,
224
+ memory_mask=memory_mask,
225
+ collect_attn=True,
226
+ )
227
+ if not attn_list:
228
+ return None
229
+ cross_attn = attn_list[-1]["cross"] # (B, heads, T, S)
230
+ attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy()
231
+
232
+ source_len = batch.lengths[0]
233
+ attn_matrix = attn_matrix[:target_len, :source_len]
234
+
235
+ source_ids = src_ids[0, :source_len].tolist()
236
+ target_id_list = target_ids[0].tolist()
237
+
238
+ special_ids = {
239
+ pipeline.tokenizer.pad_token_id,
240
+ pipeline.tokenizer.bos_token_id,
241
+ pipeline.tokenizer.eos_token_id,
242
+ }
243
+ keep_indices = [index for index, token_id in enumerate(target_id_list) if token_id not in special_ids]
244
+ if not keep_indices:
245
+ return None
246
+
247
+ pruned_matrix = attn_matrix[keep_indices, :]
248
+ tokenizer_impl = pipeline.tokenizer.tokenizer
249
+ convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None)
250
+ if convert_tokens is None:
251
+ logger.warning("Tokenizer does not expose convert_ids_to_tokens; skipping attention heatmap.")
252
+ return None
253
+
254
+ summary_tokens_raw = convert_tokens([target_id_list[index] for index in keep_indices])
255
+ source_tokens_raw = convert_tokens(source_ids)
256
+
257
+ summary_tokens = _clean_tokens(summary_tokens_raw)
258
+ source_tokens = _clean_tokens(source_tokens_raw)
259
+
260
+ height = max(4.0, 0.4 * len(summary_tokens))
261
+ width = max(6.0, 0.4 * len(source_tokens))
262
+ fig, ax = plt.subplots(figsize=(width, height))
263
+ sns.heatmap(
264
+ pruned_matrix,
265
+ cmap="mako",
266
+ xticklabels=source_tokens,
267
+ yticklabels=summary_tokens,
268
+ ax=ax,
269
+ cbar_kws={"label": "Attention"},
270
+ )
271
+ ax.set_xlabel("Input Tokens")
272
+ ax.set_ylabel("Summary Tokens")
273
+ ax.set_title("Cross-Attention (decoder last layer)")
274
+ ax.tick_params(axis="x", rotation=90)
275
+ ax.tick_params(axis="y", rotation=0)
276
+ fig.tight_layout()
277
+ return fig
278
+
279
+ except Exception as exc:
280
+ logger.error("Unable to build attention heatmap: %s", exc, exc_info=True)
281
+ return None
282
+
283
+
284
+ def prepare_download(
285
+ text: str,
286
+ summary: str,
287
+ emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
288
+ topic: TopicPrediction | dict[str, float | str],
289
+ ) -> str:
290
+ """Prepare JSON data for download."""
291
+ if isinstance(emotions, EmotionPrediction):
292
+ emotion_payload = {
293
+ "labels": list(emotions.labels),
294
+ "scores": list(emotions.scores),
295
+ }
296
+ else:
297
+ emotion_payload = emotions
298
+
299
+ if isinstance(topic, TopicPrediction):
300
+ topic_payload = {
301
+ "label": topic.label,
302
+ "confidence": topic.confidence,
303
+ }
304
+ else:
305
+ topic_payload = topic
306
+
307
+ data = {
308
+ "original_text": text,
309
+ "summary": summary,
310
+ "emotions": emotion_payload,
311
+ "topic": topic_payload,
312
+ }
313
+ return json.dumps(data, indent=2)
314
+
315
+ # Sample data for the demo
316
+ SAMPLE_TEXT = """
317
+ Artificial intelligence is rapidly transforming the technology landscape.
318
+ Machine learning algorithms are now capable of processing vast amounts of data,
319
+ identifying patterns, and making predictions with unprecedented accuracy.
320
+ From healthcare diagnostics to financial forecasting, AI applications are
321
+ revolutionizing industries worldwide. However, ethical considerations around
322
+ privacy, bias, and transparency remain critical challenges that must be addressed
323
+ as these technologies continue to evolve.
324
+ """
325
+
326
+ def create_interface():
327
+ with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo:
328
+ gr.Markdown("""
329
+ # LexiMind NLP Pipeline Demo
330
+
331
+ **Full pipleine for text summarization, emotion detection, and topic prediction.**
332
+
333
+ Enter text below and adjust compressoin to see the results.
334
+ """)
335
+ with gr.Row():
336
+ # Left column - Input
337
+ with gr.Column(scale=1):
338
+ gr.Markdown("### Input")
339
+ input_text = gr.Textbox(
340
+ label="Enter text",
341
+ placeholder="Paste or type your text here...",
342
+ lines=10,
343
+ value=SAMPLE_TEXT
344
+ )
345
+ token_count = gr.Textbox(
346
+ label="Token Count",
347
+ value="Tokens: 0",
348
+ interactive=False
349
+ )
350
+ compression = gr.Slider(
351
+ minimum=20,
352
+ maximum=80,
353
+ value=50,
354
+ step=5,
355
+ label="Compression %",
356
+ info="Higher = shorter summary"
357
+ )
358
+ predict_btn = gr.Button("🚀 Analyze", variant="primary", size="lg")
359
+ # Right column - Outputs
360
+ with gr.Column(scale=2):
361
+ gr.Markdown("### Result")
362
+ with gr.Tabs():
363
+ with gr.TabItem("Summary"):
364
+ summary_output = gr.HTML(label="Summary")
365
+ with gr.TabItem("Emotions"):
366
+ emotion_output = gr.Plot(label="Emotion Analysis")
367
+ with gr.TabItem("Topic"):
368
+ topic_output = gr.Markdown(label="Topic Prediction")
369
+ with gr.TabItem("Attention Heatmap"):
370
+ attention_output = gr.Plot(label="Attention Weights")
371
+ gr.Markdown("*Visualizes which parts of the input the model focused on.*")
372
+ # Download section
373
+ gr.Markdown("### Export Results")
374
+ download_data = gr.Textbox(visible=False)
375
+ download_btn = gr.DownloadButton(
376
+ "Download Results (JSON)",
377
+ visible=True
378
+ )
379
+ # Event Handlers
380
+ input_text.change(
381
+ fn=count_tokens,
382
+ inputs=[input_text],
383
+ outputs=[token_count]
384
+ )
385
+ predict_btn.click(
386
+ fn=predict,
387
+ inputs=[input_text, compression],
388
+ outputs=[summary_output, emotion_output, topic_output, attention_output, download_data]
389
+ ).then(
390
+ fn=lambda x: gr.DownloadButton("Download Results (JSON)", value=x, visible=True),
391
+ inputs=[download_data],
392
+ outputs=[download_btn]
393
+ )
394
+ # Examples
395
+ gr.Examples(
396
+ examples=[
397
+ [SAMPLE_TEXT, 50],
398
+ ["Climate change poses significant risks to global ecosystems. Rising temperatures, melting ice caps, and extreme weather events are becoming more frequent. Scientists urge immediate action to reduce carbon emissions and transition to renewable energy sources.", 40],
399
+ ],
400
+ inputs=[input_text, compression],
401
+ label="Try these examples:"
402
+ )
403
+ return demo
404
+
405
+ if __name__ == "__main__":
406
+ try:
407
+ # Pre-load pipeline
408
+ get_pipeline()
409
+
410
+ # Create and launch interface
411
+ demo = create_interface()
412
+ demo.queue() # Enable queuing for better responsiveness
413
+ demo.launch(
414
+ share=True,
415
+ server_name="0.0.0.0",
416
+ server_port=7860,
417
+ show_error=True
418
+ )
419
+
420
+ except Exception as e:
421
+ logger.error(f"Failed to launch demo: {e}", exc_info=True)
422
+ print(f"Error: {e}")
423
+ sys.exit(1)
424
+
425
+