phase2a / script.py
yusufbardolia's picture
Update script.py
ced6b45 verified
import os
import pandas as pd
from ultralytics import YOLO
import glob
# --- CONFIGURATION ---
TEST_IMAGE_PATH = "/tmp/data/test_images"
SUBMISSION_SAVE_PATH = "submission.csv"
MODEL_WEIGHTS = "best.pt"
# βœ… 1. Balanced Confidence
# 0.10 captures weaker tools but keeps precision reasonable.
CONF_THRESHOLD = 0.10
def get_category_id(cls_id):
# IDs 0, 1, 2 are correct.
return int(cls_id)
def run_inference(model, image_path, conf_threshold, save_path):
print(f"πŸš€ Checking for images in {image_path}...")
test_images = []
if os.path.exists(image_path):
extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp']
for ext in extensions:
test_images.extend(glob.glob(os.path.join(image_path, ext)))
test_images.sort()
test_images = [os.path.basename(x) for x in test_images]
else:
print(f"⚠️ Warning: {image_path} not found.")
print(f"πŸ” Found {len(test_images)} images.")
df_rows = []
if len(test_images) > 0:
for image_name in test_images:
full_path = os.path.join(image_path, image_name)
# πŸš€ BALANCED TTA SETTINGS
results = model.predict(
full_path,
conf=conf_threshold,
imgsz=1024, # Keep High Res
augment=True, # ⬅️ ENABLED: Test Time Augmentation (Boosts accuracy)
iou=0.60, # Merge overlapping boxes
max_det=30, # ⬅️ INCREASED: 5 was too strict, 30 is safe
verbose=False
)
bbox_list = []
category_list = []
for result in results:
for box in result.boxes:
# YOLO (Center) -> COCO (Top-Left)
x_c, y_c, w, h = box.xywh[0].tolist()
x_min = x_c - (w / 2)
y_min = y_c - (h / 2)
bbox_list.append([x_min, y_min, w, h])
cls_id = int(box.cls[0])
category_list.append(get_category_id(cls_id))
df_rows.append({
"file_name": image_name,
"bbox": str(bbox_list),
"category_id": str(category_list)
})
df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"])
if df_predictions.empty:
df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
df_predictions.to_csv(save_path, index=False)
print("βœ… Done!")
if __name__ == "__main__":
try:
print(f"πŸ”₯ Loading Model...")
model = YOLO(MODEL_WEIGHTS)
run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)
except Exception as e:
print(f"❌ Error: {e}")