| import os |
| import pandas as pd |
| from ultralytics import YOLO |
| import glob |
|
|
| |
| TEST_IMAGE_PATH = "/tmp/data/test_images" |
| SUBMISSION_SAVE_PATH = "submission.csv" |
| MODEL_WEIGHTS = "best.pt" |
|
|
| |
| |
| CONF_THRESHOLD = 0.10 |
|
|
| def get_category_id(cls_id): |
| |
| return int(cls_id) |
|
|
| def run_inference(model, image_path, conf_threshold, save_path): |
| print(f"π Checking for images in {image_path}...") |
| |
| test_images = [] |
| if os.path.exists(image_path): |
| extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp'] |
| for ext in extensions: |
| test_images.extend(glob.glob(os.path.join(image_path, ext))) |
| test_images.sort() |
| test_images = [os.path.basename(x) for x in test_images] |
| else: |
| print(f"β οΈ Warning: {image_path} not found.") |
|
|
| print(f"π Found {len(test_images)} images.") |
| df_rows = [] |
|
|
| if len(test_images) > 0: |
| for image_name in test_images: |
| full_path = os.path.join(image_path, image_name) |
| |
| |
| results = model.predict( |
| full_path, |
| conf=conf_threshold, |
| imgsz=1024, |
| augment=True, |
| iou=0.60, |
| max_det=30, |
| verbose=False |
| ) |
| |
| bbox_list = [] |
| category_list = [] |
| |
| for result in results: |
| for box in result.boxes: |
| |
| x_c, y_c, w, h = box.xywh[0].tolist() |
| x_min = x_c - (w / 2) |
| y_min = y_c - (h / 2) |
| |
| bbox_list.append([x_min, y_min, w, h]) |
| |
| cls_id = int(box.cls[0]) |
| category_list.append(get_category_id(cls_id)) |
|
|
| df_rows.append({ |
| "file_name": image_name, |
| "bbox": str(bbox_list), |
| "category_id": str(category_list) |
| }) |
|
|
| df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"]) |
| if df_predictions.empty: |
| df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"]) |
|
|
| df_predictions.to_csv(save_path, index=False) |
| print("β
Done!") |
|
|
| if __name__ == "__main__": |
| try: |
| print(f"π₯ Loading Model...") |
| model = YOLO(MODEL_WEIGHTS) |
| run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH) |
| except Exception as e: |
| print(f"β Error: {e}") |