OliverPerrin commited on
Commit
4bda87e
·
1 Parent(s): 5c8762c

Fix Gradio demo metrics display and visualization script MLflow URI

Browse files

- Update demo to show training metrics from history JSON
- Add text-based loss/accuracy progress display
- Fix MLflow tracking URI to use SQLite database
- Update Architecture tab with correct training data info
- Add dev checkpoint note to demo header

scripts/demo_gradio.py CHANGED
@@ -135,6 +135,8 @@ with gr.Blocks(title="LexiMind") as demo:
135
 
136
  A custom 272M parameter encoder-decoder model trained jointly on three NLP tasks.
137
  Built from scratch in PyTorch, initialized from FLAN-T5-base weights.
 
 
138
  """
139
  )
140
 
@@ -176,7 +178,7 @@ with gr.Blocks(title="LexiMind") as demo:
176
  gr.Markdown("### Training Results")
177
 
178
  # Load metrics from training history
179
- metrics_md = "| Task | Metric | Score |\n|------|--------|-------|\n"
180
  if TRAINING_HISTORY_PATH.exists():
181
  with open(TRAINING_HISTORY_PATH) as f:
182
  history = json.load(f)
@@ -185,13 +187,60 @@ with gr.Blocks(title="LexiMind") as demo:
185
  if val_keys:
186
  latest = sorted(val_keys)[-1]
187
  val = history[latest]
188
- metrics_md += f"| Topic Classification | Accuracy | **{val.get('topic_accuracy', 0):.1%}** |\n"
189
- metrics_md += f"| Emotion Detection | F1 Score | {val.get('emotion_f1', 0):.1%} |\n"
190
- metrics_md += f"| Summarization | ROUGE-like | {val.get('summarization_rouge_like', 0):.1%} |\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- gr.Markdown(metrics_md)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- gr.Markdown("### Training Visualizations")
195
  with gr.Row():
196
  loss_curve = OUTPUTS_DIR / "training_loss_curve.png"
197
  if loss_curve.exists():
@@ -232,15 +281,15 @@ with gr.Blocks(title="LexiMind") as demo:
232
  |------|--------|---------------|
233
  | **Summarization** | Seq2seq generation | Cross-entropy + label smoothing |
234
  | **Emotion** | 28-class multi-label | Binary cross-entropy |
235
- | **Topic** | 4-class single-label | Cross-entropy |
236
 
237
  ### Training Data
238
 
239
  | Dataset | Task | Size |
240
  |---------|------|------|
241
- | CNN/DailyMail | Summarization | ~100K |
242
  | GoEmotions | Emotion | ~43K |
243
- | AG News | Topic | ~120K |
244
 
245
  ### Links
246
 
 
135
 
136
  A custom 272M parameter encoder-decoder model trained jointly on three NLP tasks.
137
  Built from scratch in PyTorch, initialized from FLAN-T5-base weights.
138
+
139
+ *Currently running development checkpoint - quality improves with longer training.*
140
  """
141
  )
142
 
 
178
  gr.Markdown("### Training Results")
179
 
180
  # Load metrics from training history
181
+ metrics_html = ""
182
  if TRAINING_HISTORY_PATH.exists():
183
  with open(TRAINING_HISTORY_PATH) as f:
184
  history = json.load(f)
 
187
  if val_keys:
188
  latest = sorted(val_keys)[-1]
189
  val = history[latest]
190
+ epoch_num = latest.split("_")[-1]
191
+
192
+ topic_acc = val.get('topic_accuracy', 0)
193
+ emotion_f1 = val.get('emotion_f1', 0)
194
+ rouge = val.get('summarization_rouge_like', 0)
195
+ total_loss = val.get('total_loss', 0)
196
+
197
+ metrics_html = f"""
198
+ | Task | Metric | Score |
199
+ |------|--------|-------|
200
+ | **Topic Classification** | Accuracy | **{topic_acc:.1%}** |
201
+ | **Emotion Detection** | F1 Score | {emotion_f1:.1%} |
202
+ | **Summarization** | ROUGE-like | {rouge:.1%} |
203
+ | **Total** | Val Loss | {total_loss:.3f} |
204
+
205
+ *Results from epoch {epoch_num}*
206
+ """
207
+ else:
208
+ metrics_html = "*No training history found. Run training first.*"
209
 
210
+ gr.Markdown(metrics_html)
211
+
212
+ gr.Markdown("### Training Progress")
213
+
214
+ # Generate inline training chart using history data
215
+ if TRAINING_HISTORY_PATH.exists():
216
+ with open(TRAINING_HISTORY_PATH) as f:
217
+ history = json.load(f)
218
+
219
+ train_keys = sorted([k for k in history.keys() if k.startswith("train_epoch")])
220
+ val_keys = sorted([k for k in history.keys() if k.startswith("val_epoch")])
221
+
222
+ if train_keys and val_keys:
223
+ epochs = list(range(1, len(train_keys) + 1))
224
+ train_loss = [history[k]["total_loss"] for k in train_keys]
225
+ val_loss = [history[k]["total_loss"] for k in val_keys[:len(train_keys)]]
226
+ topic_acc = [history[k].get("topic_accuracy", 0) * 100 for k in val_keys[:len(train_keys)]]
227
+
228
+ # Create a simple text-based progress display
229
+ progress_md = "**Loss Curve:**\n```\n"
230
+ for i, (tl, vl) in enumerate(zip(train_loss, val_loss)):
231
+ bar_len = int((1 - vl/max(val_loss)) * 20) + 1
232
+ progress_md += f"Epoch {i+1}: Train={tl:.3f} Val={vl:.3f} {'█' * bar_len}\n"
233
+ progress_md += "```\n\n"
234
+
235
+ progress_md += "**Topic Accuracy:**\n```\n"
236
+ for i, acc in enumerate(topic_acc):
237
+ bar_len = int(acc / 5)
238
+ progress_md += f"Epoch {i+1}: {acc:.1f}% {'█' * bar_len}\n"
239
+ progress_md += "```"
240
+
241
+ gr.Markdown(progress_md)
242
 
243
+ # Show visualization images if they exist
244
  with gr.Row():
245
  loss_curve = OUTPUTS_DIR / "training_loss_curve.png"
246
  if loss_curve.exists():
 
281
  |------|--------|---------------|
282
  | **Summarization** | Seq2seq generation | Cross-entropy + label smoothing |
283
  | **Emotion** | 28-class multi-label | Binary cross-entropy |
284
+ | **Topic** | 7-class single-label | Cross-entropy |
285
 
286
  ### Training Data
287
 
288
  | Dataset | Task | Size |
289
  |---------|------|------|
290
+ | BookSum + arXiv | Summarization | ~40K |
291
  | GoEmotions | Emotion | ~43K |
292
+ | 20 Newsgroups + Gutenberg | Topic | ~3.4K |
293
 
294
  ### Links
295
 
scripts/visualize_training.py CHANGED
@@ -127,7 +127,8 @@ def get_mlflow_client():
127
  raise ImportError("MLflow not installed. Install with: pip install mlflow")
128
  import mlflow
129
  import mlflow.tracking
130
- mlflow.set_tracking_uri(f"file://{MLRUNS_DIR}")
 
131
  return mlflow.tracking.MlflowClient()
132
 
133
 
 
127
  raise ImportError("MLflow not installed. Install with: pip install mlflow")
128
  import mlflow
129
  import mlflow.tracking
130
+ # Use SQLite database (same as trainer.py)
131
+ mlflow.set_tracking_uri("sqlite:///mlruns.db")
132
  return mlflow.tracking.MlflowClient()
133
 
134