Spaces:
Running
on
Zero
Running
on
Zero
alex
commited on
Commit
·
5fb16c4
1
Parent(s):
f77572e
unlogged users
Browse files
app.py
CHANGED
|
@@ -252,6 +252,8 @@ def is_portrait(video_file):
|
|
| 252 |
|
| 253 |
def calculate_time_required(max_duration_s, rc_bool):
|
| 254 |
|
|
|
|
|
|
|
| 255 |
if max_duration_s == 2:
|
| 256 |
return 120
|
| 257 |
elif max_duration_s == 4:
|
|
@@ -300,9 +302,15 @@ def _animate(input_video, max_duration_s, edited_frame, rc_bool, pts_by_frame, l
|
|
| 300 |
else:
|
| 301 |
w, h = 832, 480
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
tag_string = "replace_flag" if rc_bool else "retarget_flag"
|
| 304 |
|
| 305 |
-
preprocess_model = load_preprocess_models()
|
| 306 |
|
| 307 |
# NOTE: run_preprocess now receives dicts keyed by frame_idx.
|
| 308 |
# Update run_preprocess(...) accordingly in your preprocess_data.py.
|
|
@@ -422,10 +430,9 @@ def get_sam_mask(prompt_state, frames, frame_idx, input_points, input_labels):
|
|
| 422 |
:return: (H, W) boolean mask, (H, W) float32 logits (or None), prompt_state
|
| 423 |
"""
|
| 424 |
|
| 425 |
-
|
| 426 |
-
model_cfg = "sam2_hiera_l.yaml"
|
| 427 |
ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
|
| 428 |
-
sam2_checkpoint_path = os.path.join(ckpt_path, 'sam2/
|
| 429 |
|
| 430 |
video_predictor_local = build_sam2_video_predictor(model_cfg, sam2_checkpoint_path, device="cpu")
|
| 431 |
inference_state = video_predictor_local.init_state(images=np.array(frames), device="cpu")
|
|
@@ -508,7 +515,16 @@ def animate_scene(input_video, max_duration_s, edited_frame, rc_str,
|
|
| 508 |
except Exception as e:
|
| 509 |
err = str(e).lower()
|
| 510 |
print(f"{session_id} failed due to {err}")
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
final_video_path = os.path.join(output_dir, 'final_result.mp4')
|
| 514 |
|
|
|
|
| 252 |
|
| 253 |
def calculate_time_required(max_duration_s, rc_bool):
|
| 254 |
|
| 255 |
+
if max_duration_s == -1:
|
| 256 |
+
return 75
|
| 257 |
if max_duration_s == 2:
|
| 258 |
return 120
|
| 259 |
elif max_duration_s == 4:
|
|
|
|
| 302 |
else:
|
| 303 |
w, h = 832, 480
|
| 304 |
|
| 305 |
+
if max_duration_s == -1:
|
| 306 |
+
if is_portrait(input_video):
|
| 307 |
+
w, h = 360, 640
|
| 308 |
+
else:
|
| 309 |
+
w, h = 640, 360
|
| 310 |
+
|
| 311 |
tag_string = "replace_flag" if rc_bool else "retarget_flag"
|
| 312 |
|
| 313 |
+
preprocess_model = load_preprocess_models(max_duration_s)
|
| 314 |
|
| 315 |
# NOTE: run_preprocess now receives dicts keyed by frame_idx.
|
| 316 |
# Update run_preprocess(...) accordingly in your preprocess_data.py.
|
|
|
|
| 430 |
:return: (H, W) boolean mask, (H, W) float32 logits (or None), prompt_state
|
| 431 |
"""
|
| 432 |
|
| 433 |
+
model_cfg = "sam2_hiera_s.yaml"
|
|
|
|
| 434 |
ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
|
| 435 |
+
sam2_checkpoint_path = os.path.join(ckpt_path, 'sam2/sam2_hiera_small.pt')
|
| 436 |
|
| 437 |
video_predictor_local = build_sam2_video_predictor(model_cfg, sam2_checkpoint_path, device="cpu")
|
| 438 |
inference_state = video_predictor_local.init_state(images=np.array(frames), device="cpu")
|
|
|
|
| 515 |
except Exception as e:
|
| 516 |
err = str(e).lower()
|
| 517 |
print(f"{session_id} failed due to {err}")
|
| 518 |
+
try:
|
| 519 |
+
|
| 520 |
+
output_video_path = _animate(
|
| 521 |
+
input_video, -1, edited_frame_png, rc_bool,
|
| 522 |
+
pts_by_frame, lbs_by_frame, session_id, progress
|
| 523 |
+
)
|
| 524 |
+
except Exception as e:
|
| 525 |
+
err = str(e).lower()
|
| 526 |
+
print(f"{session_id} failed due to {err}")
|
| 527 |
+
raise
|
| 528 |
|
| 529 |
final_video_path = os.path.join(output_dir, 'final_result.mp4')
|
| 530 |
|
wan/modules/animate/preprocess/preprocess_data.py
CHANGED
|
@@ -33,12 +33,18 @@ def _parse_args():
|
|
| 33 |
|
| 34 |
return args
|
| 35 |
|
| 36 |
-
def load_preprocess_models():
|
| 37 |
ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
|
| 38 |
|
| 39 |
pose2d_checkpoint_path = os.path.join(ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
|
| 40 |
det_checkpoint_path = os.path.join(ckpt_path, 'det/yolov10m.onnx')
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
flux_kontext_path = None
|
| 43 |
|
| 44 |
process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)
|
|
|
|
| 33 |
|
| 34 |
return args
|
| 35 |
|
| 36 |
+
def load_preprocess_models(max_duration_s):
|
| 37 |
ckpt_path = "./Wan2.2-Animate-14B/process_checkpoint"
|
| 38 |
|
| 39 |
pose2d_checkpoint_path = os.path.join(ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
|
| 40 |
det_checkpoint_path = os.path.join(ckpt_path, 'det/yolov10m.onnx')
|
| 41 |
+
|
| 42 |
+
if max_duration_s == -1:
|
| 43 |
+
print("using small sam2")
|
| 44 |
+
sam2_checkpoint_path = [os.path.join(ckpt_path, 'sam2/sam2_hiera_small.pt'),"sam2_hiera_s.yaml"]
|
| 45 |
+
else:
|
| 46 |
+
sam2_checkpoint_path = [os.path.join(ckpt_path, 'sam2/sam2_hiera_large.pt'),"sam2_hiera_l.yaml"]
|
| 47 |
+
|
| 48 |
flux_kontext_path = None
|
| 49 |
|
| 50 |
process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)
|
wan/modules/animate/preprocess/process_pipepline.py
CHANGED
|
@@ -73,9 +73,9 @@ class ProcessPipeline():
|
|
| 73 |
def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
|
| 74 |
self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
|
| 75 |
|
| 76 |
-
model_cfg = "sam2_hiera_l.yaml"
|
| 77 |
if sam_checkpoint_path is not None:
|
| 78 |
-
|
|
|
|
| 79 |
if flux_kontext_path is not None:
|
| 80 |
self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
|
| 81 |
|
|
|
|
| 73 |
def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
|
| 74 |
self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
|
| 75 |
|
|
|
|
| 76 |
if sam_checkpoint_path is not None:
|
| 77 |
+
model_cfg = sam_checkpoint_path[1]
|
| 78 |
+
self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path[0], device="cuda")
|
| 79 |
if flux_kontext_path is not None:
|
| 80 |
self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
|
| 81 |
|