File size: 33,576 Bytes
9c32fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcbdf35
9c32fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcbdf35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8a71e
75c4c3f
cc800d1
 
 
9d8a71e
 
cc800d1
9d8a71e
 
cc800d1
 
 
 
 
 
 
 
 
 
 
9b74663
cc800d1
9b74663
cc800d1
9b74663
 
 
cc800d1
 
 
 
 
9b74663
cc800d1
 
 
 
 
 
 
 
 
 
 
 
9b74663
cc800d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92516fd
cc800d1
 
92516fd
 
 
a615bff
92516fd
 
 
a615bff
 
92516fd
a615bff
 
 
92516fd
cc800d1
92516fd
 
 
 
 
 
 
 
cc800d1
92516fd
 
 
 
 
 
 
 
 
d1b769c
 
 
 
92516fd
 
 
9b74663
92516fd
 
 
 
d1b769c
92516fd
 
d1b769c
 
92516fd
d1b769c
92516fd
 
 
 
 
 
 
 
cc800d1
 
 
 
 
 
 
 
 
 
 
9c32fea
 
 
 
 
73f86b7
9c32fea
 
 
 
 
 
 
 
 
75c4c3f
127cda9
9c32fea
 
 
 
 
 
dcbdf35
 
 
73f86b7
 
dcbdf35
73f86b7
dcbdf35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75c4c3f
 
127cda9
 
 
 
75c4c3f
127cda9
 
75c4c3f
 
 
9c32fea
 
 
 
 
75c4c3f
dcbdf35
 
 
 
 
3b38a35
 
 
dcbdf35
 
 
 
 
 
 
 
 
cc800d1
 
dcbdf35
 
 
 
 
3b38a35
 
73f86b7
 
 
 
 
75c4c3f
 
73f86b7
dcbdf35
127cda9
dcbdf35
 
 
 
 
 
73f86b7
dcbdf35
 
73f86b7
cc800d1
dcbdf35
cc800d1
 
 
 
 
 
75c4c3f
cc800d1
75c4c3f
 
 
d1b769c
 
 
 
cc800d1
 
 
 
 
 
 
 
9c32fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b38a35
 
9c32fea
3b38a35
 
 
 
 
 
 
 
 
 
 
 
 
 
9c32fea
 
ccd3f47
dcbdf35
9c32fea
 
 
 
 
 
 
 
 
 
 
 
 
 
cc800d1
 
5a33bb3
9c32fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b38a35
 
9c32fea
 
 
 
 
 
 
 
 
3b38a35
9c32fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcbdf35
9c32fea
3b38a35
9c32fea
cc800d1
9c32fea
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
import os
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor

# Enable fast downloads
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["HF_XET_HIGH_PERFORMANCE"] = "1"

# Disable torch.compile / dynamo before any torch import
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"

# Install xformers for memory-efficient attention
subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)

# Clone LTX-2 repo at a pinned compatible commit and install packages
LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
LTX_COMMIT = "ae855f8538843825f9015a419cf4ba5edaf5eec2"

if os.path.exists(LTX_REPO_DIR):
    print(f"Removing existing repo at {LTX_REPO_DIR}...")
    subprocess.run(["rm", "-rf", LTX_REPO_DIR], check=True)

print(f"Cloning {LTX_REPO_URL}...")
subprocess.run(["git", "clone", LTX_REPO_URL, LTX_REPO_DIR], check=True)

print(f"Checking out commit {LTX_COMMIT}...")
subprocess.run(["git", "-C", LTX_REPO_DIR, "checkout", LTX_COMMIT], check=True)

print("Installing ltx-core and ltx-pipelines from pinned repo commit...")
subprocess.run(
    [
        sys.executable, "-m", "pip", "install",
        "--force-reinstall", "--no-deps",
        "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
        "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines"),
    ],
    check=True,
)

sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))

import logging
import random
import tempfile
from pathlib import Path

import torch
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True

# Critical workaround: Replace inference_mode with no_grad
# Avoids "inference tensor" failures in spatial upsampler and VAE decoder
torch.inference_mode = torch.no_grad

import spaces
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download, snapshot_download

from ltx_pipelines.distilled import DistilledPipeline
from ltx_pipelines.utils.args import ImageConditioningInput
from ltx_pipelines.utils.media_io import encode_video

# Patch attention backend into the LTX attention module.
import torch.nn.functional as F
from ltx_core.model.transformer import attention as _attn_mod

def _sdpa_as_mea(query, key, value, attn_bias=None, scale=None, **kwargs):
    # xformers memory_efficient_attention: (B, S, H, D) -> (B, S, H, D)
    # torch SDPA:                          (B, H, S, D) -> (B, H, S, D)
    q, k, v = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
    return F.scaled_dot_product_attention(q, k, v, scale=scale).transpose(1, 2)

_cap = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
_use_xformers = False
if _cap < (12, 0):
    try:
        from xformers.ops import memory_efficient_attention as _mea
        _attn_mod.memory_efficient_attention = _mea
        _use_xformers = True
        print(f"[ATTN] Using xformers memory_efficient_attention")
    except Exception as e:
        print(f"[ATTN] xformers unavailable ({e}), falling back to SDPA")

if not _use_xformers:
    _attn_mod.memory_efficient_attention = _sdpa_as_mea
    print(f"[ATTN] Using SDPA fallback (sm_{_cap[0]}{_cap[1]})")

logging.getLogger().setLevel(logging.INFO)

MAX_SEED = np.iinfo(np.int32).max
DEFAULT_PROMPT = (
    "An astronaut hatches from a fragile egg on the surface of the Moon, "
    "the shell cracking and peeling apart in gentle low-gravity motion. "
    "Fine lunar dust lifts and drifts outward with each movement, floating "
    "in slow arcs before settling back onto the ground."
)
DEFAULT_FRAME_RATE = 24.0

# Resolution presets: (width, height)
RESOLUTIONS = {
    "high": {"16:9": (1024, 640), "9:16": (640, 1024), "1:1": (1024, 1024)},
    "low": {"16:9": (512, 320), "9:16": (320, 512), "1:1": (512, 512)},
}


# Model repos
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"

# Download model checkpoints in parallel for speed
print("=" * 80)
print("Downloading Element-8 (pre-distilled LTX) + Gemma (parallel)...")
print("=" * 80)

def download_checkpoint():
    # Use pre-distilled LTX checkpoint - no LoRA needed
    return hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")

def download_upsampler():
    return hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")

def download_gemma():
    return snapshot_download(repo_id=GEMMA_REPO)

with ThreadPoolExecutor(max_workers=3) as executor:
    future_checkpoint = executor.submit(download_checkpoint)
    future_upsampler = executor.submit(download_upsampler)
    future_gemma = executor.submit(download_gemma)

    checkpoint_path = future_checkpoint.result()
    spatial_upsampler_path = future_upsampler.result()
    gemma_root = future_gemma.result()

print(f"Checkpoint: {checkpoint_path}")
print(f"Spatial upsampler: {spatial_upsampler_path}")
print(f"Gemma root: {gemma_root}")

# Initialize pipeline with pre-distilled checkpoint (no LoRA needed)
pipeline = DistilledPipeline(
    distilled_checkpoint_path=checkpoint_path,
    spatial_upsampler_path=spatial_upsampler_path,
    gemma_root=gemma_root,
    loras=(),
)

# Preload all models for ZeroGPU tensor packing
print("Preloading all pipeline components via model_ledger...")

# DistilledPipeline uses model_ledger similar to other pipelines
ledger = pipeline.model_ledger
_transformer = ledger.transformer()
_video_encoder = ledger.video_encoder()
_video_decoder = ledger.video_decoder()
_spatial_upsampler = ledger.spatial_upsampler()
_text_encoder = ledger.text_encoder()
_embeddings_processor = ledger.gemma_embeddings_processor()
_audio_encoder = ledger.audio_encoder()
_audio_decoder = ledger.audio_decoder()
_vocoder = ledger.vocoder()

# Replace ledger methods with lambdas returning preloaded instances
ledger.transformer = lambda: _transformer
ledger.video_encoder = lambda: _video_encoder
ledger.video_decoder = lambda: _video_decoder
ledger.spatial_upsampler = lambda: _spatial_upsampler
ledger.text_encoder = lambda: _text_encoder
ledger.gemma_embeddings_processor = lambda: _embeddings_processor
ledger.audio_encoder = lambda: _audio_encoder
ledger.audio_decoder = lambda: _audio_decoder
ledger.vocoder = lambda: _vocoder

print("All models preloaded!")

print("=" * 80)
print("Pipeline ready!")
print("=" * 80)


def log_memory(tag: str):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        peak = torch.cuda.max_memory_allocated() / 1024**3
        free, total = torch.cuda.mem_get_info()
        print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")


def detect_aspect_ratio(image) -> str:
    if image is None:
        return "16:9"
    if hasattr(image, "size"):
        w, h = image.size
    elif hasattr(image, "shape"):
        h, w = image.shape[:2]
    else:
        return "16:9"
    ratio = w / h
    candidates = {"16:9": 16 / 9, "9:16": 9 / 16, "1:1": 1.0}
    return min(candidates, key=lambda k: abs(ratio - candidates[k]))


def on_image_upload(first_image, last_image, high_res):
    ref_image = first_image if first_image is not None else last_image
    aspect = detect_aspect_ratio(ref_image)
    tier = "high" if high_res else "low"
    w, h = RESOLUTIONS[tier][aspect]
    return gr.update(value=w), gr.update(value=h)


def on_highres_toggle(first_image, last_image, high_res):
    ref_image = first_image if first_image is not None else last_image
    aspect = detect_aspect_ratio(ref_image)
    tier = "high" if high_res else "low"
    w, h = RESOLUTIONS[tier][aspect]
    return gr.update(value=w), gr.update(value=h)


DEFAULT_NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走, blurry, glasses, deformed, subtitles, text, captions, worst quality, low quality, inconsistent motion, jittery, distorted"


def remove_music_demucs(input_video_path: str, output_video_path: str) -> bool:
    """Remove background music from video using Demucs, keeping only vocals."""
    import subprocess
    import tempfile
    from pathlib import Path

    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            tmpdir = Path(tmpdir)

            # Extract audio from video
            audio_in = tmpdir / "audio.wav"
            extract_cmd = [
                'ffmpeg', '-y', '-i', input_video_path,
                '-vn', '-acodec', 'pcm_s16le', '-ar', '44100', '-ac', '2',
                str(audio_in)
            ]
            result = subprocess.run(extract_cmd, capture_output=True, text=True)
            if result.returncode != 0:
                print(f"[demucs] Failed to extract audio: {result.stderr[-200:]}")
                return False

            print(f"[demucs] Running music separation...")

            import soundfile as sf
            from demucs.pretrained import get_model
            from demucs.apply import apply_model

            # Load model (cached after first run)
            model = get_model('htdemucs')
            model.to('cuda')
            model.eval()

            # Load audio
            data, sr = sf.read(str(audio_in))
            wav = torch.from_numpy(data.T).float()
            if wav.dim() == 1:
                wav = wav.unsqueeze(0)

            # Resample if needed
            if sr != model.samplerate:
                import torchaudio
                wav = torchaudio.functional.resample(wav, sr, model.samplerate)

            wav = wav.unsqueeze(0).to('cuda')

            # Separate sources
            with torch.no_grad():
                sources = apply_model(model, wav, overlap=0.25, progress=False)

            # Keep only vocals (index 3)
            vocals = sources[0, 3].cpu()

            # Save vocals
            audio_out = tmpdir / "vocals.wav"
            audio_np = vocals.numpy().T
            sf.write(str(audio_out), audio_np, model.samplerate)

            print(f"[demucs] Merging vocals back with video...")
            merge_cmd = [
                'ffmpeg', '-y',
                '-i', input_video_path,
                '-i', str(audio_out),
                '-c:v', 'copy',
                '-map', '0:v:0', '-map', '1:a:0',
                '-c:a', 'aac', '-b:a', '128k',
                '-shortest',
                output_video_path
            ]
            result = subprocess.run(merge_cmd, capture_output=True, text=True)
            if result.returncode != 0:
                print(f"[demucs] Failed to merge: {result.stderr[-200:]}")
                return False

            print(f"[demucs] Successfully removed music")
            return True

    except Exception as e:
        print(f"[demucs] Error: {e}")
        import traceback
        traceback.print_exc()
        return False


def apply_gaussian_blur(video_tensor: torch.Tensor, blur_amount: int) -> torch.Tensor:
    """Apply Gaussian blur to video tensor. Video shape: [frames, H, W, C]"""
    if blur_amount <= 0:
        return video_tensor

    from torchvision.transforms.functional import gaussian_blur

    # Ensure kernel size is odd and at least 3
    kernel_size = blur_amount * 2 + 1
    sigma = blur_amount / 2.0

    # Video tensor is [frames, H, W, C], but gaussian_blur expects [batch, C, H, W]
    # Permute to [frames, C, H, W]
    video_tensor = video_tensor.permute(0, 3, 1, 2)

    blurred = gaussian_blur(video_tensor, kernel_size=[kernel_size, kernel_size], sigma=[sigma, sigma])

    # Permute back to [frames, H, W, C]
    blurred = blurred.permute(0, 2, 3, 1)

    return blurred


def loop_clips_with_audio_track(clip_paths: list[str], audio_path: str) -> str:
    """Loop video clips to match audio duration. CPU work - free."""
    import subprocess
    from pydub import AudioSegment

    try:
        # Get audio duration
        audio = AudioSegment.from_file(audio_path)
        audio_duration = len(audio) / 1000.0  # Convert to seconds

        # Get total clips duration
        clips_duration = 0.0
        for clip in clip_paths:
            probe = subprocess.run([
                'ffprobe', '-v', 'error', '-show_entries', 'format=duration',
                '-of', 'default=noprint_wrappers=1:nokey=1', clip
            ], capture_output=True, text=True, check=True)
            clips_duration += float(probe.stdout.strip())

        # Calculate loop count
        loop_count = int(audio_duration / clips_duration) + 1

        print(f"[loop] Audio: {audio_duration:.2f}s, Clips: {clips_duration:.2f}s, Loops: {loop_count}")

        # Create concat file with loops
        concat_file = tempfile.mktemp(suffix=".txt")
        with open(concat_file, 'w') as f:
            for _ in range(loop_count):
                for clip in clip_paths:
                    f.write(f"file '{clip}'\n")

        # Concat videos
        concat_video = tempfile.mktemp(suffix=".mp4")
        result = subprocess.run([
            'ffmpeg', '-y', '-f', 'concat', '-safe', '0', '-i', concat_file,
            '-c', 'copy', concat_video
        ], capture_output=True, text=True)

        if result.returncode != 0:
            raise Exception(f"Concat failed: {result.stderr[-200:]}")

        # Replace audio and trim to audio duration
        final_video = tempfile.mktemp(suffix=".mp4")
        result = subprocess.run([
            'ffmpeg', '-y',
            '-i', concat_video,
            '-i', audio_path,
            '-map', '0:v:0', '-map', '1:a:0',
            '-c:v', 'copy', '-c:a', 'aac', '-b:a', '192k',
            '-t', str(audio_duration),
            '-shortest',
            final_video
        ], capture_output=True, text=True)

        if result.returncode != 0:
            raise Exception(f"Audio merge failed: {result.stderr[-200:]}")

        print(f"[loop] Created looped video: {audio_duration:.2f}s")
        return final_video

    except Exception as e:
        print(f"[loop] Error: {e}")
        import traceback
        traceback.print_exc()
        return clip_paths[0] if clip_paths else None


def transcribe_with_whisper_gpu(video_path: str, model_size: str = "small") -> list[dict]:
    """Transcribe video audio with Whisper on GPU (already inside GPU context). Returns segments with timestamps."""
    import whisper

    try:
        print(f"[whisper] Loading {model_size} model on GPU...")
        model = whisper.load_model(model_size).to('cuda')

        print(f"[whisper] Transcribing audio on GPU...")
        result = model.transcribe(video_path, word_timestamps=True, fp16=True)

        print(f"[whisper] Transcription complete: {len(result['segments'])} segments")
        return result['segments']
    except Exception as e:
        print(f"[whisper] Error: {e}")
        import traceback
        traceback.print_exc()
        return []


def create_beautiful_ass_subtitles(segments: list[dict], output_path: str, video_width: int, video_height: int):
    """Create elegant animated ASS subtitles with universal language support."""

    # Download Noto Sans - supports all languages (Latin, CJK, Tamil, etc.)
    import urllib.request
    font_url = "https://github.com/google/fonts/raw/main/ofl/notosans/NotoSans-SemiBold.ttf"
    font_path = "/tmp/NotoSans-SemiBold.ttf"
    font_name = "Noto Sans SemiBold"

    try:
        if not os.path.exists(font_path):
            urllib.request.urlretrieve(font_url, font_path)
    except:
        font_name = "Arial"  # Fallback

    # ASS subtitle header with beautiful styling
    ass_content = f"""[Script Info]
Title: Elegant Subtitles
ScriptType: v4.00+
WrapStyle: 0
PlayResX: {video_width}
PlayResY: {video_height}
ScaledBorderAndShadow: yes

[V4+ Styles]
Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
Style: Default,{font_name},{int(video_height * 0.05)},&H00FFFFFF,&H000000FF,&H00000000,&H80000000,0,0,0,0,100,100,0,0,1,2,1,5,10,10,{int(video_height * 0.42)},1

[Events]
Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
"""

    # Add each segment with fade animation
    for seg in segments:
        start_time = format_ass_time(seg['start'])
        end_time = format_ass_time(seg['end'])
        text = seg['text'].strip()

        # Add fade in/out animation
        fade_duration = 200  # ms
        animated_text = f"{{\\fad({fade_duration},{fade_duration})}}{text}"

        ass_content += f"Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,{animated_text}\n"

    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(ass_content)

    print(f"[subtitles] Created ASS file with {len(segments)} segments")


def format_ass_time(seconds: float) -> str:
    """Convert seconds to ASS timestamp format (h:mm:ss.cc)."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    centisecs = int((seconds % 1) * 100)
    return f"{hours}:{minutes:02d}:{secs:02d}.{centisecs:02d}"


def burn_subtitles_and_watermark(video_path: str, output_path: str, subtitle_path: str = None, watermark_path: str = None):
    """Burn subtitles and/or watermark into video using FFmpeg. CPU work - free."""
    import subprocess
    import tempfile

    try:
        current_video = video_path

        # Step 1: Apply watermark if needed (first pass)
        if watermark_path and os.path.exists(watermark_path):
            print(f"[burn] Pass 1: Overlaying watermark...")

            # Get video duration for looping watermark
            probe = subprocess.run([
                'ffprobe', '-v', 'error', '-show_entries', 'format=duration',
                '-of', 'default=noprint_wrappers=1:nokey=1', current_video
            ], capture_output=True, text=True, check=True)
            video_duration = float(probe.stdout.strip())

            temp_watermarked = tempfile.mktemp(suffix=".mp4")

            cmd = [
                'ffmpeg', '-y', '-i', current_video,
                '-loop', '1', '-t', str(video_duration), '-i', watermark_path,
                '-filter_complex', '[1:v][0:v]scale2ref[ovr][base];[base][ovr]overlay=0:0[vout]',
                '-map', '[vout]', '-map', '0:a?',
                '-c:a', 'copy', '-pix_fmt', 'yuv420p',
                temp_watermarked
            ]

            result = subprocess.run(cmd, capture_output=True, text=True)
            if result.returncode != 0:
                raise Exception(f"Watermark pass failed: {result.stderr[-200:]}")

            current_video = temp_watermarked
            print(f"[burn] Watermark applied")

        # Step 2: Apply subtitles if needed (second pass)
        if subtitle_path and os.path.exists(subtitle_path):
            print(f"[burn] Pass 2: Burning subtitles from {subtitle_path}...")

            # Escape the subtitle path for FFmpeg (replace : with \\: and \ with /)
            subtitle_path_escaped = subtitle_path.replace('\\', '/').replace(':', '\\:')

            cmd = [
                'ffmpeg', '-y', '-i', current_video,
                '-vf', f"subtitles='{subtitle_path_escaped}':force_style='FontName=Noto Sans SemiBold'",
                '-c:a', 'copy', '-pix_fmt', 'yuv420p',
                output_path
            ]

            print(f"[burn] FFmpeg command: {' '.join(cmd)}")
            result = subprocess.run(cmd, capture_output=True, text=True)
            if result.returncode != 0:
                print(f"[burn] Subtitle stderr: {result.stderr}")
                raise Exception(f"Subtitle pass failed: {result.stderr[-500:]}")

            print(f"[burn] Subtitles burned successfully")
        elif current_video != video_path:
            # Only watermark was applied, move temp file to output
            import shutil
            shutil.move(current_video, output_path)
        else:
            # Nothing to do, just copy
            import shutil
            shutil.copy2(video_path, output_path)

        print(f"[burn] Successfully burned subtitles/watermark")
        return True

    except Exception as e:
        print(f"[burn] Error: {e}")
        import traceback
        traceback.print_exc()
        return False


@spaces.GPU(duration=90)
@torch.inference_mode()
def generate_video(
    first_image,
    last_image,
    prompts: list[str],
    duration: float,
    enhance_prompt: bool = True,
    seed: int = 42,
    randomize_seed: bool = True,
    height: int = 320,
    width: int = 512,
    negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
    blur_amount: int = 0,
    remove_music: bool = False,
    add_subtitles: bool = False,
    audio_track = None,
    progress=gr.Progress(track_tqdm=True),
):
    try:
        torch.cuda.reset_peak_memory_stats()
        log_memory("start")

        base_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
        generated_clips = []

        # Generate multiple clips in one GPU session (one per prompt)
        for clip_idx, prompt in enumerate(prompts):
            current_seed = base_seed + clip_idx
            print(f"[GPU] Generating clip {clip_idx + 1}/{len(prompts)}, prompt: {prompt[:50]}..., seed={current_seed}")

            frame_rate = DEFAULT_FRAME_RATE
            num_frames = int(duration * frame_rate) + 1
            num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1

            print(f"Generating: {height}x{width}, {num_frames} frames ({duration}s), seed={current_seed}")

            images = []
            output_dir = Path("outputs")
            output_dir.mkdir(exist_ok=True)

            if first_image is not None:
                temp_first_path = output_dir / f"temp_first_{current_seed}.jpg"
                if hasattr(first_image, "save"):
                    first_image.save(temp_first_path)
                else:
                    temp_first_path = Path(first_image)
                images.append(ImageConditioningInput(path=str(temp_first_path), frame_idx=0, strength=1.0))

            if last_image is not None:
                temp_last_path = output_dir / f"temp_last_{current_seed}.jpg"
                if hasattr(last_image, "save"):
                    last_image.save(temp_last_path)
                else:
                    temp_last_path = Path(last_image)
                images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))

            from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number

            tiling_config = TilingConfig.default()
            video_chunks_number = get_video_chunks_number(num_frames, tiling_config)

            log_memory("before pipeline call")

            # Run inference - DistilledPipeline has simpler API
            video_frames_iter, audio = pipeline(
                prompt=prompt,
                seed=current_seed,
                height=int(height),
                width=int(width),
                num_frames=num_frames,
                frame_rate=frame_rate,
                images=images,
                enhance_prompt=enhance_prompt,
            )

            # Collect video frames
            frames = [frame for frame in video_frames_iter]
            video_tensor = torch.cat(frames, dim=0) if len(frames) > 1 else frames[0]

            log_memory("after pipeline call")

            # Apply Gaussian blur if requested (for censoring/teaser effect)
            if blur_amount > 0:
                print(f"Applying Gaussian blur (amount={blur_amount})...")
                video_tensor = apply_gaussian_blur(video_tensor, blur_amount)
                log_memory("after blur")

            output_path = tempfile.mktemp(suffix=".mp4")
            encode_video(
                video=video_tensor,
                fps=frame_rate,
                audio=audio,
                output_path=output_path,
                video_chunks_number=video_chunks_number,
            )

            log_memory("after encode_video")

            # Remove background music if requested
            if remove_music:
                print(f"Removing background music with Demucs...")
                processed_path = tempfile.mktemp(suffix=".mp4")
                success = remove_music_demucs(output_path, processed_path)
                if success:
                    output_path = processed_path
                    log_memory("after demucs")
                else:
                    print(f"Warning: Music removal failed, using original video")

            generated_clips.append(str(output_path))

        # Transcribe with Whisper if requested (still within GPU context)
        subtitle_segments = []
        if add_subtitles and audio_track:
            print("[GPU] Transcribing audio track with Whisper...")
            # Transcribe the audio track file, not the generated video (which has no audio yet)
            subtitle_segments = transcribe_with_whisper_gpu(audio_track, model_size="small")
            log_memory("after whisper")
        elif add_subtitles and not audio_track:
            print("[GPU] Warning: Subtitles requested but no audio track provided - skipping transcription")

        # Return all generated clips and subtitle segments
        return generated_clips, subtitle_segments, base_seed

    except Exception as e:
        import traceback
        log_memory("on error")
        print(f"Error: {str(e)}\n{traceback.format_exc()}")
        return [], [], base_seed


def full_generation_process(
    first_image,
    last_image,
    prompt1: str,
    prompt2: str,
    prompt3: str,
    duration: float,
    enhance_prompt: bool,
    seed: int,
    randomize_seed: bool,
    height: int,
    width: int,
    negative_prompt: str,
    blur_amount: int,
    remove_music: bool,
    add_subtitles: bool,
    watermark,
    audio_track,
    progress=gr.Progress(track_tqdm=True),
):
    """Main entry point: generates clips (GPU) then optionally loops with audio (CPU)."""

    # Collect non-empty prompts
    prompts = [p.strip() for p in [prompt1, prompt2, prompt3] if p and p.strip()]
    if not prompts:
        return None, seed

    print(f"Generating {len(prompts)} clip(s)")

    # Phase 1: Generate clips + transcribe (GPU time counted)
    clips, subtitle_segments, final_seed = generate_video(
        first_image, last_image, prompts, duration, enhance_prompt,
        seed, randomize_seed, height, width, negative_prompt,
        blur_amount, remove_music, add_subtitles, audio_track, progress
    )

    if not clips:
        return None, final_seed

    # Phase 2: CPU work (free) - loop clips with audio if provided
    if audio_track and len(clips) > 1:
        print("[CPU] Looping clips to match audio duration...")
        final_video = loop_clips_with_audio_track(clips, audio_track)
    elif len(clips) == 1:
        final_video = clips[0]
    else:
        final_video = clips[0]

    # Phase 3: CPU work (free) - add subtitles and/or watermark
    if add_subtitles or watermark:
        print("[CPU] Adding subtitles/watermark...")

        # Use subtitle segments from GPU transcription
        subtitle_file = None
        if add_subtitles and subtitle_segments:
            subtitle_file = tempfile.mktemp(suffix=".ass")
            create_beautiful_ass_subtitles(subtitle_segments, subtitle_file, int(width), int(height))
            print(f"[subtitles] Created subtitle file: {subtitle_file}, exists: {os.path.exists(subtitle_file)}")
            if os.path.exists(subtitle_file):
                with open(subtitle_file, 'r') as f:
                    print(f"[subtitles] File size: {len(f.read())} bytes")

        # Burn subtitles and/or watermark
        output_with_extras = tempfile.mktemp(suffix=".mp4")
        success = burn_subtitles_and_watermark(final_video, output_with_extras, subtitle_file, watermark)
        if success:
            final_video = output_with_extras

    return final_video, final_seed


with gr.Blocks(title="Element-8 Video", delete_cache=(3600, 7200)) as demo:  # cleanup: check every 1h, delete files >2h old
    gr.Markdown("# Element-8: Fast Video Generation with Frame Conditioning")
    gr.Markdown(
        "High quality video + audio generation with first and last frame conditioning. "
        "Pre-distilled LTX model for fast inference. "
        "[[code]](https://github.com/Lightricks/LTX-2)"
    )

    with gr.Row():
        with gr.Column():
            with gr.Row():
                first_image = gr.Image(label="First Frame (Optional)", type="pil")
                last_image = gr.Image(label="Last Frame (Optional)", type="pil")
            prompt1 = gr.Textbox(
                label="Prompt 1",
                value="Make this image come alive with cinematic motion, smooth animation",
                lines=2,
                placeholder="First prompt (required)",
            )
            prompt2 = gr.Textbox(
                label="Prompt 2 (Optional)",
                value="",
                lines=2,
                placeholder="Second prompt (leave empty if not needed)",
            )
            prompt3 = gr.Textbox(
                label="Prompt 3 (Optional)",
                value="",
                lines=2,
                placeholder="Third prompt (leave empty if not needed)",
            )
            duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
            audio_track = gr.Audio(label="Audio Track (Optional) - loops clips to match duration", type="filepath", sources=["upload"])

            generate_btn = gr.Button("Generate Video", variant="primary", size="lg")

            with gr.Accordion("Advanced Settings", open=False):
                seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=10, step=1)
                randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                with gr.Row():
                    width = gr.Number(label="Width", value=512, precision=0)
                    height = gr.Number(label="Height", value=320, precision=0)
                with gr.Row():
                    enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
                    high_res = gr.Checkbox(label="High Resolution", value=False)
                with gr.Row():
                    blur_amount = gr.Number(label="Blur (0=off, 36=heavy)", value=0, precision=0)
                    remove_music = gr.Checkbox(label="Remove Music", value=False)
                with gr.Row():
                    add_subtitles = gr.Checkbox(label="Add Subtitles (Whisper)", value=False)
                    watermark = gr.File(label="Watermark PNG (full-video size, position in your editor)", file_types=[".png"])
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    value=DEFAULT_NEGATIVE_PROMPT,
                    lines=3,
                    placeholder="What to avoid in the generated video...",
                )

        with gr.Column():
            output_video = gr.Video(label="Generated Video", autoplay=True)

    gr.Examples(
        examples=[
            [
                None,
                "pinkknit.jpg",
                "The camera falls downward through darkness as if dropped into a tunnel. "
                "As it slows, five friends wearing pink knitted hats and sunglasses lean "
                "over and look down toward the camera with curious expressions. The lens "
                "has a strong fisheye effect, creating a circular frame around them. They "
                "crowd together closely, forming a symmetrical cluster while staring "
                "directly into the lens.",
                "",
                "",
                3.0,
                False,
                42,
                True,
                1024,
                1024,
            ],
        ],
        inputs=[
            first_image, last_image, prompt1, prompt2, prompt3, duration,
            enhance_prompt, seed, randomize_seed, height, width,
        ],
    )

    first_image.change(
        fn=on_image_upload,
        inputs=[first_image, last_image, high_res],
        outputs=[width, height],
    )

    last_image.change(
        fn=on_image_upload,
        inputs=[first_image, last_image, high_res],
        outputs=[width, height],
    )

    high_res.change(
        fn=on_highres_toggle,
        inputs=[first_image, last_image, high_res],
        outputs=[width, height],
    )

    generate_btn.click(
        fn=full_generation_process,
        inputs=[
            first_image, last_image, prompt1, prompt2, prompt3, duration, enhance_prompt,
            seed, randomize_seed, height, width, negative_prompt, blur_amount, remove_music,
            add_subtitles, watermark, audio_track,
        ],
        outputs=[output_video, seed],
    )


css = """
.fillable{max-width: 1200px !important}
"""

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Citrus(), css=css)