sharvari0b26 commited on
Commit
f7140ec
·
verified ·
1 Parent(s): 431142b

Update submission/script.py

Browse files
Files changed (1) hide show
  1. submission/script.py +25 -8
submission/script.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
  import pandas as pd
 
4
  from rfdetr import RFDETRBase
5
 
6
 
@@ -17,12 +18,15 @@ def run_inference(model, image_path, conf_threshold, save_path):
17
 
18
  for image_name in test_images:
19
  test_images_names.append(image_name)
20
-
21
  image_file = os.path.join(image_path, image_name)
22
 
23
  bbox = []
24
  category_id = []
25
 
 
 
 
 
26
  preds = model.predict(image_file)
27
 
28
  if preds is not None and preds.xyxy is not None and len(preds.xyxy) > 0:
@@ -32,14 +36,26 @@ def run_inference(model, image_path, conf_threshold, save_path):
32
  preds.class_id
33
  ):
34
  score = float(score)
35
- if score >= conf_threshold:
36
- xmin, ymin, xmax, ymax = map(float, box)
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- width = xmax - xmin
39
- height = ymax - ymin
 
40
 
41
- bbox.append([xmin, ymin, width, height])
42
- category_id.append(int(label))
43
 
44
  bboxes.append(bbox)
45
  category_ids.append(category_id)
@@ -56,13 +72,14 @@ def run_inference(model, image_path, conf_threshold, save_path):
56
  df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
57
 
58
  df_predictions.to_csv(save_path, index=False)
 
59
 
60
 
61
  if __name__ == "__main__":
62
 
63
  TEST_IMAGE_PATH = "/tmp/data/test_images"
64
  SUBMISSION_SAVE_PATH = "submission.csv"
65
- CONF_THRESHOLD = 0.30
66
 
67
  model = RFDETRBase(
68
  checkpoint_path="checkpoint_best_total.pth",
 
1
  import os
2
  import torch
3
  import pandas as pd
4
+ from PIL import Image
5
  from rfdetr import RFDETRBase
6
 
7
 
 
18
 
19
  for image_name in test_images:
20
  test_images_names.append(image_name)
 
21
  image_file = os.path.join(image_path, image_name)
22
 
23
  bbox = []
24
  category_id = []
25
 
26
+ # Load image to get dimensions (IMPORTANT)
27
+ with Image.open(image_file) as img:
28
+ img_w, img_h = img.size
29
+
30
  preds = model.predict(image_file)
31
 
32
  if preds is not None and preds.xyxy is not None and len(preds.xyxy) > 0:
 
36
  preds.class_id
37
  ):
38
  score = float(score)
39
+ if score < conf_threshold:
40
+ continue
41
+
42
+ xmin, ymin, xmax, ymax = map(float, box)
43
+
44
+ # ---- CLAMP TO IMAGE BOUNDARIES ----
45
+ xmin = max(0.0, xmin)
46
+ ymin = max(0.0, ymin)
47
+ xmax = min(float(img_w), xmax)
48
+ ymax = min(float(img_h), ymax)
49
+
50
+ width = xmax - xmin
51
+ height = ymax - ymin
52
 
53
+ # ---- FILTER INVALID BOXES ----
54
+ if width <= 0 or height <= 0:
55
+ continue
56
 
57
+ bbox.append([xmin, ymin, width, height])
58
+ category_id.append(int(label))
59
 
60
  bboxes.append(bbox)
61
  category_ids.append(category_id)
 
72
  df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
73
 
74
  df_predictions.to_csv(save_path, index=False)
75
+ print(f"Submission saved to {save_path}")
76
 
77
 
78
  if __name__ == "__main__":
79
 
80
  TEST_IMAGE_PATH = "/tmp/data/test_images"
81
  SUBMISSION_SAVE_PATH = "submission.csv"
82
+ CONF_THRESHOLD = 0.30 # you may lower to 0.15 if recall is poor
83
 
84
  model = RFDETRBase(
85
  checkpoint_path="checkpoint_best_total.pth",