File size: 33,369 Bytes
400a879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c41e4a
400a879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c41e4a
 
1b0ed38
 
 
 
fd55666
1b0ed38
fd55666
 
 
 
 
1b0ed38
fd55666
 
 
 
 
 
 
 
 
400a879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c41e4a
400a879
 
 
 
 
 
 
 
 
 
 
 
6c41e4a
fd55666
6c41e4a
1b0ed38
 
 
 
 
fd55666
400a879
 
fd55666
400a879
 
 
 
 
 
 
 
 
 
fd55666
 
400a879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c41e4a
fd55666
6c41e4a
1b0ed38
 
 
 
 
fd55666
400a879
 
fd55666
400a879
 
 
 
 
 
 
 
 
 
fd55666
 
400a879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
import torch
import numpy as np
import cv2
import os
import librosa
import math

def calculate_frame_num_from_audio(audio_paths, fps=24, mode="pad"):
    """
    Calculate corresponding frame number based on audio file length
    
    Args:
        audio_paths (list): List of audio file paths
        fps (int): Video frame rate, default 24fps
        mode (str): Audio processing mode, "pad" or "concat". 
                   In "pad" mode, returns max duration.
                   In "concat" mode, returns sum of all durations.
        
    Returns:
        int: Calculated frame number, returns default 81 if audio file does not exist
    """
    if not audio_paths:
        raise ValueError("No audio files, cannot determine frame number")
    
    if mode == "concat":
        # Concat mode: sum all audio durations
        total_duration = 0
        for audio_path in audio_paths:
            if audio_path and os.path.exists(audio_path):
                try:
                    # Use librosa to get audio duration
                    duration = librosa.get_duration(filename=audio_path)
                    total_duration += duration
                    print(f"audio file {audio_path} duration: {duration:.2f} seconds")
                except Exception as e:
                    raise ValueError(f"Failed to read audio file {audio_path}: {e}")
        
        if total_duration > 0:
            # Calculate frame number, round up
            frame_num = int(math.ceil(total_duration * fps))
            # Ensure frame number is in 4n+1 format (model requirement)
            frame_num = ((frame_num - 1) // 4) * 4 + 1
            print(f"Calculated frame number (concat mode): {frame_num} based on total audio duration {total_duration:.2f}s and frame rate {fps}fps")
            return frame_num
        else:
            raise ValueError("No audio files, cannot determine frame number")
    else:
        # Pad mode: use max duration (original behavior)
        max_duration = 0
        for audio_path in audio_paths:
            if audio_path and os.path.exists(audio_path):
                try:
                    # Use librosa to get audio duration
                    duration = librosa.get_duration(filename=audio_path)
                    max_duration = max(max_duration, duration)
                    print(f"audio file {audio_path} duration: {duration:.2f} seconds")
                except Exception as e:
                    raise ValueError(f"Failed to read audio file {audio_path}: {e}")
        
        if max_duration > 0:
            # Calculate frame number, round up
            frame_num = int(math.ceil(max_duration * fps))
            # Ensure frame number is in 4n+1 format (model requirement)
            frame_num = ((frame_num - 1) // 4) * 4 + 1
            print(f"Calculated frame number (pad mode): {frame_num} based on max audio duration {max_duration:.2f}s and frame rate {fps}fps")
            return frame_num
        else:
            raise ValueError("No audio files, cannot determine frame number")

# 计算模型的参数量
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    total_params_in_millions = total_params / 1e6  # Convert to millions
    return total_params_in_millions

# 构建空条件的audio_ref_features - 适配多人情况
def create_null_audio_ref_features(audio_ref_features):
    null_features = {}
    # 处理ref_face_list - 多人情况
    if 'ref_face_list' in audio_ref_features and audio_ref_features['ref_face_list']:
        null_ref_face_list = []
        for ref_face in audio_ref_features['ref_face_list']:
            if ref_face is not None:
                null_ref_face_list.append(ref_face.clone().detach())
            else:
                null_ref_face_list.append(None)
        null_features['ref_face_list'] = null_ref_face_list
    else:
        null_features['ref_face_list'] = []
    
    # 处理audio_list - 多人情况
    if 'audio_list' in audio_ref_features and audio_ref_features['audio_list']:
        null_audio_list = []
        for audio in audio_ref_features['audio_list']:
            if audio is not None:
                null_audio_list.append(torch.zeros_like(audio))
            else:
                null_audio_list.append(None)
        null_features['audio_list'] = null_audio_list
    else:
        null_features['audio_list'] = []
    
    return null_features

def process_audio_features(
    audio_paths=None,
    audio=None,
    mode="pad",
    F=None,
    frame_num=None,
    task_key=None,
    fps=None,
    wav2vec_model=None,
    vocal_separator_model=None,
    audio_output_dir=None,
    device=None,
    use_half=False,
    half_dtype=None,
    preprocess_audio=None,
    resample_audio=None,
    trim_to_4s=False,  # Fast mode: trim audio to 4 seconds
):
    """
    Process audio files and extract audio features.
    
    Args:
        audio_paths (list): List of audio file paths (new format, supports multiple audio files)
        audio (str): Single audio file path (legacy format)
        mode (str): Audio processing mode, "pad" or "concat"
        F (int): Target frame number (already calculated outside)
        frame_num (int): Frame number for cache file naming (legacy)
        task_key (str): Task key for cache file naming
        fps (int): Frames per second
        wav2vec_model (str): Path to wav2vec model
        vocal_separator_model (str): Path to vocal separator model
        audio_output_dir (str): Directory for audio output
        device: Device to use for processing
        use_half (bool): Whether to use half precision
        half_dtype: Half precision dtype (torch.float16 or torch.float32)
        preprocess_audio: Function to preprocess audio
        resample_audio: Function to resample audio
        
    Returns:
        list: List of audio feature tensors
    """
    from .audio_utils import preprocess_audio as _preprocess_audio, resample_audio as _resample_audio
    
    # Use provided functions or import from audio_utils
    if preprocess_audio is None:
        preprocess_audio = _preprocess_audio
    if resample_audio is None:
        resample_audio = _resample_audio
    
    audio_feat_list = []
    
    if audio_paths and len(audio_paths) > 0:
        print(f"Processing {len(audio_paths)} audio files in {mode} mode: {audio_paths}")
        cache_dir = os.path.join(audio_output_dir, "audio_preprocess")
        os.makedirs(cache_dir, exist_ok=True)

        if mode == "concat":
            # Concat mode: record each audio's length, calculate total length, 
            # and pad each audio with zeros in non-speaker segments
            audio_lengths = []  # Store actual length of each audio in frames
            raw_audio_feat_list = []  # Store raw audio features before padding
            
            # First pass: process all audios and record their actual lengths
            for i, audio_path in enumerate(audio_paths):
                if audio_path and os.path.exists(audio_path):
                    print(f"Processing audio {i} (first pass): {audio_path}")
                    target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio_path).split('.')[0]}-{task_key}_16k_concat.wav")
                    if not os.path.exists(target_resampled_audio_path):
                        resample_audio(
                            audio_path,
                            target_resampled_audio_path,
                        )
                    with torch.no_grad():
                        # Process audio without padding to get actual length
                        audio_emb, audio_length = preprocess_audio(
                            wav_path=target_resampled_audio_path,
                            num_generated_frames_per_clip=-1,  # -1 means no padding
                            fps=fps,
                            wav2vec_model=wav2vec_model,
                            vocal_separator_model=vocal_separator_model,
                            cache_dir=cache_dir,
                            device=device,
                        )
                        # If half precision is enabled, use float16; otherwise use bfloat16
                        audio_dtype = half_dtype if use_half else torch.bfloat16
                        audio_emb = audio_emb.to(device, dtype=audio_dtype)
                    
                    # Get actual frame length (audio_length is in frames)
                    actual_frame_length = audio_emb.shape[0]
                    audio_lengths.append(actual_frame_length)
                    raw_audio_feat_list.append(audio_emb)
                    print(f"Audio {i} actual length: {actual_frame_length} frames, shape: {audio_emb.shape}")
                else:
                    print(f"Warning: Audio {i} path is empty or file not found: {audio_path}")
                    audio_lengths.append(0)
                    raw_audio_feat_list.append(None)
            
            # Calculate total length from actual processed frames
            total_length = sum(audio_lengths)
            print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
            
            # Fast mode: trim to 4 seconds if trim_to_4s is True
            if trim_to_4s:
                # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
                max_frames_4s = 97
                if total_length > max_frames_4s:
                    print(f"Fast mode: Trimming audio from {total_length} frames to {max_frames_4s} frames (4 seconds)")
                    # Truncate each audio proportionally
                    scale_factor = max_frames_4s / total_length
                    cumulative_length = 0
                    for i, audio_len in enumerate(audio_lengths):
                        if audio_len > 0:
                            new_audio_len = int(audio_len * scale_factor)
                            # Ensure it fits within remaining space
                            remaining_space = max_frames_4s - cumulative_length
                            new_audio_len = min(new_audio_len, remaining_space)
                            audio_lengths[i] = new_audio_len
                            # Truncate the corresponding raw audio feature
                            if raw_audio_feat_list[i] is not None:
                                raw_audio_feat_list[i] = raw_audio_feat_list[i][:new_audio_len]
                            cumulative_length += new_audio_len
                    total_length = sum(audio_lengths)
                    print(f"After trimming: total_length = {total_length} frames")
            
            # Ensure total length is in 4n+1 format (model requirement)
            total_length = ((total_length - 1) // 4) * 4 + 1
            print(f"Adjusted total length to 4n+1 format: {total_length} frames")
            
            # Note: F was already calculated outside and passed as parameter
            # We should not update F here because it has been used to create other tensors (noise, mask, etc.)
            # If there's a mismatch, it means the calculation outside was inaccurate, but we'll use F as is
            if total_length > F:
                print(f"Warning: Actual processed frames ({total_length}) > pre-calculated F ({F}). Using F={F} to maintain consistency with other tensors.")
            elif total_length < F:
                print(f"Info: Actual processed frames ({total_length}) < pre-calculated F ({F}). Using F={F}.")
            else:
                print(f"Info: Actual processed frames ({total_length}) matches pre-calculated F={F}.")
            
            # Second pass: create padded audio features for each audio
            # Each audio is placed in its corresponding time segment, with zeros elsewhere
            cumulative_length = 0
            reference_feat_shape = None
            
            # First, find a reference feature shape from valid audio
            for raw_audio_feat in raw_audio_feat_list:
                if raw_audio_feat is not None:
                    reference_feat_shape = raw_audio_feat.shape[1:]  # Get shape without frame dimension
                    break
            
            if reference_feat_shape is None:
                raise ValueError("No valid audio files found in concat mode")
            
            for i, (raw_audio_feat, audio_len) in enumerate(zip(raw_audio_feat_list, audio_lengths)):
                if raw_audio_feat is not None and audio_len > 0:
                    # Create zero tensor with total length and same feature shape
                    padded_audio_feat = torch.zeros(
                        (F,) + reference_feat_shape,
                        dtype=raw_audio_feat.dtype,
                        device=raw_audio_feat.device
                    )
                    
                    # Place audio data in its corresponding time segment
                    end_pos = min(cumulative_length + audio_len, F)
                    actual_audio_len = end_pos - cumulative_length
                    padded_audio_feat[cumulative_length:end_pos] = raw_audio_feat[:actual_audio_len]
                    
                    audio_feat_list.append(padded_audio_feat)
                    print(f"Audio {i} padded: placed at frames [{cumulative_length}:{end_pos}], shape: {padded_audio_feat.shape}")
                    cumulative_length += audio_len
                else:
                    # Create zero features for missing audio with total length
                    zero_audio_feat = torch.zeros(
                        (F,) + reference_feat_shape,
                        dtype=torch.bfloat16 if not use_half else half_dtype,
                        device=device
                    )
                    audio_feat_list.append(zero_audio_feat)
                    print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
        else:
            # Pad mode: keep existing logic, but apply trim_to_4s if needed
            for i, audio_path in enumerate(audio_paths):
                if audio_path and os.path.exists(audio_path):
                    print(f"Processing audio {i}: {audio_path}")
                    target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio_path).split('.')[0]}-{task_key}_16k_{F}.wav")
                    if not os.path.exists(target_resampled_audio_path):
                        resample_audio(
                            audio_path,
                            target_resampled_audio_path,
                        )
                    with torch.no_grad():
                        print(f"wav2vec_model: {wav2vec_model}")
                        print(f"cache_dir:{cache_dir}")
                        # Fast mode: if trim_to_4s, limit to 4 seconds
                        target_frames = F
                        if trim_to_4s:
                            # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
                            max_frames_4s = 97
                            target_frames = min(F, max_frames_4s)
                            if F > max_frames_4s:
                                print(f"Fast mode: Trimming audio {i} from {F} frames to {max_frames_4s} frames (4 seconds)")
                        # Use dynamically determined frame number
                        audio_emb, audio_length = preprocess_audio(
                            wav_path=target_resampled_audio_path,
                            num_generated_frames_per_clip=target_frames,  # Use target frames (may be trimmed)
                            fps=fps,
                            wav2vec_model=wav2vec_model,
                            vocal_separator_model=vocal_separator_model,
                            cache_dir=cache_dir,
                            device=device,
                        )
                        # If half precision is enabled, use float16; otherwise use bfloat16
                        audio_dtype = half_dtype if use_half else torch.bfloat16
                        audio_emb = audio_emb.to(device, dtype=audio_dtype)
                    
                    # Ensure we don't exceed F frames (for consistency with other tensors)
                    audio_feat = audio_emb[:F]  # Use F to maintain consistency
                    audio_feat_list.append(audio_feat)
                    print(f"Audio {i} processed, shape: {audio_feat.shape}")
                else:
                    print(f"Warning: Audio {i} path is empty or file not found: {audio_path}")
                    # Create zero features for missing audio
                    if len(audio_feat_list) > 0:
                        # Use first audio's shape to create zero features
                        zero_audio_feat = torch.zeros_like(audio_feat_list[0])
                        audio_feat_list.append(zero_audio_feat)
                    else:
                        print(f"Error: No valid audio files found, cannot create zero features")
    else:
        # Compatible with old format: use single audio parameter
        if audio is not None:
            print(f"Processing single audio (legacy format): {audio}")
            cache_dir = os.path.join(audio_output_dir, "audio_preprocess")
            os.makedirs(cache_dir, exist_ok=True)

            target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio).split('.')[0]}-16k.wav")
            if not os.path.exists(target_resampled_audio_path):
                audio = resample_audio(
                    audio,
                    target_resampled_audio_path,
                )
            with torch.no_grad():
                # Fast mode: if trim_to_4s, limit to 4 seconds
                target_frames = F
                if trim_to_4s:
                    # 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
                    max_frames_4s = 97
                    target_frames = min(F, max_frames_4s)
                    if F > max_frames_4s:
                        print(f"Fast mode: Trimming single audio from {F} frames to {max_frames_4s} frames (4 seconds)")
                # Use dynamically determined frame number
                audio_emb, audio_length = preprocess_audio(
                    wav_path=audio,
                    num_generated_frames_per_clip=target_frames,  # Use target frames (may be trimmed)
                    fps=fps,
                    wav2vec_model=wav2vec_model,
                    vocal_separator_model=vocal_separator_model,
                    cache_dir=cache_dir,
                    device=device,
                )
                # If half precision is enabled, use float16; otherwise use bfloat16
                audio_dtype = half_dtype if use_half else torch.bfloat16
                audio_emb = audio_emb.to(device, dtype=audio_dtype)
            
            # Ensure we don't exceed F frames (for consistency with other tensors)
            audio_feat = audio_emb[:F]  # Use F to maintain consistency
            audio_feat_list.append(audio_feat)
            print(f"Single audio processed, shape: {audio_feat.shape}")
        else:
            print("No audio files provided")
    
    return audio_feat_list

@torch.cuda.amp.autocast(dtype=torch.float32)
def optimized_scale(positive_flat, negative_flat):
    # Calculate dot production
    positive_norm = torch.norm(positive_flat, dim=-1, keepdim=True)
    negative_norm = torch.norm(negative_flat, dim=-1, keepdim=True)
    
    # Calculate cosine similarity
    cosine_sim = torch.sum(positive_flat * negative_flat, dim=-1, keepdim=True) / (positive_norm * negative_norm + 1e-8)
    
    # Calculate scale factor
    scale = (positive_norm / (negative_norm + 1e-8)) * cosine_sim
    
    return scale


def expand_face_mask_flexible(face_mask, width_scale_factor, height_scale_factor):
    """
    将face_mask中值为1的区域按指定的宽度和高度倍数独立扩大
    
    Args:
        face_mask: tensor, shape: [H, W],原始的face mask
        width_scale_factor: float, 宽度扩大倍数
        height_scale_factor: float, 高度扩大倍数
        
    Returns:
        tensor: shape: [H, W],扩大后的face mask
    """
    if width_scale_factor == 1.0 and height_scale_factor == 1.0:
        return face_mask
        
    # 找到mask中非零区域的边界框
    mask_indices = torch.nonzero(face_mask > 0.5)
    if mask_indices.numel() == 0:
        return face_mask
        
    # 计算当前mask的边界框
    min_h, min_w = mask_indices.min(dim=0)[0]
    max_h, max_w = mask_indices.max(dim=0)[0]
    
    # 计算中心点
    center_h = (min_h + max_h) / 2.0
    center_w = (min_w + max_w) / 2.0
    
    # 计算当前bbox的尺寸
    current_h = max_h - min_h + 1
    current_w = max_w - min_w + 1
    
    # 计算扩大后的尺寸,宽度和高度独立缩放
    new_h = int(current_h * height_scale_factor)
    new_w = int(current_w * width_scale_factor)
    
    # 计算新的边界框(居中扩大)
    new_min_h = int(center_h - new_h / 2.0)
    new_max_h = int(center_h + new_h / 2.0)
    new_min_w = int(center_w - new_w / 2.0)
    new_max_w = int(center_w + new_w / 2.0)
    
    # 确保新边界框不超出原图像范围
    H, W = face_mask.shape
    new_min_h = max(0, new_min_h)
    new_max_h = min(H - 1, new_max_h)
    new_min_w = max(0, new_min_w)
    new_max_w = min(W - 1, new_max_w)
    
    # 创建新的mask
    expanded_mask = torch.zeros_like(face_mask)
    
    # 将原始mask区域调整到新的边界框
    if new_max_h > new_min_h and new_max_w > new_min_w:
        # 提取原始mask的内容
        original_content = face_mask[min_h:max_h+1, min_w:max_w+1]
        
        # 将原始内容缩放到新的尺寸
        target_h = new_max_h - new_min_h + 1
        target_w = new_max_w - new_min_w + 1
        
        if target_h > 0 and target_w > 0:
            scaled_content = torch.nn.functional.interpolate(
                original_content.unsqueeze(0).unsqueeze(0),
                size=(target_h, target_w),
                mode='bilinear',
                align_corners=False
            ).squeeze(0).squeeze(0)
            
            # 将缩放后的内容放置到新位置
            expanded_mask[new_min_h:new_max_h+1, new_min_w:new_max_w+1] = scaled_content
    
    return expanded_mask


def gen_inference_masks(masks, img_shape, num_frames=None):
    """
    为推理生成与训练时相同格式的mask
    注意:推理时的mask是按整个图片标记的,不需要切割50%的逻辑
    为了适配训练格式,需要添加batch维度和帧维度 [H, W] -> [1, F, H, W]
    
    Args:
        masks: list of tensors, 人脸检测模型生成的mask列表,每个mask都是[H, W]格式
        img_shape: tuple, 图像形状 (H, W)
        num_frames: int, 视频帧数
        
    Returns:
        dict: 包含face_mask_list的字典,human_mask_list设为None
    """
    H, W = img_shape
    F = num_frames if num_frames is not None else 1
    num_faces = len(masks)
    
    print(f"gen_inference_masks: 处理{num_faces}个人脸,图像尺寸{H}x{W},帧数{F}")
    
    with torch.no_grad():
        face_mask_list = []
        
        # 为每个人脸生成多帧mask
        for i, mask in enumerate(masks):
            # 创建多帧mask:所有帧都使用face_mask
            face_mask_multi = mask.unsqueeze(0).unsqueeze(0).repeat(1, 1, F, 1, 1)  # [B, C, F, H, W]
            face_mask_list.append(face_mask_multi)
        
        # 构建concat mask - 将所有mask在宽度方向拼接
        if num_faces > 1:
            face_mask_concat = torch.cat(face_mask_list, dim=4)  # [B, C, F, H, num_faces*W]
        else:
            face_mask_concat = face_mask_list[0]
        
        return {
            "face_mask_list": face_mask_list,
            "human_mask_list": None,  # 不再使用human mask
            "face_mask_concat": face_mask_concat,
            "num_faces": num_faces
        }


def expand_bbox_and_crop_image(img, bbox, width_scale_factor, height_scale_factor):
    """
    将bbox按scale_factor放大并从图像中安全切割对应区域
    
    Args:
        img: tensor, shape: [C, H, W], 输入图像 (值域为-1到1)
        bbox: list or tuple, [x1, y1, x2, y2], bbox坐标
        width_scale_factor: float, 宽度放大倍数
        height_scale_factor: float, 高度放大倍数
        
    Returns:
        tuple: (cropped_image, new_bbox)
        - cropped_image: tensor, shape: [C, new_h, new_w], 切割后的图像
        - new_bbox: list, [new_x1, new_y1, new_x2, new_y2], 调整后的bbox坐标
    """
    # 获取原始bbox坐标
    x1, y1, x2, y2 = bbox
    
    # 获取图像尺寸
    _, img_h, img_w = img.shape
    
    # 计算bbox的中心点和原始尺寸
    center_x = (x1 + x2) / 2.0
    center_y = (y1 + y2) / 2.0
    original_w = x2 - x1
    original_h = y2 - y1
    
    # 计算放大后的尺寸
    new_w = original_w * width_scale_factor
    new_h = original_h * height_scale_factor
    
    # 计算放大后的bbox坐标(以中心点为准)
    new_x1 = center_x - new_w / 2.0
    new_y1 = center_y - new_h / 2.0
    new_x2 = center_x + new_w / 2.0
    new_y2 = center_y + new_h / 2.0
    
    # 确保bbox不超出图像边界,同时保持最小尺寸
    new_x1 = max(0, new_x1)
    new_y1 = max(0, new_y1)
    new_x2 = min(img_w, new_x2)
    new_y2 = min(img_h, new_y2)
    
    # 确保切割后的尺寸至少为1像素
    if new_x2 <= new_x1:
        # 如果宽度为0或负数,调整为最小可用宽度
        if center_x < img_w / 2:
            new_x1 = max(0, int(center_x) - 1)
            new_x2 = min(img_w, new_x1 + max(1, int(original_w)))
        else:
            new_x2 = min(img_w, int(center_x) + 1)
            new_x1 = max(0, new_x2 - max(1, int(original_w)))
    
    if new_y2 <= new_y1:
        # 如果高度为0或负数,调整为最小可用高度
        if center_y < img_h / 2:
            new_y1 = max(0, int(center_y) - 1)
            new_y2 = min(img_h, new_y1 + max(1, int(original_h)))
        else:
            new_y2 = min(img_h, int(center_y) + 1)
            new_y1 = max(0, new_y2 - max(1, int(original_h)))
    
    # 转换为整数坐标
    new_x1, new_y1, new_x2, new_y2 = int(new_x1), int(new_y1), int(new_x2), int(new_y2)
    
    # 最终检查,确保坐标有效
    assert new_x2 > new_x1 and new_y2 > new_y1, f"Invalid bbox after adjustment: [{new_x1}, {new_y1}, {new_x2}, {new_y2}]"
    
    # 从原图切割放大后的区域
    cropped_image = img[:, new_y1:new_y2, new_x1:new_x2]
    
    return cropped_image, [new_x1, new_y1, new_x2, new_y2]


def gen_smooth_transition_mask_for_dit(face_mask, lat_h, lat_w, F, device, mask_dtype, target_translate=(0, 0), target_scale=1.0):
    """
    Generate smooth transition mask based on face_mask and latent shape for DIT mask
    First frame is all white (all 1s), subsequent frames gradually transition from original position to target position and scale
    
    Args:
        face_mask: tensor, shape: [H, W]
        lat_h: int, latent height
        lat_w: int, latent width
        F: int, number of frames in original video
        device: torch.device, device to create tensors on
        mask_dtype: torch.dtype, dtype for mask tensors
        target_translate: tuple, (x, y) target translation amount
        target_scale: float, target scale ratio
        
    Returns:
        tensor: shape: [4, F, lat_h, lat_w], mask for DIT
    """
    
    # Resize face_mask to latent size
    face_mask_resized = torch.nn.functional.interpolate(
        face_mask.unsqueeze(0).unsqueeze(0),  # [1, 1, H, W]
        size=(lat_h, lat_w),
        mode='bilinear',
        align_corners=False
    ).squeeze(0).squeeze(0)  # [lat_h, lat_w]
    
    # Create mask, first frame all white (all 1s), remaining frames gradually transition
    msk = torch.zeros(1, F, lat_h, lat_w, device=device, dtype=mask_dtype)
    msk[:, 0:1] = 1.0  # First frame all white
    
    if F > 1:
        # Generate different transformation parameters for each frame to achieve smooth transition
        for frame_idx in range(1, F):
            # Calculate transition progress for current frame (0 to 1)
            progress = (frame_idx - 1) / (F - 2) if F > 2 else 1.0
            
            # Use linear transition for more uniform changes
            # progress is already linear, use directly
            
            # Translation and scale for current frame (only horizontal translation allowed)
            current_translate = (
                0,  # Vertical direction always 0, no vertical movement allowed
                int(target_translate[1] * progress)  # Only use horizontal translation
            )
            current_scale = 1.0 + (target_scale - 1.0) * progress
            
            # Generate mask for current frame
            if current_scale != 1.0:
                # Calculate scaled size
                scaled_h = int(lat_h * current_scale)
                scaled_w = int(lat_w * current_scale)
                
                # Scale mask
                scaled_mask = torch.nn.functional.interpolate(
                    face_mask_resized.unsqueeze(0).unsqueeze(0),  # [1, 1, lat_h, lat_w]
                    size=(scaled_h, scaled_w),
                    mode='bilinear',
                    align_corners=False
                ).squeeze(0).squeeze(0)  # [scaled_h, scaled_w]
                
                # Create zero mask of target size
                transformed_mask = torch.zeros(lat_h, lat_w, device=device, dtype=mask_dtype)
                
                # Calculate placement position (centered)
                start_h = max(0, (lat_h - scaled_h) // 2)
                start_w = max(0, (lat_w - scaled_w) // 2)
                end_h = min(lat_h, start_h + scaled_h)
                end_w = min(lat_w, start_w + scaled_w)
                
                # Calculate crop range in scaled_mask
                src_start_h = max(0, (scaled_h - lat_h) // 2)
                src_start_w = max(0, (scaled_w - lat_w) // 2)
                src_end_h = src_start_h + (end_h - start_h)
                src_end_w = src_start_w + (end_w - start_w)
                
                # Place scaled mask to target position
                transformed_mask[start_h:end_h, start_w:end_w] = scaled_mask[src_start_h:src_end_h, src_start_w:src_end_w]
            else:
                transformed_mask = face_mask_resized.clone().to(dtype=mask_dtype)
            
            # Apply horizontal translation, stop when touching boundary
            translate_w = current_translate[1]  # Only take horizontal translation
            if translate_w != 0:
                # Find horizontal boundaries of mask
                mask_indices = torch.nonzero(transformed_mask > 0.5)
                if mask_indices.numel() > 0:
                    mask_min_w = mask_indices[:, 1].min().item()
                    mask_max_w = mask_indices[:, 1].max().item()
                    
                    # Calculate actual available horizontal translation amount
                    if translate_w < 0:
                        # When moving left, check left boundary
                        max_translate_w = -mask_min_w
                        actual_translate_w = max(translate_w, max_translate_w)
                    else:
                        # When moving right, check right boundary
                        max_translate_w = lat_w - 1 - mask_max_w
                        actual_translate_w = min(translate_w, max_translate_w)
                    
                    # If there is valid translation amount, execute translation
                    if actual_translate_w != 0:
                        # Use torch.roll for horizontal translation, but ensure not exceeding boundary
                        if abs(actual_translate_w) <= min(mask_min_w, lat_w - 1 - mask_max_w):
                            # Only use roll within safe range
                            transformed_mask = torch.roll(transformed_mask, shifts=actual_translate_w, dims=1)
                        else:
                            # Manually copy to avoid wrapping
                            new_mask = torch.zeros_like(transformed_mask, dtype=mask_dtype)
                            if actual_translate_w > 0:
                                # Move right
                                new_mask[:, actual_translate_w:] = transformed_mask[:, :-actual_translate_w]
                            else:
                                # Move left
                                new_mask[:, :actual_translate_w] = transformed_mask[:, -actual_translate_w:]
                            transformed_mask = new_mask
            
            # Assign mask for current frame
            msk[:, frame_idx:frame_idx+1] = transformed_mask.unsqueeze(0).unsqueeze(0)
    
    # Reference encode_image_vae processing method, convert mask to format required by DIT
    msk = torch.concat([
        torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
    ], dim=1)
    
    msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
    msk = msk.transpose(1, 2)[0]  # shape: [4, F, lat_h, lat_w]
    return msk