File size: 4,683 Bytes
c4de01b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import subprocess
import os
import glob
import time
from pathlib import Path
from huggingface_hub import snapshot_download

MODEL_REPO = "hpcai-tech/Open-Sora-v2"
CKPT_DIR = Path("ckpts")
SAMPLES_DIR = Path("samples")

def ensure_ckpts():
    if CKPT_DIR.exists() and any(CKPT_DIR.iterdir()):
        print("Found existing checkpoints in", CKPT_DIR)
        return True
    hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
    if not hf_token:
        print("No HF token found in env. Cannot auto-download. Please add HUGGINGFACE_HUB_TOKEN or download ckpts manually.")
        return False
    print("Downloading model weights from HF... (this will take several minutes)")
    try:
        snapshot_download(repo_id=MODEL_REPO, local_dir=str(CKPT_DIR), local_dir_use_symlinks=False)
        print("Download complete.")
        return True
    except Exception as e:
        print("Error downloading checkpoints:", e)
        return False

def find_latest_video():
    SAMPLES_DIR.mkdir(exist_ok=True)
    matches = list(SAMPLES_DIR.glob("*.mp4"))
    if not matches:
        return None
    matches.sort(key=lambda p: p.stat().st_mtime, reverse=True)
    return str(matches[0])

def run_torch_inference(config, prompt, ref_image=None, aspect_ratio=None, num_frames=None, offload=False):
    SAMPLES_DIR.mkdir(exist_ok=True)
    cmd = [
        "torchrun", "--nproc_per_node", "1", "--standalone",
        "scripts/diffusion/inference.py",
        f"configs/diffusion/inference/{config}.py",
        "--save-dir", str(SAMPLES_DIR),
        "--prompt", prompt
    ]
    if ref_image:
        cmd += ["--cond_type", "i2v_head", "--ref", ref_image]
    if aspect_ratio:
        cmd += ["--aspect_ratio", aspect_ratio]
    if num_frames:
        cmd += ["--num_frames", str(num_frames)]
    if offload:
        cmd += ["--offload", "True"]
    print("Running command:", " ".join(cmd))
    try:
        subprocess.run(cmd, check=True, env=os.environ)
    except subprocess.CalledProcessError as e:
        print("Inference failed:", e)
        raise

def generate_video(prompt, mode="t2i2v_256px", ref_image_path=None, aspect_ratio="16:9", num_frames=None, offload=False):
    # mode options: t2i2v_256px, 256px (text2video direct), 768px...
    ok = ensure_ckpts()
    if not ok:
        return "Model checkpoints not found and no HF token provided. Upload ckpts to ./ckpts or set HUGGINGFACE_HUB_TOKEN."

    # Map UI selection to config file names
    config_map = {
        "256 (t2i2v)": "t2i2v_256px",
        "256 (t2v)": "256px",
        "768 (t2v)": "768px",
        "768 (t2i2v)": "t2i2v_768px"
    }
    config = config_map.get(mode, "t2i2v_256px")
    try:
        run_torch_inference(config, prompt, ref_image=ref_image_path, aspect_ratio=aspect_ratio, num_frames=num_frames, offload=offload)
        # wait for file to appear
        for _ in range(120):
            latest = find_latest_video()
            if latest:
                return latest
            time.sleep(1)
        return "No output video detected after inference."
    except Exception as e:
        return f"Error during generation: {str(e)}"

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# 🎬 Open-Sora (Open-Sora-v2) β€” Text/Image to Video")
    with gr.Row():
        prompt = gr.Textbox(lines=3, label="Prompt", placeholder="A cinematic shot of ...")
    with gr.Row():
        mode = gr.Radio(["256 (t2i2v)", "256 (t2v)", "768 (t2v)", "768 (t2i2v)"], value="256 (t2i2v)", label="Generation Mode")
        aspect_ratio = gr.Dropdown(["16:9","9:16","1:1","2.39:1"], value="16:9", label="Aspect Ratio")
        num_frames = gr.Number(value=17, label="Frames (use 4k+1 rules)", precision=0)
    with gr.Row():
        ref_image = gr.Image(type="filepath", label="Reference image (optional, for I2V)")
        offload = gr.Checkbox(label="Memory offload (slower but uses less GPU memory)", value=False)
    generate_btn = gr.Button("Generate Video")
    output_video = gr.Video(label="Generated Video")
    status = gr.Textbox(label="Status/Logs", interactive=False)

    def on_generate(prompt_text, mode_val, ar, nf, ref_img, off):
        status_text = "Starting..."
        status.update(status_text)
        res = generate_video(prompt_text, mode_val, ref_image_path=ref_img, aspect_ratio=ar, num_frames=int(nf) if nf else None, offload=off)
        return res, f"Completed: {res}"

    generate_btn.click(on_generate, inputs=[prompt, mode, aspect_ratio, num_frames, ref_image, offload], outputs=[output_video, status])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))