File size: 2,874 Bytes
b602b83 fac07d4 6412f39 fac07d4 b602b83 6412f39 fac07d4 b602b83 ced6b45 29664ea fac07d4 ced6b45 87f1d97 b602b83 fac07d4 b602b83 29664ea fac07d4 29664ea fac07d4 29664ea fac07d4 87f1d97 6412f39 fac07d4 6412f39 fac07d4 b602b83 ced6b45 9d74de7 0e1cacf ced6b45 9d74de7 b602b83 fac07d4 87f1d97 fac07d4 87f1d97 fac07d4 b602b83 fac07d4 6412f39 b602b83 fac07d4 87f1d97 fac07d4 87f1d97 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import os
import pandas as pd
from ultralytics import YOLO
import glob
# --- CONFIGURATION ---
TEST_IMAGE_PATH = "/tmp/data/test_images"
SUBMISSION_SAVE_PATH = "submission.csv"
MODEL_WEIGHTS = "best.pt"
# ✅ 1. Balanced Confidence
# 0.10 captures weaker tools but keeps precision reasonable.
CONF_THRESHOLD = 0.10
def get_category_id(cls_id):
# IDs 0, 1, 2 are correct.
return int(cls_id)
def run_inference(model, image_path, conf_threshold, save_path):
print(f"🚀 Checking for images in {image_path}...")
test_images = []
if os.path.exists(image_path):
extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp']
for ext in extensions:
test_images.extend(glob.glob(os.path.join(image_path, ext)))
test_images.sort()
test_images = [os.path.basename(x) for x in test_images]
else:
print(f"⚠️ Warning: {image_path} not found.")
print(f"🔍 Found {len(test_images)} images.")
df_rows = []
if len(test_images) > 0:
for image_name in test_images:
full_path = os.path.join(image_path, image_name)
# 🚀 BALANCED TTA SETTINGS
results = model.predict(
full_path,
conf=conf_threshold,
imgsz=1024, # Keep High Res
augment=True, # ⬅️ ENABLED: Test Time Augmentation (Boosts accuracy)
iou=0.60, # Merge overlapping boxes
max_det=30, # ⬅️ INCREASED: 5 was too strict, 30 is safe
verbose=False
)
bbox_list = []
category_list = []
for result in results:
for box in result.boxes:
# YOLO (Center) -> COCO (Top-Left)
x_c, y_c, w, h = box.xywh[0].tolist()
x_min = x_c - (w / 2)
y_min = y_c - (h / 2)
bbox_list.append([x_min, y_min, w, h])
cls_id = int(box.cls[0])
category_list.append(get_category_id(cls_id))
df_rows.append({
"file_name": image_name,
"bbox": str(bbox_list),
"category_id": str(category_list)
})
df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"])
if df_predictions.empty:
df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
df_predictions.to_csv(save_path, index=False)
print("✅ Done!")
if __name__ == "__main__":
try:
print(f"🔥 Loading Model...")
model = YOLO(MODEL_WEIGHTS)
run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)
except Exception as e:
print(f"❌ Error: {e}") |