Spaces:
Paused
Paused
Update api/mdm_loader.py
Browse files- api/mdm_loader.py +7 -6
api/mdm_loader.py
CHANGED
|
@@ -60,16 +60,17 @@ def generate_motion(prompt: str, num_frames: int, style: str = "default"):
|
|
| 60 |
# Setup device
|
| 61 |
dist_util.setup_dist(args.device)
|
| 62 |
|
| 63 |
-
# Load model and diffusion
|
| 64 |
-
model, diffusion = create_model_and_diffusion(args, data=None)
|
| 65 |
-
load_saved_model(model, args.model_path, use_avg=getattr(args, 'use_ema', False))
|
| 66 |
-
model.to(dist_util.dev())
|
| 67 |
-
model.eval()
|
| 68 |
-
|
| 69 |
# Prepare dataset and model_kwargs
|
| 70 |
max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
|
| 71 |
n_frames = min(max_frames, int(args.motion_length * 20))
|
| 72 |
data = get_dataset_loader(name=args.dataset, batch_size=1, num_frames=max_frames, split='test', hml_mode='text_only', fixed_len=0, pred_len=0, device=dist_util.dev())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
iterator = iter(data)
|
| 74 |
_, model_kwargs = next(iterator)
|
| 75 |
model_kwargs['y']['text'] = [prompt]
|
|
|
|
| 60 |
# Setup device
|
| 61 |
dist_util.setup_dist(args.device)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
# Prepare dataset and model_kwargs
|
| 64 |
max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
|
| 65 |
n_frames = min(max_frames, int(args.motion_length * 20))
|
| 66 |
data = get_dataset_loader(name=args.dataset, batch_size=1, num_frames=max_frames, split='test', hml_mode='text_only', fixed_len=0, pred_len=0, device=dist_util.dev())
|
| 67 |
+
|
| 68 |
+
# Load model and diffusion with dataset
|
| 69 |
+
model, diffusion = create_model_and_diffusion(args, data)
|
| 70 |
+
load_saved_model(model, args.model_path, use_avg=getattr(args, 'use_ema', False))
|
| 71 |
+
model.to(dist_util.dev())
|
| 72 |
+
model.eval()
|
| 73 |
+
|
| 74 |
iterator = iter(data)
|
| 75 |
_, model_kwargs = next(iterator)
|
| 76 |
model_kwargs['y']['text'] = [prompt]
|