oliveryanzuolu commited on
Commit
da2d97d
·
verified ·
1 Parent(s): 18016af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -30
app.py CHANGED
@@ -7,9 +7,6 @@ from PIL import Image
7
 
8
  from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL, UniPCMultistepScheduler
9
 
10
- # -----------------------------------------------------------------------------
11
- # Configuration & Registry
12
- # -----------------------------------------------------------------------------
13
  LORA_REGISTRY = {
14
  "None (Base SDXL)": {
15
  "repo": None,
@@ -36,19 +33,18 @@ LORA_REGISTRY = {
36
  }
37
  }
38
 
39
- # -----------------------------------------------------------------------------
40
- # Model Initialization
41
- # -----------------------------------------------------------------------------
42
- print("Initializing SDXL Pipeline on CPU...")
43
 
44
  vae = AutoencoderKL.from_pretrained(
45
  "madebyollin/sdxl-vae-fp16-fix",
46
- torch_dtype=torch.float16
47
  )
48
 
49
  controlnet = ControlNetModel.from_pretrained(
50
  "diffusers/controlnet-canny-sdxl-1.0",
51
- torch_dtype=torch.float16,
52
  use_safetensors=True
53
  )
54
 
@@ -56,17 +52,15 @@ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
56
  "stabilityai/stable-diffusion-xl-base-1.0",
57
  controlnet=controlnet,
58
  vae=vae,
59
- torch_dtype=torch.float16,
60
  use_safetensors=True
61
  )
62
 
63
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
64
 
65
- print("Pipeline loaded successfully.")
66
 
67
- # -----------------------------------------------------------------------------
68
- # Helper Functions
69
- # -----------------------------------------------------------------------------
70
 
71
  def get_canny_image(image, low_threshold=100, high_threshold=200):
72
  image_array = np.array(image)
@@ -75,10 +69,6 @@ def get_canny_image(image, low_threshold=100, high_threshold=200):
75
  canny_edges = np.concatenate([canny_edges, canny_edges, canny_edges], axis=2)
76
  return Image.fromarray(canny_edges)
77
 
78
- # -----------------------------------------------------------------------------
79
- # Inference Logic
80
- # -----------------------------------------------------------------------------
81
-
82
  @spaces.GPU(duration=120)
83
  def generate_controlled_image(
84
  input_image,
@@ -91,10 +81,6 @@ def generate_controlled_image(
91
  ):
92
  if input_image is None:
93
  raise gr.Error("Please upload an image first!")
94
-
95
- device = "cuda" if torch.cuda.is_available() else "cpu"
96
-
97
- pipe.to(device)
98
 
99
  width, height = 1024, 1024
100
  input_image = input_image.resize((width, height))
@@ -121,7 +107,7 @@ def generate_controlled_image(
121
  print(f"LoRA Load Error: {e}")
122
  gr.Warning(f"Failed to load LoRA. Using base model.")
123
 
124
- generator = torch.Generator(device).manual_seed(int(seed))
125
 
126
  print(f"Generating: {final_prompt[:100]}...")
127
 
@@ -141,16 +127,10 @@ def generate_controlled_image(
141
  raise e
142
 
143
  pipe.unload_lora_weights()
144
-
145
- if device == "cuda":
146
- torch.cuda.empty_cache()
147
 
148
  return canny_image, output_image
149
 
150
- # -----------------------------------------------------------------------------
151
- # Gradio UI
152
- # -----------------------------------------------------------------------------
153
-
154
  css = """
155
  #col-container {max-width: 1200px; margin-left: auto; margin-right: auto;}
156
  .guide-text {font-size: 1.1em; color: #4a5568;}
 
7
 
8
  from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL, UniPCMultistepScheduler
9
 
 
 
 
10
  LORA_REGISTRY = {
11
  "None (Base SDXL)": {
12
  "repo": None,
 
33
  }
34
  }
35
 
36
+ print("Loading SDXL Pipeline...")
37
+
38
+ dtype = torch.float16
 
39
 
40
  vae = AutoencoderKL.from_pretrained(
41
  "madebyollin/sdxl-vae-fp16-fix",
42
+ torch_dtype=dtype
43
  )
44
 
45
  controlnet = ControlNetModel.from_pretrained(
46
  "diffusers/controlnet-canny-sdxl-1.0",
47
+ torch_dtype=dtype,
48
  use_safetensors=True
49
  )
50
 
 
52
  "stabilityai/stable-diffusion-xl-base-1.0",
53
  controlnet=controlnet,
54
  vae=vae,
55
+ torch_dtype=dtype,
56
  use_safetensors=True
57
  )
58
 
59
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
60
 
61
+ pipe.enable_model_cpu_offload()
62
 
63
+ print("Pipeline loaded successfully.")
 
 
64
 
65
  def get_canny_image(image, low_threshold=100, high_threshold=200):
66
  image_array = np.array(image)
 
69
  canny_edges = np.concatenate([canny_edges, canny_edges, canny_edges], axis=2)
70
  return Image.fromarray(canny_edges)
71
 
 
 
 
 
72
  @spaces.GPU(duration=120)
73
  def generate_controlled_image(
74
  input_image,
 
81
  ):
82
  if input_image is None:
83
  raise gr.Error("Please upload an image first!")
 
 
 
 
84
 
85
  width, height = 1024, 1024
86
  input_image = input_image.resize((width, height))
 
107
  print(f"LoRA Load Error: {e}")
108
  gr.Warning(f"Failed to load LoRA. Using base model.")
109
 
110
+ generator = torch.Generator("cuda").manual_seed(int(seed))
111
 
112
  print(f"Generating: {final_prompt[:100]}...")
113
 
 
127
  raise e
128
 
129
  pipe.unload_lora_weights()
130
+ torch.cuda.empty_cache()
 
 
131
 
132
  return canny_image, output_image
133
 
 
 
 
 
134
  css = """
135
  #col-container {max-width: 1200px; margin-left: auto; margin-right: auto;}
136
  .guide-text {font-size: 1.1em; color: #4a5568;}