oliveryanzuolu commited on
Commit
71a28f3
·
verified ·
1 Parent(s): b1fa94a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -46
app.py CHANGED
@@ -7,24 +7,21 @@ from PIL import Image
7
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
8
 
9
  # -----------------------------------------------------------------------------
10
- # 1. Model Setup (Global Loading)
11
  # -----------------------------------------------------------------------------
12
  print("Loading models... This might take a minute.")
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  dtype = torch.float16
16
 
17
- # A. Load ControlNet (Canny)
18
- # We use Canny because it's the most intuitive for students to understand "Edge Detection"
19
  controlnet = ControlNetModel.from_pretrained(
20
  "lllyasviel/sd-controlnet-canny",
21
  torch_dtype=dtype,
22
  use_safetensors=True
23
  )
24
 
25
- # B. Load Base Stable Diffusion 1.5
26
- # SD1.5 is chosen over SDXL here because swapping LoRAs on the fly is much faster
27
- # and less memory intensive for a live demo.
28
  model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
29
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
30
  model_id,
@@ -33,7 +30,6 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
33
  use_safetensors=True
34
  )
35
 
36
- # Use a fast scheduler
37
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
38
  pipe.enable_model_cpu_offload()
39
 
@@ -44,14 +40,8 @@ print("Base models loaded.")
44
  # -----------------------------------------------------------------------------
45
 
46
  def get_canny_image(image, low_threshold=100, high_threshold=200):
47
- """
48
- Converts an input image into a Canny edge map.
49
- This helps students visualize what the ControlNet actually 'sees'.
50
- """
51
  image = np.array(image)
52
- # Convert to grayscale edges
53
  canny_image = cv2.Canny(image, low_threshold, high_threshold)
54
- # Convert back to 3-channel RGB for the model
55
  canny_image = canny_image[:, :, None]
56
  canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
57
  return Image.fromarray(canny_image)
@@ -60,14 +50,10 @@ def get_canny_image(image, low_threshold=100, high_threshold=200):
60
  # 3. Inference Logic
61
  # -----------------------------------------------------------------------------
62
 
63
- # Define available LoRAs for the tutorial
64
- # Format: "Display Name": "HuggingFace_Path"
65
  LORA_OPTIONS = {
66
  "None (Base SD1.5)": None,
67
- "Lego Style": "minimaxir/sd-1-5-lego-lora", # Turns objects into Lego
68
- "Claymation Style": "MoShin/clay-style-lora-sd1.5", # Turns objects into Clay
69
- "Pixel Art": "nerijs/pixel-art-xl", # Note: Some LoRAs might be specific, stick to SD1.5 ones usually
70
- # Let's use a reliable Pixel Art for 1.5
71
  "Pixel Art (SD1.5)": "ismail/pixel-art-style-lora"
72
  }
73
 
@@ -75,7 +61,7 @@ LORA_OPTIONS = {
75
  def generate_controlled_image(
76
  input_image,
77
  prompt,
78
- negative_prompt,
79
  lora_selection,
80
  controlnet_conditioning_scale,
81
  steps,
@@ -84,35 +70,27 @@ def generate_controlled_image(
84
  if input_image is None:
85
  raise gr.Error("Please upload an image first!")
86
 
87
- # 1. Preprocess: Create Canny Map
88
- # We resize to 512x512 for standard SD1.5 inference
89
  input_image = input_image.resize((512, 512))
90
  canny_image = get_canny_image(input_image)
91
 
92
- # 2. Manage LoRA Adapters
93
- # This is the key educational part: Dynamic Adapter Swapping
94
  try:
95
- pipe.unload_lora_weights() # Clear previous LoRAs
96
-
97
  lora_path = LORA_OPTIONS[lora_selection]
98
  if lora_path:
99
  print(f"Loading LoRA: {lora_path}")
100
- # adapter_name is optional but good practice
101
  pipe.load_lora_weights(lora_path)
102
-
103
  except Exception as e:
104
  print(f"Error loading LoRA: {e}")
105
- # Continue without LoRA if it fails
106
 
107
- # 3. Generator for reproducibility
108
- generator = torch.Generator("cuda").manual_seed(seed)
109
 
110
- # 4. Inference
111
  print("Generating...")
112
  result = pipe(
113
  prompt=prompt,
114
  negative_prompt=negative_prompt,
115
- image=canny_image, # The ControlNet input
116
  num_inference_steps=steps,
117
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
118
  generator=generator,
@@ -128,7 +106,8 @@ css = """
128
  #col-container {max_width: 1200px; margin-left: auto; margin-right: auto;}
129
  """
130
 
131
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
 
132
 
133
  with gr.Column(elem_id="col-container"):
134
  gr.Markdown("# Tutorial: ControlNet + LoRA")
@@ -138,7 +117,6 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
138
  )
139
 
140
  with gr.Row():
141
- # Left Column: Settings
142
  with gr.Column(scale=1):
143
  input_image = gr.Image(label="Input Image (Structure Source)", type="pil")
144
 
@@ -148,7 +126,13 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
148
  lines=2
149
  )
150
 
151
- # LoRA Selector
 
 
 
 
 
 
152
  lora_selection = gr.Dropdown(
153
  label="Select LoRA Style",
154
  choices=list(LORA_OPTIONS.keys()),
@@ -156,7 +140,6 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
156
  info="LoRA changes the artistic style without changing the model architecture."
157
  )
158
 
159
- # ControlNet Settings
160
  with gr.Accordion("Control & Inference Settings", open=True):
161
  controlnet_conditioning_scale = gr.Slider(
162
  label="ControlNet Strength",
@@ -168,25 +151,26 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
168
 
169
  submit_btn = gr.Button("Generate", variant="primary")
170
 
171
- # Right Column: Results
172
  with gr.Column(scale=1):
173
  with gr.Row():
174
  output_canny = gr.Image(label="Detected Edges (ControlNet Input)", type="pil")
175
  output_result = gr.Image(label="Final Generated Image", type="pil")
176
 
 
177
  submit_btn.click(
178
  fn=generate_controlled_image,
179
  inputs=[
180
- input_image, prompt, "blurry, low quality, distorted",
181
- lora_selection, controlnet_conditioning_scale, steps, seed
 
 
 
 
 
182
  ],
183
  outputs=[output_canny, output_result]
184
  )
185
 
186
- # Examples are crucial for tutorials
187
- # Note: You would need to host a local image or use a URL for the example to work perfectly in Spaces
188
- # But here is the structure:
189
- # gr.Examples(...)
190
-
191
  if __name__ == "__main__":
192
- demo.launch()
 
 
7
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
8
 
9
  # -----------------------------------------------------------------------------
10
+ # 1. Model Setup
11
  # -----------------------------------------------------------------------------
12
  print("Loading models... This might take a minute.")
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  dtype = torch.float16
16
 
17
+ # Load ControlNet (Canny)
 
18
  controlnet = ControlNetModel.from_pretrained(
19
  "lllyasviel/sd-controlnet-canny",
20
  torch_dtype=dtype,
21
  use_safetensors=True
22
  )
23
 
24
+ # Load Base SD 1.5
 
 
25
  model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
26
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
27
  model_id,
 
30
  use_safetensors=True
31
  )
32
 
 
33
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
34
  pipe.enable_model_cpu_offload()
35
 
 
40
  # -----------------------------------------------------------------------------
41
 
42
  def get_canny_image(image, low_threshold=100, high_threshold=200):
 
 
 
 
43
  image = np.array(image)
 
44
  canny_image = cv2.Canny(image, low_threshold, high_threshold)
 
45
  canny_image = canny_image[:, :, None]
46
  canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
47
  return Image.fromarray(canny_image)
 
50
  # 3. Inference Logic
51
  # -----------------------------------------------------------------------------
52
 
 
 
53
  LORA_OPTIONS = {
54
  "None (Base SD1.5)": None,
55
+ "Lego Style": "minimaxir/sd-1-5-lego-lora",
56
+ "Claymation Style": "MoShin/clay-style-lora-sd1.5",
 
 
57
  "Pixel Art (SD1.5)": "ismail/pixel-art-style-lora"
58
  }
59
 
 
61
  def generate_controlled_image(
62
  input_image,
63
  prompt,
64
+ negative_prompt, # Added this argument
65
  lora_selection,
66
  controlnet_conditioning_scale,
67
  steps,
 
70
  if input_image is None:
71
  raise gr.Error("Please upload an image first!")
72
 
73
+ # Resize for SD1.5
 
74
  input_image = input_image.resize((512, 512))
75
  canny_image = get_canny_image(input_image)
76
 
77
+ # Manage LoRA
 
78
  try:
79
+ pipe.unload_lora_weights()
 
80
  lora_path = LORA_OPTIONS[lora_selection]
81
  if lora_path:
82
  print(f"Loading LoRA: {lora_path}")
 
83
  pipe.load_lora_weights(lora_path)
 
84
  except Exception as e:
85
  print(f"Error loading LoRA: {e}")
 
86
 
87
+ generator = torch.Generator("cuda").manual_seed(int(seed))
 
88
 
 
89
  print("Generating...")
90
  result = pipe(
91
  prompt=prompt,
92
  negative_prompt=negative_prompt,
93
+ image=canny_image,
94
  num_inference_steps=steps,
95
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
96
  generator=generator,
 
106
  #col-container {max_width: 1200px; margin-left: auto; margin-right: auto;}
107
  """
108
 
109
+ # FIX: Moved theme and css to launch(), removed from Blocks()
110
+ with gr.Blocks() as demo:
111
 
112
  with gr.Column(elem_id="col-container"):
113
  gr.Markdown("# Tutorial: ControlNet + LoRA")
 
117
  )
118
 
119
  with gr.Row():
 
120
  with gr.Column(scale=1):
121
  input_image = gr.Image(label="Input Image (Structure Source)", type="pil")
122
 
 
126
  lines=2
127
  )
128
 
129
+ # FIX: Added a Negative Prompt component
130
+ negative_prompt = gr.Textbox(
131
+ label="Negative Prompt",
132
+ value="blurry, low quality, distorted, ugly, bad anatomy",
133
+ lines=1
134
+ )
135
+
136
  lora_selection = gr.Dropdown(
137
  label="Select LoRA Style",
138
  choices=list(LORA_OPTIONS.keys()),
 
140
  info="LoRA changes the artistic style without changing the model architecture."
141
  )
142
 
 
143
  with gr.Accordion("Control & Inference Settings", open=True):
144
  controlnet_conditioning_scale = gr.Slider(
145
  label="ControlNet Strength",
 
151
 
152
  submit_btn = gr.Button("Generate", variant="primary")
153
 
 
154
  with gr.Column(scale=1):
155
  with gr.Row():
156
  output_canny = gr.Image(label="Detected Edges (ControlNet Input)", type="pil")
157
  output_result = gr.Image(label="Final Generated Image", type="pil")
158
 
159
+ # FIX: inputs list now contains only Gradio components
160
  submit_btn.click(
161
  fn=generate_controlled_image,
162
  inputs=[
163
+ input_image,
164
+ prompt,
165
+ negative_prompt, # Passed the component variable here
166
+ lora_selection,
167
+ controlnet_conditioning_scale,
168
+ steps,
169
+ seed
170
  ],
171
  outputs=[output_canny, output_result]
172
  )
173
 
 
 
 
 
 
174
  if __name__ == "__main__":
175
+ # FIX: Passed theme and css here
176
+ demo.launch(theme=gr.themes.Soft(), css=css)