File size: 1,822 Bytes
8f748c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import io
import json
import tarfile
import zipfile
import numpy as np
import sys
import os
from datasets import load_dataset
from hoho2025.metric_helper import hss

# Import the solution script locally
import script

print("Loading dataset from local parquet...")
dataset = load_dataset('parquet', data_files={"train": "/tmp/data/data/train-00000-of-00002.parquet"})
print(f"Loaded {len(dataset['train'])} examples.")

scores = []

for idx, sample in enumerate(dataset['train']):
    if idx >= 5: # Just test first 5
        break

    order_id = sample.get('order_id', str(idx))
    print(f"\n--- Testing order_id: {order_id} ---")
    
    # 1. Evaluate prediction
    # This will likely fall back to empty_solution due to missing 'gestalt'
    pred_v, pred_e, _ = script.predict_wireframe_safely(sample)

    # 2. Extract Ground Truth from ZIP
    gt_v = None
    gt_e = None
    
    try:
        with zipfile.ZipFile(io.BytesIO(sample['data']), "r") as zf:
            if 'gt_vertices.npy' in zf.namelist():
                gt_v = np.load(io.BytesIO(zf.read('gt_vertices.npy')))
            if 'gt_edges.npy' in zf.namelist():
                gt_e = np.load(io.BytesIO(zf.read('gt_edges.npy')))
    except Exception as e:
        print(f"Failed to read ZIP contents for GT: {e}")
        continue
    
    if gt_v is None or gt_e is None:
        print("Missing ground truth for this sample.")
        continue

    # 3. Compute HSS metric Score
    res = hss(pred_v, pred_e, gt_v, gt_e)
    scores.append(res.hss)
    
    print(f"Predicted Vertices: {len(pred_v)} | Predicted Edges: {len(pred_e)}")
    print(f"GT Vertices: {len(gt_v)} | GT Edges: {len(gt_e)}")
    print(f"HSS Score: {res.hss:.4f}")

avg_score = sum(scores) / len(scores) if scores else 0
print(f"\nAverage HSS Score on subset: {avg_score:.4f}")