File size: 1,126 Bytes
7df6a88 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | import io
import json
import numpy as np
import sys
import os
from datasets import load_dataset
from hoho2025.metric_helper import hss
import sklearn_submission
print("Loading dataset...")
dataset = load_dataset('usm3d/hoho22k_2026_trainval', split='train', streaming=True, trust_remote_code=True)
samples = []
for idx, s in enumerate(dataset):
if idx >= 10:
break
samples.append(s)
scores = []
for idx, sample in enumerate(samples):
print(f"Testing sample {idx}")
try:
pred_v, pred_e = sklearn_submission.predict_wireframe_sklearn(sample)
except Exception as e:
print(f"Error on sample {idx}: {e}")
pred_v, pred_e = np.zeros((2, 3)), [(0, 1)]
gt_v = sample.get('wf_vertices')
gt_e = sample.get('wf_edges')
if gt_v is None or gt_e is None:
print(f"Skipping sample {idx} due to missing ground truth.")
continue
res = hss(pred_v, pred_e, gt_v, gt_e)
scores.append(res.hss)
print(f"Sample {idx} HSS: {res.hss:.4f}")
if scores:
print(f"Average HSS: {sum(scores) / len(scores):.4f}")
else:
print("No valid scores.")
|