Spaces:
Sleeping
Sleeping
Senum2001 commited on
Commit ·
de87a47
1
Parent(s): 7c16ab2
Fix image download to preserve format and add classification debugging
Browse files- inference_core.py +30 -3
inference_core.py
CHANGED
|
@@ -117,9 +117,13 @@ def classify_filtered_image(filtered_img_path: str):
|
|
| 117 |
if img is None:
|
| 118 |
raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 121 |
|
| 122 |
-
# Color masks
|
| 123 |
blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
|
| 124 |
black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
|
| 125 |
yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
|
|
@@ -135,6 +139,11 @@ def classify_filtered_image(filtered_img_path: str):
|
|
| 135 |
orange_count = np.sum(orange_mask > 0)
|
| 136 |
red_count = np.sum(red_mask > 0)
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
label = "Unknown"
|
| 139 |
box_list, label_list = [], []
|
| 140 |
|
|
@@ -157,7 +166,8 @@ def classify_filtered_image(filtered_img_path: str):
|
|
| 157 |
x, y, w, h = cv2.boundingRect(cnt)
|
| 158 |
box_list.append((x, y, w, h))
|
| 159 |
label_list.append("Point Overload (Faulty)")
|
| 160 |
-
|
|
|
|
| 161 |
return label, box_list, label_list, img
|
| 162 |
|
| 163 |
|
|
@@ -213,10 +223,27 @@ def download_image_from_url(url):
|
|
| 213 |
"""Download image from URL to temp file"""
|
| 214 |
import requests
|
| 215 |
import tempfile
|
|
|
|
|
|
|
|
|
|
| 216 |
response = requests.get(url, stream=True)
|
| 217 |
if response.status_code != 200:
|
| 218 |
raise Exception(f"Failed to download image from {url}")
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
for chunk in response.iter_content(1024):
|
| 221 |
tmp.write(chunk)
|
| 222 |
tmp.close()
|
|
|
|
| 117 |
if img is None:
|
| 118 |
raise FileNotFoundError(f"Could not read filtered image: {filtered_img_path}")
|
| 119 |
|
| 120 |
+
# Ensure consistent color space
|
| 121 |
+
if img.dtype != np.uint8:
|
| 122 |
+
img = img.astype(np.uint8)
|
| 123 |
+
|
| 124 |
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 125 |
|
| 126 |
+
# Color masks (with slight tolerance adjustments for consistency)
|
| 127 |
blue_mask = cv2.inRange(hsv, (90, 50, 20), (130, 255, 255))
|
| 128 |
black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 50))
|
| 129 |
yellow_mask = cv2.inRange(hsv, (20, 130, 130), (35, 255, 255))
|
|
|
|
| 139 |
orange_count = np.sum(orange_mask > 0)
|
| 140 |
red_count = np.sum(red_mask > 0)
|
| 141 |
|
| 142 |
+
# Debug logging
|
| 143 |
+
print(f"[Classification] Image shape: {img.shape}")
|
| 144 |
+
print(f"[Classification] Color counts - Blue: {blue_count}, Black: {black_count}, "
|
| 145 |
+
f"Yellow: {yellow_count}, Orange: {orange_count}, Red: {red_count}")
|
| 146 |
+
|
| 147 |
label = "Unknown"
|
| 148 |
box_list, label_list = [], []
|
| 149 |
|
|
|
|
| 166 |
x, y, w, h = cv2.boundingRect(cnt)
|
| 167 |
box_list.append((x, y, w, h))
|
| 168 |
label_list.append("Point Overload (Faulty)")
|
| 169 |
+
|
| 170 |
+
print(f"[Classification] Final label: {label}, Boxes found: {len(box_list)}")
|
| 171 |
return label, box_list, label_list, img
|
| 172 |
|
| 173 |
|
|
|
|
| 223 |
"""Download image from URL to temp file"""
|
| 224 |
import requests
|
| 225 |
import tempfile
|
| 226 |
+
from urllib.parse import urlparse
|
| 227 |
+
import mimetypes
|
| 228 |
+
|
| 229 |
response = requests.get(url, stream=True)
|
| 230 |
if response.status_code != 200:
|
| 231 |
raise Exception(f"Failed to download image from {url}")
|
| 232 |
+
|
| 233 |
+
# Determine file extension from URL or Content-Type
|
| 234 |
+
content_type = response.headers.get('content-type', '')
|
| 235 |
+
if 'image/png' in content_type:
|
| 236 |
+
suffix = '.png'
|
| 237 |
+
elif 'image/jpeg' in content_type or 'image/jpg' in content_type:
|
| 238 |
+
suffix = '.jpg'
|
| 239 |
+
else:
|
| 240 |
+
# Try to get extension from URL
|
| 241 |
+
parsed_url = urlparse(url)
|
| 242 |
+
path = parsed_url.path
|
| 243 |
+
ext = os.path.splitext(path)[1]
|
| 244 |
+
suffix = ext if ext in ['.jpg', '.jpeg', '.png', '.bmp'] else '.jpg'
|
| 245 |
+
|
| 246 |
+
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
| 247 |
for chunk in response.iter_content(1024):
|
| 248 |
tmp.write(chunk)
|
| 249 |
tmp.close()
|