lehehroi commited on
Commit
f385a2c
·
verified ·
1 Parent(s): 9d1a0dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -0
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import numpy as np
4
+ import PIL.Image
5
+ from PIL import Image
6
+ import random
7
+ from diffusers import StableDiffusionXLPipeline
8
+ from diffusers import EulerAncestralDiscreteScheduler
9
+ import torch
10
+ from compel import Compel, ReturnedEmbeddingsType
11
+ #import os
12
+ #from gradio_client import Client
13
+ #client = Client("dhead/ntr-mix-illustrious-xl-noob-xl-xiii-sdxl", hf_token=os.getenv("HUGGING_FACE_TOKEN"))
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Make sure to use torch.float16 consistently throughout the pipeline
18
+ pipe = StableDiffusionXLPipeline.from_pretrained(
19
+ "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
20
+ torch_dtype=torch.float16,
21
+ variant="fp16", # Explicitly use fp16 variant
22
+ use_safetensors=True # Use safetensors if available
23
+ )
24
+
25
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
26
+ pipe.to(device)
27
+
28
+ # Force all components to use the same dtype
29
+ pipe.text_encoder.to(torch.float16)
30
+ pipe.text_encoder_2.to(torch.float16)
31
+ pipe.vae.to(torch.float16)
32
+ pipe.unet.to(torch.float16)
33
+
34
+ # 追加: Initialize Compel for long prompt processing
35
+ compel = Compel(
36
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
37
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
38
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
39
+ requires_pooled=[False, True],
40
+ truncate_long_prompts=False
41
+ )
42
+
43
+ MAX_SEED = np.iinfo(np.int32).max
44
+ MAX_IMAGE_SIZE = 1216
45
+
46
+ # 追加: Simple long prompt processing function
47
+ def process_long_prompt(prompt, negative_prompt=""):
48
+ """Simple long prompt processing using Compel"""
49
+ try:
50
+ conditioning, pooled = compel([prompt, negative_prompt])
51
+ return conditioning, pooled
52
+ except Exception as e:
53
+ print(f"Long prompt processing failed: {e}, falling back to standard processing")
54
+ return None, None
55
+
56
+ @spaces.GPU
57
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
58
+ # 変更: Remove the 60-word limit warning and add long prompt check
59
+ use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300
60
+
61
+ if randomize_seed:
62
+ seed = random.randint(0, MAX_SEED)
63
+
64
+ generator = torch.Generator(device=device).manual_seed(seed)
65
+
66
+ try:
67
+ # 追加: Try long prompt processing first if prompt is long
68
+ if use_long_prompt:
69
+ print("Using long prompt processing...")
70
+ conditioning, pooled = process_long_prompt(prompt, negative_prompt)
71
+
72
+ if conditioning is not None:
73
+ output_image = pipe(
74
+ prompt_embeds=conditioning[0:1],
75
+ pooled_prompt_embeds=pooled[0:1],
76
+ negative_prompt_embeds=conditioning[1:2],
77
+ negative_pooled_prompt_embeds=pooled[1:2],
78
+ guidance_scale=guidance_scale,
79
+ num_inference_steps=num_inference_steps,
80
+ width=width,
81
+ height=height,
82
+ generator=generator
83
+ ).images[0]
84
+ return output_image
85
+
86
+ # Fall back to standard processing
87
+ output_image = pipe(
88
+ prompt=prompt,
89
+ negative_prompt=negative_prompt,
90
+ guidance_scale=guidance_scale,
91
+ num_inference_steps=num_inference_steps,
92
+ width=width,
93
+ height=height,
94
+ generator=generator
95
+ ).images[0]
96
+
97
+ return output_image
98
+ except RuntimeError as e:
99
+ print(f"Error during generation: {e}")
100
+ # Return a blank image with error message
101
+ error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
102
+ return error_img
103
+
104
+
105
+ css = """
106
+ #col-container {
107
+ margin: 0 auto;
108
+ max-width: 1024px;
109
+ }
110
+ """
111
+
112
+ with gr.Blocks(css=css) as demo:
113
+
114
+ with gr.Column(elem_id="col-container"):
115
+
116
+ with gr.Row():
117
+ prompt = gr.Text(
118
+ label="Prompt",
119
+ show_label=False,
120
+ max_lines=1,
121
+ placeholder="Enter your prompt (long prompts are automatically supported)",
122
+ container=False,
123
+ )
124
+
125
+ run_button = gr.Button("Run", scale=0)
126
+
127
+ result = gr.Image(format="png", label="Result", show_label=False)
128
+
129
+ with gr.Accordion("Advanced Settings", open=False):
130
+
131
+ negative_prompt = gr.Text(
132
+ label="Negative prompt",
133
+ max_lines=1,
134
+ placeholder="Enter a negative prompt",
135
+ # value="bad quality,worst quality,worst detail,sketch,censor,"
136
+ value="monochrome, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn,"
137
+ )
138
+
139
+ seed = gr.Slider(
140
+ label="Seed",
141
+ minimum=0,
142
+ maximum=MAX_SEED,
143
+ step=1,
144
+ value=0,
145
+ )
146
+
147
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
148
+
149
+ with gr.Row():
150
+ width = gr.Slider(
151
+ label="Width",
152
+ minimum=256,
153
+ maximum=MAX_IMAGE_SIZE,
154
+ step=32,
155
+ value=1024,
156
+ )
157
+
158
+ height = gr.Slider(
159
+ label="Height",
160
+ minimum=256,
161
+ maximum=MAX_IMAGE_SIZE,
162
+ step=32,
163
+ value=MAX_IMAGE_SIZE,
164
+ )
165
+
166
+ with gr.Row():
167
+ guidance_scale = gr.Slider(
168
+ label="Guidance scale",
169
+ minimum=0.0,
170
+ maximum=20.0,
171
+ step=0.1,
172
+ value=7,
173
+ )
174
+
175
+ num_inference_steps = gr.Slider(
176
+ label="Number of inference steps",
177
+ minimum=1,
178
+ maximum=28,
179
+ step=1,
180
+ value=28,
181
+ )
182
+
183
+ run_button.click(
184
+ fn=infer,
185
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
186
+ outputs=[result]
187
+ )
188
+
189
+ demo.queue().launch()