Spaces:
Runtime error
Runtime error
Update sonic.py
Browse files
sonic.py
CHANGED
|
@@ -22,6 +22,11 @@ from src.dataset.face_align.align import AlignImage
|
|
| 22 |
|
| 23 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# ------------------------------------------------------------------
|
| 26 |
# single image + speech → video-tensor generator
|
| 27 |
# ------------------------------------------------------------------
|
|
@@ -29,32 +34,29 @@ def test(
|
|
| 29 |
pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
|
| 30 |
width, height, batch,
|
| 31 |
):
|
| 32 |
-
#
|
| 33 |
for k, v in batch.items():
|
| 34 |
if isinstance(v, torch.Tensor):
|
| 35 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
| 36 |
|
| 37 |
-
ref_img = batch["ref_img"]
|
| 38 |
clip_img = batch["clip_images"]
|
| 39 |
face_mask = batch["face_mask"]
|
| 40 |
-
image_embeds = image_encoder(clip_img).image_embeds
|
| 41 |
|
| 42 |
-
audio_feature = batch["audio_feature"]
|
| 43 |
-
audio_len = int(batch["audio_len"])
|
| 44 |
step = int(config.step)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
window = 16_000 # 1 초
|
| 48 |
audio_prompts, last_prompts = [], []
|
| 49 |
|
| 50 |
for i in range(0, audio_feature.shape[-1], window):
|
| 51 |
chunk = audio_feature[:, :, i : i + window]
|
| 52 |
-
|
| 53 |
layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
| 54 |
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
last_prompts.append(last) # (1,?,1,384)
|
| 58 |
|
| 59 |
if not audio_prompts:
|
| 60 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
|
@@ -62,17 +64,13 @@ def test(
|
|
| 62 |
audio_prompts = torch.cat(audio_prompts, dim=1)
|
| 63 |
last_prompts = torch.cat(last_prompts, dim=1)
|
| 64 |
|
| 65 |
-
# ---------- 모델 입력 규칙에 맞춰 padding -----------------------
|
| 66 |
audio_prompts = torch.cat(
|
| 67 |
-
[torch.zeros_like(audio_prompts[:, :4]),
|
| 68 |
-
audio_prompts,
|
| 69 |
torch.zeros_like(audio_prompts[:, :6])], dim=1)
|
| 70 |
last_prompts = torch.cat(
|
| 71 |
-
[torch.zeros_like(last_prompts[:, :24]),
|
| 72 |
-
last_prompts,
|
| 73 |
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
| 74 |
|
| 75 |
-
# ---------- 음성 길이에 따라 chunk 횟수 산정 ---------------------
|
| 76 |
total_tokens = audio_prompts.shape[1]
|
| 77 |
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
| 78 |
|
|
@@ -81,38 +79,35 @@ def test(
|
|
| 81 |
for i in tqdm(range(num_chunks)):
|
| 82 |
start = i * 2 * step
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
clip_raw = audio_prompts[:, start : start + 10]
|
| 86 |
-
|
| 87 |
-
# w-pad
|
| 88 |
-
if clip_raw.shape[1] < 10:
|
| 89 |
pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
|
| 90 |
clip_raw = torch.cat([clip_raw, pad_w], dim=1)
|
| 91 |
|
| 92 |
-
# ★ L-pad
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
|
| 97 |
-
|
| 98 |
-
cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
|
| 99 |
|
| 100 |
-
#
|
| 101 |
-
bucket_raw = last_prompts[:, start : start + 50]
|
| 102 |
if bucket_raw.shape[1] < 50:
|
| 103 |
pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
|
| 104 |
bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
|
| 105 |
-
|
| 106 |
-
bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
|
| 107 |
|
| 108 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
| 109 |
|
| 110 |
ref_list.append(ref_img[0])
|
| 111 |
-
|
| 112 |
-
|
|
|
|
| 113 |
motion_buckets.append(motion[0])
|
| 114 |
|
| 115 |
-
#
|
| 116 |
video = pipe(
|
| 117 |
ref_img, clip_img, face_mask,
|
| 118 |
audio_list, uncond_list, motion_buckets,
|
|
@@ -137,6 +132,9 @@ def test(
|
|
| 137 |
return video.to(pipe.device).unsqueeze(0).cpu()
|
| 138 |
|
| 139 |
|
|
|
|
|
|
|
|
|
|
| 140 |
# ------------------------------------------------------------------
|
| 141 |
# Sonic class
|
| 142 |
# ------------------------------------------------------------------
|
|
|
|
| 22 |
|
| 23 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
|
| 25 |
+
# ------------------------------------------------------------------
|
| 26 |
+
# single image + speech → video-tensor generator
|
| 27 |
+
# ------------------------------------------------------------------
|
| 28 |
+
# …(상단 import 및 기타 정의 동일)…
|
| 29 |
+
|
| 30 |
# ------------------------------------------------------------------
|
| 31 |
# single image + speech → video-tensor generator
|
| 32 |
# ------------------------------------------------------------------
|
|
|
|
| 34 |
pipe, config, wav_enc, audio_pe, audio2bucket, image_encoder,
|
| 35 |
width, height, batch,
|
| 36 |
):
|
| 37 |
+
# ---- 배치 차원 맞추기 -----------------------------------------
|
| 38 |
for k, v in batch.items():
|
| 39 |
if isinstance(v, torch.Tensor):
|
| 40 |
batch[k] = v.unsqueeze(0).to(pipe.device).float()
|
| 41 |
|
| 42 |
+
ref_img = batch["ref_img"]
|
| 43 |
clip_img = batch["clip_images"]
|
| 44 |
face_mask = batch["face_mask"]
|
| 45 |
+
image_embeds = image_encoder(clip_img).image_embeds
|
| 46 |
|
| 47 |
+
audio_feature = batch["audio_feature"]
|
| 48 |
+
audio_len = int(batch["audio_len"])
|
| 49 |
step = int(config.step)
|
| 50 |
|
| 51 |
+
window = 16_000 # 1 초
|
|
|
|
| 52 |
audio_prompts, last_prompts = [], []
|
| 53 |
|
| 54 |
for i in range(0, audio_feature.shape[-1], window):
|
| 55 |
chunk = audio_feature[:, :, i : i + window]
|
|
|
|
| 56 |
layers = wav_enc.encoder(chunk, output_hidden_states=True).hidden_states
|
| 57 |
last = wav_enc.encoder(chunk).last_hidden_state.unsqueeze(-2)
|
| 58 |
+
audio_prompts.append(torch.stack(layers, dim=2)) # (1,w,L,384)
|
| 59 |
+
last_prompts.append(last)
|
|
|
|
| 60 |
|
| 61 |
if not audio_prompts:
|
| 62 |
raise ValueError("[ERROR] No speech recognised in the provided audio.")
|
|
|
|
| 64 |
audio_prompts = torch.cat(audio_prompts, dim=1)
|
| 65 |
last_prompts = torch.cat(last_prompts, dim=1)
|
| 66 |
|
|
|
|
| 67 |
audio_prompts = torch.cat(
|
| 68 |
+
[torch.zeros_like(audio_prompts[:, :4]), audio_prompts,
|
|
|
|
| 69 |
torch.zeros_like(audio_prompts[:, :6])], dim=1)
|
| 70 |
last_prompts = torch.cat(
|
| 71 |
+
[torch.zeros_like(last_prompts[:, :24]), last_prompts,
|
|
|
|
| 72 |
torch.zeros_like(last_prompts[:, :26])], dim=1)
|
| 73 |
|
|
|
|
| 74 |
total_tokens = audio_prompts.shape[1]
|
| 75 |
num_chunks = max(1, math.ceil(total_tokens / (2 * step)))
|
| 76 |
|
|
|
|
| 79 |
for i in tqdm(range(num_chunks)):
|
| 80 |
start = i * 2 * step
|
| 81 |
|
| 82 |
+
# ------------ cond_clip : (1,1,10,5,384) ------------------
|
| 83 |
+
clip_raw = audio_prompts[:, start : start + 10] # (1,≤10,L,384)
|
| 84 |
+
if clip_raw.shape[1] < 10: # w-pad
|
|
|
|
|
|
|
| 85 |
pad_w = torch.zeros_like(clip_raw[:, :10 - clip_raw.shape[1]])
|
| 86 |
clip_raw = torch.cat([clip_raw, pad_w], dim=1)
|
| 87 |
|
| 88 |
+
# ★ L-pad → 정확히 5 레이어 만들기
|
| 89 |
+
while clip_raw.shape[2] < 5:
|
| 90 |
+
clip_raw = torch.cat([clip_raw, clip_raw[:, :, -1:]], dim=2)
|
| 91 |
+
clip_raw = clip_raw[:, :, :5] # (1,10,5,384)
|
| 92 |
|
| 93 |
+
cond_clip = clip_raw.unsqueeze(1) # (1,1,10,5,384)
|
|
|
|
| 94 |
|
| 95 |
+
# ------------ bucket_clip : (1,1,50,1,384) -----------------
|
| 96 |
+
bucket_raw = last_prompts[:, start : start + 50]
|
| 97 |
if bucket_raw.shape[1] < 50:
|
| 98 |
pad_w = torch.zeros_like(bucket_raw[:, :50 - bucket_raw.shape[1]])
|
| 99 |
bucket_raw = torch.cat([bucket_raw, pad_w], dim=1)
|
| 100 |
+
bucket_clip = bucket_raw.unsqueeze(1) # (1,1,50,1,384)
|
|
|
|
| 101 |
|
| 102 |
motion = audio2bucket(bucket_clip, image_embeds) * 16 + 16
|
| 103 |
|
| 104 |
ref_list.append(ref_img[0])
|
| 105 |
+
# ★ 여기: squeeze(0)만 (bz 제거). [0] 인덱싱 제거
|
| 106 |
+
audio_list.append(audio_pe(cond_clip).squeeze(0)) # (50,1024)
|
| 107 |
+
uncond_list.append(audio_pe(torch.zeros_like(cond_clip)).squeeze(0))
|
| 108 |
motion_buckets.append(motion[0])
|
| 109 |
|
| 110 |
+
# ---- Stable Video Diffusion 호출 ------------------------------
|
| 111 |
video = pipe(
|
| 112 |
ref_img, clip_img, face_mask,
|
| 113 |
audio_list, uncond_list, motion_buckets,
|
|
|
|
| 132 |
return video.to(pipe.device).unsqueeze(0).cpu()
|
| 133 |
|
| 134 |
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
# ------------------------------------------------------------------
|
| 139 |
# Sonic class
|
| 140 |
# ------------------------------------------------------------------
|