File size: 4,683 Bytes
c4de01b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import subprocess
import os
import glob
import time
from pathlib import Path
from huggingface_hub import snapshot_download
MODEL_REPO = "hpcai-tech/Open-Sora-v2"
CKPT_DIR = Path("ckpts")
SAMPLES_DIR = Path("samples")
def ensure_ckpts():
if CKPT_DIR.exists() and any(CKPT_DIR.iterdir()):
print("Found existing checkpoints in", CKPT_DIR)
return True
hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
if not hf_token:
print("No HF token found in env. Cannot auto-download. Please add HUGGINGFACE_HUB_TOKEN or download ckpts manually.")
return False
print("Downloading model weights from HF... (this will take several minutes)")
try:
snapshot_download(repo_id=MODEL_REPO, local_dir=str(CKPT_DIR), local_dir_use_symlinks=False)
print("Download complete.")
return True
except Exception as e:
print("Error downloading checkpoints:", e)
return False
def find_latest_video():
SAMPLES_DIR.mkdir(exist_ok=True)
matches = list(SAMPLES_DIR.glob("*.mp4"))
if not matches:
return None
matches.sort(key=lambda p: p.stat().st_mtime, reverse=True)
return str(matches[0])
def run_torch_inference(config, prompt, ref_image=None, aspect_ratio=None, num_frames=None, offload=False):
SAMPLES_DIR.mkdir(exist_ok=True)
cmd = [
"torchrun", "--nproc_per_node", "1", "--standalone",
"scripts/diffusion/inference.py",
f"configs/diffusion/inference/{config}.py",
"--save-dir", str(SAMPLES_DIR),
"--prompt", prompt
]
if ref_image:
cmd += ["--cond_type", "i2v_head", "--ref", ref_image]
if aspect_ratio:
cmd += ["--aspect_ratio", aspect_ratio]
if num_frames:
cmd += ["--num_frames", str(num_frames)]
if offload:
cmd += ["--offload", "True"]
print("Running command:", " ".join(cmd))
try:
subprocess.run(cmd, check=True, env=os.environ)
except subprocess.CalledProcessError as e:
print("Inference failed:", e)
raise
def generate_video(prompt, mode="t2i2v_256px", ref_image_path=None, aspect_ratio="16:9", num_frames=None, offload=False):
# mode options: t2i2v_256px, 256px (text2video direct), 768px...
ok = ensure_ckpts()
if not ok:
return "Model checkpoints not found and no HF token provided. Upload ckpts to ./ckpts or set HUGGINGFACE_HUB_TOKEN."
# Map UI selection to config file names
config_map = {
"256 (t2i2v)": "t2i2v_256px",
"256 (t2v)": "256px",
"768 (t2v)": "768px",
"768 (t2i2v)": "t2i2v_768px"
}
config = config_map.get(mode, "t2i2v_256px")
try:
run_torch_inference(config, prompt, ref_image=ref_image_path, aspect_ratio=aspect_ratio, num_frames=num_frames, offload=offload)
# wait for file to appear
for _ in range(120):
latest = find_latest_video()
if latest:
return latest
time.sleep(1)
return "No output video detected after inference."
except Exception as e:
return f"Error during generation: {str(e)}"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# π¬ Open-Sora (Open-Sora-v2) β Text/Image to Video")
with gr.Row():
prompt = gr.Textbox(lines=3, label="Prompt", placeholder="A cinematic shot of ...")
with gr.Row():
mode = gr.Radio(["256 (t2i2v)", "256 (t2v)", "768 (t2v)", "768 (t2i2v)"], value="256 (t2i2v)", label="Generation Mode")
aspect_ratio = gr.Dropdown(["16:9","9:16","1:1","2.39:1"], value="16:9", label="Aspect Ratio")
num_frames = gr.Number(value=17, label="Frames (use 4k+1 rules)", precision=0)
with gr.Row():
ref_image = gr.Image(type="filepath", label="Reference image (optional, for I2V)")
offload = gr.Checkbox(label="Memory offload (slower but uses less GPU memory)", value=False)
generate_btn = gr.Button("Generate Video")
output_video = gr.Video(label="Generated Video")
status = gr.Textbox(label="Status/Logs", interactive=False)
def on_generate(prompt_text, mode_val, ar, nf, ref_img, off):
status_text = "Starting..."
status.update(status_text)
res = generate_video(prompt_text, mode_val, ref_image_path=ref_img, aspect_ratio=ar, num_frames=int(nf) if nf else None, offload=off)
return res, f"Completed: {res}"
generate_btn.click(on_generate, inputs=[prompt, mode, aspect_ratio, num_frames, ref_image, offload], outputs=[output_video, status])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
|