Update script.py
Browse files
script.py
CHANGED
|
@@ -8,13 +8,12 @@ TEST_IMAGE_PATH = "/tmp/data/test_images"
|
|
| 8 |
SUBMISSION_SAVE_PATH = "submission.csv"
|
| 9 |
MODEL_WEIGHTS = "best.pt"
|
| 10 |
|
| 11 |
-
# ✅ 1.
|
| 12 |
-
# 0.
|
| 13 |
-
|
| 14 |
-
CONF_THRESHOLD = 0.05
|
| 15 |
|
| 16 |
def get_category_id(cls_id):
|
| 17 |
-
# IDs 0, 1, 2 are correct
|
| 18 |
return int(cls_id)
|
| 19 |
|
| 20 |
def run_inference(model, image_path, conf_threshold, save_path):
|
|
@@ -37,14 +36,14 @@ def run_inference(model, image_path, conf_threshold, save_path):
|
|
| 37 |
for image_name in test_images:
|
| 38 |
full_path = os.path.join(image_path, image_name)
|
| 39 |
|
| 40 |
-
# 🚀
|
| 41 |
results = model.predict(
|
| 42 |
full_path,
|
| 43 |
conf=conf_threshold,
|
| 44 |
imgsz=1024, # Keep High Res
|
| 45 |
-
augment=
|
| 46 |
-
iou=0.60, #
|
| 47 |
-
max_det=
|
| 48 |
verbose=False
|
| 49 |
)
|
| 50 |
|
|
|
|
| 8 |
SUBMISSION_SAVE_PATH = "submission.csv"
|
| 9 |
MODEL_WEIGHTS = "best.pt"
|
| 10 |
|
| 11 |
+
# ✅ 1. Balanced Confidence
|
| 12 |
+
# 0.10 captures weaker tools but keeps precision reasonable.
|
| 13 |
+
CONF_THRESHOLD = 0.10
|
|
|
|
| 14 |
|
| 15 |
def get_category_id(cls_id):
|
| 16 |
+
# IDs 0, 1, 2 are correct.
|
| 17 |
return int(cls_id)
|
| 18 |
|
| 19 |
def run_inference(model, image_path, conf_threshold, save_path):
|
|
|
|
| 36 |
for image_name in test_images:
|
| 37 |
full_path = os.path.join(image_path, image_name)
|
| 38 |
|
| 39 |
+
# 🚀 BALANCED TTA SETTINGS
|
| 40 |
results = model.predict(
|
| 41 |
full_path,
|
| 42 |
conf=conf_threshold,
|
| 43 |
imgsz=1024, # Keep High Res
|
| 44 |
+
augment=True, # ⬅️ ENABLED: Test Time Augmentation (Boosts accuracy)
|
| 45 |
+
iou=0.60, # Merge overlapping boxes
|
| 46 |
+
max_det=30, # ⬅️ INCREASED: 5 was too strict, 30 is safe
|
| 47 |
verbose=False
|
| 48 |
)
|
| 49 |
|