Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import sys | |
| import torch | |
| import gradio as gr | |
| from diffusers import AutoPipelineForText2Image, LCMScheduler | |
| import warnings | |
| # ==================== Configuration ==================== | |
| BASE_MODEL = "Lykon/dreamshaper-8" | |
| LORA_ID = "latent-consistency/lcm-lora-sdv1-5" | |
| MAX_SEED = 2147483647 | |
| DEFAULT_STEPS = 8 | |
| DEFAULT_WIDTH = 512 | |
| DEFAULT_HEIGHT = 512 | |
| DEFAULT_GUIDANCE = 1.0 | |
| warnings.filterwarnings("ignore") | |
| # ==================== Model Loading ==================== | |
| def load_model(): | |
| print(f"Loading Base Model: {BASE_MODEL} on CPU...") | |
| try: | |
| # 1. Load the high-quality base model | |
| pipe = AutoPipelineForText2Image.from_pretrained( | |
| BASE_MODEL, | |
| safety_checker=None, | |
| use_safetensors=True, | |
| ) | |
| # 2. Load the LCM LoRA adapter | |
| print("Loading LCM LoRA...") | |
| pipe.load_lora_weights(LORA_ID) | |
| pipe.fuse_lora() | |
| # 3. Switch to LCM Scheduler | |
| pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
| # 4. Move to CPU | |
| pipe.to("cpu", torch.float32) | |
| print("Model loaded successfully.") | |
| return pipe | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None | |
| # Initialize model globally | |
| pipe = load_model() | |
| # ==================== Generation Function ==================== | |
| def generate( | |
| prompt, | |
| enhance_prompt, | |
| width, | |
| height, | |
| steps, | |
| guidance_scale, | |
| seed, | |
| random_seed, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| if pipe is None: | |
| raise gr.Error("Model failed to load.") | |
| if random_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| # Enhanced Magic Prompt Logic | |
| final_prompt = prompt | |
| if enhance_prompt: | |
| final_prompt = f"{prompt}, (masterpiece:1.2), best quality, highres, original, extremely detailed wallpaper, perfect lighting, 8k" | |
| try: | |
| generator = torch.Generator("cpu").manual_seed(int(seed)) | |
| images = pipe( | |
| prompt=final_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| output_type="pil", | |
| return_dict=True | |
| ).images | |
| return images[0], seed | |
| except Exception as e: | |
| raise gr.Error(f"Generation error: {e}") | |
| # ==================== Gradio Interface ==================== | |
| css = """ | |
| .container { max-width: 800px; margin: auto; } | |
| .output_image { min-height: 512px; } | |
| """ | |
| # FIXED: Removed 'css' from here | |
| with gr.Blocks(title="HQ CPU Image Gen") as demo: | |
| gr.Markdown("# 💎 High-Quality CPU Generator") | |
| gr.Markdown(f"Using `{BASE_MODEL}` + LCM LoRA • Optimized for Free Tier") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", placeholder="man eating banana...", lines=3) | |
| enhance_prompt = gr.Checkbox(label="✨ Magic Prompt (Enhance Detail)", value=True) | |
| with gr.Row(): | |
| width = gr.Slider(256, 768, value=DEFAULT_WIDTH, step=64, label="Width") | |
| height = gr.Slider(256, 768, value=DEFAULT_HEIGHT, step=64, label="Height") | |
| with gr.Row(): | |
| steps = gr.Slider(4, 20, value=DEFAULT_STEPS, step=1, label="Steps (8-12 recommended)") | |
| guidance = gr.Slider(0.1, 4.0, value=DEFAULT_GUIDANCE, step=0.1, label="Guidance Scale") | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=42, precision=0) | |
| random_seed = gr.Checkbox(label="Random Seed", value=True) | |
| btn = gr.Button("Generate HQ Image", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Result", elem_classes="output_image") | |
| seed_output = gr.Textbox(label="Used Seed", interactive=False) | |
| btn.click( | |
| generate, | |
| inputs=[prompt, enhance_prompt, width, height, steps, guidance, seed, random_seed], | |
| outputs=[output_image, seed_output] | |
| ) | |
| if __name__ == "__main__": | |
| # FIXED: Added 'css' here | |
| demo.launch(css=css) |