Fred808 commited on
Commit
d25e8a6
·
verified ·
1 Parent(s): 7480aa2

Update api/mdm_loader.py

Browse files
Files changed (1) hide show
  1. api/mdm_loader.py +7 -6
api/mdm_loader.py CHANGED
@@ -60,16 +60,17 @@ def generate_motion(prompt: str, num_frames: int, style: str = "default"):
60
  # Setup device
61
  dist_util.setup_dist(args.device)
62
 
63
- # Load model and diffusion
64
- model, diffusion = create_model_and_diffusion(args, data=None)
65
- load_saved_model(model, args.model_path, use_avg=getattr(args, 'use_ema', False))
66
- model.to(dist_util.dev())
67
- model.eval()
68
-
69
  # Prepare dataset and model_kwargs
70
  max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
71
  n_frames = min(max_frames, int(args.motion_length * 20))
72
  data = get_dataset_loader(name=args.dataset, batch_size=1, num_frames=max_frames, split='test', hml_mode='text_only', fixed_len=0, pred_len=0, device=dist_util.dev())
 
 
 
 
 
 
 
73
  iterator = iter(data)
74
  _, model_kwargs = next(iterator)
75
  model_kwargs['y']['text'] = [prompt]
 
60
  # Setup device
61
  dist_util.setup_dist(args.device)
62
 
 
 
 
 
 
 
63
  # Prepare dataset and model_kwargs
64
  max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
65
  n_frames = min(max_frames, int(args.motion_length * 20))
66
  data = get_dataset_loader(name=args.dataset, batch_size=1, num_frames=max_frames, split='test', hml_mode='text_only', fixed_len=0, pred_len=0, device=dist_util.dev())
67
+
68
+ # Load model and diffusion with dataset
69
+ model, diffusion = create_model_and_diffusion(args, data)
70
+ load_saved_model(model, args.model_path, use_avg=getattr(args, 'use_ema', False))
71
+ model.to(dist_util.dev())
72
+ model.eval()
73
+
74
  iterator = iter(data)
75
  _, model_kwargs = next(iterator)
76
  model_kwargs['y']['text'] = [prompt]