import gradio as gr import subprocess import os import glob import time from pathlib import Path from huggingface_hub import snapshot_download MODEL_REPO = "hpcai-tech/Open-Sora-v2" CKPT_DIR = Path("ckpts") SAMPLES_DIR = Path("samples") def ensure_ckpts(): if CKPT_DIR.exists() and any(CKPT_DIR.iterdir()): print("Found existing checkpoints in", CKPT_DIR) return True hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") if not hf_token: print("No HF token found in env. Cannot auto-download. Please add HUGGINGFACE_HUB_TOKEN or download ckpts manually.") return False print("Downloading model weights from HF... (this will take several minutes)") try: snapshot_download(repo_id=MODEL_REPO, local_dir=str(CKPT_DIR), local_dir_use_symlinks=False) print("Download complete.") return True except Exception as e: print("Error downloading checkpoints:", e) return False def find_latest_video(): SAMPLES_DIR.mkdir(exist_ok=True) matches = list(SAMPLES_DIR.glob("*.mp4")) if not matches: return None matches.sort(key=lambda p: p.stat().st_mtime, reverse=True) return str(matches[0]) def run_torch_inference(config, prompt, ref_image=None, aspect_ratio=None, num_frames=None, offload=False): SAMPLES_DIR.mkdir(exist_ok=True) cmd = [ "torchrun", "--nproc_per_node", "1", "--standalone", "scripts/diffusion/inference.py", f"configs/diffusion/inference/{config}.py", "--save-dir", str(SAMPLES_DIR), "--prompt", prompt ] if ref_image: cmd += ["--cond_type", "i2v_head", "--ref", ref_image] if aspect_ratio: cmd += ["--aspect_ratio", aspect_ratio] if num_frames: cmd += ["--num_frames", str(num_frames)] if offload: cmd += ["--offload", "True"] print("Running command:", " ".join(cmd)) try: subprocess.run(cmd, check=True, env=os.environ) except subprocess.CalledProcessError as e: print("Inference failed:", e) raise def generate_video(prompt, mode="t2i2v_256px", ref_image_path=None, aspect_ratio="16:9", num_frames=None, offload=False): # mode options: t2i2v_256px, 256px (text2video direct), 768px... ok = ensure_ckpts() if not ok: return "Model checkpoints not found and no HF token provided. Upload ckpts to ./ckpts or set HUGGINGFACE_HUB_TOKEN." # Map UI selection to config file names config_map = { "256 (t2i2v)": "t2i2v_256px", "256 (t2v)": "256px", "768 (t2v)": "768px", "768 (t2i2v)": "t2i2v_768px" } config = config_map.get(mode, "t2i2v_256px") try: run_torch_inference(config, prompt, ref_image=ref_image_path, aspect_ratio=aspect_ratio, num_frames=num_frames, offload=offload) # wait for file to appear for _ in range(120): latest = find_latest_video() if latest: return latest time.sleep(1) return "No output video detected after inference." except Exception as e: return f"Error during generation: {str(e)}" # Gradio UI with gr.Blocks() as demo: gr.Markdown("# 🎬 Open-Sora (Open-Sora-v2) — Text/Image to Video") with gr.Row(): prompt = gr.Textbox(lines=3, label="Prompt", placeholder="A cinematic shot of ...") with gr.Row(): mode = gr.Radio(["256 (t2i2v)", "256 (t2v)", "768 (t2v)", "768 (t2i2v)"], value="256 (t2i2v)", label="Generation Mode") aspect_ratio = gr.Dropdown(["16:9","9:16","1:1","2.39:1"], value="16:9", label="Aspect Ratio") num_frames = gr.Number(value=17, label="Frames (use 4k+1 rules)", precision=0) with gr.Row(): ref_image = gr.Image(type="filepath", label="Reference image (optional, for I2V)") offload = gr.Checkbox(label="Memory offload (slower but uses less GPU memory)", value=False) generate_btn = gr.Button("Generate Video") output_video = gr.Video(label="Generated Video") status = gr.Textbox(label="Status/Logs", interactive=False) def on_generate(prompt_text, mode_val, ar, nf, ref_img, off): status_text = "Starting..." status.update(status_text) res = generate_video(prompt_text, mode_val, ref_image_path=ref_img, aspect_ratio=ar, num_frames=int(nf) if nf else None, offload=off) return res, f"Completed: {res}" generate_btn.click(on_generate, inputs=[prompt, mode, aspect_ratio, num_frames, ref_image, offload], outputs=[output_video, status]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))