Spaces:
Sleeping
Sleeping
OliverPerrin
commited on
Commit
·
4bda87e
1
Parent(s):
5c8762c
Fix Gradio demo metrics display and visualization script MLflow URI
Browse files- Update demo to show training metrics from history JSON
- Add text-based loss/accuracy progress display
- Fix MLflow tracking URI to use SQLite database
- Update Architecture tab with correct training data info
- Add dev checkpoint note to demo header
- scripts/demo_gradio.py +58 -9
- scripts/visualize_training.py +2 -1
scripts/demo_gradio.py
CHANGED
|
@@ -135,6 +135,8 @@ with gr.Blocks(title="LexiMind") as demo:
|
|
| 135 |
|
| 136 |
A custom 272M parameter encoder-decoder model trained jointly on three NLP tasks.
|
| 137 |
Built from scratch in PyTorch, initialized from FLAN-T5-base weights.
|
|
|
|
|
|
|
| 138 |
"""
|
| 139 |
)
|
| 140 |
|
|
@@ -176,7 +178,7 @@ with gr.Blocks(title="LexiMind") as demo:
|
|
| 176 |
gr.Markdown("### Training Results")
|
| 177 |
|
| 178 |
# Load metrics from training history
|
| 179 |
-
|
| 180 |
if TRAINING_HISTORY_PATH.exists():
|
| 181 |
with open(TRAINING_HISTORY_PATH) as f:
|
| 182 |
history = json.load(f)
|
|
@@ -185,13 +187,60 @@ with gr.Blocks(title="LexiMind") as demo:
|
|
| 185 |
if val_keys:
|
| 186 |
latest = sorted(val_keys)[-1]
|
| 187 |
val = history[latest]
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
|
| 195 |
with gr.Row():
|
| 196 |
loss_curve = OUTPUTS_DIR / "training_loss_curve.png"
|
| 197 |
if loss_curve.exists():
|
|
@@ -232,15 +281,15 @@ with gr.Blocks(title="LexiMind") as demo:
|
|
| 232 |
|------|--------|---------------|
|
| 233 |
| **Summarization** | Seq2seq generation | Cross-entropy + label smoothing |
|
| 234 |
| **Emotion** | 28-class multi-label | Binary cross-entropy |
|
| 235 |
-
| **Topic** |
|
| 236 |
|
| 237 |
### Training Data
|
| 238 |
|
| 239 |
| Dataset | Task | Size |
|
| 240 |
|---------|------|------|
|
| 241 |
-
|
|
| 242 |
| GoEmotions | Emotion | ~43K |
|
| 243 |
-
|
|
| 244 |
|
| 245 |
### Links
|
| 246 |
|
|
|
|
| 135 |
|
| 136 |
A custom 272M parameter encoder-decoder model trained jointly on three NLP tasks.
|
| 137 |
Built from scratch in PyTorch, initialized from FLAN-T5-base weights.
|
| 138 |
+
|
| 139 |
+
*Currently running development checkpoint - quality improves with longer training.*
|
| 140 |
"""
|
| 141 |
)
|
| 142 |
|
|
|
|
| 178 |
gr.Markdown("### Training Results")
|
| 179 |
|
| 180 |
# Load metrics from training history
|
| 181 |
+
metrics_html = ""
|
| 182 |
if TRAINING_HISTORY_PATH.exists():
|
| 183 |
with open(TRAINING_HISTORY_PATH) as f:
|
| 184 |
history = json.load(f)
|
|
|
|
| 187 |
if val_keys:
|
| 188 |
latest = sorted(val_keys)[-1]
|
| 189 |
val = history[latest]
|
| 190 |
+
epoch_num = latest.split("_")[-1]
|
| 191 |
+
|
| 192 |
+
topic_acc = val.get('topic_accuracy', 0)
|
| 193 |
+
emotion_f1 = val.get('emotion_f1', 0)
|
| 194 |
+
rouge = val.get('summarization_rouge_like', 0)
|
| 195 |
+
total_loss = val.get('total_loss', 0)
|
| 196 |
+
|
| 197 |
+
metrics_html = f"""
|
| 198 |
+
| Task | Metric | Score |
|
| 199 |
+
|------|--------|-------|
|
| 200 |
+
| **Topic Classification** | Accuracy | **{topic_acc:.1%}** |
|
| 201 |
+
| **Emotion Detection** | F1 Score | {emotion_f1:.1%} |
|
| 202 |
+
| **Summarization** | ROUGE-like | {rouge:.1%} |
|
| 203 |
+
| **Total** | Val Loss | {total_loss:.3f} |
|
| 204 |
+
|
| 205 |
+
*Results from epoch {epoch_num}*
|
| 206 |
+
"""
|
| 207 |
+
else:
|
| 208 |
+
metrics_html = "*No training history found. Run training first.*"
|
| 209 |
|
| 210 |
+
gr.Markdown(metrics_html)
|
| 211 |
+
|
| 212 |
+
gr.Markdown("### Training Progress")
|
| 213 |
+
|
| 214 |
+
# Generate inline training chart using history data
|
| 215 |
+
if TRAINING_HISTORY_PATH.exists():
|
| 216 |
+
with open(TRAINING_HISTORY_PATH) as f:
|
| 217 |
+
history = json.load(f)
|
| 218 |
+
|
| 219 |
+
train_keys = sorted([k for k in history.keys() if k.startswith("train_epoch")])
|
| 220 |
+
val_keys = sorted([k for k in history.keys() if k.startswith("val_epoch")])
|
| 221 |
+
|
| 222 |
+
if train_keys and val_keys:
|
| 223 |
+
epochs = list(range(1, len(train_keys) + 1))
|
| 224 |
+
train_loss = [history[k]["total_loss"] for k in train_keys]
|
| 225 |
+
val_loss = [history[k]["total_loss"] for k in val_keys[:len(train_keys)]]
|
| 226 |
+
topic_acc = [history[k].get("topic_accuracy", 0) * 100 for k in val_keys[:len(train_keys)]]
|
| 227 |
+
|
| 228 |
+
# Create a simple text-based progress display
|
| 229 |
+
progress_md = "**Loss Curve:**\n```\n"
|
| 230 |
+
for i, (tl, vl) in enumerate(zip(train_loss, val_loss)):
|
| 231 |
+
bar_len = int((1 - vl/max(val_loss)) * 20) + 1
|
| 232 |
+
progress_md += f"Epoch {i+1}: Train={tl:.3f} Val={vl:.3f} {'█' * bar_len}\n"
|
| 233 |
+
progress_md += "```\n\n"
|
| 234 |
+
|
| 235 |
+
progress_md += "**Topic Accuracy:**\n```\n"
|
| 236 |
+
for i, acc in enumerate(topic_acc):
|
| 237 |
+
bar_len = int(acc / 5)
|
| 238 |
+
progress_md += f"Epoch {i+1}: {acc:.1f}% {'█' * bar_len}\n"
|
| 239 |
+
progress_md += "```"
|
| 240 |
+
|
| 241 |
+
gr.Markdown(progress_md)
|
| 242 |
|
| 243 |
+
# Show visualization images if they exist
|
| 244 |
with gr.Row():
|
| 245 |
loss_curve = OUTPUTS_DIR / "training_loss_curve.png"
|
| 246 |
if loss_curve.exists():
|
|
|
|
| 281 |
|------|--------|---------------|
|
| 282 |
| **Summarization** | Seq2seq generation | Cross-entropy + label smoothing |
|
| 283 |
| **Emotion** | 28-class multi-label | Binary cross-entropy |
|
| 284 |
+
| **Topic** | 7-class single-label | Cross-entropy |
|
| 285 |
|
| 286 |
### Training Data
|
| 287 |
|
| 288 |
| Dataset | Task | Size |
|
| 289 |
|---------|------|------|
|
| 290 |
+
| BookSum + arXiv | Summarization | ~40K |
|
| 291 |
| GoEmotions | Emotion | ~43K |
|
| 292 |
+
| 20 Newsgroups + Gutenberg | Topic | ~3.4K |
|
| 293 |
|
| 294 |
### Links
|
| 295 |
|
scripts/visualize_training.py
CHANGED
|
@@ -127,7 +127,8 @@ def get_mlflow_client():
|
|
| 127 |
raise ImportError("MLflow not installed. Install with: pip install mlflow")
|
| 128 |
import mlflow
|
| 129 |
import mlflow.tracking
|
| 130 |
-
|
|
|
|
| 131 |
return mlflow.tracking.MlflowClient()
|
| 132 |
|
| 133 |
|
|
|
|
| 127 |
raise ImportError("MLflow not installed. Install with: pip install mlflow")
|
| 128 |
import mlflow
|
| 129 |
import mlflow.tracking
|
| 130 |
+
# Use SQLite database (same as trainer.py)
|
| 131 |
+
mlflow.set_tracking_uri("sqlite:///mlruns.db")
|
| 132 |
return mlflow.tracking.MlflowClient()
|
| 133 |
|
| 134 |
|