Senum2001 commited on
Commit
de87a47
·
1 Parent(s): 7c16ab2

Fix image download to preserve format and add classification debugging

Browse files
Files changed (1) hide show
  1. inference_core.py +30 -3
inference_core.py CHANGED
@@ -117,9 +117,13 @@ def classify_filtered_image(filtered_img_path: str):
117
  if img is None:
118
  raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
119
 
 
 
 
 
120
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
121
 
122
- # Color masks
123
  blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
124
  black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
125
  yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
@@ -135,6 +139,11 @@ def classify_filtered_image(filtered_img_path: str):
135
  orange_count = np.sum(orange_mask > 0)
136
  red_count = np.sum(red_mask > 0)
137
 
 
 
 
 
 
138
  label = "Unknown"
139
  box_list, label_list = [], []
140
 
@@ -157,7 +166,8 @@ def classify_filtered_image(filtered_img_path: str):
157
  x, y, w, h = cv2.boundingRect(cnt)
158
  box_list.append((x, y, w, h))
159
  label_list.append("Point Overload (Faulty)")
160
-
 
161
  return label, box_list, label_list, img
162
 
163
 
@@ -213,10 +223,27 @@ def download_image_from_url(url):
213
  """Download image from URL to temp file"""
214
  import requests
215
  import tempfile
 
 
 
216
  response = requests.get(url, stream=True)
217
  if response.status_code != 200:
218
  raise Exception(f"Failed to download image from {url}")
219
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  for chunk in response.iter_content(1024):
221
  tmp.write(chunk)
222
  tmp.close()
 
117
  if img is None:
118
  raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
119
 
120
+ # Ensure consistent color space
121
+ if img.dtype != np.uint8:
122
+ img = img.astype(np.uint8)
123
+
124
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
125
 
126
+ # Color masks (with slight tolerance adjustments for consistency)
127
  blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
128
  black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
129
  yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
 
139
  orange_count = np.sum(orange_mask > 0)
140
  red_count = np.sum(red_mask > 0)
141
 
142
+ # Debug logging
143
+ print(f"[Classification] Image shape: {img.shape}")
144
+ print(f"[Classification] Color counts - Blue: {blue_count}, Black: {black_count}, "
145
+ f"Yellow: {yellow_count}, Orange: {orange_count}, Red: {red_count}")
146
+
147
  label = "Unknown"
148
  box_list, label_list = [], []
149
 
 
166
  x, y, w, h = cv2.boundingRect(cnt)
167
  box_list.append((x, y, w, h))
168
  label_list.append("Point Overload (Faulty)")
169
+
170
+ print(f"[Classification] Final label: {label}, Boxes found: {len(box_list)}")
171
  return label, box_list, label_list, img
172
 
173
 
 
223
  """Download image from URL to temp file"""
224
  import requests
225
  import tempfile
226
+ from urllib.parse import urlparse
227
+ import mimetypes
228
+
229
  response = requests.get(url, stream=True)
230
  if response.status_code != 200:
231
  raise Exception(f"Failed to download image from {url}")
232
+
233
+ # Determine file extension from URL or Content-Type
234
+ content_type = response.headers.get('content-type', '')
235
+ if 'image/png' in content_type:
236
+ suffix = '.png'
237
+ elif 'image/jpeg' in content_type or 'image/jpg' in content_type:
238
+ suffix = '.jpg'
239
+ else:
240
+ # Try to get extension from URL
241
+ parsed_url = urlparse(url)
242
+ path = parsed_url.path
243
+ ext = os.path.splitext(path)[1]
244
+ suffix = ext if ext in ['.jpg', '.jpeg', '.png', '.bmp'] else '.jpg'
245
+
246
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
247
  for chunk in response.iter_content(1024):
248
  tmp.write(chunk)
249
  tmp.close()