oliveryanzuolu commited on
Commit
d6df1df
·
verified ·
1 Parent(s): b53f48b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -58
app.py CHANGED
@@ -12,8 +12,6 @@ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCM
12
  # -----------------------------------------------------------------------------
13
  # 1. Configuration & Registry
14
  # -----------------------------------------------------------------------------
15
- # This dictionary serves as the "Registry" for valid models and their specific
16
- # trigger words. This decouples configuration from logic.
17
  LORA_REGISTRY = {
18
  "None (Base SD1.5)": {
19
  "repo": None,
@@ -22,9 +20,9 @@ LORA_REGISTRY = {
22
  },
23
  "Lego Style": {
24
  "repo": "lordjia/lelo-lego-lora-for-xl-sd1-5",
25
- "trigger": "LEGO Creator, LEGO MiniFig, ", # Combined triggers for general usage
26
  "weight": 0.8,
27
- "file": "Lego_XL_v2.1.safetensors" # Note: Auto-resolution usually handles this, but explicitly noted for context
28
  },
29
  "Claymation Style": {
30
  "repo": "DoctorDiffusion/doctor-diffusion-s-claymation-style-lora",
@@ -46,8 +44,7 @@ print("Initializing Inference Pipeline...")
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  dtype = torch.float16 if device == "cuda" else torch.float32
48
 
49
- # Load ControlNet (Canny Edge Detection)
50
- # We use the standard lllyasviel checkpoint which is the gold standard for SD1.5
51
  controlnet = ControlNetModel.from_pretrained(
52
  "lllyasviel/sd-controlnet-canny",
53
  torch_dtype=dtype,
@@ -55,7 +52,6 @@ controlnet = ControlNetModel.from_pretrained(
55
  )
56
 
57
  # Load Base Stable Diffusion 1.5
58
- # We use the official RunwayML checkpoint
59
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
60
  "stable-diffusion-v1-5/stable-diffusion-v1-5",
61
  controlnet=controlnet,
@@ -63,12 +59,10 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
63
  use_safetensors=True
64
  )
65
 
66
- # Optimization: Use UniPC Scheduler for fast convergence (20-30 steps)
67
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
68
 
69
- # Optimization: Offload model to CPU when not in use to save VRAM
70
- # Crucial for running on constrained hardware (e.g., free tier Spaces)
71
- pipe.enable_model_cpu_offload()
72
 
73
  print("Base Pipeline Loaded Successfully.")
74
 
@@ -77,23 +71,14 @@ print("Base Pipeline Loaded Successfully.")
77
  # -----------------------------------------------------------------------------
78
 
79
  def get_canny_image(image, low_threshold=100, high_threshold=200):
80
- """
81
- Converts a PIL image into a Canny edge map.
82
- The map is converted to RGB (3-channel) to match ControlNet input requirements.
83
- """
84
  image_array = np.array(image)
85
-
86
- # Canny edge detection via OpenCV
87
  canny_edges = cv2.Canny(image_array, low_threshold, high_threshold)
88
-
89
- # Replicate the single channel to 3 channels (RGB)
90
  canny_edges = canny_edges[:, :, None]
91
  canny_edges = np.concatenate([canny_edges, canny_edges, canny_edges], axis=2)
92
-
93
  return Image.fromarray(canny_edges)
94
 
95
  # -----------------------------------------------------------------------------
96
- # 4. Inference Logic (The "Middleware")
97
  # -----------------------------------------------------------------------------
98
 
99
  @spaces.GPU(duration=120)
@@ -110,52 +95,58 @@ def generate_controlled_image(
110
  raise gr.Error("Validation Error: Please upload an image first!")
111
 
112
  # 1. Preprocess Image
113
- # Resizing to 512x512 is standard for SD1.5 to avoid duplication artifacts
114
  width, height = 512, 512
115
  input_image = input_image.resize((width, height))
116
  canny_image = get_canny_image(input_image)
117
 
118
  # 2. Manage LoRA State
119
- # We must explicitly unload previous weights to prevent style contamination
 
 
 
 
 
 
 
 
120
  try:
121
- pipe.unload_lora_weights()
122
-
123
- style_config = LORA_REGISTRY[lora_selection]
124
- repo_id = style_config["repo"]
125
- trigger_text = style_config["trigger"]
126
-
127
- # Modify prompt with trigger words
128
- final_prompt = f"{trigger_text}{prompt}"
129
-
130
  if repo_id:
131
  print(f"Loading LoRA: {repo_id}")
132
  pipe.load_lora_weights(repo_id)
133
-
134
- # Note: In more complex setups with multiple adapters, we would use
135
- # pipe.set_adapters() and fuse_lora(), but for single-style swap,
136
- # load/unload is sufficient and memory-safe.
137
 
138
  except Exception as e:
139
  print(f"LoRA Load Error: {e}")
140
- # Fallback to base model if LoRA fails, but warn user via prompt
141
- final_prompt = prompt
142
- gr.Warning(f"Failed to load LoRA {lora_selection}. Using base model.")
143
 
144
- # 3. Deterministic Generation
145
- # Using a manual seed ensures reproducibility
146
  generator = torch.Generator(device).manual_seed(int(seed))
147
 
148
  print(f"Generating with Prompt: {final_prompt}")
149
 
150
- output_image = pipe(
151
- prompt=final_prompt,
152
- negative_prompt=negative_prompt,
153
- image=canny_image,
154
- num_inference_steps=int(steps),
155
- controlnet_conditioning_scale=float(controlnet_conditioning_scale),
156
- guidance_scale=7.5, # Standard CFG scale
157
- generator=generator,
158
- ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  return canny_image, output_image
161
 
@@ -168,10 +159,7 @@ css = """
168
  .guide-text {font-size: 1.1em; color: #4a5568;}
169
  """
170
 
171
- # Example Data (Using Public Domain / CC0 URLs for reproducibility)
172
- # Nested list format:
173
- # Example Data
174
- # Format: [Image URL, Prompt, Negative Prompt, LoRA Selection, ControlNet Scale, Steps, Seed]
175
  examples = [
176
  [
177
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_canny.png",
@@ -231,8 +219,6 @@ examples = [
231
  ]
232
  ]
233
 
234
-
235
-
236
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
237
 
238
  with gr.Column(elem_id="col-container"):
@@ -297,7 +283,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
297
  inputs=[input_image, prompt, negative_prompt, lora_selection, controlnet_conditioning_scale, steps, seed],
298
  outputs=[output_canny, output_result],
299
  fn=generate_controlled_image,
300
- cache_examples=True # Pre-compute examples for instant display
301
  )
302
 
303
  # Event Wiring
@@ -316,4 +302,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
316
  )
317
 
318
  if __name__ == "__main__":
319
- demo.launch()
 
12
  # -----------------------------------------------------------------------------
13
  # 1. Configuration & Registry
14
  # -----------------------------------------------------------------------------
 
 
15
  LORA_REGISTRY = {
16
  "None (Base SD1.5)": {
17
  "repo": None,
 
20
  },
21
  "Lego Style": {
22
  "repo": "lordjia/lelo-lego-lora-for-xl-sd1-5",
23
+ "trigger": "LEGO Creator, LEGO MiniFig, ",
24
  "weight": 0.8,
25
+ "file": "Lego_XL_v2.1.safetensors"
26
  },
27
  "Claymation Style": {
28
  "repo": "DoctorDiffusion/doctor-diffusion-s-claymation-style-lora",
 
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
  dtype = torch.float16 if device == "cuda" else torch.float32
46
 
47
+ # Load ControlNet
 
48
  controlnet = ControlNetModel.from_pretrained(
49
  "lllyasviel/sd-controlnet-canny",
50
  torch_dtype=dtype,
 
52
  )
53
 
54
  # Load Base Stable Diffusion 1.5
 
55
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
56
  "stable-diffusion-v1-5/stable-diffusion-v1-5",
57
  controlnet=controlnet,
 
59
  use_safetensors=True
60
  )
61
 
 
62
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
63
 
64
+ if device == "cuda":
65
+ pipe.to(device)
 
66
 
67
  print("Base Pipeline Loaded Successfully.")
68
 
 
71
  # -----------------------------------------------------------------------------
72
 
73
  def get_canny_image(image, low_threshold=100, high_threshold=200):
 
 
 
 
74
  image_array = np.array(image)
 
 
75
  canny_edges = cv2.Canny(image_array, low_threshold, high_threshold)
 
 
76
  canny_edges = canny_edges[:, :, None]
77
  canny_edges = np.concatenate([canny_edges, canny_edges, canny_edges], axis=2)
 
78
  return Image.fromarray(canny_edges)
79
 
80
  # -----------------------------------------------------------------------------
81
+ # 4. Inference Logic
82
  # -----------------------------------------------------------------------------
83
 
84
  @spaces.GPU(duration=120)
 
95
  raise gr.Error("Validation Error: Please upload an image first!")
96
 
97
  # 1. Preprocess Image
 
98
  width, height = 512, 512
99
  input_image = input_image.resize((width, height))
100
  canny_image = get_canny_image(input_image)
101
 
102
  # 2. Manage LoRA State
103
+ pipe.unload_lora_weights()
104
+
105
+ style_config = LORA_REGISTRY[lora_selection]
106
+ repo_id = style_config["repo"]
107
+ trigger_text = style_config["trigger"]
108
+ lora_weight = style_config["weight"]
109
+
110
+ final_prompt = f"{trigger_text}{prompt}"
111
+
112
  try:
 
 
 
 
 
 
 
 
 
113
  if repo_id:
114
  print(f"Loading LoRA: {repo_id}")
115
  pipe.load_lora_weights(repo_id)
116
+ pipe.fuse_lora(lora_scale=lora_weight)
117
+ print("LoRA fused successfully.")
 
 
118
 
119
  except Exception as e:
120
  print(f"LoRA Load Error: {e}")
121
+ gr.Warning(f"Failed to load LoRA {lora_selection}. Using base model. Error: {str(e)}")
 
 
122
 
123
+ # 3. Generation
 
124
  generator = torch.Generator(device).manual_seed(int(seed))
125
 
126
  print(f"Generating with Prompt: {final_prompt}")
127
 
128
+ try:
129
+ output_image = pipe(
130
+ prompt=final_prompt,
131
+ negative_prompt=negative_prompt,
132
+ image=canny_image,
133
+ num_inference_steps=int(steps),
134
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
135
+ guidance_scale=7.5,
136
+ generator=generator,
137
+ ).images
138
+ except Exception as e:
139
+ pipe.unfuse_lora()
140
+ pipe.unload_lora_weights()
141
+ raise e
142
+
143
+ # 4. Cleanup
144
+ if repo_id:
145
+ print("Unfusing LoRA...")
146
+ pipe.unfuse_lora()
147
+ pipe.unload_lora_weights()
148
+
149
+ torch.cuda.empty_cache()
150
 
151
  return canny_image, output_image
152
 
 
159
  .guide-text {font-size: 1.1em; color: #4a5568;}
160
  """
161
 
162
+ # Example Data (Using resolve URLs)
 
 
 
163
  examples = [
164
  [
165
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_canny.png",
 
219
  ]
220
  ]
221
 
 
 
222
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
223
 
224
  with gr.Column(elem_id="col-container"):
 
283
  inputs=[input_image, prompt, negative_prompt, lora_selection, controlnet_conditioning_scale, steps, seed],
284
  outputs=[output_canny, output_result],
285
  fn=generate_controlled_image,
286
+ cache_examples=False # CRITICAL FIX: Set to False to prevent async loop errors
287
  )
288
 
289
  # Event Wiring
 
302
  )
303
 
304
  if __name__ == "__main__":
305
+ demo.launch()