File size: 3,332 Bytes
7653ce6
73d63fb
7653ce6
2f10316
5859b26
 
d5d2bf6
bcea283
d5d2bf6
5859b26
88bcca9
2f10316
73d63fb
d5d2bf6
 
5859b26
d5d2bf6
 
 
5859b26
bcea283
d5d2bf6
 
73d63fb
d5d2bf6
73d63fb
 
d5d2bf6
bcea283
d5d2bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcea283
d5d2bf6
b411129
5859b26
d5d2bf6
73d63fb
 
d5d2bf6
 
 
 
 
9fe1333
9bc2ffb
d5d2bf6
bcea283
73d63fb
b411129
bcea283
73d63fb
 
bcea283
d5d2bf6
b411129
fe205f5
 
d5d2bf6
5859b26
d5d2bf6
5859b26
fe205f5
bcea283
d5d2bf6
08ee949
fe205f5
bcea283
 
08ee949
 
bcea283
 
 
 
 
 
 
 
 
 
 
73d63fb
bcea283
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import subprocess
import gdown
import gradio as gr
import zipfile

# ------------------------------------------------------
# 1) DOWNLOAD SadTalker ZIP FROM GOOGLE DRIVE (ONLY ONCE)
# ------------------------------------------------------

SADTALKER_ZIP_ID = "1ERCCvqt2YTNiwfqY1N95aQIIBm3maApu"

if not os.path.exists("SadTalker"):
    print("⬇️ Downloading SadTalker.zip...")
    gdown.download(f"https://drive.google.com/uc?id={SADTALKER_ZIP_ID}", "sad.zip", quiet=False)

    print("📦 Extracting ZIP...")
    with zipfile.ZipFile("sad.zip", 'r') as zip_ref:
        zip_ref.extractall("./")

    # Rename folder if extracted as SadTalker-main
    if os.path.exists("SadTalker-main"):
        os.rename("SadTalker-main", "SadTalker")

print("✔ SadTalker ready!")
os.chdir("SadTalker")

# ------------------------------------------------------
# 2) DOWNLOAD CHECKPOINT MODELS (ONLY IF MISSING)
# ------------------------------------------------------

def gd(file_id, out):
    url = f"https://drive.google.com/uc?id={file_id}"
    if not os.path.exists(out):
        print(f"⬇️ Downloading: {out}")
        gdown.download(url, out, quiet=False)
    else:
        print(f"✔ Already exists: {out}")

os.makedirs("checkpoints", exist_ok=True)

gd("1g3VIpU3yhpITMZtrWU2mbmyzLoF9K3Gz", "checkpoints/audio2exp_00300-model.pth")
gd("1Jp4i_Qc-6qCms7v1kN61RE3qnT_9J8vg", "checkpoints/audio2pose_00140-model.pth")
gd("1pTbKmpOeWRSA1NYQ3DnWe3W-9Qeyy7PK", "checkpoints/mapping_00109-model.pth.tar")
gd("1QJins0e_hvvZM4d47gR9eoCxa0L4LpWj", "checkpoints/mapping_00229-model.pth.tar")
gd("1caz3ESiy-aEzI0ptff8qpjcxdTf-L26E", "checkpoints/SadTalker_V0.0.2_256.safetensors")
gd("12mj9EEs4KIS6-9FjF-PjEegF-TDKvvi6", "checkpoints/SadTalker_V0.0.2_512.safetensors")
gd("1cag0u7e5RgdKoxBa7exY599VNGtYgNQb", "checkpoints/shape_predictor_68_face_landmarks.dat")

print("✔ All checkpoints ready!")

# ------------------------------------------------------
# 3) GENERATE VIDEO (FAST + ALWAYS RETURNS OUTPUT)
# ------------------------------------------------------

def generate_video(image, audio):
    image.save("input.png")

    cmd = [
        "python3", "inference.py",
        "--driven_audio", audio,
        "--source_image", "input.png",
        "--result_dir", "results",
        "--still",
        "--size", "512",
        "--enhancer", "None",
        "--expression_scale", "1.8",
        "--preprocess", "full"
    ]

    print("▶ Running inference...")
    subprocess.run(cmd)

    # Find all mp4 in results
    mp4s = []
    for root, dirs, files in os.walk("results"):
        for f in files:
            if f.endswith(".mp4"):
                mp4s.append(os.path.join(root, f))

    if not mp4s:
        return "❌ No video generated."

    # Pick latest file
    mp4s.sort(key=lambda x: os.path.getmtime(x), reverse=True)
    latest = mp4s[0]

    print("🎬 Video generated:", latest)
    return latest


# ------------------------------------------------------
# 4) GRADIO UI
# ------------------------------------------------------

demo = gr.Interface(
    fn=generate_video,
    inputs=[gr.Image(type="pil"), gr.Audio(type="filepath")],
    outputs=gr.Video(),
    title="SadTalker (Google Drive Version)",
    description="Fast loading + no duplicate downloads + auto video return"
)

demo.launch()