primerz commited on
Commit
4236be3
·
verified ·
1 Parent(s): 256cde6

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +59 -44
generator.py CHANGED
@@ -97,50 +97,65 @@ class RetroArtConverter:
97
  print("===================\n")
98
 
99
  def get_depth_map(self, image):
100
- """Generate depth map using Zoe Depth"""
101
- if self.zoe_depth is not None:
102
- try:
103
- if image.mode != 'RGB':
104
- image = image.convert('RGB')
105
-
106
- # Use safe helpers for type safety
107
- orig_width, orig_height = safe_image_size(image)
108
-
109
- # FIXED: Use multiples of 64 (not 32)
110
- target_width = ensure_int((orig_width // 64) * 64)
111
- target_height = ensure_int((orig_height // 64) * 64)
112
-
113
- target_width = ensure_int(max(64, target_width))
114
- target_height = ensure_int(max(64, target_height))
115
-
116
- # Create an explicit tuple of standard ints
117
- size_for_depth = (target_width, target_height)
118
-
119
- # Always resize using the explicit int tuple
120
- image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
121
-
122
- # Generate depth map
123
- depth_image = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
124
-
125
- # Resize to match original if needed
126
- if (depth_image.width, depth_image.height) != (orig_width, orig_height):
127
- depth_image = depth_image.resize((orig_width, orig_height), Image.LANCZOS)
128
-
129
- # Convert to RGB if needed
130
- if depth_image.mode != 'RGB':
131
- depth_image = depth_image.convert('RGB')
132
-
133
- return depth_image, np.array(depth_image)
134
-
135
- except Exception as e:
136
- print(f"Depth map generation failed: {e}")
137
- import traceback
138
- traceback.print_exc()
139
- return None, None
140
- else:
141
- print(" Zoe Depth not available")
142
- return None, None
143
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def generate_caption(self, image):
145
  """Generate caption for image using loaded caption model"""
146
  if not self.caption_enabled or self.caption_model is None:
 
97
  print("===================\n")
98
 
99
  def get_depth_map(self, image):
100
+ """
101
+ Generate depth map using available depth detector.
102
+ Supports: LeresDetector, ZoeDetector, or MidasDetector.
103
+ """
104
+ if self.depth_detector is not None:
105
+ try:
106
+ if image.mode != 'RGB':
107
+ image = image.convert('RGB')
108
+
109
+ orig_width, orig_height = image.size
110
+ orig_width = int(orig_width)
111
+ orig_height = int(orig_height)
112
+
113
+ target_width = int((orig_width // 64) * 64)
114
+ target_height = int((orig_height // 64) * 64)
115
+
116
+ target_width = int(max(64, target_width))
117
+ target_height = int(max(64, target_height))
118
+
119
+ size_for_depth = (int(target_width), int(target_height))
120
+ image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
121
+
122
+ if target_width != orig_width or target_height != orig_height:
123
+ print(f"[DEPTH] Resized for {self.depth_type.upper()}Detector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
124
+
125
+ # Use torch.no_grad() and clear cache
126
+ with torch.no_grad():
127
+ # --- FIX: Move model to GPU for inference and back to CPU ---
128
+ self.depth_detector.to(self.device)
129
+ depth_image = self.depth_detector(image_for_depth)
130
+ self.depth_detector.to("cpu")
131
+
132
+ # ADDED: Clear GPU cache after depth detection
133
+ if torch.cuda.is_available():
134
+ torch.cuda.empty_cache()
135
+
136
+ depth_width, depth_height = depth_image.size
137
+ if depth_width != orig_width or depth_height != orig_height:
138
+ depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
139
+
140
+ print(f"[DEPTH] {self.depth_type.upper()} depth map generated: {orig_width}x{orig_height}")
141
+ return depth_image
142
+
143
+ except Exception as e:
144
+ print(f"[DEPTH] {self.depth_type.upper()}Detector failed ({e}), falling back to grayscale depth")
145
+ # ADDED: Clear cache on error
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+
149
+ gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
150
+ depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
151
+ return Image.fromarray(depth_colored)
152
+ else:
153
+ print("[DEPTH] No depth detector available, using grayscale fallback")
154
+ gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
155
+ depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
156
+ return Image.fromarray(depth_colored)
157
+
158
+
159
  def generate_caption(self, image):
160
  """Generate caption for image using loaded caption model"""
161
  if not self.caption_enabled or self.caption_model is None: