Spaces:
Sleeping
Sleeping
| # --- START OF FILE media.py (FINAL WITH LIVE PROGRESS & FIXES) --- | |
| # --- LIBRARIES --- | |
| import torch | |
| import gradio as gr | |
| import random | |
| import time | |
| from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline, EulerAncestralDiscreteScheduler | |
| import gc | |
| import os | |
| import imageio | |
| import numpy as np | |
| import threading | |
| from queue import Queue, Empty as QueueEmpty | |
| from PIL import Image | |
| import os | |
| from huggingface_hub import login | |
| # --- DYNAMIC HARDWARE DETECTION --- | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| torch_dtype = torch.float16 | |
| print("✅ GPU detected. Using CUDA.") | |
| else: | |
| device = "cpu" | |
| torch_dtype = torch.float32 | |
| print("⚠️ No GPU detected. Using CPU.") | |
| HF_TOKEN = os.environ.get('HF_TOKEN') | |
| if HF_TOKEN: | |
| print("✅ Found HF_TOKEN secret. Logging in...") | |
| try: | |
| login(token=HF_TOKEN) | |
| print("✅ Hugging Face Authentication successful.") | |
| except Exception as e: | |
| print(f"❌ Hugging Face login failed: {e}") | |
| else: | |
| # This message will show when you run the app locally, which is fine. | |
| print("⚠️ No HF_TOKEN secret found. This is normal for local testing.") | |
| print(" The deployed app will use the secret you set on Hugging Face.") | |
| # --- CONFIGURATION & STATE --- | |
| available_models = { | |
| "Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo", | |
| "Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0", | |
| "Photorealism (Juggernaut)": "RunDiffusion/Juggernaut-XL-v9", | |
| "Video (Damo-Vilab)": "damo-vilab/text-to-video-ms-1.7b" | |
| } | |
| model_state = { "current_pipe": None, "loaded_model_name": None } | |
| # --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS & FIXES --- | |
| def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames): | |
| global model_state | |
| # --- Model Loading & Cleanup --- | |
| if model_state.get("loaded_model_name") != model_key: | |
| yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."} | |
| # --- More Aggressive & Explicit Cleanup --- | |
| pipe_to_delete = model_state.pop("current_pipe", None) | |
| if pipe_to_delete: | |
| # FIX: Explicitly move the model to CPU before deleting to free VRAM. | |
| print("Offloading previous model to CPU...") | |
| pipe_to_delete.to("cpu") | |
| del pipe_to_delete | |
| print("Previous model deleted.") | |
| # Explicitly run garbage collection and empty CUDA cache. | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Load the new pipeline | |
| model_id = available_models[model_key] | |
| if "Video" in model_key: | |
| pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16") | |
| else: | |
| pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16") | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe.to(device) | |
| if device == "cuda": | |
| if "Video" not in model_key: | |
| pipe.enable_model_cpu_offload() | |
| pipe.enable_vae_slicing() | |
| model_state["current_pipe"] = pipe | |
| model_state["loaded_model_name"] = model_key | |
| print(f"✅ Model '{model_key}' loaded on {device.upper()}.") | |
| pipe = model_state["current_pipe"] | |
| generator = torch.Generator(device).manual_seed(seed) | |
| # --- Generation Logic --- | |
| if "Video" in model_key: | |
| yield {output_image: None, output_video: None, status_textbox: "Generating video..."} | |
| try: | |
| video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames | |
| # FIX: More memory-efficient video saving | |
| video_path = f"video_{seed}.mp4" | |
| with imageio.get_writer(video_path, fps=12) as writer: | |
| for frame in video_frames: | |
| writer.append_data((frame * 255).astype(np.uint8)) | |
| yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"} | |
| except Exception as e: | |
| print(f"An error occurred during video generation: {e}") | |
| yield {status_textbox: f"Error during video generation: {e}"} | |
| else: # Image Generation with Live Progress | |
| progress_queue = Queue() | |
| def run_pipe(): | |
| start_time = time.time() | |
| def progress_callback(pipe, step, timestep, callback_kwargs): | |
| elapsed_time = time.time() - start_time | |
| if elapsed_time > 0: | |
| its_per_sec = (step + 1) / elapsed_time | |
| progress_queue.put(("progress", (step + 1, its_per_sec))) | |
| return callback_kwargs | |
| try: | |
| final_image = pipe( | |
| prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(steps), | |
| guidance_scale=float(cfg_scale), width=int(width), height=int(height), | |
| generator=generator, | |
| callback_on_step_end=progress_callback | |
| ).images[0] | |
| progress_queue.put(("result", final_image)) | |
| except Exception as e: | |
| print(f"An error occurred in the generation thread: {e}") | |
| progress_queue.put(("error", str(e))) | |
| thread = threading.Thread(target=run_pipe) | |
| thread.start() | |
| total_steps = int(steps) | |
| yield {status_textbox: "Generating..."} | |
| while True: | |
| try: | |
| update_type, payload = progress_queue.get(timeout=1.0) | |
| if update_type == "result": | |
| yield {output_image: payload, status_textbox: f"Generation complete! Seed: {seed}"} | |
| break | |
| elif update_type == "progress": | |
| current_step, its_per_sec = payload | |
| progress_percent = (current_step / total_steps) * 100 | |
| steps_remaining = total_steps - current_step | |
| eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0 | |
| eta_minutes, eta_seconds_rem = divmod(int(eta_seconds), 60) | |
| status_text = ( | |
| f"Generating... {progress_percent:.0f}% ({current_step}/{total_steps}) | " | |
| f"{its_per_sec:.2f}it/s | " | |
| f"ETA: {eta_minutes:02d}:{eta_seconds_rem:02d}" | |
| ) | |
| yield {status_textbox: status_text} | |
| elif update_type == "error": | |
| yield {status_textbox: f"Error: {payload}. Check console."} | |
| break | |
| except QueueEmpty: | |
| if not thread.is_alive(): | |
| print("⚠️ Generation thread finished unexpectedly.") | |
| yield {status_textbox: "Generation failed. Check console for details."} | |
| break | |
| thread.join() | |
| print("Generation thread joined.") | |
| # --- GRADIO UI --- | |
| with gr.Blocks(theme='gradio/soft') as demo: | |
| gr.Markdown("# The Generative Media Suite") | |
| gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)") | |
| seed_state = gr.State(-1) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0]) | |
| prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...") | |
| negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text, overblown, high contrast, not photorealistic") | |
| with gr.Accordion("Settings", open=True): | |
| steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps") | |
| cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)") | |
| with gr.Row(): | |
| width_slider = gr.Slider(256, 1024, 768, step=64, label="Width") | |
| height_slider = gr.Slider(256, 1024, 768, step=64, label="Height") | |
| num_frames_slider = gr.Slider(12, 48, 24, step=4, label="Video Frames", visible=False) | |
| seed_input = gr.Number(-1, label="Seed (-1 for random)") | |
| generate_button = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=3): | |
| output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True) | |
| output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False) | |
| status_textbox = gr.Textbox(label="Status", interactive=False) | |
| def update_ui_on_model_change(model_key): | |
| is_video = "Video" in model_key | |
| is_turbo = "Turbo" in model_key | |
| return { | |
| steps_slider: gr.update(interactive=not is_turbo, value=1 if is_turbo else 30), | |
| cfg_slider: gr.update(interactive=not is_turbo, value=0.0 if is_turbo else 7.5), | |
| width_slider: gr.update(visible=not is_video), | |
| height_slider: gr.update(visible=not is_video), | |
| num_frames_slider: gr.update(visible=is_video), | |
| output_image: gr.update(visible=not is_video), | |
| output_video: gr.update(visible=is_video) | |
| } | |
| model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video]) | |
| generate_button.click( | |
| fn=lambda s: s if s != -1 else random.randint(0, 2**32 - 1), | |
| inputs=seed_input, | |
| outputs=seed_state, | |
| queue=False | |
| ).then( | |
| fn=generate_media_live_progress, | |
| inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider], | |
| outputs=[output_image, output_video, status_textbox] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, debug=True) |