yusufbardolia commited on
Commit
fac07d4
Β·
verified Β·
1 Parent(s): b5ada13

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +78 -75
script.py CHANGED
@@ -1,91 +1,94 @@
1
  import os
2
- import json
3
- import glob
4
  from ultralytics import YOLO
 
5
 
6
  # --- CONFIGURATION ---
7
- # These paths are standard for most competition containers.
8
- # If testing locally, change 'TEST_IMAGES_DIR' to your local test folder.
9
- MODEL_PATH = "best.pt" # The model must be in the same folder as this script
10
- TEST_IMAGES_DIR = "test/images" # The competition usually puts images here
11
- OUTPUT_FILE = "submission.json"
 
12
 
13
- # Phase 2 Category Mapping (Ensure these match your training!)
14
- # 0: Large Needle Driver, 1: Prograsp Forceps, 2: Monopolar Curved Scissors
15
- # We map YOLO ID -> Category ID required by submission (usually 1, 2, 3 or same)
16
- # If your competition expects IDs 1, 2, 3, use this map:
17
- ID_MAPPING = {
18
- 0: 1, # Large Needle Driver
19
- 1: 2, # Prograsp Forceps
20
- 2: 3 # Monopolar Curved Scissors
21
- }
22
 
23
- def main():
24
- print(f"πŸš€ Loading model from {MODEL_PATH}...")
25
 
26
- # 1. Load the trained YOLOv8 model
27
- try:
28
- model = YOLO(MODEL_PATH)
29
- except Exception as e:
30
- print(f"❌ Error loading model: {e}")
31
- return
32
-
33
- # 2. Get all test images
34
- # We look for jpg, png, and jpeg
35
- image_paths = glob.glob(os.path.join(TEST_IMAGES_DIR, "*.*"))
36
- print(f"πŸ” Found {len(image_paths)} images in {TEST_IMAGES_DIR}")
37
 
38
- submission_results = []
39
 
40
- # 3. Run Inference
41
- # stream=True prevents crashing on large datasets
42
- results = model.predict(
43
- source=TEST_IMAGES_DIR,
44
- conf=0.25, # Confidence threshold (0.25 is standard for mAP)
45
- iou=0.45, # NMS IoU threshold
46
- save=False, # Don't save plotted images (saves time)
47
- save_txt=False,
48
- verbose=False,
49
- stream=True
50
- )
51
 
52
- print("πŸƒ Processing predictions...")
53
-
54
- for result in results:
55
- # Get filename (e.g., 'frame_001.jpg')
56
- file_name = os.path.basename(result.path)
57
-
58
- # Get filename without extension (often used as image_id)
59
- image_id = os.path.splitext(file_name)[0]
60
-
61
- # Loop through detections
62
- for box in result.boxes:
63
- # Get data
64
- cls_id = int(box.cls[0])
65
- score = float(box.conf[0])
66
- bbox = box.xywh[0].tolist() # x_center, y_center, width, height (Normalized? No, usually pixels)
67
 
68
- # YOLOv8 .xywh returns pixels: [x_center, y_center, width, height]
69
- # MTEC/COCO usually wants: [x_min, y_min, width, height]
70
- x_c, y_c, w, h = bbox
71
- x_min = x_c - (w / 2)
72
- y_min = y_c - (h / 2)
73
 
74
- # Create annotation entry
75
- annotation = {
76
- "image_id": file_name, # Or use image_id variable depending on rules
77
- "category_id": ID_MAPPING.get(cls_id, cls_id + 1), # Map to correct ID
78
- "bbox": [x_min, y_min, w, h],
79
- "score": score
80
- }
81
- submission_results.append(annotation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # 4. Save to JSON
84
- print(f"πŸ’Ύ Saving {len(submission_results)} detections to {OUTPUT_FILE}...")
85
- with open(OUTPUT_FILE, 'w') as f:
86
- json.dump(submission_results, f, indent=4)
87
-
88
  print("βœ… Done!")
89
 
90
  if __name__ == "__main__":
91
- main()
 
 
 
 
 
 
 
 
1
  import os
2
+ import pandas as pd
 
3
  from ultralytics import YOLO
4
+ import glob
5
 
6
  # --- CONFIGURATION ---
7
+ # Use the exact paths from the baseline script to avoid "File Not Found" errors
8
+ # The server puts images in "/tmp/data/test_images"
9
+ TEST_IMAGE_PATH = "/tmp/data/test_images"
10
+ SUBMISSION_SAVE_PATH = "submission.csv"
11
+ MODEL_WEIGHTS = "best.pt"
12
+ CONF_THRESHOLD = 0.30
13
 
14
+ # Mapping: YOLO ID (0,1,2) -> Competition Category ID (likely 1,2,3)
15
+ # If your previous training used 0,1,2, usually competitions want 1-based IDs.
16
+ # We will add +1 to be safe, matching standard Phase 2 rules.
17
+ def get_category_id(cls_id):
18
+ return int(cls_id) + 1
 
 
 
 
19
 
20
+ def run_inference(model, image_path, conf_threshold, save_path):
21
+ print(f"πŸš€ Checking for images in {image_path}...")
22
 
23
+ # 1. Get all images (support multiple extensions)
24
+ if os.path.exists(image_path):
25
+ test_images = [f for f in os.listdir(image_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
26
+ test_images.sort()
27
+ else:
28
+ print(f"⚠️ Warning: {image_path} not found. Using current folder for testing.")
29
+ test_images = []
 
 
 
 
30
 
31
+ print(f"πŸ” Found {len(test_images)} images.")
32
 
33
+ # Prepare lists for the dataframe
34
+ # The baseline wants a specific format: stringified lists of lists
35
+ df_rows = []
 
 
 
 
 
 
 
 
36
 
37
+ # 2. Run Inference
38
+ if len(test_images) > 0:
39
+ # Load images one by one or in batches
40
+ # We loop to match the baseline structure exactly
41
+ for image_name in test_images:
42
+ full_path = os.path.join(image_path, image_name)
 
 
 
 
 
 
 
 
 
43
 
44
+ # Run YOLOv8 inference
45
+ results = model.predict(full_path, conf=conf_threshold, verbose=False)
 
 
 
46
 
47
+ bbox_list = []
48
+ category_list = []
49
+
50
+ for result in results:
51
+ for box in result.boxes:
52
+ # Get box in xywh format (x, y, width, height)
53
+ # The baseline calculated width = xmax - xmin, so it wants [x, y, w, h]
54
+ x_c, y_c, w, h = box.xywh[0].tolist()
55
+
56
+ # Convert Center-XY to Top-Left-XY
57
+ x_min = x_c - (w / 2)
58
+ y_min = y_c - (h / 2)
59
+
60
+ # Append to list
61
+ bbox_list.append([x_min, y_min, w, h])
62
+
63
+ # Get Class ID
64
+ cls_id = int(box.cls[0])
65
+ category_list.append(get_category_id(cls_id))
66
+
67
+ # 3. Format exactly like the baseline
68
+ # It expects columns: file_name, bbox (string), category_id (string)
69
+ df_rows.append({
70
+ "file_name": image_name,
71
+ "bbox": str(bbox_list), # e.g. "[[10, 20, 50, 50]]"
72
+ "category_id": str(category_list) # e.g. "[1]"
73
+ })
74
+
75
+ # 4. Create DataFrame and Save
76
+ df_predictions = pd.DataFrame(df_rows, columns=["file_name", "bbox", "category_id"])
77
+
78
+ # Safety check: if empty, create empty CSV with correct headers
79
+ if df_predictions.empty:
80
+ df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
81
 
82
+ print(f"πŸ’Ύ Saving {len(df_predictions)} rows to {save_path}...")
83
+ df_predictions.to_csv(save_path, index=False)
 
 
 
84
  print("βœ… Done!")
85
 
86
  if __name__ == "__main__":
87
+ # Load your YOLOv8 model
88
+ # Note: We do NOT use torch.hub anymore
89
+ print(f"πŸ”₯ Loading YOLOv8 model: {MODEL_WEIGHTS}...")
90
+ try:
91
+ model = YOLO(MODEL_WEIGHTS)
92
+ run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)
93
+ except Exception as e:
94
+ print(f"❌ Critical Error: {e}")