import os import pandas as pd from ultralytics import YOLO import glob # --- CONFIGURATION --- TEST_IMAGE_PATH = "/tmp/data/test_images" SUBMISSION_SAVE_PATH = "submission.csv" MODEL_WEIGHTS = "best.pt" # ✅ 1. Balanced Confidence # 0.10 captures weaker tools but keeps precision reasonable. CONF_THRESHOLD = 0.10 def get_category_id(cls_id): # IDs 0, 1, 2 are correct. return int(cls_id) def run_inference(model, image_path, conf_threshold, save_path): print(f"🚀 Checking for images in {image_path}...") test_images = [] if os.path.exists(image_path): extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp'] for ext in extensions: test_images.extend(glob.glob(os.path.join(image_path, ext))) test_images.sort() test_images = [os.path.basename(x) for x in test_images] else: print(f"⚠️ Warning: {image_path} not found.") print(f"🔍 Found {len(test_images)} images.") df_rows = [] if len(test_images) > 0: for image_name in test_images: full_path = os.path.join(image_path, image_name) # 🚀 BALANCED TTA SETTINGS results = model.predict( full_path, conf=conf_threshold, imgsz=1024, # Keep High Res augment=True, # ⬅️ ENABLED: Test Time Augmentation (Boosts accuracy) iou=0.60, # Merge overlapping boxes max_det=30, # ⬅️ INCREASED: 5 was too strict, 30 is safe verbose=False ) bbox_list = [] category_list = [] for result in results: for box in result.boxes: # YOLO (Center) -> COCO (Top-Left) x_c, y_c, w, h = box.xywh[0].tolist() x_min = x_c - (w / 2) y_min = y_c - (h / 2) bbox_list.append([x_min, y_min, w, h]) cls_id = int(box.cls[0]) category_list.append(get_category_id(cls_id)) df_rows.append({ "file_name": image_name, "bbox": str(bbox_list), "category_id": str(category_list) }) df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"]) if df_predictions.empty: df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"]) df_predictions.to_csv(save_path, index=False) print("✅ Done!") if __name__ == "__main__": try: print(f"🔥 Loading Model...") model = YOLO(MODEL_WEIGHTS) run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH) except Exception as e: print(f"❌ Error: {e}")