Tiger14n commited on
Commit
f90d455
·
1 Parent(s): 705c3b8
Files changed (2) hide show
  1. inference.py +4 -4
  2. requirements.txt +1 -1
inference.py CHANGED
@@ -1,4 +1,6 @@
1
  from pathlib import Path
 
 
2
  import torch
3
  from omegaconf import DictConfig
4
  from slider import Beatmap
@@ -8,9 +10,6 @@ from osuT5.inference import Preprocessor, Pipeline, Postprocessor, DiffisionPipe
8
  from osuT5.tokenizer import Tokenizer
9
  from osuT5.utils import get_model
10
 
11
- config_path = "configs/inference.yaml"
12
- config = DictConfig.load(config_path)
13
-
14
 
15
  def get_args_from_beatmap(args: DictConfig):
16
  if args.beatmap_path is None or args.beatmap_path == "":
@@ -49,6 +48,7 @@ def find_model(ckpt_path, args: DictConfig, device):
49
  return model
50
 
51
 
 
52
  def main(args: DictConfig):
53
  get_args_from_beatmap(args)
54
 
@@ -114,4 +114,4 @@ def main(args: DictConfig):
114
 
115
 
116
  if __name__ == "__main__":
117
- main(config)
 
1
  from pathlib import Path
2
+
3
+ import hydra
4
  import torch
5
  from omegaconf import DictConfig
6
  from slider import Beatmap
 
10
  from osuT5.tokenizer import Tokenizer
11
  from osuT5.utils import get_model
12
 
 
 
 
13
 
14
  def get_args_from_beatmap(args: DictConfig):
15
  if args.beatmap_path is None or args.beatmap_path == "":
 
48
  return model
49
 
50
 
51
+ @hydra.main(config_path="configs", config_name="inference", version_base="1.1")
52
  def main(args: DictConfig):
53
  get_args_from_beatmap(args)
54
 
 
114
 
115
 
116
  if __name__ == "__main__":
117
+ main()
requirements.txt CHANGED
@@ -7,4 +7,4 @@ tensorboard
7
  slider==0.8.1
8
  torch_tb_profiler
9
  rosu_pp_py
10
- omegaconf
 
7
  slider==0.8.1
8
  torch_tb_profiler
9
  rosu_pp_py
10
+ hydra-core