OliverPerrin commited on
Commit
40ccedf
Β·
1 Parent(s): 3ca077f

Demo: Update Gradio UI with metrics, visualizations, and clean tabs

Browse files
Files changed (1) hide show
  1. scripts/demo_gradio.py +152 -67
scripts/demo_gradio.py CHANGED
@@ -1,10 +1,10 @@
1
  """
2
  Gradio demo for LexiMind multi-task NLP model.
3
 
4
- Provides a simple web interface for the three core tasks:
5
  - Summarization: Generates concise summaries of input text
6
- - Emotion Detection: Identifies emotional content with confidence scores
7
- - Topic Classification: Categorizes text into predefined topics
8
 
9
  Author: Oliver Perrin
10
  Date: 2025-12-04
@@ -19,7 +19,6 @@ from pathlib import Path
19
  import gradio as gr
20
 
21
  # --------------- Path Setup ---------------
22
- # Ensure local src package is importable when running script directly
23
 
24
  SCRIPT_DIR = Path(__file__).resolve().parent
25
  PROJECT_ROOT = SCRIPT_DIR.parent
@@ -40,13 +39,24 @@ logger = get_logger(__name__)
40
  OUTPUTS_DIR = PROJECT_ROOT / "outputs"
41
  EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
42
 
43
- SAMPLE_TEXT = (
44
- "Artificial intelligence is rapidly transforming technology. "
45
- "Machine learning algorithms process vast amounts of data, identifying "
46
- "patterns with unprecedented accuracy. From healthcare to finance, AI is "
47
- "revolutionizing industries worldwide. However, ethical considerations "
48
- "around privacy and bias remain critical challenges."
49
- )
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # --------------- Pipeline Management ---------------
52
 
@@ -61,7 +71,6 @@ def get_pipeline():
61
 
62
  checkpoint_path = Path("checkpoints/best.pt")
63
 
64
- # Download from HuggingFace Hub if checkpoint doesn't exist locally
65
  if not checkpoint_path.exists():
66
  checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
67
  hf_hub_download(
@@ -82,99 +91,175 @@ def get_pipeline():
82
  # --------------- Core Functions ---------------
83
 
84
 
85
- def analyze(text: str) -> str:
86
- """
87
- Run all three tasks on input text.
88
-
89
- Returns markdown-formatted results for display in Gradio.
90
- """
91
  if not text or not text.strip():
92
- return "Please enter some text to analyze."
93
 
94
  try:
95
  pipe = get_pipeline()
96
 
97
- # Run each task
98
  summary = pipe.summarize([text], max_length=128)[0].strip() or "(empty)"
99
  emotions = pipe.predict_emotions([text], threshold=0.5)[0]
100
  topic = pipe.predict_topics([text])[0]
101
 
102
- # Format emotion results
103
  if emotions.labels:
104
- emotion_str = ", ".join(
105
- f"{lbl} ({score:.1%})"
106
  for lbl, score in zip(emotions.labels, emotions.scores, strict=True)
107
  )
108
  else:
109
  emotion_str = "No strong emotions detected"
110
 
111
- return f"""## Summary
112
- {summary}
113
 
114
- ## Detected Emotions
115
- {emotion_str}
116
 
117
- ## Topic
118
- {topic.label} ({topic.confidence:.1%})
119
- """
120
  except Exception as e:
121
  logger.error("Analysis failed: %s", e, exc_info=True)
122
- return f"Error: {e}"
123
 
124
 
125
- def get_metrics() -> str:
126
- """Load evaluation metrics from JSON and format as markdown tables."""
127
  if not EVAL_REPORT_PATH.exists():
128
- return "No evaluation report found. Run `scripts/evaluate.py` first."
129
 
130
  try:
131
  with open(EVAL_REPORT_PATH) as f:
132
  r = json.load(f)
133
 
134
- # Build overall metrics table
135
- lines = [
136
- "## Model Performance\n",
137
- "| Task | Metric | Score |",
138
- "|------|--------|-------|",
139
- f"| Summarization | ROUGE-Like | {r['summarization']['rouge_like']:.4f} |",
140
- f"| Summarization | BLEU | {r['summarization']['bleu']:.4f} |",
141
- f"| Emotion | F1 Macro | {r['emotion']['f1_macro']:.4f} |",
142
- f"| Topic | Accuracy | {r['topic']['accuracy']:.4f} |",
143
- "",
144
- "### Topic Classification Details\n",
145
- "| Label | Precision | Recall | F1 |",
146
- "|-------|-----------|--------|-----|",
147
- ]
148
-
149
- # Add per-class metrics
150
- for label, metrics in r["topic"]["classification_report"].items():
151
- if isinstance(metrics, dict) and "precision" in metrics:
152
- lines.append(
153
- f"| {label} | {metrics['precision']:.3f} | "
154
- f"{metrics['recall']:.3f} | {metrics['f1-score']:.3f} |"
155
- )
156
 
157
- return "\n".join(lines)
 
 
 
 
 
 
 
 
158
  except Exception as e:
159
  return f"Error loading metrics: {e}"
160
 
161
 
162
  # --------------- Gradio Interface ---------------
163
 
164
- with gr.Blocks(title="LexiMind Demo") as demo:
 
 
 
 
165
  gr.Markdown(
166
- "# LexiMind NLP Demo\n"
167
- "Multi-task model: summarization, emotion detection, topic classification."
 
 
 
 
 
168
  )
169
 
170
- with gr.Tab("Analyze"):
171
- text_input = gr.Textbox(label="Input Text", lines=6, value=SAMPLE_TEXT)
172
- analyze_btn = gr.Button("Analyze", variant="primary")
173
- output = gr.Markdown(label="Results")
174
- analyze_btn.click(fn=analyze, inputs=text_input, outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- with gr.Tab("Metrics"):
177
- gr.Markdown(get_metrics())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
 
180
  # --------------- Entry Point ---------------
 
1
  """
2
  Gradio demo for LexiMind multi-task NLP model.
3
 
4
+ Showcases the model's capabilities across three tasks:
5
  - Summarization: Generates concise summaries of input text
6
+ - Emotion Detection: Multi-label emotion classification
7
+ - Topic Classification: Categorizes text into news topics
8
 
9
  Author: Oliver Perrin
10
  Date: 2025-12-04
 
19
  import gradio as gr
20
 
21
  # --------------- Path Setup ---------------
 
22
 
23
  SCRIPT_DIR = Path(__file__).resolve().parent
24
  PROJECT_ROOT = SCRIPT_DIR.parent
 
39
  OUTPUTS_DIR = PROJECT_ROOT / "outputs"
40
  EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
41
 
42
+ SAMPLE_TEXTS = [
43
+ (
44
+ "Artificial intelligence is rapidly transforming technology. "
45
+ "Machine learning algorithms process vast amounts of data, identifying "
46
+ "patterns with unprecedented accuracy. From healthcare to finance, AI is "
47
+ "revolutionizing industries worldwide."
48
+ ),
49
+ (
50
+ "The team's incredible comeback in the final quarter left fans in tears of joy. "
51
+ "After trailing by 20 points, they scored three consecutive touchdowns to secure "
52
+ "their first championship victory in over a decade."
53
+ ),
54
+ (
55
+ "Global markets tumbled today as investors reacted to rising inflation concerns. "
56
+ "The Federal Reserve hinted at potential interest rate hikes, sending shockwaves "
57
+ "through technology and banking sectors."
58
+ ),
59
+ ]
60
 
61
  # --------------- Pipeline Management ---------------
62
 
 
71
 
72
  checkpoint_path = Path("checkpoints/best.pt")
73
 
 
74
  if not checkpoint_path.exists():
75
  checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
76
  hf_hub_download(
 
91
  # --------------- Core Functions ---------------
92
 
93
 
94
+ def analyze(text: str) -> tuple[str, str, str]:
95
+ """Run all three tasks and return formatted results."""
 
 
 
 
96
  if not text or not text.strip():
97
+ return "Enter text above", "", ""
98
 
99
  try:
100
  pipe = get_pipeline()
101
 
102
+ # Run tasks
103
  summary = pipe.summarize([text], max_length=128)[0].strip() or "(empty)"
104
  emotions = pipe.predict_emotions([text], threshold=0.5)[0]
105
  topic = pipe.predict_topics([text])[0]
106
 
107
+ # Format emotions
108
  if emotions.labels:
109
+ emotion_str = " β€’ ".join(
110
+ f"**{lbl}** ({score:.0%})"
111
  for lbl, score in zip(emotions.labels, emotions.scores, strict=True)
112
  )
113
  else:
114
  emotion_str = "No strong emotions detected"
115
 
116
+ # Format topic
117
+ topic_str = f"**{topic.label}** ({topic.confidence:.0%})"
118
 
119
+ return summary, emotion_str, topic_str
 
120
 
 
 
 
121
  except Exception as e:
122
  logger.error("Analysis failed: %s", e, exc_info=True)
123
+ return f"Error: {e}", "", ""
124
 
125
 
126
+ def load_metrics() -> str:
127
+ """Load evaluation metrics and format as markdown."""
128
  if not EVAL_REPORT_PATH.exists():
129
+ return "No evaluation report found."
130
 
131
  try:
132
  with open(EVAL_REPORT_PATH) as f:
133
  r = json.load(f)
134
 
135
+ return f"""
136
+ ### Overall Performance
137
+
138
+ | Task | Metric | Score |
139
+ |------|--------|-------|
140
+ | **Emotion** | F1 Macro | **{r["emotion"]["f1_macro"]:.1%}** |
141
+ | **Topic** | Accuracy | **{r["topic"]["accuracy"]:.1%}** |
142
+ | **Summarization** | ROUGE-Like | {r["summarization"]["rouge_like"]:.1%} |
143
+ | **Summarization** | BLEU | {r["summarization"]["bleu"]:.1%} |
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ ### Topic Classification (per-class)
146
+
147
+ | Category | Precision | Recall | F1 |
148
+ |----------|-----------|--------|-----|
149
+ | Business | {r["topic"]["classification_report"]["Business"]["precision"]:.1%} | {r["topic"]["classification_report"]["Business"]["recall"]:.1%} | {r["topic"]["classification_report"]["Business"]["f1-score"]:.1%} |
150
+ | Sci/Tech | {r["topic"]["classification_report"]["Sci/Tech"]["precision"]:.1%} | {r["topic"]["classification_report"]["Sci/Tech"]["recall"]:.1%} | {r["topic"]["classification_report"]["Sci/Tech"]["f1-score"]:.1%} |
151
+ | Sports | {r["topic"]["classification_report"]["Sports"]["precision"]:.1%} | {r["topic"]["classification_report"]["Sports"]["recall"]:.1%} | {r["topic"]["classification_report"]["Sports"]["f1-score"]:.1%} |
152
+ | World | {r["topic"]["classification_report"]["World"]["precision"]:.1%} | {r["topic"]["classification_report"]["World"]["recall"]:.1%} | {r["topic"]["classification_report"]["World"]["f1-score"]:.1%} |
153
+ """
154
  except Exception as e:
155
  return f"Error loading metrics: {e}"
156
 
157
 
158
  # --------------- Gradio Interface ---------------
159
 
160
+ with gr.Blocks(
161
+ title="LexiMind Demo",
162
+ theme=gr.themes.Soft(),
163
+ css=".output-box { min-height: 80px; }",
164
+ ) as demo:
165
  gr.Markdown(
166
+ """
167
+ # 🧠 LexiMind
168
+ ### Multi-Task Transformer for Document Analysis
169
+
170
+ A custom encoder-decoder Transformer trained on summarization, emotion detection,
171
+ and topic classification. Built from scratch with PyTorch.
172
+ """
173
  )
174
 
175
+ # --------------- Try It Tab ---------------
176
+ with gr.Tab("πŸš€ Try It"):
177
+ with gr.Row():
178
+ with gr.Column(scale=2):
179
+ text_input = gr.Textbox(
180
+ label="Input Text",
181
+ lines=5,
182
+ placeholder="Enter text to analyze...",
183
+ value=SAMPLE_TEXTS[0],
184
+ )
185
+ with gr.Row():
186
+ analyze_btn = gr.Button("Analyze", variant="primary", scale=2)
187
+ gr.Examples(
188
+ examples=[[t] for t in SAMPLE_TEXTS],
189
+ inputs=text_input,
190
+ label="Examples",
191
+ )
192
+
193
+ with gr.Column(scale=2):
194
+ summary_out = gr.Textbox(label="πŸ“ Summary", lines=3, elem_classes="output-box")
195
+ emotion_out = gr.Markdown(label="😊 Emotions")
196
+ topic_out = gr.Markdown(label="πŸ“‚ Topic")
197
+
198
+ analyze_btn.click(
199
+ fn=analyze,
200
+ inputs=text_input,
201
+ outputs=[summary_out, emotion_out, topic_out],
202
+ )
203
+
204
+ # --------------- Metrics Tab ---------------
205
+ with gr.Tab("πŸ“Š Metrics"):
206
+ gr.Markdown(load_metrics())
207
+ gr.Markdown("### Confusion Matrix")
208
+ gr.Image(str(OUTPUTS_DIR / "topic_confusion_matrix.png"), label="Topic Classification")
209
+
210
+ # --------------- Architecture Tab ---------------
211
+ with gr.Tab("πŸ”§ Architecture"):
212
+ gr.Markdown(
213
+ """
214
+ ### Model Architecture
215
+
216
+ - **Base**: Custom Transformer (encoder-decoder)
217
+ - **Initialized from**: FLAN-T5-base weights
218
+ - **Encoder**: 6 layers, 768 hidden dim, 12 attention heads
219
+ - **Decoder**: 6 layers with cross-attention
220
+ - **Task Heads**: Classification heads for emotion/topic
221
+
222
+ ### Training
223
+
224
+ - **Optimizer**: AdamW with cosine LR schedule
225
+ - **Mixed Precision**: bfloat16 with TF32
226
+ - **Compilation**: torch.compile with inductor backend
227
+ """
228
+ )
229
+ with gr.Row():
230
+ gr.Image(
231
+ str(OUTPUTS_DIR / "attention_visualization.png"),
232
+ label="Self-Attention Pattern",
233
+ )
234
+ gr.Image(
235
+ str(OUTPUTS_DIR / "positional_encoding_heatmap.png"),
236
+ label="Positional Encodings",
237
+ )
238
 
239
+ # --------------- About Tab ---------------
240
+ with gr.Tab("ℹ️ About"):
241
+ gr.Markdown(
242
+ """
243
+ ### About LexiMind
244
+
245
+ LexiMind is a multi-task NLP model designed to demonstrate end-to-end
246
+ machine learning engineering skills:
247
+
248
+ - **Custom Transformer** implementation from scratch
249
+ - **Multi-task learning** with shared encoder
250
+ - **Production-ready** inference pipeline
251
+ - **Comprehensive evaluation** with multiple metrics
252
+
253
+ ### Links
254
+
255
+ - πŸ”— [GitHub Repository](https://github.com/OliverPerrin/LexiMind)
256
+ - πŸ€— [HuggingFace Space](https://huggingface.co/spaces/OliverPerrin/LexiMind)
257
+
258
+ ### Author
259
+
260
+ **Oliver Perrin** - Machine Learning Engineer
261
+ """
262
+ )
263
 
264
 
265
  # --------------- Entry Point ---------------