yusufbardolia commited on
Commit
ce0ade6
·
verified ·
1 Parent(s): ec17f25

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +44 -12
script.py CHANGED
@@ -5,6 +5,7 @@ from PIL import Image
5
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
6
  from tqdm import tqdm
7
 
 
8
  def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold, device):
9
 
10
  try:
@@ -26,13 +27,17 @@ def run_inference(image_path, model, save_path, prompt, box_threshold, text_thre
26
 
27
  try:
28
  full_img_path = os.path.join(image_path, image_name)
29
- img = Image.open(full_img_path).convert("RGB")
30
- except Exception as e:
31
  bboxes.append([])
32
  category_ids.append([])
33
  continue
34
 
35
- inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
 
 
 
 
36
 
37
  with torch.no_grad():
38
  outputs = model(**inputs)
@@ -45,7 +50,7 @@ def run_inference(image_path, model, save_path, prompt, box_threshold, text_thre
45
  target_sizes=[img.size[::-1]]
46
  )
47
 
48
- # Safe Mode: ID=0
49
  for result in results:
50
  boxes = result["boxes"]
51
  for box in boxes:
@@ -53,7 +58,7 @@ def run_inference(image_path, model, save_path, prompt, box_threshold, text_thre
53
  width = xmax - xmin
54
  height = ymax - ymin
55
  bbox.append([xmin, ymin, width, height])
56
- category_id.append(0)
57
 
58
  bboxes.append(bbox)
59
  category_ids.append(category_id)
@@ -73,11 +78,13 @@ def run_inference(image_path, model, save_path, prompt, box_threshold, text_thre
73
 
74
 
75
  if __name__ == "__main__":
 
76
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
77
  os.environ["HF_HUB_OFFLINE"] = "1"
78
  os.environ["HF_DATASETS_OFFLINE"] = "1"
79
 
80
  current_directory = os.path.dirname(os.path.abspath(__file__))
 
81
  TEST_IMAGE_PATH = "/tmp/data/test_images"
82
  SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
83
 
@@ -89,13 +96,38 @@ if __name__ == "__main__":
89
  processor = AutoProcessor.from_pretrained(processor_path)
90
  model = AutoModelForZeroShotObjectDetection.from_pretrained(model_path)
91
  model.to(device)
 
92
 
93
- # --- OPTIMIZED SETTINGS ---
94
- # 1. Prompt: Reverting to the one that got 0.047 ("Large" helps!)
95
- PROMPT = "Monopolar Curved Scissors . Prograsp Forceps . Large Needle Driver ."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # 2. Threshold: Increased to 0.35 (Stricter than 0.25 to improve Precision)
98
- BOX_THRESHOLD = 0.35
99
- TEXT_THRESHOLD = 0.25
 
 
100
 
101
- run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device)
 
 
 
 
 
 
 
 
 
5
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
6
  from tqdm import tqdm
7
 
8
+
9
  def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold, device):
10
 
11
  try:
 
27
 
28
  try:
29
  full_img_path = os.path.join(image_path, image_name)
30
+ img = Image.open(full_img_path).convert("RGB")
31
+ except Exception:
32
  bboxes.append([])
33
  category_ids.append([])
34
  continue
35
 
36
+ inputs = processor(
37
+ images=img,
38
+ text=prompt,
39
+ return_tensors="pt"
40
+ ).to(device)
41
 
42
  with torch.no_grad():
43
  outputs = model(**inputs)
 
50
  target_sizes=[img.size[::-1]]
51
  )
52
 
53
+ # Safe Mode: Single category (ID = 0)
54
  for result in results:
55
  boxes = result["boxes"]
56
  for box in boxes:
 
58
  width = xmax - xmin
59
  height = ymax - ymin
60
  bbox.append([xmin, ymin, width, height])
61
+ category_id.append(0)
62
 
63
  bboxes.append(bbox)
64
  category_ids.append(category_id)
 
78
 
79
 
80
  if __name__ == "__main__":
81
+ # Offline HuggingFace settings
82
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
83
  os.environ["HF_HUB_OFFLINE"] = "1"
84
  os.environ["HF_DATASETS_OFFLINE"] = "1"
85
 
86
  current_directory = os.path.dirname(os.path.abspath(__file__))
87
+
88
  TEST_IMAGE_PATH = "/tmp/data/test_images"
89
  SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
90
 
 
96
  processor = AutoProcessor.from_pretrained(processor_path)
97
  model = AutoModelForZeroShotObjectDetection.from_pretrained(model_path)
98
  model.to(device)
99
+ model.eval()
100
 
101
+ # =========================
102
+ # 🔥 PROMPT ENGINEERING
103
+ # =========================
104
+ PROMPT = (
105
+ "Monopolar Curved Scissors. "
106
+ "curved surgical scissors. "
107
+ "surgical scissors cutting tissue. "
108
+ "Prograsp Forceps. "
109
+ "surgical forceps grasping tissue. "
110
+ "grasping forceps. "
111
+ "Large Needle Driver. "
112
+ "needle holder. "
113
+ "surgical needle driver. "
114
+ "laparoscopic surgical instrument. "
115
+ "robotic surgical instrument. "
116
+ "metal surgical tool inside the body."
117
+ )
118
 
119
+ # =========================
120
+ # 🎯 THRESHOLDS (Recall-Oriented)
121
+ # =========================
122
+ BOX_THRESHOLD = 0.25
123
+ TEXT_THRESHOLD = 0.20
124
 
125
+ run_inference(
126
+ TEST_IMAGE_PATH,
127
+ model,
128
+ SUBMISSION_SAVE_PATH,
129
+ PROMPT,
130
+ BOX_THRESHOLD,
131
+ TEXT_THRESHOLD,
132
+ device
133
+ )