Astridkraft commited on
Commit
602e4ba
·
verified ·
1 Parent(s): f893fde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -12
app.py CHANGED
@@ -7,6 +7,8 @@ import time
7
  import os
8
  import tempfile
9
  import random
 
 
10
 
11
  # === OPTIMIERTE EINSTELLUNGEN ===
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -20,14 +22,17 @@ current_txt2img_task_id = None
20
  current_img2img_task_id = None
21
  cancel_txt2img = False
22
  cancel_img2img = False
 
23
 
24
- def interrupt_callback(pipe, step, timestep, callback_kwargs):
25
- """Callback für Diffusers: Prüft Flag und löst Exception für Interrupt aus"""
26
  global cancel_txt2img, cancel_img2img
27
- if (pipe == pipe_txt2img and cancel_txt2img) or (pipe == pipe_img2img and cancel_img2img):
28
- print(f"🛑 Interrupt ausgelöst für Pipeline (Step: {step})")
29
- raise RuntimeError("Task interrupted by user")
30
- return callback_kwargs
 
 
31
 
32
  # === GESICHTSMASKEN-FUNKTIONEN ===
33
  def create_face_mask(image, bbox_coords):
@@ -130,13 +135,18 @@ def text_to_image(prompt, steps, guidance_scale):
130
  if current_txt2img_task_id is not None:
131
  print(f"🛑 Cancel vorheriger Text-to-Image Task (ID: {current_txt2img_task_id})")
132
  cancel_txt2img = True
133
- time.sleep(1.0) # Längere Pause für CPU
134
 
135
  current_txt2img_task_id = task_id
136
  cancel_txt2img = False
137
  print(f"Starting generation for: {prompt} (Task ID: {task_id})")
138
  start_time = time.time()
139
 
 
 
 
 
 
140
  pipe = load_txt2img()
141
 
142
  # ZUFÄLLIGER SEED für Variation
@@ -144,14 +154,21 @@ def text_to_image(prompt, steps, guidance_scale):
144
  generator = torch.Generator(device=device).manual_seed(seed)
145
  print(f"🎲 Using seed: {seed}")
146
 
 
 
 
 
 
 
 
 
147
  image = pipe(
148
  prompt=prompt,
149
  height=IMG_SIZE,
150
  width=IMG_SIZE,
151
  num_inference_steps=steps,
152
  guidance_scale=guidance_scale,
153
- generator=generator,
154
- callback_on_step_end=interrupt_callback
155
  ).images[0]
156
 
157
  end_time = time.time()
@@ -188,7 +205,7 @@ def img_to_image(image, prompt, neg_prompt, strength, steps, guidance_scale, fac
188
  if current_img2img_task_id is not None:
189
  print(f"🛑 Cancel vorheriger Image-to-Image Task (ID: {current_img2img_task_id})")
190
  cancel_img2img = True
191
- time.sleep(1.0) # Längere Pause für CPU
192
 
193
  current_img2img_task_id = task_id
194
  cancel_img2img = False
@@ -197,6 +214,11 @@ def img_to_image(image, prompt, neg_prompt, strength, steps, guidance_scale, fac
197
  print(f"Gesicht beibehalten: {face_preserve}")
198
  start_time = time.time()
199
 
 
 
 
 
 
200
  pipe = load_img2img()
201
  img_resized = image.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
202
 
@@ -238,6 +260,14 @@ def img_to_image(image, prompt, neg_prompt, strength, steps, guidance_scale, fac
238
  mask = None
239
 
240
  # --- PIPELINE-AUFRUF ---
 
 
 
 
 
 
 
 
241
  result = pipe(
242
  prompt=prompt,
243
  negative_prompt=neg_prompt,
@@ -246,8 +276,7 @@ def img_to_image(image, prompt, neg_prompt, strength, steps, guidance_scale, fac
246
  strength=adj_strength,
247
  num_inference_steps=int(steps),
248
  guidance_scale=adj_guidance,
249
- generator=generator,
250
- callback_on_step_end=interrupt_callback
251
  )
252
 
253
  end_time = time.time()
 
7
  import os
8
  import tempfile
9
  import random
10
+ import threading
11
+ import queue
12
 
13
  # === OPTIMIERTE EINSTELLUNGEN ===
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
22
  current_img2img_task_id = None
23
  cancel_txt2img = False
24
  cancel_img2img = False
25
+ interrupt_queue = queue.Queue()
26
 
27
+ def interrupt_monitor(task_id, is_txt2img):
28
+ """Überwacht Cancel-Flags in einem separaten Thread und löst Interrupt aus"""
29
  global cancel_txt2img, cancel_img2img
30
+ while True:
31
+ if (is_txt2img and cancel_txt2img) or (not is_txt2img and cancel_img2img):
32
+ print(f"🛑 Interrupt ausgelöst für Task ID: {task_id}")
33
+ interrupt_queue.put(("interrupt", task_id))
34
+ break
35
+ time.sleep(0.1) # Polling-Intervall
36
 
37
  # === GESICHTSMASKEN-FUNKTIONEN ===
38
  def create_face_mask(image, bbox_coords):
 
135
  if current_txt2img_task_id is not None:
136
  print(f"🛑 Cancel vorheriger Text-to-Image Task (ID: {current_txt2img_task_id})")
137
  cancel_txt2img = True
138
+ time.sleep(1.0) # Wartezeit für CPU
139
 
140
  current_txt2img_task_id = task_id
141
  cancel_txt2img = False
142
  print(f"Starting generation for: {prompt} (Task ID: {task_id})")
143
  start_time = time.time()
144
 
145
+ # Starte Interrupt-Monitor in einem separaten Thread
146
+ monitor_thread = threading.Thread(target=interrupt_monitor, args=(task_id, True))
147
+ monitor_thread.daemon = True
148
+ monitor_thread.start()
149
+
150
  pipe = load_txt2img()
151
 
152
  # ZUFÄLLIGER SEED für Variation
 
154
  generator = torch.Generator(device=device).manual_seed(seed)
155
  print(f"🎲 Using seed: {seed}")
156
 
157
+ # Prüfe Interrupt-Queue während der Generierung
158
+ for step in range(steps):
159
+ if not interrupt_queue.empty():
160
+ action, interrupted_task_id = interrupt_queue.get()
161
+ if action == "interrupt" and interrupted_task_id == task_id:
162
+ raise RuntimeError("Task interrupted by user")
163
+ time.sleep(0.01) # Kurze Pause für Queue-Prüfung
164
+
165
  image = pipe(
166
  prompt=prompt,
167
  height=IMG_SIZE,
168
  width=IMG_SIZE,
169
  num_inference_steps=steps,
170
  guidance_scale=guidance_scale,
171
+ generator=generator
 
172
  ).images[0]
173
 
174
  end_time = time.time()
 
205
  if current_img2img_task_id is not None:
206
  print(f"🛑 Cancel vorheriger Image-to-Image Task (ID: {current_img2img_task_id})")
207
  cancel_img2img = True
208
+ time.sleep(1.0) # Wartezeit für CPU
209
 
210
  current_img2img_task_id = task_id
211
  cancel_img2img = False
 
214
  print(f"Gesicht beibehalten: {face_preserve}")
215
  start_time = time.time()
216
 
217
+ # Starte Interrupt-Monitor in einem separaten Thread
218
+ monitor_thread = threading.Thread(target=interrupt_monitor, args=(task_id, False))
219
+ monitor_thread.daemon = True
220
+ monitor_thread.start()
221
+
222
  pipe = load_img2img()
223
  img_resized = image.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
224
 
 
260
  mask = None
261
 
262
  # --- PIPELINE-AUFRUF ---
263
+ # Prüfe Interrupt-Queue während der Generierung
264
+ for step in range(int(steps)):
265
+ if not interrupt_queue.empty():
266
+ action, interrupted_task_id = interrupt_queue.get()
267
+ if action == "interrupt" and interrupted_task_id == task_id:
268
+ raise RuntimeError("Task interrupted by user")
269
+ time.sleep(0.01) # Kurze Pause für Queue-Prüfung
270
+
271
  result = pipe(
272
  prompt=prompt,
273
  negative_prompt=neg_prompt,
 
276
  strength=adj_strength,
277
  num_inference_steps=int(steps),
278
  guidance_scale=adj_guidance,
279
+ generator=generator
 
280
  )
281
 
282
  end_time = time.time()