File size: 4,885 Bytes
4e424ea
0aa6485
4e424ea
711b244
 
4e424ea
711b244
ca5f9a7
4e424ea
711b244
 
 
 
4e424ea
502938a
 
4e424ea
80502d4
0aa6485
ca753f0
ca5f9a7
 
 
711b244
 
 
 
 
 
 
 
 
 
 
 
 
80502d4
4e424ea
80502d4
4e424ea
70af8f2
4e424ea
 
 
f0f4c78
89f1ae7
ca5f9a7
4e424ea
 
ca5f9a7
80502d4
ca5f9a7
 
 
 
 
 
 
e7dadb6
02b0954
ca5f9a7
 
02b0954
80502d4
 
02b0954
 
 
80502d4
02b0954
 
ca5f9a7
 
02b0954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca5f9a7
02b0954
 
632fdb4
02b0954
ca5f9a7
 
711b244
 
 
 
02b0954
 
 
 
632fdb4
f0f4c78
02b0954
80502d4
711b244
665534e
711b244
 
 
0aa6485
80502d4
 
 
 
 
 
 
 
4e424ea
80502d4
 
4e424ea
80502d4
4e424ea
 
80502d4
 
 
 
4e424ea
 
 
502938a
 
 
4e424ea
 
0aa6485
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import re
import subprocess
import time
from tqdm import tqdm
from huggingface_hub import snapshot_download
import torch
import os

# Force the device to CPU
device = torch.device("cpu")

# Download model
snapshot_download(
    repo_id="Wan-AI/Wan2.1-T2V-1.3B",
    local_dir="./Wan2.1-T2V-1.3B"
)
print("✅ Model downloaded successfully.")

def infer(prompt, progress=gr.Progress(track_tqdm=True)):
    total_process_steps = 11
    irrelevant_steps = 4
    relevant_steps = total_process_steps - irrelevant_steps

    overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
                       ncols=120, dynamic_ncols=False, leave=True)
    processed_steps = 0

    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    video_progress_bar = None

    sub_bar = None
    sub_ticks = 0
    sub_tick_total = 1500
    video_phase = False

    # ✅ Use generate.py directly
    command = [
        "python", "-u", "generate.py",  # <- Make sure generate.py is in the same folder
        "--task", "t2v-1.3B",
        "--size", "480*480",
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--t5_cpu",
        "--offload_model", "True",
        "--save_file", "generated_video.mp4"
    ]

    print("🚀 Starting video generation process...")
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )

    stdout = process.stdout

    for line in iter(stdout.readline, ''):
        stripped_line = line.strip()
        print(f"[SUBPROCESS]: {stripped_line}")  # Debug print

        if not stripped_line:
            continue

        # Match video generation progress (like tqdm)
        progress_match = progress_pattern.search(stripped_line)
        if progress_match:
            if sub_bar is not None and sub_ticks < sub_tick_total:
                sub_bar.update(sub_tick_total - sub_ticks)
                sub_bar.close()
                overall_bar.update(1)
                sub_bar = None
                sub_ticks = 0
            video_phase = True
            current = int(progress_match.group(2))
            total = int(progress_match.group(3))
            if video_progress_bar is None:
                video_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
                                          ncols=120, dynamic_ncols=True, leave=True)
            video_progress_bar.update(current - video_progress_bar.n)
            if video_progress_bar.n >= video_progress_bar.total:
                video_phase = False
                overall_bar.update(1)
                video_progress_bar.close()
                video_progress_bar = None
            continue

        if "INFO:" in stripped_line:
            parts = stripped_line.split("INFO:", 1)
            msg = parts[1].strip() if len(parts) > 1 else ""
            print(f"[INFO]: {msg}")
            if processed_steps < irrelevant_steps:
                processed_steps += 1
                continue
            else:
                if sub_bar is not None and sub_ticks < sub_tick_total:
                    sub_bar.update(sub_tick_total - sub_ticks)
                    sub_bar.close()
                    overall_bar.update(1)
                    sub_bar = None
                    sub_ticks = 0
                sub_bar = tqdm(total=sub_tick_total, desc=msg, position=2,
                               ncols=120, dynamic_ncols=False, leave=True)
                sub_ticks = 0
            continue

    process.wait()

    # Final cleanup
    if video_progress_bar is not None:
        video_progress_bar.close()
    if sub_bar is not None:
        sub_bar.close()
    overall_bar.close()

    # ✅ Output validation
    if process.returncode == 0:
        if os.path.exists("generated_video.mp4"):
            print("✅ Video generation completed successfully.")
            return "generated_video.mp4"
        else:
            print("❌ Video generation finished but output file is missing.")
            raise gr.Error("Output video not found after generation.")
    else:
        print("❌ Subprocess failed.")
        raise gr.Error("Video generation failed. Check logs above.")

# ✅ Gradio UI
with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1 1.3B - Text to Video")
        gr.Markdown("Generate short videos from prompts. Duplicate this space to avoid queue limits.")
        prompt = gr.Textbox(label="Enter your prompt")
        submit_btn = gr.Button("Generate Video")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn=infer,
        inputs=[prompt],
        outputs=[video_res]
    )

demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)