oliveryanzuolu commited on
Commit
18016af
·
verified ·
1 Parent(s): 6580922

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -37,7 +37,7 @@ LORA_REGISTRY = {
37
  }
38
 
39
  # -----------------------------------------------------------------------------
40
- # Model Initialization (CPU only, ZeroGPU handles device transfer)
41
  # -----------------------------------------------------------------------------
42
  print("Initializing SDXL Pipeline on CPU...")
43
 
@@ -62,7 +62,7 @@ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
62
 
63
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
64
 
65
- print("Pipeline loaded. ZeroGPU will handle device management.")
66
 
67
  # -----------------------------------------------------------------------------
68
  # Helper Functions
@@ -92,6 +92,10 @@ def generate_controlled_image(
92
  if input_image is None:
93
  raise gr.Error("Please upload an image first!")
94
 
 
 
 
 
95
  width, height = 1024, 1024
96
  input_image = input_image.resize((width, height))
97
  canny_image = get_canny_image(input_image)
@@ -101,7 +105,6 @@ def generate_controlled_image(
101
  style_config = LORA_REGISTRY[lora_selection]
102
  repo_id = style_config["repo"]
103
  trigger_text = style_config["trigger"]
104
- lora_weight = style_config["weight"]
105
  lora_file = style_config.get("file", None)
106
 
107
  final_prompt = f"{trigger_text}{prompt}"
@@ -118,7 +121,7 @@ def generate_controlled_image(
118
  print(f"LoRA Load Error: {e}")
119
  gr.Warning(f"Failed to load LoRA. Using base model.")
120
 
121
- generator = torch.Generator("cuda").manual_seed(int(seed))
122
 
123
  print(f"Generating: {final_prompt[:100]}...")
124
 
@@ -138,7 +141,9 @@ def generate_controlled_image(
138
  raise e
139
 
140
  pipe.unload_lora_weights()
141
- torch.cuda.empty_cache()
 
 
142
 
143
  return canny_image, output_image
144
 
 
37
  }
38
 
39
  # -----------------------------------------------------------------------------
40
+ # Model Initialization
41
  # -----------------------------------------------------------------------------
42
  print("Initializing SDXL Pipeline on CPU...")
43
 
 
62
 
63
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
64
 
65
+ print("Pipeline loaded successfully.")
66
 
67
  # -----------------------------------------------------------------------------
68
  # Helper Functions
 
92
  if input_image is None:
93
  raise gr.Error("Please upload an image first!")
94
 
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+
97
+ pipe.to(device)
98
+
99
  width, height = 1024, 1024
100
  input_image = input_image.resize((width, height))
101
  canny_image = get_canny_image(input_image)
 
105
  style_config = LORA_REGISTRY[lora_selection]
106
  repo_id = style_config["repo"]
107
  trigger_text = style_config["trigger"]
 
108
  lora_file = style_config.get("file", None)
109
 
110
  final_prompt = f"{trigger_text}{prompt}"
 
121
  print(f"LoRA Load Error: {e}")
122
  gr.Warning(f"Failed to load LoRA. Using base model.")
123
 
124
+ generator = torch.Generator(device).manual_seed(int(seed))
125
 
126
  print(f"Generating: {final_prompt[:100]}...")
127
 
 
141
  raise e
142
 
143
  pipe.unload_lora_weights()
144
+
145
+ if device == "cuda":
146
+ torch.cuda.empty_cache()
147
 
148
  return canny_image, output_image
149