mojitocup commited on
Commit
0b06336
·
verified ·
1 Parent(s): 8d89fdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -64
app.py CHANGED
@@ -1,72 +1,176 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  image = pipe(
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
  width=width,
47
  height=height,
48
- generator=generator,
 
 
49
  ).images[0]
50
 
51
  return image, seed
52
 
53
-
54
  examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
  css = """
61
  #col-container {
62
  margin: 0 auto;
63
- max-width: 640px;
64
  }
65
  """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
 
 
 
 
 
 
 
 
 
 
70
 
71
  with gr.Row():
72
  prompt = gr.Text(
@@ -76,8 +180,7 @@ with gr.Blocks(css=css) as demo:
76
  placeholder="Enter your prompt",
77
  container=False,
78
  )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
  result = gr.Image(label="Result", show_label=False)
83
 
@@ -86,7 +189,6 @@ with gr.Blocks(css=css) as demo:
86
  label="Negative prompt",
87
  max_lines=1,
88
  placeholder="Enter a negative prompt",
89
- visible=False,
90
  )
91
 
92
  seed = gr.Slider(
@@ -102,27 +204,26 @@ with gr.Blocks(css=css) as demo:
102
  with gr.Row():
103
  width = gr.Slider(
104
  label="Width",
105
- minimum=256,
106
  maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
-
111
  height = gr.Slider(
112
  label="Height",
113
- minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
 
119
  with gr.Row():
120
  guidance_scale = gr.Slider(
121
  label="Guidance scale",
122
  minimum=0.0,
123
- maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
 
128
  num_inference_steps = gr.Slider(
@@ -130,25 +231,25 @@ with gr.Blocks(css=css) as demo:
130
  minimum=1,
131
  maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
 
 
 
 
 
 
 
 
 
 
137
  gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
2
+ import spaces
3
  import numpy as np
4
  import random
 
 
 
5
  import torch
6
+ from diffusers import StableDiffusion3Pipeline
7
+ from huggingface_hub import login
8
+ import os
9
 
10
+ # Add this import to fix BaseTunerLayer error
11
+ try:
12
+ from peft.tuners.tuners_utils import BaseTunerLayer
13
+ except ImportError:
14
+ print("Warning: peft not installed. LoRA functionality may be limited.")
15
 
16
+ # Login to Hugging Face using environment variable
17
+ login(token=os.getenv("HF_TOKEN"))
 
 
18
 
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ dtype = torch.float16 if device == "cuda" else torch.float32
21
+
22
+ # Base model
23
+ repo = "mojitocup/realistic-xl"
24
+ pipe = StableDiffusion3Pipeline.from_pretrained(
25
+ repo,
26
+ torch_dtype=dtype,
27
+ use_safetensors=True,
28
+ variant="fp16" if dtype == torch.float16 else None
29
+ ).to(device)
30
+
31
+ # List of LoRA models (can expand later)
32
+ loras = {
33
+ "None": None,
34
+ "SD3.5 Photorealistic": "prithivMLmods/SD3.5-Large-Photorealistic-LoRA",
35
+ "Face Helper SDXL": "ostris/face-helper-sdxl-lora",
36
+ "LCM LoRA SDXL": "latent-consistency/lcm-lora-sdxl"
37
+ }
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
40
+ MAX_IMAGE_SIZE = 1536
41
+
42
+ class LoRAManager:
43
+ """Manages LoRA loading and unloading with proper error handling"""
44
+
45
+ def __init__(self, pipe):
46
+ self.pipe = pipe
47
+ self.current_lora = None
48
+
49
+ def load_lora(self, lora_repo, lora_scale=0.8):
50
+ """Load a LoRA adapter with error handling"""
51
+ try:
52
+ # First try to unfuse any existing LoRA
53
+ self.unfuse_current_lora()
54
+
55
+ # Try different common LoRA weight file names
56
+ weight_names_to_try = [
57
+ "pytorch_lora_weights.safetensors",
58
+ "Photorealistic-SD3.5-Large-LoRA.safetensors", # For prithivMLmods model
59
+ "diffusion_pytorch_model.safetensors",
60
+ None # Let diffusers auto-detect
61
+ ]
62
+
63
+ success = False
64
+ for weight_name in weight_names_to_try:
65
+ try:
66
+ if weight_name:
67
+ self.pipe.load_lora_weights(lora_repo, weight_name=weight_name)
68
+ else:
69
+ self.pipe.load_lora_weights(lora_repo)
70
+ success = True
71
+ break
72
+ except Exception as e:
73
+ print(f"Failed to load with weight_name='{weight_name}': {e}")
74
+ continue
75
+
76
+ if not success:
77
+ print(f"Error loading LoRA {lora_repo}: No compatible weight file found")
78
+ return False
79
+
80
+ self.pipe.fuse_lora(lora_scale=lora_scale)
81
+ self.current_lora = lora_repo
82
+ print(f"Successfully loaded LoRA: {lora_repo}")
83
+ return True
84
+
85
+ except Exception as e:
86
+ print(f"Error loading LoRA {lora_repo}: {e}")
87
+ return False
88
+
89
+ def unfuse_current_lora(self):
90
+ """Safely unfuse current LoRA"""
91
+ if self.current_lora is None:
92
+ return
93
+
94
+ try:
95
+ self.pipe.unfuse_lora()
96
+ print(f"Unfused LoRA: {self.current_lora}")
97
+ self.current_lora = None
98
+ except Exception as e:
99
+ print(f"Warning: Could not unfuse LoRA: {e}")
100
+ self.current_lora = None # Reset anyway
101
+
102
+ def truncate_prompt(prompt, max_length=77):
103
+ """Truncate prompt to fit CLIP token limit"""
104
+ if not prompt:
105
+ return prompt
106
+
107
+ # Simple word-based truncation (not perfect but helps)
108
+ words = prompt.split()
109
+ if len(words) <= max_length:
110
+ return prompt
111
+
112
+ truncated = " ".join(words[:max_length])
113
+ print(f"Warning: Prompt truncated from {len(words)} to {max_length} words")
114
+ return truncated
115
+
116
+ @spaces.GPU
117
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_choice, progress=gr.Progress(track_tqdm=True)):
118
  if randomize_seed:
119
  seed = random.randint(0, MAX_SEED)
120
 
121
+ generator = torch.Generator(device=device).manual_seed(seed)
122
+
123
+ # Truncate prompts to avoid CLIP token limit
124
+ prompt = truncate_prompt(prompt, max_length=70)
125
+ negative_prompt = truncate_prompt(negative_prompt, max_length=70)
126
+
127
+ # Handle LoRA loading with better error handling
128
+ if lora_choice != "None":
129
+ lora_manager = LoRAManager(pipe)
130
+ if not lora_manager.load_lora(loras[lora_choice]):
131
+ raise gr.Error(f"Failed to load LoRA adapter: {lora_choice}")
132
+ else:
133
+ lora_manager = LoRAManager(pipe)
134
+ lora_manager.unfuse_current_lora()
135
 
136
  image = pipe(
137
  prompt=prompt,
138
  negative_prompt=negative_prompt,
 
 
139
  width=width,
140
  height=height,
141
+ guidance_scale=guidance_scale,
142
+ num_inference_steps=num_inference_steps,
143
+ generator=generator
144
  ).images[0]
145
 
146
  return image, seed
147
 
 
148
  examples = [
149
+ "Samurai girl in the snow forest. Show on Sony camera f1.2",
150
+ "A stylish Japanese woman in her early 20s stands confidently in front of a cold, industrial background in a cinematic close-up. She wears round black sunglasses, a wide-brimmed black hat, and a brown suede coat with white shearling lining layered over a black turtleneck. A bold silver chain necklace adds a sharp urban edge to her look. Her expression is fierce and composed, staring straight into the camera with quiet intensity. The lighting is cool and bluish, creating a moody, neo-noir vibe — evoking the feel of a modern Tokyo underground fashion scene.",
151
+ "A young Indonesian woman from Bandung walks directly toward the camera across a cracked, sun-bleached desert highway under a vast, empty sky. She wears a bold futuristic high-fashion outfit: a structured matte-black coat with exaggerated shoulders and a high collar, flowing just above her minimalistic sand-toned boots. Her hijab is styled tightly and sleek under the high collar, blending seamlessly with the look — matte black with wind-swept form, no hair visible. Her makeup is bold and geometric, featuring sharp silvery eyeliner and face highlights under the eyes, catching the sun's glare. Her expression is cold, focused, and defiant — one brow slightly raised as she pierces the lens with calm intensity. The directional sunlight from the right casts elongated, sharp-edged shadows across the cracked highway. The background is pure dystopian silence: endless dry plains stretch into the distance with a few collapsed, rusted billboards leaning in the far horizon. No buildings, no people — just wind, fashion, and desert solitude. Shot in ultra high-resolution, harsh sunlight, cinematic composition, dystopian fashion editorial."
152
  ]
153
 
154
  css = """
155
  #col-container {
156
  margin: 0 auto;
157
+ max-width: 580px;
158
  }
159
  """
160
 
161
  with gr.Blocks(css=css) as demo:
162
  with gr.Column(elem_id="col-container"):
163
+ gr.Markdown(f"""
164
+ # Customized Stable Diffusion Realistic XL with LoRA + Photoreal Enhancements
165
+ Choose a high-quality LoRA model to enhance your generations. All models are tested and compatible with SD3.5.
166
+
167
+ **Available LoRA Models:**
168
+ - **Photorealistic**: Specialized for photorealistic portraits and scenes
169
+ - **Face Helper**: Enhances facial features and expressions
170
+ - **LCM LoRA**: Reduces inference steps for faster generation
171
+
172
+ Based on [StabilityAI SD3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large).
173
+ """)
174
 
175
  with gr.Row():
176
  prompt = gr.Text(
 
180
  placeholder="Enter your prompt",
181
  container=False,
182
  )
183
+ run_button = gr.Button("Run", scale=0)
 
184
 
185
  result = gr.Image(label="Result", show_label=False)
186
 
 
189
  label="Negative prompt",
190
  max_lines=1,
191
  placeholder="Enter a negative prompt",
 
192
  )
193
 
194
  seed = gr.Slider(
 
204
  with gr.Row():
205
  width = gr.Slider(
206
  label="Width",
207
+ minimum=512,
208
  maximum=MAX_IMAGE_SIZE,
209
+ step=64,
210
+ value=1024,
211
  )
 
212
  height = gr.Slider(
213
  label="Height",
214
+ minimum=512,
215
  maximum=MAX_IMAGE_SIZE,
216
+ step=64,
217
+ value=1024,
218
  )
219
 
220
  with gr.Row():
221
  guidance_scale = gr.Slider(
222
  label="Guidance scale",
223
  minimum=0.0,
224
+ maximum=20.0,
225
  step=0.1,
226
+ value=7.5,
227
  )
228
 
229
  num_inference_steps = gr.Slider(
 
231
  minimum=1,
232
  maximum=50,
233
  step=1,
234
+ value=30,
235
  )
236
 
237
+ lora_choice = gr.Dropdown(
238
+ label="LoRA adapter",
239
+ choices=list(loras.keys()),
240
+ value="None"
241
+ )
242
+
243
+ gr.Examples(
244
+ examples=examples,
245
+ inputs=[prompt]
246
+ )
247
+
248
  gr.on(
249
+ triggers=[run_button.click, prompt.submit, negative_prompt.submit],
250
  fn=infer,
251
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_choice],
252
+ outputs=[result, seed]
 
 
 
 
 
 
 
 
 
253
  )
254
 
255
+ demo.launch()