s23-model / test_local.py
IhorIvanyshyn01's picture
Update tuned hyperparameters
8f748c3
import io
import json
import tarfile
import zipfile
import numpy as np
import sys
import os
from datasets import load_dataset
from hoho2025.metric_helper import hss
# Import the solution script locally
import script
print("Loading dataset from local parquet...")
dataset = load_dataset('parquet', data_files={"train": "/tmp/data/data/train-00000-of-00002.parquet"})
print(f"Loaded {len(dataset['train'])} examples.")
scores = []
for idx, sample in enumerate(dataset['train']):
if idx >= 5: # Just test first 5
break
order_id = sample.get('order_id', str(idx))
print(f"\n--- Testing order_id: {order_id} ---")
# 1. Evaluate prediction
# This will likely fall back to empty_solution due to missing 'gestalt'
pred_v, pred_e, _ = script.predict_wireframe_safely(sample)
# 2. Extract Ground Truth from ZIP
gt_v = None
gt_e = None
try:
with zipfile.ZipFile(io.BytesIO(sample['data']), "r") as zf:
if 'gt_vertices.npy' in zf.namelist():
gt_v = np.load(io.BytesIO(zf.read('gt_vertices.npy')))
if 'gt_edges.npy' in zf.namelist():
gt_e = np.load(io.BytesIO(zf.read('gt_edges.npy')))
except Exception as e:
print(f"Failed to read ZIP contents for GT: {e}")
continue
if gt_v is None or gt_e is None:
print("Missing ground truth for this sample.")
continue
# 3. Compute HSS metric Score
res = hss(pred_v, pred_e, gt_v, gt_e)
scores.append(res.hss)
print(f"Predicted Vertices: {len(pred_v)} | Predicted Edges: {len(pred_e)}")
print(f"GT Vertices: {len(gt_v)} | GT Edges: {len(gt_e)}")
print(f"HSS Score: {res.hss:.4f}")
avg_score = sum(scores) / len(scores) if scores else 0
print(f"\nAverage HSS Score on subset: {avg_score:.4f}")