Fred808 commited on
Commit
c0c84cd
·
verified ·
1 Parent(s): 1d84ccc

Update api/mdm_loader.py

Browse files
Files changed (1) hide show
  1. api/mdm_loader.py +15 -17
api/mdm_loader.py CHANGED
@@ -1,23 +1,21 @@
1
  import torch
2
- from utils.model_util import create_model_and_diffusion, load_saved_model
3
- from utils.parser_util import generate_args
4
  import os
5
 
6
- # Load once on startup
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
-
9
- # Load args.json from models directory
10
- args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
11
- args = generate_args()
12
- if os.path.exists(args_path):
13
- import json
14
- with open(args_path, "r") as f:
15
- args.__dict__.update(json.load(f))
16
-
17
- model, diffusion = create_model_and_diffusion(args, data=None)
18
- model = load_saved_model(model, "models/model000475000.pt")
19
- model.to(device)
20
- model.eval()
21
 
22
  def generate_motion(prompt: str, num_frames: int, style: str = "default"):
23
  # TODO: Implement motion generation using MDM repo's sampling logic
 
1
  import torch
2
+ from mdm_repo.utils.model_util import create_model_and_diffusion, load_saved_model
3
+ from mdm_repo.utils.parser_util import generate_args
4
  import os
5
 
6
+ def load_model_and_args():
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
9
+ args = generate_args()
10
+ if os.path.exists(args_path):
11
+ import json
12
+ with open(args_path, "r") as f:
13
+ args.__dict__.update(json.load(f))
14
+ model, diffusion = create_model_and_diffusion(args, data=None)
15
+ model = load_saved_model(model, "models/model000475000.pt")
16
+ model.to(device)
17
+ model.eval()
18
+ return model, diffusion, args, device
 
 
19
 
20
  def generate_motion(prompt: str, num_frames: int, style: str = "default"):
21
  # TODO: Implement motion generation using MDM repo's sampling logic