Pybunny commited on
Commit
db9df7d
·
verified ·
1 Parent(s): 0fc2400

Filter benchmark labels to 7-class subset

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -274,17 +274,30 @@ def _bench_data_root():
274
 
275
 
276
  def _bench_subset(n_frames):
277
- """Memory-mapped read of the first n_frames frames from benchmark/."""
278
- import tempfile
 
 
 
 
 
 
 
 
 
 
279
  root = _bench_data_root() / "benchmark"
280
  total = int(np.load(root / "x_vi_6s.npy", mmap_mode="r").shape[0])
281
  n = max(1, min(int(n_frames), total))
282
  x = np.asarray(np.load(root / "x_vi_6s.npy", mmap_mode="r")[:n],
283
  dtype=np.float32)
284
  lab = np.load(root / "labels_and_index.npz", allow_pickle=True)
285
- y = lab["y_power"][:n].astype(np.float32)
286
- cls = [str(c) for c in lab["class_names"]]
287
- return x, y, cls, total
 
 
 
288
 
289
 
290
  def _score_demo_pt(weights_file, n_frames):
 
274
 
275
 
276
  def _bench_subset(n_frames):
277
+ """Memory-mapped read of the first n_frames frames from benchmark/.
278
+
279
+ Filters the labels to the 7-category benchmark scoring set
280
+ (electrical heating is listed in the file but never activates in House 2
281
+ and is excluded by the official protocol). This matches the shape of
282
+ the bundled byom_demo.pt and any other DemoRegressor checkpoint
283
+ trained via examples/byom_demo.py.
284
+ """
285
+ BENCH_CLASSES = [
286
+ "always on", "cooking", "dishwasher", "electronics & lighting",
287
+ "fridge", "misc", "washing machine",
288
+ ]
289
  root = _bench_data_root() / "benchmark"
290
  total = int(np.load(root / "x_vi_6s.npy", mmap_mode="r").shape[0])
291
  n = max(1, min(int(n_frames), total))
292
  x = np.asarray(np.load(root / "x_vi_6s.npy", mmap_mode="r")[:n],
293
  dtype=np.float32)
294
  lab = np.load(root / "labels_and_index.npz", allow_pickle=True)
295
+ all_cls = [str(c) for c in lab["class_names"]]
296
+ keep = [all_cls.index(c) for c in BENCH_CLASSES if c in all_cls]
297
+ y_all = lab["y_power"][:n].astype(np.float32)
298
+ y = y_all[:, keep]
299
+ classes = [all_cls[i] for i in keep]
300
+ return x, y, classes, total
301
 
302
 
303
  def _score_demo_pt(weights_file, n_frames):