| import io |
| import json |
| import numpy as np |
| import sys |
| import os |
| from datasets import load_dataset |
| from hoho2025.metric_helper import hss |
|
|
| |
| import script |
|
|
| print("Loading dataset from Hugging Face streaming (usm3d/hoho22k_2026_trainval)...") |
| dataset = load_dataset('usm3d/hoho22k_2026_trainval', split='train', streaming=True, trust_remote_code=True) |
|
|
| scores = [] |
|
|
| for idx, sample in enumerate(dataset): |
| if idx >= 3: |
| break |
|
|
| order_id = sample.get('order_id', str(idx)) |
| print(f"\n--- Testing order_id: {order_id} ---") |
| |
| |
| from hoho2025 import example_solutions |
| base_v, base_e = example_solutions.predict_wireframe(sample) |
|
|
| |
| pred_v, pred_e, _ = script.predict_wireframe_safely(sample) |
|
|
| gt_v = sample.get('wf_vertices') |
| gt_e = sample.get('wf_edges') |
| |
| if gt_v is None or gt_e is None: |
| print("Missing ground truth for this sample.") |
| continue |
|
|
| |
| base_res = hss(base_v, base_e, gt_v, gt_e) |
| res = hss(pred_v, pred_e, gt_v, gt_e) |
| scores.append(res.hss) |
| |
| print(f"BASELINE Predict -> Vertices: {len(base_v)} | Edges: {len(base_e)} | HSS: {base_res.hss:.4f}") |
| print(f"FILTERED Predict -> Vertices: {len(pred_v)} | Edges: {len(pred_e)} | HSS: {res.hss:.4f}") |
|
|
| avg_score = sum(scores) / len(scores) if scores else 0 |
| print(f"\nAverage FILTERED HSS Score on subset: {avg_score:.4f}") |
|
|