yusufbardolia commited on
Commit
6412f39
·
verified ·
1 Parent(s): 9e492af

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +78 -57
script.py CHANGED
@@ -1,70 +1,91 @@
1
- import torch
2
- from PIL import Image
3
- import matplotlib.pyplot as plt
4
- import cv2
5
- import numpy as np
6
  import os
7
- import pandas as pd
8
- import ultralytics
9
- from torchvision import transforms
10
 
 
 
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
12
 
13
- def run_inference(model, image_path, conf_threshold, save_path):
14
-
15
- test_images = os.listdir(image_path)
16
- test_images.sort()
17
 
18
- bboxes = []
19
- category_ids = []
20
- test_images_names = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- for image in test_images:
23
-
24
- test_images_names.append(image)
25
- bbox = []
26
- category_id = []
27
 
28
- results = model(os.path.join(image_path, image))
 
29
 
30
- for pred in results.pred[0]:
31
- xmin, ymin, xmax, ymax, conf, class_id = pred.tolist()
32
- if conf >= conf_threshold:
 
 
 
33
 
34
- width = xmax - xmin
35
- height = ymax - ymin
36
-
37
- bbox.append([xmin, ymin, width, height])
38
- category_id.append(int(class_id))
39
 
40
- bboxes.append(bbox)
41
- category_ids.append(category_id) # Convert class_id to int
42
-
43
- df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
44
-
45
- for i in range(len(test_images_names)):
46
- file_name = test_images_names[i]
47
- new_row = pd.DataFrame({"file_name": file_name,
48
- "bbox": str(bboxes[i]),
49
- "category_id": str(category_ids[i]),
50
- }, index=[0])
51
- df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
52
-
53
- df_predictions.to_csv(save_path, index=False)
54
 
 
 
 
 
 
 
55
 
56
  if __name__ == "__main__":
57
-
58
-
59
- current_directory = os.path.dirname(os.path.abspath(__file__))
60
- # print(current_directory)
61
- TEST_IMAGE_PATH = "/tmp/data/test_images"
62
- SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
63
-
64
- RUN_NAME = "instrument_detection"
65
- MODEL_WEIGHTS_PATH = os.path.join(current_directory, "yolov5", "runs", "train", RUN_NAME, "weights", "best.pt")
66
- CONF_THRESHOLD = 0.30
67
-
68
- model = torch.hub.load(os.path.join(current_directory, 'yolov5'), 'custom', path=MODEL_WEIGHTS_PATH, source="local")
69
-
70
- run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import glob
4
+ from ultralytics import YOLO
5
 
6
+ # --- CONFIGURATION ---
7
+ # These paths are standard for most competition containers.
8
+ # If testing locally, change 'TEST_IMAGES_DIR' to your local test folder.
9
+ MODEL_PATH = "best.pt" # The model must be in the same folder as this script
10
+ TEST_IMAGES_DIR = "test/images" # The competition usually puts images here
11
+ OUTPUT_FILE = "submission.json"
12
 
13
+ # Phase 2 Category Mapping (Ensure these match your training!)
14
+ # 0: Large Needle Driver, 1: Prograsp Forceps, 2: Monopolar Curved Scissors
15
+ # We map YOLO ID -> Category ID required by submission (usually 1, 2, 3 or same)
16
+ # If your competition expects IDs 1, 2, 3, use this map:
17
+ ID_MAPPING = {
18
+ 0: 1, # Large Needle Driver
19
+ 1: 2, # Prograsp Forceps
20
+ 2: 3 # Monopolar Curved Scissors
21
+ }
22
 
23
+ def main():
24
+ print(f"🚀 Loading model from {MODEL_PATH}...")
 
 
25
 
26
+ # 1. Load the trained YOLOv8 model
27
+ try:
28
+ model = YOLO(MODEL_PATH)
29
+ except Exception as e:
30
+ print(f"❌ Error loading model: {e}")
31
+ return
32
+
33
+ # 2. Get all test images
34
+ # We look for jpg, png, and jpeg
35
+ image_paths = glob.glob(os.path.join(TEST_IMAGES_DIR, "*.*"))
36
+ print(f"🔍 Found {len(image_paths)} images in {TEST_IMAGES_DIR}")
37
+
38
+ submission_results = []
39
+
40
+ # 3. Run Inference
41
+ # stream=True prevents crashing on large datasets
42
+ results = model.predict(
43
+ source=TEST_IMAGES_DIR,
44
+ conf=0.25, # Confidence threshold (0.25 is standard for mAP)
45
+ iou=0.45, # NMS IoU threshold
46
+ save=False, # Don't save plotted images (saves time)
47
+ save_txt=False,
48
+ verbose=False,
49
+ stream=True
50
+ )
51
+
52
+ print("🏃 Processing predictions...")
53
 
54
+ for result in results:
55
+ # Get filename (e.g., 'frame_001.jpg')
56
+ file_name = os.path.basename(result.path)
 
 
57
 
58
+ # Get filename without extension (often used as image_id)
59
+ image_id = os.path.splitext(file_name)[0]
60
 
61
+ # Loop through detections
62
+ for box in result.boxes:
63
+ # Get data
64
+ cls_id = int(box.cls[0])
65
+ score = float(box.conf[0])
66
+ bbox = box.xywh[0].tolist() # x_center, y_center, width, height (Normalized? No, usually pixels)
67
 
68
+ # YOLOv8 .xywh returns pixels: [x_center, y_center, width, height]
69
+ # MTEC/COCO usually wants: [x_min, y_min, width, height]
70
+ x_c, y_c, w, h = bbox
71
+ x_min = x_c - (w / 2)
72
+ y_min = y_c - (h / 2)
73
 
74
+ # Create annotation entry
75
+ annotation = {
76
+ "image_id": file_name, # Or use image_id variable depending on rules
77
+ "category_id": ID_MAPPING.get(cls_id, cls_id + 1), # Map to correct ID
78
+ "bbox": [x_min, y_min, w, h],
79
+ "score": score
80
+ }
81
+ submission_results.append(annotation)
 
 
 
 
 
 
82
 
83
+ # 4. Save to JSON
84
+ print(f"💾 Saving {len(submission_results)} detections to {OUTPUT_FILE}...")
85
+ with open(OUTPUT_FILE, 'w') as f:
86
+ json.dump(submission_results, f, indent=4)
87
+
88
+ print("✅ Done!")
89
 
90
  if __name__ == "__main__":
91
+ main()