alex commited on
Commit
5fb16c4
·
1 Parent(s): f77572e

unlogged users

Browse files
app.py CHANGED
@@ -252,6 +252,8 @@ def is_portrait(video_file):
252
 
253
  def calculate_time_required(max_duration_s, rc_bool):
254
 
 
 
255
  if max_duration_s == 2:
256
  return 120
257
  elif max_duration_s == 4:
@@ -300,9 +302,15 @@ def _animate(input_video, max_duration_s, edited_frame, rc_bool, pts_by_frame, l
300
  else:
301
  w, h = 832, 480
302
 
 
 
 
 
 
 
303
  tag_string = "replace_flag" if rc_bool else "retarget_flag"
304
 
305
- preprocess_model = load_preprocess_models()
306
 
307
  # NOTE: run_preprocess now receives dicts keyed by frame_idx.
308
  # Update run_preprocess(...) accordingly in your preprocess_data.py.
@@ -422,10 +430,9 @@ def get_sam_mask(prompt_state, frames, frame_idx, input_points, input_labels):
422
  :return: (H, W) boolean mask, (H, W) float32 logits (or None), prompt_state
423
  """
424
 
425
-
426
- model_cfg = "sam2_hiera_l.yaml"
427
  ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
428
- sam2_checkpoint_path = os.path.join(ckpt_path, 'sam2/sam2_hiera_large.pt')
429
 
430
  video_predictor_local = build_sam2_video_predictor(model_cfg, sam2_checkpoint_path, device="cpu")
431
  inference_state = video_predictor_local.init_state(images=np.array(frames), device="cpu")
@@ -508,7 +515,16 @@ def animate_scene(input_video, max_duration_s, edited_frame, rc_str,
508
  except Exception as e:
509
  err = str(e).lower()
510
  print(f"{session_id} failed due to {err}")
511
- raise
 
 
 
 
 
 
 
 
 
512
 
513
  final_video_path = os.path.join(output_dir, 'final_result.mp4')
514
 
 
252
 
253
  def calculate_time_required(max_duration_s, rc_bool):
254
 
255
+ if max_duration_s == -1:
256
+ return 75
257
  if max_duration_s == 2:
258
  return 120
259
  elif max_duration_s == 4:
 
302
  else:
303
  w, h = 832, 480
304
 
305
+ if max_duration_s == -1:
306
+ if is_portrait(input_video):
307
+ w, h = 360, 640
308
+ else:
309
+ w, h = 640, 360
310
+
311
  tag_string = "replace_flag" if rc_bool else "retarget_flag"
312
 
313
+ preprocess_model = load_preprocess_models(max_duration_s)
314
 
315
  # NOTE: run_preprocess now receives dicts keyed by frame_idx.
316
  # Update run_preprocess(...) accordingly in your preprocess_data.py.
 
430
  :return: (H, W) boolean mask, (H, W) float32 logits (or None), prompt_state
431
  """
432
 
433
+ model_cfg = "sam2_hiera_s.yaml"
 
434
  ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
435
+ sam2_checkpoint_path = os.path.join(ckpt_path, 'sam2/sam2_hiera_small.pt')
436
 
437
  video_predictor_local = build_sam2_video_predictor(model_cfg, sam2_checkpoint_path, device="cpu")
438
  inference_state = video_predictor_local.init_state(images=np.array(frames), device="cpu")
 
515
  except Exception as e:
516
  err = str(e).lower()
517
  print(f"{session_id} failed due to {err}")
518
+ try:
519
+
520
+ output_video_path = _animate(
521
+ input_video, -1, edited_frame_png, rc_bool,
522
+ pts_by_frame, lbs_by_frame, session_id, progress
523
+ )
524
+ except Exception as e:
525
+ err = str(e).lower()
526
+ print(f"{session_id} failed due to {err}")
527
+ raise
528
 
529
  final_video_path = os.path.join(output_dir, 'final_result.mp4')
530
 
wan/modules/animate/preprocess/preprocess_data.py CHANGED
@@ -33,12 +33,18 @@ def _parse_args():
33
 
34
  return args
35
 
36
- def load_preprocess_models():
37
  ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
38
 
39
  pose2d_checkpoint_path = os.path.join(ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
40
  det_checkpoint_path = os.path.join(ckpt_path, 'det/yolov10m.onnx')
41
- sam2_checkpoint_path = os.path.join(ckpt_path, 'sam2/sam2_hiera_large.pt')
 
 
 
 
 
 
42
  flux_kontext_path = None
43
 
44
  process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)
 
33
 
34
  return args
35
 
36
+ def load_preprocess_models(max_duration_s):
37
  ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
38
 
39
  pose2d_checkpoint_path = os.path.join(ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
40
  det_checkpoint_path = os.path.join(ckpt_path, 'det/yolov10m.onnx')
41
+
42
+ if max_duration_s == -1:
43
+ print("using small sam2")
44
+ sam2_checkpoint_path = [os.path.join(ckpt_path, 'sam2/sam2_hiera_small.pt'),"sam2_hiera_s.yaml"]
45
+ else:
46
+ sam2_checkpoint_path = [os.path.join(ckpt_path, 'sam2/sam2_hiera_large.pt'),"sam2_hiera_l.yaml"]
47
+
48
  flux_kontext_path = None
49
 
50
  process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)
wan/modules/animate/preprocess/process_pipepline.py CHANGED
@@ -73,9 +73,9 @@ class ProcessPipeline():
73
  def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
74
  self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
75
 
76
- model_cfg = "sam2_hiera_l.yaml"
77
  if sam_checkpoint_path is not None:
78
- self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path, device="cuda")
 
79
  if flux_kontext_path is not None:
80
  self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
81
 
 
73
  def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
74
  self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
75
 
 
76
  if sam_checkpoint_path is not None:
77
+ model_cfg = sam_checkpoint_path[1]
78
+ self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path[0], device="cuda")
79
  if flux_kontext_path is not None:
80
  self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
81