Spaces:
Runtime error
Runtime error
| # Stable Diffusion Hugging Face App (Turbo Version with Fixes) | |
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline, DDIMScheduler | |
| # Load the lightweight Stable Diffusion Turbo model | |
| model_id = "stabilityai/sd-turbo" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| ).to(device) | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| # Simulated style prompts (not using learned embeddings) | |
| STYLE_MAP = { | |
| "Van Gogh": "in the style of Van Gogh", | |
| "Cyberpunk": "cyberpunk futuristic cityscape", | |
| "Pixel Art": "8-bit pixel art style", | |
| "Studio Ghibli": "studio ghibli anime style", | |
| "Surrealism": "in surrealistic dreamscape style" | |
| } | |
| # Custom loss placeholder (for assignment purposes) | |
| def custom_loss_placeholder(image_tensor): | |
| yellow = torch.tensor([1.0, 1.0, 0.0]).to(image_tensor.device) | |
| image_mean = image_tensor.mean(dim=[1, 2]) | |
| yellow_loss = torch.nn.functional.mse_loss(image_mean, yellow) | |
| return yellow_loss | |
| # Generate image based on prompt and style | |
| def generate(prompt, style, seed): | |
| generator = torch.manual_seed(seed) | |
| full_prompt = f"{prompt}, {STYLE_MAP.get(style, '')}" | |
| result = pipe(full_prompt, guidance_scale=7.5, generator=generator).images[0] | |
| return result | |
| # Gradio UI | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("""# Stable Diffusion Turbo App\nGenerate styled images using text prompts and different art styles.""") | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Enter Prompt", placeholder="A fox with a monocle") | |
| style = gr.Dropdown(choices=list(STYLE_MAP.keys()), label="Choose Style", value="Van Gogh") | |
| seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed") | |
| generate_btn = gr.Button("Generate Image") | |
| output = gr.Image(label="Stylized Output") | |
| generate_btn.click(fn=generate, inputs=[prompt, style, seed], outputs=output) | |
| # Launch the Gradio app | |
| demo.launch() | |