File size: 2,874 Bytes
b602b83
fac07d4
6412f39
fac07d4
b602b83
6412f39
fac07d4
 
 
b602b83
ced6b45
 
 
29664ea
fac07d4
ced6b45
87f1d97
b602b83
fac07d4
 
b602b83
29664ea
fac07d4
29664ea
 
 
fac07d4
29664ea
fac07d4
87f1d97
6412f39
fac07d4
 
6412f39
fac07d4
 
 
b602b83
ced6b45
9d74de7
 
 
0e1cacf
ced6b45
 
 
9d74de7
 
b602b83
fac07d4
 
 
 
 
87f1d97
fac07d4
 
 
 
 
 
 
 
 
 
 
87f1d97
 
fac07d4
 
 
 
 
b602b83
fac07d4
6412f39
b602b83
 
fac07d4
87f1d97
fac07d4
 
 
87f1d97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import pandas as pd
from ultralytics import YOLO
import glob

# --- CONFIGURATION ---
TEST_IMAGE_PATH = "/tmp/data/test_images"
SUBMISSION_SAVE_PATH = "submission.csv"
MODEL_WEIGHTS = "best.pt"

# ✅ 1. Balanced Confidence
# 0.10 captures weaker tools but keeps precision reasonable.
CONF_THRESHOLD = 0.10

def get_category_id(cls_id):
    # IDs 0, 1, 2 are correct.
    return int(cls_id)

def run_inference(model, image_path, conf_threshold, save_path):
    print(f"🚀 Checking for images in {image_path}...")
    
    test_images = []
    if os.path.exists(image_path):
        extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp']
        for ext in extensions:
            test_images.extend(glob.glob(os.path.join(image_path, ext)))
        test_images.sort()
        test_images = [os.path.basename(x) for x in test_images]
    else:
        print(f"⚠️ Warning: {image_path} not found.")

    print(f"🔍 Found {len(test_images)} images.")
    df_rows = []

    if len(test_images) > 0:
        for image_name in test_images:
            full_path = os.path.join(image_path, image_name)
            
            # 🚀 BALANCED TTA SETTINGS
            results = model.predict(
                full_path, 
                conf=conf_threshold, 
                imgsz=1024,       # Keep High Res
                augment=True,     # ⬅️ ENABLED: Test Time Augmentation (Boosts accuracy)
                iou=0.60,         # Merge overlapping boxes
                max_det=30,       # ⬅️ INCREASED: 5 was too strict, 30 is safe
                verbose=False
            )
            
            bbox_list = []
            category_list = []
            
            for result in results:
                for box in result.boxes:
                    # YOLO (Center) -> COCO (Top-Left)
                    x_c, y_c, w, h = box.xywh[0].tolist()
                    x_min = x_c - (w / 2)
                    y_min = y_c - (h / 2)
                    
                    bbox_list.append([x_min, y_min, w, h])
                    
                    cls_id = int(box.cls[0])
                    category_list.append(get_category_id(cls_id))

            df_rows.append({
                "file_name": image_name,
                "bbox": str(bbox_list),
                "category_id": str(category_list)
            })

    df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"])
    if df_predictions.empty:
        df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])

    df_predictions.to_csv(save_path, index=False)
    print("✅ Done!")

if __name__ == "__main__":
    try:
        print(f"🔥 Loading Model...")
        model = YOLO(MODEL_WEIGHTS)
        run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)
    except Exception as e:
        print(f"❌ Error: {e}")