Spaces:
Build error
Build error
| import torch | |
| from torch.optim import Optimizer | |
| from lightning.pytorch import LightningModule | |
| from lightning.pytorch.callbacks import BasePredictionWriter | |
| from .ar import ARSystem, ARWriter | |
| from .skin import SkinSystem, SkinWriter | |
| def get_system(**kwargs) -> LightningModule: | |
| MAP = { | |
| 'ar': ARSystem, | |
| 'skin': SkinSystem, | |
| } | |
| __target__ = kwargs['__target__'] | |
| assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}" | |
| del kwargs['__target__'] | |
| return MAP[__target__](**kwargs) | |
| def get_writer(**kwargs) -> BasePredictionWriter: | |
| MAP = { | |
| 'ar': ARWriter, | |
| 'skin': SkinWriter, | |
| } | |
| __target__ = kwargs['__target__'] | |
| assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}" | |
| del kwargs['__target__'] | |
| return MAP[__target__](**kwargs) |