| import string |
|
|
| import h5py |
| import torch |
|
|
| from siclib.datasets.base_dataset import collate |
| from siclib.models.base_model import BaseModel |
| from siclib.settings import DATA_PATH |
| from siclib.utils.tensor import batch_to_device |
|
|
| |
| |
|
|
|
|
| def pad_line_features(pred, seq_l: int = None): |
| raise NotImplementedError |
|
|
|
|
| def recursive_load(grp, pkeys): |
| return { |
| k: ( |
| torch.from_numpy(grp[k].__array__()) |
| if isinstance(grp[k], h5py.Dataset) |
| else recursive_load(grp[k], list(grp.keys())) |
| ) |
| for k in pkeys |
| } |
|
|
|
|
| class CacheLoader(BaseModel): |
| default_conf = { |
| "path": "???", |
| "data_keys": None, |
| "device": None, |
| "trainable": False, |
| "add_data_path": True, |
| "collate": True, |
| "scale": ["keypoints"], |
| "padding_fn": None, |
| "padding_length": None, |
| "numeric_type": "float32", |
| } |
|
|
| required_data_keys = ["name"] |
|
|
| def _init(self, conf): |
| self.hfiles = {} |
| self.padding_fn = conf.padding_fn |
| if self.padding_fn is not None: |
| self.padding_fn = eval(self.padding_fn) |
| self.numeric_dtype = { |
| None: None, |
| "float16": torch.float16, |
| "float32": torch.float32, |
| "float64": torch.float64, |
| }[conf.numeric_type] |
|
|
| def _forward(self, data): |
| preds = [] |
| device = self.conf.device |
| if not device: |
| if devices := {v.device for v in data.values() if isinstance(v, torch.Tensor)}: |
| assert len(devices) == 1 |
| device = devices.pop() |
|
|
| else: |
| device = "cpu" |
|
|
| var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]] |
| for i, name in enumerate(data["name"]): |
| fpath = self.conf.path.format(**{k: data[k][i] for k in var_names}) |
| if self.conf.add_data_path: |
| fpath = DATA_PATH / fpath |
| hfile = h5py.File(str(fpath), "r") |
| grp = hfile[name] |
| pkeys = self.conf.data_keys if self.conf.data_keys is not None else grp.keys() |
| pred = recursive_load(grp, pkeys) |
| if self.numeric_dtype is not None: |
| pred = { |
| k: ( |
| v |
| if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v) |
| else v.to(dtype=self.numeric_dtype) |
| ) |
| for k, v in pred.items() |
| } |
| pred = batch_to_device(pred, device) |
| for k, v in pred.items(): |
| for pattern in self.conf.scale: |
| if k.startswith(pattern): |
| view_idx = k.replace(pattern, "") |
| scales = ( |
| data["scales"] |
| if len(view_idx) == 0 |
| else data[f"view{view_idx}"]["scales"] |
| ) |
| pred[k] = pred[k] * scales[i] |
| |
| if self.padding_fn is not None: |
| pred = self.padding_fn(pred, self.conf.padding_length) |
| preds.append(pred) |
| hfile.close() |
| if self.conf.collate: |
| return batch_to_device(collate(preds), device) |
| assert len(preds) == 1 |
| return batch_to_device(preds[0], device) |
|
|
| def loss(self, pred, data): |
| raise NotImplementedError |
|
|