udedhia commited on
Commit
edb1d9f
Β·
verified Β·
1 Parent(s): adbfd78

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +432 -0
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, glob
2
+ import gradio as gr
3
+ import plotly.graph_objects as go
4
+ import plotly.express as px
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ # ── constants ─────────────────────────────────────────────────────────────
9
+ IMPUTE_FIELDS = [
10
+ 'recovered_material', 'recovered_object_type', 'recovered_condition',
11
+ 'recovered_period', 'recovered_description'
12
+ ]
13
+ FIELD_LABELS = {
14
+ 'recovered_material': 'Material',
15
+ 'recovered_object_type': 'Object Type',
16
+ 'recovered_condition': 'Condition',
17
+ 'recovered_period': 'Period',
18
+ 'recovered_description': 'Description',
19
+ }
20
+ METRICS = {
21
+ 'exact_match': 'Exact Match',
22
+ 'fuzzy_token_sort': 'Fuzzy Match',
23
+ 'semantic_sim': 'Semantic Similarity',
24
+ 'top3_match': 'Top-3 Match',
25
+ 'bleu': 'BLEU (description only)',
26
+ }
27
+ COLORS = ['#7D3A10', '#2d6a4f', '#1848A0', '#e9c46a', '#993556']
28
+
29
+ # ── load all eval jsons ───────────────────────────────────────────────────
30
+ def get_eval_files():
31
+ return sorted(glob.glob('*.json') + glob.glob('eval_results*.json'))
32
+
33
+ def load_eval(path):
34
+ with open(path) as f:
35
+ return json.load(f)
36
+
37
+ def friendly_name(path):
38
+ n = os.path.basename(path).replace('eval_results','').replace('.json','').strip('_- ')
39
+ return n if n else os.path.basename(path)
40
+
41
+ eval_files = get_eval_files()
42
+ eval_data = {friendly_name(f): load_eval(f) for f in eval_files}
43
+
44
+ # ── TAB 1: metrics dashboard ──────────────────────────────────────────────
45
+ def make_bar_chart(selected_runs, metric):
46
+ if not selected_runs:
47
+ return go.Figure()
48
+ fig = go.Figure()
49
+ for i, run in enumerate(selected_runs):
50
+ if run not in eval_data: continue
51
+ data = eval_data[run]['summary']
52
+ fields = list(data.keys())
53
+ vals = [data[f].get(metric, 0) for f in fields]
54
+ labels = [FIELD_LABELS.get(f, f) for f in fields]
55
+ fig.add_trace(go.Bar(
56
+ name=run, x=labels, y=vals,
57
+ marker_color=COLORS[i % len(COLORS)],
58
+ text=[f'{v:.1%}' for v in vals],
59
+ textposition='outside',
60
+ ))
61
+ fig.update_layout(
62
+ barmode='group',
63
+ yaxis=dict(range=[0,1.15], tickformat='.0%', title='Score', gridcolor='#eee'),
64
+ xaxis_title='Field',
65
+ plot_bgcolor='white',
66
+ paper_bgcolor='white',
67
+ font=dict(family='Georgia, serif', size=13),
68
+ legend=dict(orientation='h', y=1.12),
69
+ margin=dict(t=60, b=40, l=40, r=20),
70
+ height=420,
71
+ )
72
+ return fig
73
+
74
+ def make_radar(selected_runs):
75
+ if not selected_runs:
76
+ return go.Figure()
77
+ cats = ['Exact Match','Fuzzy Match','Semantic Sim','Top-3 Match']
78
+ metric_keys = ['exact_match','fuzzy_token_sort','semantic_sim','top3_match']
79
+ fig = go.Figure()
80
+ for i, run in enumerate(selected_runs):
81
+ if run not in eval_data: continue
82
+ data = eval_data[run]['summary']
83
+ vals = []
84
+ for mk in metric_keys:
85
+ field_vals = [data[f].get(mk, 0) for f in IMPUTE_FIELDS if mk in data.get(f,{})]
86
+ vals.append(np.mean(field_vals) if field_vals else 0)
87
+ fig.add_trace(go.Scatterpolar(
88
+ r=vals + [vals[0]],
89
+ theta=cats + [cats[0]],
90
+ name=run,
91
+ line_color=COLORS[i % len(COLORS)],
92
+ fill='toself', fillcolor=COLORS[i % len(COLORS)],
93
+ opacity=0.2,
94
+ ))
95
+ fig.update_layout(
96
+ polar=dict(radialaxis=dict(range=[0,1], tickformat='.0%')),
97
+ font=dict(family='Georgia, serif', size=12),
98
+ height=380,
99
+ margin=dict(t=40, b=40),
100
+ paper_bgcolor='white',
101
+ )
102
+ return fig
103
+
104
+ def make_summary_table(selected_runs):
105
+ if not selected_runs:
106
+ return pd.DataFrame()
107
+ rows = []
108
+ for run in selected_runs:
109
+ if run not in eval_data: continue
110
+ summary = eval_data[run]['summary']
111
+ for field, stats in summary.items():
112
+ row = {'Run': run, 'Field': FIELD_LABELS.get(field, field)}
113
+ for mk, ml in METRICS.items():
114
+ row[ml] = f"{stats.get(mk, 0):.1%}" if mk in stats else 'β€”'
115
+ rows.append(row)
116
+ return pd.DataFrame(rows)
117
+
118
+ # ── TAB 2: artifact deep dive ─────────────────────────────────────────────
119
+ def make_confusion(run, field):
120
+ if not run or run not in eval_data: return go.Figure()
121
+ results = eval_data[run].get('results', {}).get(field, [])
122
+ if not results: return go.Figure()
123
+ gts = [r['gt'][:35] for r in results]
124
+ preds = [str(r['pred'])[:35] for r in results]
125
+ labels = sorted(set(gts) | set(preds))
126
+ n = len(labels)
127
+ idx = {l: i for i, l in enumerate(labels)}
128
+ mat = np.zeros((n,n), dtype=int)
129
+ for g, p in zip(gts, preds):
130
+ if g in idx and p in idx:
131
+ mat[idx[g]][idx[p]] += 1
132
+ fig = go.Figure(go.Heatmap(
133
+ z=mat, x=labels, y=labels,
134
+ colorscale='YlOrRd',
135
+ text=mat, texttemplate='%{text}',
136
+ ))
137
+ fig.update_layout(
138
+ xaxis_title='Predicted', yaxis_title='Ground Truth',
139
+ height=max(380, n*28),
140
+ font=dict(family='Georgia, serif', size=11),
141
+ margin=dict(t=20, b=80, l=120, r=20),
142
+ paper_bgcolor='white',
143
+ )
144
+ return fig
145
+
146
+ def make_scatter(run, field):
147
+ if not run or run not in eval_data: return go.Figure()
148
+ results = eval_data[run].get('results', {}).get(field, [])
149
+ if not results: return go.Figure()
150
+ x = [r.get('fuzzy_token_sort', 0) for r in results]
151
+ y = [r.get('semantic_sim', 0) for r in results]
152
+ em = [r.get('exact_match', False) for r in results]
153
+ hover = [f"<b>{r['label']}</b><br>GT: {r['gt'][:50]}<br>PRED: {str(r['pred'])[:50]}" for r in results]
154
+ colors_pt = ['#2d6a4f' if e else '#e76f51' for e in em]
155
+ fig = go.Figure(go.Scatter(
156
+ x=x, y=y, mode='markers',
157
+ marker=dict(color=colors_pt, size=9, opacity=0.75, line=dict(width=0.5, color='white')),
158
+ text=hover, hoverinfo='text',
159
+ ))
160
+ fig.add_shape(type='line', x0=0,y0=0,x1=1,y1=1, line=dict(dash='dot', color='#aaa', width=1))
161
+ fig.update_layout(
162
+ xaxis=dict(title='Fuzzy match', range=[0,1.05], gridcolor='#eee'),
163
+ yaxis=dict(title='Semantic similarity', range=[0,1.05], gridcolor='#eee'),
164
+ height=360,
165
+ plot_bgcolor='white', paper_bgcolor='white',
166
+ font=dict(family='Georgia, serif', size=12),
167
+ margin=dict(t=20, b=40),
168
+ )
169
+ return fig
170
+
171
+ def make_error_table(run, field):
172
+ if not run or run not in eval_data: return pd.DataFrame()
173
+ results = eval_data[run].get('results', {}).get(field, [])
174
+ errors = [r for r in results if not r.get('exact_match', False)]
175
+ rows = []
176
+ for r in errors:
177
+ rows.append({
178
+ 'Label': r['label'],
179
+ 'Class': r.get('item_class',''),
180
+ 'Project': r.get('project','')[:40],
181
+ 'GT': r['gt'][:60],
182
+ 'Predicted':str(r['pred'])[:60],
183
+ 'Sem Sim': f"{r.get('semantic_sim',0):.2f}",
184
+ 'Fuzzy': f"{r.get('fuzzy_token_sort',0):.2f}",
185
+ })
186
+ return pd.DataFrame(rows)
187
+
188
+ # ── TAB 3: per-artifact browser ───────────────────────────────────────────
189
+ def get_all_artifacts(run, field, only_errors):
190
+ if not run or run not in eval_data: return [], []
191
+ results = eval_data[run].get('results', {}).get(field, [])
192
+ if only_errors:
193
+ results = [r for r in results if not r.get('exact_match', False)]
194
+ choices = [f"{r['label']} | {r.get('item_class','')} | {r.get('project','')[:30]}" for r in results]
195
+ return choices, results
196
+
197
+ _artifact_cache = {}
198
+
199
+ def search_artifacts(run, field, only_errors, query):
200
+ choices, results = get_all_artifacts(run, field, only_errors)
201
+ _artifact_cache['results'] = results
202
+ _artifact_cache['choices'] = choices
203
+ if query:
204
+ filtered = [(c, r) for c, r in zip(choices, results)
205
+ if query.lower() in c.lower() or query.lower() in r['gt'].lower()]
206
+ choices = [x[0] for x in filtered]
207
+ _artifact_cache['results'] = [x[1] for x in filtered]
208
+ _artifact_cache['choices'] = choices
209
+ return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
210
+
211
+ def show_artifact_card(selection):
212
+ if not selection or 'results' not in _artifact_cache:
213
+ return '<p>Select an artifact above</p>'
214
+ choices = _artifact_cache['choices']
215
+ results = _artifact_cache['results']
216
+ if selection not in choices:
217
+ return '<p>Not found</p>'
218
+ r = results[choices.index(selection)]
219
+
220
+ em = r.get('exact_match', False)
221
+ fuzz = r.get('fuzzy_token_sort', 0)
222
+ sem = r.get('semantic_sim', 0)
223
+ top3 = r.get('top3', [])
224
+ bleu = r.get('bleu', None)
225
+ gt = r['gt']
226
+ pred = str(r['pred'])
227
+ field = list(eval_data[list(eval_data.keys())[0]]['results'].keys())[0]
228
+
229
+ status_color = '#2d6a4f' if em else '#e76f51'
230
+ status_text = 'Exact match' if em else 'No exact match'
231
+
232
+ top3_html = ''
233
+ if top3:
234
+ top3_html = '<div style="margin-top:0.5rem"><b>Top-3 candidates:</b> ' + \
235
+ ' Β· '.join(f'<span style="background:#f5f0e8;padding:2px 6px;border-radius:3px">{c}</span>' for c in top3) + '</div>'
236
+
237
+ bleu_html = f'<span style="margin-left:1rem">BLEU: <b>{bleu:.3f}</b></span>' if bleu is not None else ''
238
+
239
+ html = f'''
240
+ <div style="font-family: Georgia, serif; padding: 1.2rem; border: 1px solid #ddd; border-radius: 8px; background: white">
241
+ <div style="display:flex; justify-content:space-between; align-items:flex-start; margin-bottom:1rem">
242
+ <div>
243
+ <h2 style="margin:0; font-size:1.4rem; color:#18100A">{r["label"]}</h2>
244
+ <p style="margin:0.2rem 0 0; color:#666; font-style:italic">
245
+ {r.get("item_class","")} Β· {r.get("project","")}
246
+ </p>
247
+ </div>
248
+ <span style="background:{status_color}; color:white; padding:4px 12px; border-radius:4px; font-size:0.85rem">
249
+ {status_text}
250
+ </span>
251
+ </div>
252
+
253
+ <table style="width:100%; border-collapse:collapse; font-size:0.92rem">
254
+ <tr style="background:#f5f0e8">
255
+ <th style="padding:0.5rem 1rem; text-align:left; border-bottom:2px solid #ddd; width:120px">Field</th>
256
+ <th style="padding:0.5rem 1rem; text-align:left; border-bottom:2px solid #ddd">Value</th>
257
+ </tr>
258
+ <tr style="border-bottom:1px solid #eee">
259
+ <td style="padding:0.6rem 1rem; color:#7D3A10; font-weight:bold">Ground Truth</td>
260
+ <td style="padding:0.6rem 1rem; color:#1a1a1a">{gt}</td>
261
+ </tr>
262
+ <tr style="background:#fffdf7; border-bottom:1px solid #eee">
263
+ <td style="padding:0.6rem 1rem; color:#2d6a4f; font-weight:bold">Predicted</td>
264
+ <td style="padding:0.6rem 1rem; color:#1a1a1a">{pred}</td>
265
+ </tr>
266
+ </table>
267
+
268
+ {top3_html}
269
+
270
+ <div style="margin-top:1rem; padding:0.8rem; background:#f9f9f9; border-radius:4px; font-size:0.85rem; color:#444">
271
+ <b>Scores:</b>
272
+ Fuzzy match: <b>{fuzz:.2f}</b>
273
+ <span style="margin-left:1rem">Semantic similarity: <b>{sem:.2f}</b></span>
274
+ {bleu_html}
275
+ </div>
276
+ </div>
277
+ '''
278
+ return html
279
+
280
+ # ── build app ─────────────────────────────────────────────────────────────
281
+ run_names = list(eval_data.keys())
282
+ field_choices = [(FIELD_LABELS[f], f) for f in IMPUTE_FIELDS]
283
+
284
+ css = '''
285
+ .gr-button-primary { background: #7D3A10 !important; }
286
+ h1, h2, h3 { font-family: Georgia, serif !important; }
287
+ '''
288
+
289
+ with gr.Blocks(
290
+ title='ArchAIa Imputation Eval',
291
+ theme=gr.themes.Base(font=[gr.themes.GoogleFont('Source Serif 4'), 'Georgia', 'serif']),
292
+ css=css,
293
+ ) as demo:
294
+
295
+ gr.Markdown("""
296
+ # ArchAIa β€” Field Imputation Evaluation Dashboard
297
+ **CMU Language Technologies Institute Β· April 2026**
298
+
299
+ Evaluation of a multimodal RAG pipeline (DINOv2 + MiniLM + GPT-4o) for filling missing metadata fields
300
+ in archaeological artifacts from the OpenContext database.
301
+ Compare results across different retrieval settings (top-15 vs top-50 neighbors).
302
+ """)
303
+
304
+ with gr.Tabs():
305
+
306
+ # ── TAB 1: metrics overview ────────────────────────────────────────
307
+ with gr.Tab('Metrics Overview'):
308
+ with gr.Row():
309
+ run_selector = gr.CheckboxGroup(
310
+ choices=run_names,
311
+ value=run_names,
312
+ label='Select eval runs to compare',
313
+ )
314
+ metric_radio = gr.Radio(
315
+ choices=list(METRICS.items()),
316
+ value='exact_match',
317
+ label='Metric',
318
+ )
319
+
320
+ bar_chart = gr.Plot(label='Per-field scores by run')
321
+
322
+ with gr.Row():
323
+ radar_chart = gr.Plot(label='Overall radar (mean across fields)')
324
+ with gr.Column():
325
+ gr.Markdown("### Summary table")
326
+ summary_table = gr.Dataframe(label='', wrap=True)
327
+
328
+ def update_overview(runs, metric):
329
+ return (
330
+ make_bar_chart(runs, metric),
331
+ make_radar(runs),
332
+ make_summary_table(runs),
333
+ )
334
+
335
+ run_selector.change(update_overview, [run_selector, metric_radio], [bar_chart, radar_chart, summary_table])
336
+ metric_radio.change(update_overview, [run_selector, metric_radio], [bar_chart, radar_chart, summary_table])
337
+
338
+ # ── TAB 2: field deep dive ─────────────────────────────────────────
339
+ with gr.Tab('Field Deep Dive'):
340
+ gr.Markdown("Inspect per-artifact predictions for a specific field and run.")
341
+ with gr.Row():
342
+ dd_run = gr.Dropdown(choices=run_names, value=run_names[0], label='Eval run')
343
+ dd_field = gr.Dropdown(choices=field_choices, value='recovered_material', label='Field')
344
+
345
+ with gr.Row():
346
+ scatter = gr.Plot(label='Fuzzy match vs Semantic similarity (green = exact match)')
347
+ conf_m = gr.Plot(label='Confusion matrix')
348
+
349
+ gr.Markdown("### Errors only")
350
+ error_table = gr.Dataframe(label='Artifacts where exact match failed', wrap=True)
351
+
352
+ def update_deepdive(run, field):
353
+ return (
354
+ make_scatter(run, field),
355
+ make_confusion(run, field),
356
+ make_error_table(run, field),
357
+ )
358
+
359
+ dd_run.change(update_deepdive, [dd_run, dd_field], [scatter, conf_m, error_table])
360
+ dd_field.change(update_deepdive, [dd_run, dd_field], [scatter, conf_m, error_table])
361
+
362
+ # ── TAB 3: artifact browser ────────────────────────────────────────
363
+ with gr.Tab('Artifact Browser'):
364
+ gr.Markdown("Browse individual artifact predictions. Filter by run, field, and correct/incorrect.")
365
+ with gr.Row():
366
+ ab_run = gr.Dropdown(choices=run_names, value=run_names[0], label='Eval run')
367
+ ab_field = gr.Dropdown(choices=field_choices, value='recovered_material', label='Field')
368
+ ab_errors = gr.Checkbox(value=False, label='Show errors only')
369
+ ab_query = gr.Textbox(label='Search by label or ground truth', placeholder='e.g. Batch 5')
370
+
371
+ ab_select = gr.Dropdown(label='Select artifact', choices=[], interactive=True)
372
+ ab_search = gr.Button('Search / Refresh', variant='primary')
373
+ ab_card = gr.HTML('<p style="color:#aaa">Search for artifacts above</p>')
374
+
375
+ ab_search.click(
376
+ search_artifacts,
377
+ inputs=[ab_run, ab_field, ab_errors, ab_query],
378
+ outputs=[ab_select],
379
+ )
380
+ ab_select.change(show_artifact_card, inputs=[ab_select], outputs=[ab_card])
381
+
382
+ # ── TAB 4: about ───────────────────────────────────────────────────
383
+ with gr.Tab('About'):
384
+ gr.Markdown("""
385
+ ## Pipeline Architecture
386
+
387
+ **Encoding:** Each v4 artifact is encoded as a 1408-dim vector by concatenating:
388
+ - DINOv2 ViT-L/14 image embedding (1024-dim) from the artifact's photograph
389
+ - all-MiniLM-L6-v2 text embedding (384-dim) from concatenated metadata fields
390
+
391
+ **Index:** FAISS flat index (IndexFlatIP) built on the 85% train split of v4 (19,215 artifacts).
392
+ The remaining 15% (3,392 artifacts) are held out as the eval set.
393
+
394
+ **Retrieval:** For each eval artifact, the top-N most similar artifacts are retrieved from the index,
395
+ filtered to only those that have the target field populated.
396
+
397
+ **Generation:** GPT-4o receives the artifact image + available fields + up to N retrieved neighbors
398
+ as structured JSON context, plus a constrained vocabulary derived from the train split.
399
+
400
+ ## Eval Setup
401
+
402
+ - 85/15 stratified split of v4 by `(project_label, item_class_label)`
403
+ - 100 artifacts sampled per field from the eval split
404
+ - Each field evaluated independently β€” the target field is blanked and predicted
405
+ - **Runs compared:** top-15 neighbors vs top-50 neighbors passed to GPT-4o
406
+
407
+ ## Metrics
408
+
409
+ | Metric | Description |
410
+ |---|---|
411
+ | Exact Match | Strict case-insensitive string equality |
412
+ | Fuzzy Match | Token sort ratio (handles word order variation) |
413
+ | Semantic Similarity | Cosine similarity of sentence embeddings |
414
+ | Top-3 Match | Ground truth appears in model's top-3 candidates |
415
+ | BLEU | N-gram overlap β€” description field only |
416
+
417
+ Urmi Dedhia Β· CMU Β· April 2026 Β· ArchAIa Project
418
+ """)
419
+
420
+ # load defaults on start
421
+ demo.load(
422
+ lambda: update_overview(run_names, 'exact_match'),
423
+ outputs=[bar_chart, radar_chart, summary_table]
424
+ )
425
+ demo.load(
426
+ lambda: update_deepdive(run_names[0], 'recovered_material'),
427
+ outputs=[scatter, conf_m, error_table]
428
+ )
429
+
430
+ if __name__ == '__main__':
431
+ demo.launch()
432
+ EOF