Spaces:
Sleeping
Sleeping
slau8405 commited on
Commit ·
c495fed
1
Parent(s): 8d41fcd
Added few more features
Browse files
app.py
CHANGED
|
@@ -9,6 +9,7 @@ import supervision as sv
|
|
| 9 |
import uuid
|
| 10 |
import random
|
| 11 |
from pathlib import Path
|
|
|
|
| 12 |
|
| 13 |
class PolygonAugmentation:
|
| 14 |
def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False):
|
|
@@ -34,12 +35,10 @@ class PolygonAugmentation:
|
|
| 34 |
else:
|
| 35 |
data = json.load(json_file)
|
| 36 |
|
| 37 |
-
# Check for 'shapes' (LabelMe) or 'segments' (custom format)
|
| 38 |
shapes = []
|
| 39 |
if 'shapes' in data and isinstance(data['shapes'], list):
|
| 40 |
shapes = data['shapes']
|
| 41 |
elif 'segments' in data and isinstance(data['segments'], list):
|
| 42 |
-
# Convert custom 'segments' to LabelMe 'shapes' format
|
| 43 |
shapes = [
|
| 44 |
{
|
| 45 |
"label": seg.get("class", "unknown"),
|
|
@@ -52,7 +51,7 @@ class PolygonAugmentation:
|
|
| 52 |
for seg in data['segments']
|
| 53 |
]
|
| 54 |
else:
|
| 55 |
-
raise ValueError("Invalid JSON: Neither 'shapes' nor '
|
| 56 |
|
| 57 |
polygons = []
|
| 58 |
labels = []
|
|
@@ -156,7 +155,7 @@ class PolygonAugmentation:
|
|
| 156 |
"group_id": None,
|
| 157 |
"shape_type": "polygon",
|
| 158 |
"flags": {},
|
| 159 |
-
"confidence": 1.0
|
| 160 |
})
|
| 161 |
|
| 162 |
aug_data = {
|
|
@@ -343,7 +342,7 @@ class PolygonAugmentation:
|
|
| 343 |
contrast_limit=aug_param,
|
| 344 |
p=1.0
|
| 345 |
),
|
| 346 |
-
"pixel_dropout": A.PixelDropout(dropout_prob=aug_param, p=1.0)
|
| 347 |
}
|
| 348 |
|
| 349 |
if aug_type not in aug_dict:
|
|
@@ -372,6 +371,21 @@ class PolygonAugmentation:
|
|
| 372 |
|
| 373 |
def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param: float):
|
| 374 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
# Convert PIL image to NumPy
|
| 376 |
img_np = np.array(image)
|
| 377 |
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
|
@@ -387,13 +401,18 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
|
|
| 387 |
img_np, polygons, labels, original_areas, original_data, aug_type, aug_param
|
| 388 |
)
|
| 389 |
|
| 390 |
-
# Create a color map for unique labels
|
| 391 |
unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
# Convert augmented image to RGB for visualization
|
| 399 |
aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
|
|
@@ -403,7 +422,7 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
|
|
| 403 |
height, width = aug_image.shape[:2]
|
| 404 |
for shape in aug_data['shapes']:
|
| 405 |
label = shape['label']
|
| 406 |
-
color = label_color_map
|
| 407 |
points = np.array(shape['points'], dtype=np.int32)
|
| 408 |
|
| 409 |
# Draw filled mask with transparency
|
|
@@ -411,7 +430,7 @@ def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param:
|
|
| 411 |
cv2.fillPoly(mask, [points], 1)
|
| 412 |
colored_mask = np.zeros_like(aug_image_rgb)
|
| 413 |
colored_mask[mask == 1] = color
|
| 414 |
-
alpha = 0.3
|
| 415 |
cv2.addWeighted(colored_mask, alpha, overlay, 1 - alpha, 0, overlay)
|
| 416 |
|
| 417 |
# Draw polygon outline
|
|
@@ -455,7 +474,8 @@ def create_interface():
|
|
| 455 |
minimum=aug_options["Rotate"]["range"][0],
|
| 456 |
maximum=aug_options["Rotate"]["range"][1],
|
| 457 |
value=aug_options["Rotate"]["default"],
|
| 458 |
-
label=aug_options["Rotate"]["param_name"]
|
|
|
|
| 459 |
)
|
| 460 |
|
| 461 |
def update_slider(aug_type):
|
|
@@ -464,7 +484,8 @@ def create_interface():
|
|
| 464 |
minimum=aug_options[aug_type]["range"][0],
|
| 465 |
maximum=aug_options[aug_type]["range"][1],
|
| 466 |
value=aug_options[aug_type]["default"],
|
| 467 |
-
label=aug_options[aug_type]["param_name"]
|
|
|
|
| 468 |
)
|
| 469 |
}
|
| 470 |
|
|
|
|
| 9 |
import uuid
|
| 10 |
import random
|
| 11 |
from pathlib import Path
|
| 12 |
+
import colorsys
|
| 13 |
|
| 14 |
class PolygonAugmentation:
|
| 15 |
def __init__(self, tolerance=0.2, area_threshold=0.01, debug=False):
|
|
|
|
| 35 |
else:
|
| 36 |
data = json.load(json_file)
|
| 37 |
|
|
|
|
| 38 |
shapes = []
|
| 39 |
if 'shapes' in data and isinstance(data['shapes'], list):
|
| 40 |
shapes = data['shapes']
|
| 41 |
elif 'segments' in data and isinstance(data['segments'], list):
|
|
|
|
| 42 |
shapes = [
|
| 43 |
{
|
| 44 |
"label": seg.get("class", "unknown"),
|
|
|
|
| 51 |
for seg in data['segments']
|
| 52 |
]
|
| 53 |
else:
|
| 54 |
+
raise ValueError("Invalid JSON: Neither 'shapes' nor 'segments' key found or not a list")
|
| 55 |
|
| 56 |
polygons = []
|
| 57 |
labels = []
|
|
|
|
| 155 |
"group_id": None,
|
| 156 |
"shape_type": "polygon",
|
| 157 |
"flags": {},
|
| 158 |
+
"confidence": 1.0
|
| 159 |
})
|
| 160 |
|
| 161 |
aug_data = {
|
|
|
|
| 342 |
contrast_limit=aug_param,
|
| 343 |
p=1.0
|
| 344 |
),
|
| 345 |
+
"pixel_dropout": A.PixelDropout(dropout_prob=min(max(aug_param, 0.0), 1.0), p=1.0)
|
| 346 |
}
|
| 347 |
|
| 348 |
if aug_type not in aug_dict:
|
|
|
|
| 371 |
|
| 372 |
def augment_image(image: Image.Image, json_file: Any, aug_type: str, aug_param: float):
|
| 373 |
try:
|
| 374 |
+
# Validate aug_param based on aug_type
|
| 375 |
+
aug_ranges = {
|
| 376 |
+
"Rotate": (-30, 30),
|
| 377 |
+
"Horizontal Flip": (0, 1),
|
| 378 |
+
"Vertical Flip": (0, 1),
|
| 379 |
+
"Scale": (0.5, 1.5),
|
| 380 |
+
"Brightness/Contrast": (-0.3, 0.3),
|
| 381 |
+
"Pixel Dropout": (0.01, 0.1)
|
| 382 |
+
}
|
| 383 |
+
if aug_type not in aug_ranges:
|
| 384 |
+
raise ValueError(f"Invalid augmentation type: {aug_type}")
|
| 385 |
+
min_val, max_val = aug_ranges[aug_type]
|
| 386 |
+
if not (min_val <= aug_param <= max_val):
|
| 387 |
+
raise ValueError(f"Parameter {aug_param} for {aug_type} is out of range [{min_val}, {max_val}]")
|
| 388 |
+
|
| 389 |
# Convert PIL image to NumPy
|
| 390 |
img_np = np.array(image)
|
| 391 |
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
|
|
|
| 401 |
img_np, polygons, labels, original_areas, original_data, aug_type, aug_param
|
| 402 |
)
|
| 403 |
|
| 404 |
+
# Create a dynamic color map for unique labels
|
| 405 |
unique_labels = list(set(shape['label'] for shape in aug_data['shapes']))
|
| 406 |
+
if not unique_labels:
|
| 407 |
+
label_color_map = {"unknown": (0, 255, 0)}
|
| 408 |
+
else:
|
| 409 |
+
num_labels = len(unique_labels)
|
| 410 |
+
hues = [i / num_labels for i in range(num_labels)]
|
| 411 |
+
label_color_map = {}
|
| 412 |
+
for label, hue in zip(unique_labels, hues):
|
| 413 |
+
rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
|
| 414 |
+
rgb = tuple(int(c * 255) for c in rgb)
|
| 415 |
+
label_color_map[label] = rgb
|
| 416 |
|
| 417 |
# Convert augmented image to RGB for visualization
|
| 418 |
aug_image_rgb = cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)
|
|
|
|
| 422 |
height, width = aug_image.shape[:2]
|
| 423 |
for shape in aug_data['shapes']:
|
| 424 |
label = shape['label']
|
| 425 |
+
color = label_color_map.get(label, (0, 255, 0))
|
| 426 |
points = np.array(shape['points'], dtype=np.int32)
|
| 427 |
|
| 428 |
# Draw filled mask with transparency
|
|
|
|
| 430 |
cv2.fillPoly(mask, [points], 1)
|
| 431 |
colored_mask = np.zeros_like(aug_image_rgb)
|
| 432 |
colored_mask[mask == 1] = color
|
| 433 |
+
alpha = 0.3
|
| 434 |
cv2.addWeighted(colored_mask, alpha, overlay, 1 - alpha, 0, overlay)
|
| 435 |
|
| 436 |
# Draw polygon outline
|
|
|
|
| 474 |
minimum=aug_options["Rotate"]["range"][0],
|
| 475 |
maximum=aug_options["Rotate"]["range"][1],
|
| 476 |
value=aug_options["Rotate"]["default"],
|
| 477 |
+
label=aug_options["Rotate"]["param_name"],
|
| 478 |
+
step=0.01
|
| 479 |
)
|
| 480 |
|
| 481 |
def update_slider(aug_type):
|
|
|
|
| 484 |
minimum=aug_options[aug_type]["range"][0],
|
| 485 |
maximum=aug_options[aug_type]["range"][1],
|
| 486 |
value=aug_options[aug_type]["default"],
|
| 487 |
+
label=aug_options[aug_type]["param_name"],
|
| 488 |
+
step=0.01 if aug_type in ["Pixel Dropout", "Brightness/Contrast", "Scale"] else 1
|
| 489 |
)
|
| 490 |
}
|
| 491 |
|