edit
Browse files- inference.py +4 -4
- requirements.txt +1 -1
inference.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
from pathlib import Path
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
from omegaconf import DictConfig
|
| 4 |
from slider import Beatmap
|
|
@@ -8,9 +10,6 @@ from osuT5.inference import Preprocessor, Pipeline, Postprocessor, DiffisionPipe
|
|
| 8 |
from osuT5.tokenizer import Tokenizer
|
| 9 |
from osuT5.utils import get_model
|
| 10 |
|
| 11 |
-
config_path = "configs/inference.yaml"
|
| 12 |
-
config = DictConfig.load(config_path)
|
| 13 |
-
|
| 14 |
|
| 15 |
def get_args_from_beatmap(args: DictConfig):
|
| 16 |
if args.beatmap_path is None or args.beatmap_path == "":
|
|
@@ -49,6 +48,7 @@ def find_model(ckpt_path, args: DictConfig, device):
|
|
| 49 |
return model
|
| 50 |
|
| 51 |
|
|
|
|
| 52 |
def main(args: DictConfig):
|
| 53 |
get_args_from_beatmap(args)
|
| 54 |
|
|
@@ -114,4 +114,4 @@ def main(args: DictConfig):
|
|
| 114 |
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
-
main(
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
import torch
|
| 5 |
from omegaconf import DictConfig
|
| 6 |
from slider import Beatmap
|
|
|
|
| 10 |
from osuT5.tokenizer import Tokenizer
|
| 11 |
from osuT5.utils import get_model
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def get_args_from_beatmap(args: DictConfig):
|
| 15 |
if args.beatmap_path is None or args.beatmap_path == "":
|
|
|
|
| 48 |
return model
|
| 49 |
|
| 50 |
|
| 51 |
+
@hydra.main(config_path="configs", config_name="inference", version_base="1.1")
|
| 52 |
def main(args: DictConfig):
|
| 53 |
get_args_from_beatmap(args)
|
| 54 |
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
+
main()
|
requirements.txt
CHANGED
|
@@ -7,4 +7,4 @@ tensorboard
|
|
| 7 |
slider==0.8.1
|
| 8 |
torch_tb_profiler
|
| 9 |
rosu_pp_py
|
| 10 |
-
|
|
|
|
| 7 |
slider==0.8.1
|
| 8 |
torch_tb_profiler
|
| 9 |
rosu_pp_py
|
| 10 |
+
hydra-core
|