yusufbardolia commited on
Commit
ced6b45
·
verified ·
1 Parent(s): 0e1cacf

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +8 -9
script.py CHANGED
@@ -8,13 +8,12 @@ TEST_IMAGE_PATH = "/tmp/data/test_images"
8
  SUBMISSION_SAVE_PATH = "submission.csv"
9
  MODEL_WEIGHTS = "best.pt"
10
 
11
- # ✅ 1. Sane Confidence:
12
- # 0.20 was safe. We go slightly lower to 0.15 to improve Recall,
13
- # relying on 'max_det' to fix the Precision.
14
- CONF_THRESHOLD = 0.05
15
 
16
  def get_category_id(cls_id):
17
- # IDs 0, 1, 2 are correct based on your experiments.
18
  return int(cls_id)
19
 
20
  def run_inference(model, image_path, conf_threshold, save_path):
@@ -37,14 +36,14 @@ def run_inference(model, image_path, conf_threshold, save_path):
37
  for image_name in test_images:
38
  full_path = os.path.join(image_path, image_name)
39
 
40
- # 🚀 OPTIMIZED INFERENCE SETTINGS
41
  results = model.predict(
42
  full_path,
43
  conf=conf_threshold,
44
  imgsz=1024, # Keep High Res
45
- augment=False,
46
- iou=0.60, # Slightly higher IoU to keep close instruments
47
- max_det=5, # ⬅️ CRITICAL: Only allow top 5 instruments per image
48
  verbose=False
49
  )
50
 
 
8
  SUBMISSION_SAVE_PATH = "submission.csv"
9
  MODEL_WEIGHTS = "best.pt"
10
 
11
+ # ✅ 1. Balanced Confidence
12
+ # 0.10 captures weaker tools but keeps precision reasonable.
13
+ CONF_THRESHOLD = 0.10
 
14
 
15
  def get_category_id(cls_id):
16
+ # IDs 0, 1, 2 are correct.
17
  return int(cls_id)
18
 
19
  def run_inference(model, image_path, conf_threshold, save_path):
 
36
  for image_name in test_images:
37
  full_path = os.path.join(image_path, image_name)
38
 
39
+ # 🚀 BALANCED TTA SETTINGS
40
  results = model.predict(
41
  full_path,
42
  conf=conf_threshold,
43
  imgsz=1024, # Keep High Res
44
+ augment=True, # ⬅️ ENABLED: Test Time Augmentation (Boosts accuracy)
45
+ iou=0.60, # Merge overlapping boxes
46
+ max_det=30, # ⬅️ INCREASED: 5 was too strict, 30 is safe
47
  verbose=False
48
  )
49