| import io |
| import json |
| import tarfile |
| import zipfile |
| import numpy as np |
| import sys |
| import os |
| from datasets import load_dataset |
| from hoho2025.metric_helper import hss |
|
|
| |
| import script |
|
|
| print("Loading dataset from local parquet...") |
| dataset = load_dataset('parquet', data_files={"train": "/tmp/data/data/train-00000-of-00002.parquet"}) |
| print(f"Loaded {len(dataset['train'])} examples.") |
|
|
| scores = [] |
|
|
| for idx, sample in enumerate(dataset['train']): |
| if idx >= 5: |
| break |
|
|
| order_id = sample.get('order_id', str(idx)) |
| print(f"\n--- Testing order_id: {order_id} ---") |
| |
| |
| |
| pred_v, pred_e, _ = script.predict_wireframe_safely(sample) |
|
|
| |
| gt_v = None |
| gt_e = None |
| |
| try: |
| with zipfile.ZipFile(io.BytesIO(sample['data']), "r") as zf: |
| if 'gt_vertices.npy' in zf.namelist(): |
| gt_v = np.load(io.BytesIO(zf.read('gt_vertices.npy'))) |
| if 'gt_edges.npy' in zf.namelist(): |
| gt_e = np.load(io.BytesIO(zf.read('gt_edges.npy'))) |
| except Exception as e: |
| print(f"Failed to read ZIP contents for GT: {e}") |
| continue |
| |
| if gt_v is None or gt_e is None: |
| print("Missing ground truth for this sample.") |
| continue |
|
|
| |
| res = hss(pred_v, pred_e, gt_v, gt_e) |
| scores.append(res.hss) |
| |
| print(f"Predicted Vertices: {len(pred_v)} | Predicted Edges: {len(pred_e)}") |
| print(f"GT Vertices: {len(gt_v)} | GT Edges: {len(gt_e)}") |
| print(f"HSS Score: {res.hss:.4f}") |
|
|
| avg_score = sum(scores) / len(scores) if scores else 0 |
| print(f"\nAverage HSS Score on subset: {avg_score:.4f}") |
|
|