Update script.py
Browse files
script.py
CHANGED
|
@@ -4,44 +4,52 @@ from ultralytics import YOLO
|
|
| 4 |
import glob
|
| 5 |
|
| 6 |
# --- CONFIGURATION ---
|
| 7 |
-
#
|
| 8 |
-
# The server puts images in "/tmp/data/test_images"
|
| 9 |
TEST_IMAGE_PATH = "/tmp/data/test_images"
|
| 10 |
SUBMISSION_SAVE_PATH = "submission.csv"
|
| 11 |
MODEL_WEIGHTS = "best.pt"
|
| 12 |
-
CONF_THRESHOLD = 0.30
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
| 16 |
-
|
| 17 |
def get_category_id(cls_id):
|
| 18 |
-
return int(cls_id) + 1
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def run_inference(model, image_path, conf_threshold, save_path):
|
| 21 |
print(f"π Checking for images in {image_path}...")
|
| 22 |
|
| 23 |
# 1. Get all images (support multiple extensions)
|
|
|
|
| 24 |
if os.path.exists(image_path):
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
test_images.sort()
|
|
|
|
|
|
|
| 27 |
else:
|
| 28 |
-
print(f"β οΈ Warning: {image_path} not found.
|
| 29 |
-
|
|
|
|
| 30 |
|
| 31 |
print(f"π Found {len(test_images)} images.")
|
| 32 |
|
| 33 |
# Prepare lists for the dataframe
|
| 34 |
-
# The baseline wants a specific format: stringified lists of lists
|
| 35 |
df_rows = []
|
| 36 |
|
| 37 |
# 2. Run Inference
|
| 38 |
if len(test_images) > 0:
|
| 39 |
-
# Load images one by one or in batches
|
| 40 |
-
# We loop to match the baseline structure exactly
|
| 41 |
for image_name in test_images:
|
| 42 |
full_path = os.path.join(image_path, image_name)
|
| 43 |
|
| 44 |
# Run YOLOv8 inference
|
|
|
|
| 45 |
results = model.predict(full_path, conf=conf_threshold, verbose=False)
|
| 46 |
|
| 47 |
bbox_list = []
|
|
@@ -49,34 +57,34 @@ def run_inference(model, image_path, conf_threshold, save_path):
|
|
| 49 |
|
| 50 |
for result in results:
|
| 51 |
for box in result.boxes:
|
| 52 |
-
#
|
| 53 |
-
# The baseline calculated width = xmax - xmin, so it wants [x, y, w, h]
|
| 54 |
x_c, y_c, w, h = box.xywh[0].tolist()
|
| 55 |
|
| 56 |
-
# Convert
|
| 57 |
x_min = x_c - (w / 2)
|
| 58 |
y_min = y_c - (h / 2)
|
| 59 |
|
| 60 |
-
#
|
| 61 |
bbox_list.append([x_min, y_min, w, h])
|
| 62 |
|
| 63 |
-
# Get
|
| 64 |
cls_id = int(box.cls[0])
|
| 65 |
category_list.append(get_category_id(cls_id))
|
| 66 |
|
| 67 |
-
# 3. Format
|
| 68 |
-
#
|
| 69 |
df_rows.append({
|
| 70 |
"file_name": image_name,
|
| 71 |
-
"bbox": str(bbox_list), # e.g. "[[10, 20, 50, 50]]"
|
| 72 |
-
"category_id": str(category_list) # e.g. "[
|
| 73 |
})
|
| 74 |
|
| 75 |
# 4. Create DataFrame and Save
|
| 76 |
df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"])
|
| 77 |
|
| 78 |
-
# Safety check: if
|
| 79 |
if df_predictions.empty:
|
|
|
|
| 80 |
df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
|
| 81 |
|
| 82 |
print(f"πΎ Saving {len(df_predictions)} rows to {save_path}...")
|
|
@@ -84,11 +92,15 @@ def run_inference(model, image_path, conf_threshold, save_path):
|
|
| 84 |
print("β
Done!")
|
| 85 |
|
| 86 |
if __name__ == "__main__":
|
| 87 |
-
#
|
| 88 |
-
# Note: We do NOT use torch.hub anymore
|
| 89 |
print(f"π₯ Loading YOLOv8 model: {MODEL_WEIGHTS}...")
|
|
|
|
| 90 |
try:
|
|
|
|
| 91 |
model = YOLO(MODEL_WEIGHTS)
|
|
|
|
|
|
|
| 92 |
run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)
|
|
|
|
| 93 |
except Exception as e:
|
| 94 |
-
print(f"β Critical Error: {e}")
|
|
|
|
| 4 |
import glob
|
| 5 |
|
| 6 |
# --- CONFIGURATION ---
|
| 7 |
+
# The evaluation server usually puts test images here
|
|
|
|
| 8 |
TEST_IMAGE_PATH = "/tmp/data/test_images"
|
| 9 |
SUBMISSION_SAVE_PATH = "submission.csv"
|
| 10 |
MODEL_WEIGHTS = "best.pt"
|
|
|
|
| 11 |
|
| 12 |
+
# You can tune this. 0.25 - 0.30 is usually a safe balance for mAP.
|
| 13 |
+
CONF_THRESHOLD = 0.25
|
| 14 |
+
|
| 15 |
def get_category_id(cls_id):
|
| 16 |
+
# β OLD INCORRECT LOGIC: return int(cls_id) + 1
|
| 17 |
+
# β
CORRECT LOGIC: Your JSON annotations use IDs 0, 1, 2.
|
| 18 |
+
# Since your model was trained on 0, 1, 2, we return the class ID directly.
|
| 19 |
+
return int(cls_id)
|
| 20 |
|
| 21 |
def run_inference(model, image_path, conf_threshold, save_path):
|
| 22 |
print(f"π Checking for images in {image_path}...")
|
| 23 |
|
| 24 |
# 1. Get all images (support multiple extensions)
|
| 25 |
+
test_images = []
|
| 26 |
if os.path.exists(image_path):
|
| 27 |
+
# Recursive search just in case, though usually they are flat in the folder
|
| 28 |
+
extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp']
|
| 29 |
+
for ext in extensions:
|
| 30 |
+
test_images.extend(glob.glob(os.path.join(image_path, ext)))
|
| 31 |
+
|
| 32 |
+
# Sort to ensure consistent order if needed
|
| 33 |
test_images.sort()
|
| 34 |
+
# Just keep filenames for the loop to match your specific structure requirements
|
| 35 |
+
test_images = [os.path.basename(x) for x in test_images]
|
| 36 |
else:
|
| 37 |
+
print(f"β οΈ Warning: {image_path} not found. Testing with empty list (or verify path).")
|
| 38 |
+
# If running locally for debug, you might want to point to a local folder here
|
| 39 |
+
# test_images = os.listdir(".")
|
| 40 |
|
| 41 |
print(f"π Found {len(test_images)} images.")
|
| 42 |
|
| 43 |
# Prepare lists for the dataframe
|
|
|
|
| 44 |
df_rows = []
|
| 45 |
|
| 46 |
# 2. Run Inference
|
| 47 |
if len(test_images) > 0:
|
|
|
|
|
|
|
| 48 |
for image_name in test_images:
|
| 49 |
full_path = os.path.join(image_path, image_name)
|
| 50 |
|
| 51 |
# Run YOLOv8 inference
|
| 52 |
+
# verbose=False keeps the logs clean
|
| 53 |
results = model.predict(full_path, conf=conf_threshold, verbose=False)
|
| 54 |
|
| 55 |
bbox_list = []
|
|
|
|
| 57 |
|
| 58 |
for result in results:
|
| 59 |
for box in result.boxes:
|
| 60 |
+
# YOLO format is Center-X, Center-Y, Width, Height
|
|
|
|
| 61 |
x_c, y_c, w, h = box.xywh[0].tolist()
|
| 62 |
|
| 63 |
+
# Convert to COCO format: Top-Left-X, Top-Left-Y, Width, Height
|
| 64 |
x_min = x_c - (w / 2)
|
| 65 |
y_min = y_c - (h / 2)
|
| 66 |
|
| 67 |
+
# Store as list [x, y, w, h]
|
| 68 |
bbox_list.append([x_min, y_min, w, h])
|
| 69 |
|
| 70 |
+
# Get predicted class ID (0, 1, or 2)
|
| 71 |
cls_id = int(box.cls[0])
|
| 72 |
category_list.append(get_category_id(cls_id))
|
| 73 |
|
| 74 |
+
# 3. Format Prediction String
|
| 75 |
+
# The baseline requires the bbox and category_id columns to be Strings of Lists
|
| 76 |
df_rows.append({
|
| 77 |
"file_name": image_name,
|
| 78 |
+
"bbox": str(bbox_list), # e.g. "[[10.5, 20.1, 50.0, 50.0]]"
|
| 79 |
+
"category_id": str(category_list) # e.g. "[0]" or "[0, 2]"
|
| 80 |
})
|
| 81 |
|
| 82 |
# 4. Create DataFrame and Save
|
| 83 |
df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"])
|
| 84 |
|
| 85 |
+
# Safety check: if no images found, create an empty CSV with headers to prevent crash
|
| 86 |
if df_predictions.empty:
|
| 87 |
+
print("β οΈ No predictions made (no images found?). Creating empty CSV.")
|
| 88 |
df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
|
| 89 |
|
| 90 |
print(f"πΎ Saving {len(df_predictions)} rows to {save_path}...")
|
|
|
|
| 92 |
print("β
Done!")
|
| 93 |
|
| 94 |
if __name__ == "__main__":
|
| 95 |
+
# Ensure ultralytics is installed via requirements.txt before this runs
|
|
|
|
| 96 |
print(f"π₯ Loading YOLOv8 model: {MODEL_WEIGHTS}...")
|
| 97 |
+
|
| 98 |
try:
|
| 99 |
+
# Load the trained model
|
| 100 |
model = YOLO(MODEL_WEIGHTS)
|
| 101 |
+
|
| 102 |
+
# Execute inference
|
| 103 |
run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)
|
| 104 |
+
|
| 105 |
except Exception as e:
|
| 106 |
+
print(f"β Critical Error in main execution: {e}")
|