Update script.py
Browse files
script.py
CHANGED
|
@@ -1,91 +1,94 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
-
import glob
|
| 4 |
from ultralytics import YOLO
|
|
|
|
| 5 |
|
| 6 |
# --- CONFIGURATION ---
|
| 7 |
-
#
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
#
|
| 15 |
-
# We
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
0: 1, # Large Needle Driver
|
| 19 |
-
1: 2, # Prograsp Forceps
|
| 20 |
-
2: 3 # Monopolar Curved Scissors
|
| 21 |
-
}
|
| 22 |
|
| 23 |
-
def
|
| 24 |
-
print(f"π
|
| 25 |
|
| 26 |
-
# 1.
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# 2. Get all test images
|
| 34 |
-
# We look for jpg, png, and jpeg
|
| 35 |
-
image_paths = glob.glob(os.path.join(TEST_IMAGES_DIR, "*.*"))
|
| 36 |
-
print(f"π Found {len(image_paths)} images in {TEST_IMAGES_DIR}")
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
source=TEST_IMAGES_DIR,
|
| 44 |
-
conf=0.25, # Confidence threshold (0.25 is standard for mAP)
|
| 45 |
-
iou=0.45, # NMS IoU threshold
|
| 46 |
-
save=False, # Don't save plotted images (saves time)
|
| 47 |
-
save_txt=False,
|
| 48 |
-
verbose=False,
|
| 49 |
-
stream=True
|
| 50 |
-
)
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# Get filename without extension (often used as image_id)
|
| 59 |
-
image_id = os.path.splitext(file_name)[0]
|
| 60 |
-
|
| 61 |
-
# Loop through detections
|
| 62 |
-
for box in result.boxes:
|
| 63 |
-
# Get data
|
| 64 |
-
cls_id = int(box.cls[0])
|
| 65 |
-
score = float(box.conf[0])
|
| 66 |
-
bbox = box.xywh[0].tolist() # x_center, y_center, width, height (Normalized? No, usually pixels)
|
| 67 |
|
| 68 |
-
# YOLOv8
|
| 69 |
-
|
| 70 |
-
x_c, y_c, w, h = bbox
|
| 71 |
-
x_min = x_c - (w / 2)
|
| 72 |
-
y_min = y_c - (h / 2)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
with open(OUTPUT_FILE, 'w') as f:
|
| 86 |
-
json.dump(submission_results, f, indent=4)
|
| 87 |
-
|
| 88 |
print("β
Done!")
|
| 89 |
|
| 90 |
if __name__ == "__main__":
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import pandas as pd
|
|
|
|
| 3 |
from ultralytics import YOLO
|
| 4 |
+
import glob
|
| 5 |
|
| 6 |
# --- CONFIGURATION ---
|
| 7 |
+
# Use the exact paths from the baseline script to avoid "File Not Found" errors
|
| 8 |
+
# The server puts images in "/tmp/data/test_images"
|
| 9 |
+
TEST_IMAGE_PATH = "/tmp/data/test_images"
|
| 10 |
+
SUBMISSION_SAVE_PATH = "submission.csv"
|
| 11 |
+
MODEL_WEIGHTS = "best.pt"
|
| 12 |
+
CONF_THRESHOLD = 0.30
|
| 13 |
|
| 14 |
+
# Mapping: YOLO ID (0,1,2) -> Competition Category ID (likely 1,2,3)
|
| 15 |
+
# If your previous training used 0,1,2, usually competitions want 1-based IDs.
|
| 16 |
+
# We will add +1 to be safe, matching standard Phase 2 rules.
|
| 17 |
+
def get_category_id(cls_id):
|
| 18 |
+
return int(cls_id) + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
def run_inference(model, image_path, conf_threshold, save_path):
|
| 21 |
+
print(f"π Checking for images in {image_path}...")
|
| 22 |
|
| 23 |
+
# 1. Get all images (support multiple extensions)
|
| 24 |
+
if os.path.exists(image_path):
|
| 25 |
+
test_images = [f for f in os.listdir(image_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
|
| 26 |
+
test_images.sort()
|
| 27 |
+
else:
|
| 28 |
+
print(f"β οΈ Warning: {image_path} not found. Using current folder for testing.")
|
| 29 |
+
test_images = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
print(f"π Found {len(test_images)} images.")
|
| 32 |
|
| 33 |
+
# Prepare lists for the dataframe
|
| 34 |
+
# The baseline wants a specific format: stringified lists of lists
|
| 35 |
+
df_rows = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
# 2. Run Inference
|
| 38 |
+
if len(test_images) > 0:
|
| 39 |
+
# Load images one by one or in batches
|
| 40 |
+
# We loop to match the baseline structure exactly
|
| 41 |
+
for image_name in test_images:
|
| 42 |
+
full_path = os.path.join(image_path, image_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
# Run YOLOv8 inference
|
| 45 |
+
results = model.predict(full_path, conf=conf_threshold, verbose=False)
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
bbox_list = []
|
| 48 |
+
category_list = []
|
| 49 |
+
|
| 50 |
+
for result in results:
|
| 51 |
+
for box in result.boxes:
|
| 52 |
+
# Get box in xywh format (x, y, width, height)
|
| 53 |
+
# The baseline calculated width = xmax - xmin, so it wants [x, y, w, h]
|
| 54 |
+
x_c, y_c, w, h = box.xywh[0].tolist()
|
| 55 |
+
|
| 56 |
+
# Convert Center-XY to Top-Left-XY
|
| 57 |
+
x_min = x_c - (w / 2)
|
| 58 |
+
y_min = y_c - (h / 2)
|
| 59 |
+
|
| 60 |
+
# Append to list
|
| 61 |
+
bbox_list.append([x_min, y_min, w, h])
|
| 62 |
+
|
| 63 |
+
# Get Class ID
|
| 64 |
+
cls_id = int(box.cls[0])
|
| 65 |
+
category_list.append(get_category_id(cls_id))
|
| 66 |
+
|
| 67 |
+
# 3. Format exactly like the baseline
|
| 68 |
+
# It expects columns: file_name, bbox (string), category_id (string)
|
| 69 |
+
df_rows.append({
|
| 70 |
+
"file_name": image_name,
|
| 71 |
+
"bbox": str(bbox_list), # e.g. "[[10, 20, 50, 50]]"
|
| 72 |
+
"category_id": str(category_list) # e.g. "[1]"
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
# 4. Create DataFrame and Save
|
| 76 |
+
df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"])
|
| 77 |
+
|
| 78 |
+
# Safety check: if empty, create empty CSV with correct headers
|
| 79 |
+
if df_predictions.empty:
|
| 80 |
+
df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
|
| 81 |
|
| 82 |
+
print(f"πΎ Saving {len(df_predictions)} rows to {save_path}...")
|
| 83 |
+
df_predictions.to_csv(save_path, index=False)
|
|
|
|
|
|
|
|
|
|
| 84 |
print("β
Done!")
|
| 85 |
|
| 86 |
if __name__ == "__main__":
|
| 87 |
+
# Load your YOLOv8 model
|
| 88 |
+
# Note: We do NOT use torch.hub anymore
|
| 89 |
+
print(f"π₯ Loading YOLOv8 model: {MODEL_WEIGHTS}...")
|
| 90 |
+
try:
|
| 91 |
+
model = YOLO(MODEL_WEIGHTS)
|
| 92 |
+
run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"β Critical Error: {e}")
|