Fred808 commited on
Commit
f389eb9
·
verified ·
1 Parent(s): 8a45a74

Update api/mdm_loader.py

Browse files
Files changed (1) hide show
  1. api/mdm_loader.py +25 -25
api/mdm_loader.py CHANGED
@@ -1,25 +1,25 @@
1
- import torch
2
- from mdm_repo.utils.model_util import create_model_and_diffusion, load_saved_model
3
- from mdm_repo.utils.parser_util import generate_args
4
- import os
5
-
6
- # Load once on startup
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
-
9
- # Load args.json from models directory
10
- args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
11
- args = generate_args()
12
- if os.path.exists(args_path):
13
- import json
14
- with open(args_path, "r") as f:
15
- args.__dict__.update(json.load(f))
16
-
17
- model, diffusion = create_model_and_diffusion(args, data=None)
18
- model = load_saved_model(model, "models/model000475000.pt")
19
- model.to(device)
20
- model.eval()
21
-
22
- def generate_motion(prompt: str, num_frames: int, style: str = "default"):
23
- # TODO: Implement motion generation using MDM repo's sampling logic
24
- # See mdm_repo/sample/generate.py for reference
25
- raise NotImplementedError("Motion generation logic needs to be implemented using MDM repo's sampling code.")
 
1
+ import torch
2
+ from utils.model_util import create_model_and_diffusion, load_saved_model
3
+ from utils.parser_util import generate_args
4
+ import os
5
+
6
+ # Load once on startup
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load args.json from models directory
10
+ args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
11
+ args = generate_args()
12
+ if os.path.exists(args_path):
13
+ import json
14
+ with open(args_path, "r") as f:
15
+ args.__dict__.update(json.load(f))
16
+
17
+ model, diffusion = create_model_and_diffusion(args, data=None)
18
+ model = load_saved_model(model, "models/model000475000.pt")
19
+ model.to(device)
20
+ model.eval()
21
+
22
+ def generate_motion(prompt: str, num_frames: int, style: str = "default"):
23
+ # TODO: Implement motion generation using MDM repo's sampling logic
24
+ # See mdm_repo/sample/generate.py for reference
25
+ raise NotImplementedError("Motion generation logic needs to be implemented using MDM repo's sampling code.")