wi-lab commited on
Commit
3c27f51
·
1 Parent(s): cd5b954

Create io_demo.py

Browse files
Files changed (1) hide show
  1. io_demo.py +47 -0
io_demo.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Optional, Tuple
3
+ import torch
4
+
5
+ DEMO_ROOT = "data"
6
+
7
+ def list_demo_tasks(root: str = DEMO_ROOT) -> List[str]:
8
+ r = Path(root)
9
+ if not r.exists():
10
+ return []
11
+ return sorted([p.name for p in r.iterdir() if p.is_dir()])
12
+
13
+ def list_demo_scenarios(task: str, root: str = DEMO_ROOT) -> List[str]:
14
+ base = Path(root) / task
15
+ if not base.exists():
16
+ return []
17
+ return sorted([p.name for p in base.iterdir() if p.is_dir()])
18
+
19
+ def _find_dataset_file(scenario_dir: Path) -> Optional[Path]:
20
+ preferred = ["train_data.pt", "data.pt", "dataset.pt"]
21
+ for name in preferred:
22
+ cand = scenario_dir / name
23
+ if cand.exists():
24
+ return cand
25
+ for ext in ("*.pt", "*.p"):
26
+ files = list(scenario_dir.glob(ext))
27
+ if files:
28
+ return files[0]
29
+ return None
30
+
31
+ def list_demo_dataset_files(task: str, root: str = DEMO_ROOT) -> List[str]:
32
+ out = []
33
+ base = Path(root) / task
34
+ if not base.exists():
35
+ return out
36
+ for scen in list_demo_scenarios(task, root):
37
+ scen_dir = base / scen
38
+ f = _find_dataset_file(scen_dir)
39
+ if f is not None:
40
+ out.append(str(f))
41
+ return out
42
+
43
+ def load_pt_dataset(path: str) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
44
+ obj = torch.load(path, map_location="cpu")
45
+ ch = obj["channels"]
46
+ y = obj.get("labels", None)
47
+ return ch, y