File size: 4,599 Bytes
c5bfeea
e7039b9
9bc086f
e7039b9
9bc086f
c5bfeea
9bc086f
 
 
 
c5bfeea
9bc086f
 
 
e7039b9
9bc086f
 
 
 
 
e7039b9
9bc086f
 
e7039b9
 
9bc086f
 
 
 
 
 
c5bfeea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bc086f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5bfeea
532bae2
e7039b9
9bc086f
 
e7039b9
c5bfeea
9bc086f
e7039b9
 
c5bfeea
9bc086f
e7039b9
 
c5bfeea
9bc086f
e7039b9
 
 
 
 
 
c5bfeea
9bc086f
e7039b9
c5bfeea
e7039b9
c5bfeea
9bc086f
 
 
 
c5bfeea
 
 
 
1efa926
 
 
 
 
9bc086f
 
 
 
 
1efa926
 
9bc086f
 
 
 
1efa926
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import gradio as gr
import spaces
from functools import lru_cache
from diffusers import StableDiffusionXLPipeline

# ===============================
# 🩹 FIX for Gradio bug (bool schema issue)
# ===============================
import gradio_client.utils as gu

# Monkey patch for "TypeError: argument of type 'bool' is not iterable"
if not hasattr(gu, "_patched_json_schema_to_python_type"):
    orig_get_type = gu.get_type

    def safe_get_type(schema):
        # Ensure schema is always a dict before checking keys
        if not isinstance(schema, dict):
            return str(schema)
        return orig_get_type(schema)

    gu.get_type = safe_get_type
    gu._patched_json_schema_to_python_type = True


# ===============================
# 🎨 Model and Styles Configuration
# ===============================
color_book_lora_path = "artificialguybr/ColoringBookRedmond-V2"
color_book_trigger = ", ColoringBookAF, Coloring Book"

styles = {
    "Neonpunk": {
        "prompt": "neonpunk style, cyberpunk, vaporwave, neon, vibrant, stunningly beautiful, crisp, "
                  "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic",
        "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured"
    },
    "Retro Cyberpunk": {
        "prompt": "retro cyberpunk, 80's inspired, synthwave, neon, vibrant, detailed, retro futurism",
        "negative_prompt": "modern, desaturated, black and white, realism, low contrast"
    },
    "Dark Fantasy": {
        "prompt": "Dark Fantasy Art, dark, moody, dark fantasy style",
        "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, bright, sunny"
    },
    "Double Exposure": {
        "prompt": "Double Exposure Style, double image ghost effect, image combination, double exposure style",
        "negative_prompt": "ugly, deformed, noisy, blurry, low contrast"
    },
    "None": {
        "prompt": "8K",
        "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured"
    }
}


# ===============================
# πŸš€ Pipeline Loader (with caching)
# ===============================
@lru_cache(maxsize=1)
def load_pipeline(use_lora: bool):
    """Load Stable Diffusion XL pipeline and optionally apply LoRA weights."""
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        use_safetensors=True
    )
    pipe.to("cpu")

    if use_lora:
        pipe.load_lora_weights(color_book_lora_path)

    return pipe


# ===============================
# 🎨 Image Generation Function
# ===============================
@spaces.GPU  # ZeroGPU: allocate GPU only when generating
def generate_image(prompt: str, style_name: str, use_lora: bool):
    """Generate an image using Stable Diffusion XL with optional LoRA fine-tuning."""

    # Load cached pipeline
    pipeline = load_pipeline(use_lora)
    pipeline.to("cuda")

    # Retrieve style info
    style_prompt = styles.get(style_name, {}).get("prompt", "")
    negative_prompt = styles.get(style_name, {}).get("negative_prompt", "")

    # Add LoRA trigger if needed
    if use_lora:
        prompt += color_book_trigger

    # Generate image
    image = pipeline(
        prompt=prompt + " " + style_prompt,
        negative_prompt="blurred, ugly, watermark, low resolution, " + negative_prompt,
        num_inference_steps=20,
        guidance_scale=9.0
    ).images[0]

    # Move model back to CPU to release GPU
    pipeline.to("cpu")

    return image


# ===============================
# 🌐 Gradio Interface (for Spaces)
# ===============================
interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Enter Your Prompt", placeholder="A cute lion"),
        gr.Dropdown(label="Select a Style", choices=list(styles.keys()), value="None"),
        gr.Checkbox(label="Use Coloring Book LoRA", value=False)
    ],
    outputs=gr.Image(label="Generated Image"),
    title="🎨 AI Coloring Book & Style Generator",
    description=(
        "Generate AI-powered art using Stable Diffusion XL on Hugging Face Spaces. "
        "Choose a style or enable a LoRA fine-tuned coloring book effect. "
        "This app dynamically allocates GPU (ZeroGPU) only during generation."
    )
)


# ===============================
# 🏁 Launch App
# ===============================
if __name__ == "__main__":
    interface.launch()