| | |
| | import streamlit as st |
| | import torch |
| | from diffusers import CogVideoXPipeline |
| | from diffusers.utils import export_to_video |
| | from pathlib import Path |
| |
|
| | st.set_page_config(page_title="Text-to-Video Generator", page_icon="๐ฌ", layout="wide") |
| |
|
| | TEMP_DIR = Path("/tmp/video_gen") |
| | TEMP_DIR.mkdir(exist_ok=True, parents=True) |
| |
|
| | @st.cache_resource |
| | def load_model(model_name): |
| | """Load and cache the model""" |
| | try: |
| | with st.spinner(f"Loading {model_name}... First load may take 10-15 minutes."): |
| | pipe = CogVideoXPipeline.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.float16, |
| | cache_dir="/tmp/huggingface_cache" |
| | ) |
| | |
| | if torch.cuda.is_available(): |
| | pipe.enable_model_cpu_offload() |
| | pipe.enable_vae_slicing() |
| | pipe.enable_vae_tiling() |
| | |
| | return pipe |
| | except Exception as e: |
| | st.error(f"Error loading model: {str(e)}") |
| | return None |
| |
|
| | |
| | st.title("๐ฌ AI Video Generator") |
| | st.markdown("Generate videos from text descriptions") |
| |
|
| | |
| | with st.sidebar: |
| | st.header("๐ค Model Selection") |
| | |
| | model_options = { |
| | "CogVideoX-5B (Recommended)": "THUDM/CogVideoX-5b", |
| | "CogVideoX-2B (Faster)": "THUDM/CogVideoX-2b", |
| | "CogVideoX-5B-I2V (Image to Video)": "THUDM/CogVideoX-5b-I2V", |
| | } |
| | |
| | selected_model_name = st.selectbox( |
| | "Choose Model", |
| | options=list(model_options.keys()), |
| | index=0 |
| | ) |
| | |
| | selected_model = model_options[selected_model_name] |
| | |
| | st.divider() |
| | st.header("โ๏ธ Generation Settings") |
| | |
| | num_frames = st.slider("Frames", 16, 49, 49, help="More frames = longer video") |
| | num_inference_steps = st.slider("Quality Steps", 20, 100, 50, help="More steps = better quality") |
| | guidance_scale = st.slider("Guidance Scale", 1.0, 20.0, 6.0, 0.5) |
| | seed = st.number_input("Seed (-1 = random)", value=-1, step=1) |
| | fps = st.slider("FPS", 4, 16, 8) |
| |
|
| | |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | if torch.cuda.is_available(): |
| | st.success(f"โ
GPU: {torch.cuda.get_device_name(0)}") |
| | vram = torch.cuda.get_device_properties(0).total_memory / 1024**3 |
| | st.info(f"๐พ VRAM: {vram:.1f} GB") |
| | |
| | if vram < 16: |
| | st.warning("โ ๏ธ Low VRAM. Use CogVideoX-2B or reduce frames.") |
| | else: |
| | st.error("โ No GPU found! This requires CUDA GPU.") |
| | st.stop() |
| |
|
| | with col2: |
| | st.info(f"๐ฅ PyTorch: {torch.__version__}") |
| | st.info(f"๐ฎ CUDA: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}") |
| |
|
| | |
| | st.subheader("๐ Your Video Description") |
| |
|
| | prompt = st.text_area( |
| | "Describe your video", |
| | value="A cat walking through a beautiful garden, photorealistic, 4k, cinematic", |
| | height=120, |
| | help="Be descriptive for best results" |
| | ) |
| |
|
| | |
| | with st.expander("๐ก Example Prompts"): |
| | examples = [ |
| | "A cat walking on the grass, realistic style, high quality", |
| | "A panda playing guitar in a bamboo forest, cinematic lighting", |
| | "Waves crashing on a beach at sunset, aerial view, 4k", |
| | "A futuristic car driving through a neon city at night", |
| | "Flowers blooming in time-lapse, macro photography, vibrant colors", |
| | ] |
| | |
| | cols = st.columns(2) |
| | for i, example in enumerate(examples): |
| | with cols[i % 2]: |
| | if st.button(f"๐ {example[:30]}...", key=f"ex_{i}", use_container_width=True): |
| | st.session_state.selected_prompt = example |
| | st.rerun() |
| |
|
| | if 'selected_prompt' in st.session_state: |
| | prompt = st.session_state.selected_prompt |
| | del st.session_state.selected_prompt |
| |
|
| | |
| | if st.button("๐ฌ Generate Video", type="primary", use_container_width=True): |
| | |
| | if not prompt.strip(): |
| | st.error("Please enter a prompt!") |
| | st.stop() |
| | |
| | |
| | pipe = load_model(selected_model) |
| | |
| | if pipe is None: |
| | st.error("Failed to load model!") |
| | st.stop() |
| | |
| | |
| | progress_bar = st.progress(0) |
| | status = st.empty() |
| | |
| | try: |
| | status.info("๐จ Preparing generation...") |
| | progress_bar.progress(10) |
| | |
| | |
| | if seed != -1: |
| | generator = torch.Generator(device="cuda").manual_seed(int(seed)) |
| | else: |
| | generator = None |
| | |
| | status.info(f"๐ฌ Generating {num_frames} frames... This may take 5-10 minutes...") |
| | progress_bar.progress(20) |
| | |
| | |
| | with torch.inference_mode(): |
| | video_frames = pipe( |
| | prompt=prompt, |
| | num_frames=num_frames, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | generator=generator, |
| | ).frames[0] |
| | |
| | progress_bar.progress(80) |
| | status.info("๐พ Encoding video...") |
| | |
| | |
| | output_path = TEMP_DIR / f"video_{abs(hash(prompt + str(seed)))}.mp4" |
| | export_to_video(video_frames, str(output_path), fps=fps) |
| | |
| | progress_bar.progress(100) |
| | status.success("โ
Generation complete!") |
| | |
| | |
| | st.balloons() |
| | |
| | col1, col2 = st.columns([3, 1]) |
| | |
| | with col1: |
| | st.video(str(output_path)) |
| | |
| | with col2: |
| | st.markdown("### ๐ Info") |
| | st.write(f"**Frames:** {num_frames}") |
| | st.write(f"**Steps:** {num_inference_steps}") |
| | st.write(f"**Guidance:** {guidance_scale}") |
| | st.write(f"**Seed:** {seed if seed != -1 else 'Random'}") |
| | st.write(f"**FPS:** {fps}") |
| | st.write(f"**Duration:** ~{num_frames/fps:.1f}s") |
| | |
| | |
| | with open(output_path, "rb") as f: |
| | st.download_button( |
| | "๐ฅ Download MP4", |
| | f, |
| | file_name=f"generated_{seed if seed != -1 else 'random'}.mp4", |
| | mime="video/mp4", |
| | use_container_width=True |
| | ) |
| | |
| | |
| | with st.expander("๐ Generation Details"): |
| | st.write(f"**Model:** {selected_model_name}") |
| | st.write(f"**Prompt:** {prompt}") |
| | |
| | except torch.cuda.OutOfMemoryError: |
| | progress_bar.empty() |
| | status.empty() |
| | st.error("โ Out of GPU memory! Try:\n- Reduce number of frames\n- Use CogVideoX-2B model\n- Lower inference steps") |
| | except Exception as e: |
| | progress_bar.empty() |
| | status.empty() |
| | st.error(f"โ Error: {str(e)}") |
| | st.exception(e) |
| |
|
| | |
| | st.markdown("---") |
| | st.markdown( |
| | """ |
| | <div style='text-align: center; color: #666;'> |
| | <p>Built with Streamlit โข Powered by CogVideoX & Hugging Face ๐ค</p> |
| | </div> |
| | """, |
| | unsafe_allow_html=True |
| | ) |