Filter benchmark labels to 7-class subset
Browse files
app.py
CHANGED
|
@@ -274,17 +274,30 @@ def _bench_data_root():
|
|
| 274 |
|
| 275 |
|
| 276 |
def _bench_subset(n_frames):
|
| 277 |
-
"""Memory-mapped read of the first n_frames frames from benchmark/.
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
root = _bench_data_root() / "benchmark"
|
| 280 |
total = int(np.load(root / "x_vi_6s.npy", mmap_mode="r").shape[0])
|
| 281 |
n = max(1, min(int(n_frames), total))
|
| 282 |
x = np.asarray(np.load(root / "x_vi_6s.npy", mmap_mode="r")[:n],
|
| 283 |
dtype=np.float32)
|
| 284 |
lab = np.load(root / "labels_and_index.npz", allow_pickle=True)
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
|
| 290 |
def _score_demo_pt(weights_file, n_frames):
|
|
|
|
| 274 |
|
| 275 |
|
| 276 |
def _bench_subset(n_frames):
|
| 277 |
+
"""Memory-mapped read of the first n_frames frames from benchmark/.
|
| 278 |
+
|
| 279 |
+
Filters the labels to the 7-category benchmark scoring set
|
| 280 |
+
(electrical heating is listed in the file but never activates in House 2
|
| 281 |
+
and is excluded by the official protocol). This matches the shape of
|
| 282 |
+
the bundled byom_demo.pt and any other DemoRegressor checkpoint
|
| 283 |
+
trained via examples/byom_demo.py.
|
| 284 |
+
"""
|
| 285 |
+
BENCH_CLASSES = [
|
| 286 |
+
"always on", "cooking", "dishwasher", "electronics & lighting",
|
| 287 |
+
"fridge", "misc", "washing machine",
|
| 288 |
+
]
|
| 289 |
root = _bench_data_root() / "benchmark"
|
| 290 |
total = int(np.load(root / "x_vi_6s.npy", mmap_mode="r").shape[0])
|
| 291 |
n = max(1, min(int(n_frames), total))
|
| 292 |
x = np.asarray(np.load(root / "x_vi_6s.npy", mmap_mode="r")[:n],
|
| 293 |
dtype=np.float32)
|
| 294 |
lab = np.load(root / "labels_and_index.npz", allow_pickle=True)
|
| 295 |
+
all_cls = [str(c) for c in lab["class_names"]]
|
| 296 |
+
keep = [all_cls.index(c) for c in BENCH_CLASSES if c in all_cls]
|
| 297 |
+
y_all = lab["y_power"][:n].astype(np.float32)
|
| 298 |
+
y = y_all[:, keep]
|
| 299 |
+
classes = [all_cls[i] for i in keep]
|
| 300 |
+
return x, y, classes, total
|
| 301 |
|
| 302 |
|
| 303 |
def _score_demo_pt(weights_file, n_frames):
|