primerz commited on
Commit
ae0aa20
·
verified ·
1 Parent(s): ce1d33e

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +54 -9
generator.py CHANGED
@@ -21,6 +21,7 @@ from models import (
21
  load_sdxl_pipeline, load_lora, setup_compel,
22
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
23
  )
 
24
 
25
 
26
  class RetroArtConverter:
@@ -36,10 +37,13 @@ class RetroArtConverter:
36
  'zoe_depth': False
37
  }
38
 
39
- # Load face analysis
 
 
 
40
  self.face_app, self.face_detection_enabled = load_face_analysis()
41
 
42
- # Load depth detector
43
  self.zoe_depth, zoe_success = load_depth_detector()
44
  self.models_loaded['zoe_depth'] = zoe_success
45
 
@@ -67,7 +71,7 @@ class RetroArtConverter:
67
  # Optimize
68
  optimize_pipeline(self.pipe)
69
 
70
- # Load caption model
71
  self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
72
 
73
  # Set CLIP skip
@@ -76,7 +80,10 @@ class RetroArtConverter:
76
  # Print status
77
  self._print_status()
78
 
79
- print(" [OK] RetroArtConverter initialized with InstantID!")
 
 
 
80
 
81
  def _print_status(self):
82
  """Print model loading status"""
@@ -89,7 +96,7 @@ class RetroArtConverter:
89
  print("===================\n")
90
 
91
  def get_depth_map(self, image):
92
- """Generate depth map using Zoe Depth"""
93
  if self.zoe_depth is not None:
94
  try:
95
  if image.mode != 'RGB':
@@ -108,16 +115,27 @@ class RetroArtConverter:
108
  size_for_depth = (int(target_width), int(target_height))
109
  image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
110
 
 
 
 
 
111
  depth_array = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
112
  depth_image = Image.fromarray(depth_array)
113
 
 
 
 
 
114
  if depth_image.size != image.size:
115
  depth_image = depth_image.resize(image.size, Image.LANCZOS)
116
 
117
- print(f"[DEPTH] Generated depth map: {depth_image.size}")
118
  return depth_image, depth_array
119
  except Exception as e:
120
  print(f"[DEPTH] Generation failed: {e}, using grayscale")
 
 
 
121
  return image.convert('L').convert('RGB'), None
122
  else:
123
  print("[DEPTH] Detector not available, using grayscale")
@@ -162,11 +180,14 @@ class RetroArtConverter:
162
  return None
163
 
164
  def generate_caption(self, image):
165
- """Generate caption for image"""
166
  if not self.caption_enabled or self.caption_model is None:
167
  return None
168
 
169
  try:
 
 
 
170
  if self.caption_model_type == 'git':
171
  inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
172
  generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
@@ -176,11 +197,19 @@ class RetroArtConverter:
176
  generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
177
  caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
178
  else:
 
179
  return None
180
 
 
 
 
 
181
  return sanitize_text(caption)
182
  except Exception as e:
183
  print(f"[CAPTION] Generation failed: {e}")
 
 
 
184
  return None
185
 
186
  def generate_retro_art(
@@ -389,10 +418,26 @@ class RetroArtConverter:
389
  return generated_image
390
 
391
  finally:
392
- # Memory cleanup
393
  if torch.cuda.is_available():
394
  torch.cuda.empty_cache()
395
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
 
398
  print("[OK] Generator class ready with InstantID support")
 
21
  load_sdxl_pipeline, load_lora, setup_compel,
22
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
23
  )
24
+ from memory_utils import MemoryManager, ModelOffloader
25
 
26
 
27
  class RetroArtConverter:
 
37
  'zoe_depth': False
38
  }
39
 
40
+ # Initialize memory manager
41
+ self.memory_manager = MemoryManager(device=device, dtype=dtype, verbose=True)
42
+
43
+ # Load face analysis (stays on CPU)
44
  self.face_app, self.face_detection_enabled = load_face_analysis()
45
 
46
+ # Load depth detector (starts on CPU)
47
  self.zoe_depth, zoe_success = load_depth_detector()
48
  self.models_loaded['zoe_depth'] = zoe_success
49
 
 
71
  # Optimize
72
  optimize_pipeline(self.pipe)
73
 
74
+ # Load caption model (starts on CPU)
75
  self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
76
 
77
  # Set CLIP skip
 
80
  # Print status
81
  self._print_status()
82
 
83
+ # Initial memory cleanup
84
+ self.memory_manager.cleanup_memory(aggressive=True)
85
+
86
+ print(" [OK] RetroArtConverter initialized with optimized memory management!")
87
 
88
  def _print_status(self):
89
  """Print model loading status"""
 
96
  print("===================\n")
97
 
98
  def get_depth_map(self, image):
99
+ """Generate depth map using Zoe Depth with optimized GPU usage"""
100
  if self.zoe_depth is not None:
101
  try:
102
  if image.mode != 'RGB':
 
115
  size_for_depth = (int(target_width), int(target_height))
116
  image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
117
 
118
+ # Move depth model to GPU temporarily
119
+ self.zoe_depth = self.zoe_depth.to(self.device)
120
+
121
+ # Generate depth map
122
  depth_array = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
123
  depth_image = Image.fromarray(depth_array)
124
 
125
+ # Move depth model back to CPU to free GPU memory
126
+ self.zoe_depth = self.zoe_depth.to("cpu")
127
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
128
+
129
  if depth_image.size != image.size:
130
  depth_image = depth_image.resize(image.size, Image.LANCZOS)
131
 
132
+ print(f"[DEPTH] Generated depth map: {depth_image.size} (model offloaded to CPU)")
133
  return depth_image, depth_array
134
  except Exception as e:
135
  print(f"[DEPTH] Generation failed: {e}, using grayscale")
136
+ # Ensure model is back on CPU even if error
137
+ if hasattr(self, 'zoe_depth') and self.zoe_depth is not None:
138
+ self.zoe_depth = self.zoe_depth.to("cpu")
139
  return image.convert('L').convert('RGB'), None
140
  else:
141
  print("[DEPTH] Detector not available, using grayscale")
 
180
  return None
181
 
182
  def generate_caption(self, image):
183
+ """Generate caption for image with optimized GPU usage"""
184
  if not self.caption_enabled or self.caption_model is None:
185
  return None
186
 
187
  try:
188
+ # Move caption model to GPU temporarily
189
+ self.caption_model = self.caption_model.to(self.device)
190
+
191
  if self.caption_model_type == 'git':
192
  inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
193
  generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
 
197
  generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
198
  caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
199
  else:
200
+ self.caption_model = self.caption_model.to("cpu") # Move back to CPU
201
  return None
202
 
203
+ # Move caption model back to CPU to free GPU memory
204
+ self.caption_model = self.caption_model.to("cpu")
205
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
206
+
207
  return sanitize_text(caption)
208
  except Exception as e:
209
  print(f"[CAPTION] Generation failed: {e}")
210
+ # Ensure model is back on CPU even if error
211
+ if hasattr(self, 'caption_model') and self.caption_model is not None:
212
+ self.caption_model = self.caption_model.to("cpu")
213
  return None
214
 
215
  def generate_retro_art(
 
418
  return generated_image
419
 
420
  finally:
421
+ # Aggressive memory cleanup
422
  if torch.cuda.is_available():
423
  torch.cuda.empty_cache()
424
+ torch.cuda.synchronize() # Ensure all GPU operations complete
425
+
426
+ # Force garbage collection multiple times for thorough cleanup
427
+ for _ in range(3):
428
+ gc.collect()
429
+
430
+ # Additional cleanup for large tensors
431
+ if 'pipe_kwargs' in locals():
432
+ for key in list(pipe_kwargs.keys()):
433
+ if isinstance(pipe_kwargs.get(key), torch.Tensor):
434
+ del pipe_kwargs[key]
435
+
436
+ # Log memory status if in debug mode
437
+ if torch.cuda.is_available():
438
+ allocated = torch.cuda.memory_allocated() / 1024**3
439
+ reserved = torch.cuda.memory_reserved() / 1024**3
440
+ print(f"[MEMORY] GPU: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
441
 
442
 
443
  print("[OK] Generator class ready with InstantID support")