yusufbardolia commited on
Commit
4da07fc
·
verified ·
1 Parent(s): 20c6fc7

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +61 -80
script.py CHANGED
@@ -1,31 +1,43 @@
1
  import requests
2
-
3
  import torch
4
- from PIL import Image, ImageDraw
5
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
6
-
7
  from tqdm import tqdm
8
  import os
9
  import pandas as pd
 
 
 
 
 
10
 
 
 
 
 
 
11
 
12
  def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold,
13
  visualize_results, visualization_path, device):
14
 
15
- test_images = os.listdir(image_path)
16
- test_images.sort()
17
 
18
  bboxes = []
19
  category_ids = []
20
  test_images_names = []
21
 
 
 
22
  for image_name in tqdm(test_images):
23
-
24
  test_images_names.append(image_name)
25
  bbox = []
26
  category_id = []
27
 
28
- img = Image.open(os.path.join(image_path, image_name))
 
 
 
29
 
30
  inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
31
 
@@ -40,98 +52,67 @@ def run_inference(image_path, model, save_path, prompt, box_threshold, text_thre
40
  target_sizes=[img.size[::-1]]
41
  )
42
 
43
- # visualize results
44
  if visualize_results:
45
  draw = ImageDraw.Draw(img)
46
- print(image_name)
47
- print(results)
48
-
 
 
 
49
  for result in results:
50
  boxes = result["boxes"]
51
- for i, _ in enumerate(range(len(boxes))):
52
- box = boxes[i].tolist()
53
- label = result["labels"][i]
54
- draw.rectangle(box, outline="red", width=3, )
 
 
 
 
 
 
 
 
55
  img.save(os.path.join(visualization_path, image_name))
56
-
57
- # --- REPLACEMENT BLOCK START ---
 
58
  for result in results:
59
  boxes = result["boxes"]
60
- labels = result["labels"] # The model returns the text label here (e.g. "metal curved scissors")
61
-
62
- for i, box in enumerate(boxes):
63
  xmin, ymin, xmax, ymax = box.tolist()
64
- width = xmax - xmin
65
- height = ymax - ymin
66
- bbox.append([xmin, ymin, width, height])
67
-
68
- # Get the text label found by the model
69
- label_text = labels[i].lower()
70
-
71
- # Assign ID based on the text description
72
- # WE USE 0, 1, 2 based on your successful test
73
- if "needle" in label_text or "driver" in label_text:
74
- final_id = 0 # Large Needle Driver
75
- elif "forceps" in label_text or "grasper" in label_text:
76
- final_id = 1 # Prograsp Forceps
77
- elif "scissors" in label_text or "curved" in label_text:
78
- final_id = 2 # Monopolar Curved Scissors
79
- else:
80
- final_id = 0 # Default fallback
81
-
82
- category_id.append(final_id)
83
 
84
  bboxes.append(bbox)
85
  category_ids.append(category_id)
86
 
87
- df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
88
-
89
- for i in range(len(test_images_names)):
90
- file_name = test_images_names[i]
91
- new_row = pd.DataFrame({"file_name": file_name,
92
- "bbox": str(bboxes[i]),
93
- "category_id": str(category_ids[i]),
94
- }, index=[0])
95
- df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
96
-
97
- df_predictions.to_csv(save_path, index=False)
98
 
99
 
100
  if __name__ == "__main__":
101
-
102
- # The following environment variables are required for offline mode during HuggingFace Submission
103
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
104
- os.environ["HF_HUB_OFFLINE"] = "1"
105
- os.environ["HF_DATASETS_OFFLINE"] = "1"
106
-
107
- current_directory = os.path.dirname(os.path.abspath(__file__))
108
- TEST_IMAGE_PATH = "/tmp/data/test_images"
109
- SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
110
 
111
- # Configure the model. More information here: https://huggingface.co/docs/transformers/model_doc/grounding-dino
112
- # If you want to use another model - you need to make it avaible for offline usage. More information here: https://huggingface.co/docs/transformers/installation#offline-mode
113
  model_id = "IDEA-Research/grounding-dino-tiny"
114
- device = "cuda"
115
- processor = AutoProcessor.from_pretrained(os.path.join(current_directory, "processor"))
116
- model = AutoModelForZeroShotObjectDetection.from_pretrained(os.path.join(current_directory, "model"))
117
 
 
 
 
 
118
  model.to(device)
119
 
120
- BOX_THRESHOLD = 0.25 # Lowered from 0.4 to catch more items
121
- TEXT_THRESHOLD = 0.20 # Lowered from 0.3
122
- # Describing the shape helps the model find the object
123
- PROMPT = "metal curved scissors . surgical grasper forceps . needle driver holder ."
 
124
 
125
- # If you want to test out your model on training images and visualize the results, set visualize_results to True - Visualization images will be saved in the "outputs" folder
126
- parent_directory = os.path.dirname(current_directory)
127
- PATH_TO_TRAINING_IMAGES_FOR_FOR_VISUALIZATION = os.path.join(parent_directory, "images")
128
- visualization_path = os.path.join(parent_directory, "outputs")
129
- visualize_results = False
130
- if visualize_results:
131
- if os.path.exists(visualization_path):
132
- os.system("rm -rf " + visualization_path)
133
- os.makedirs(visualization_path, exist_ok=True)
134
- run_inference(PATH_TO_TRAINING_IMAGES_FOR_FOR_VISUALIZATION, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, visualize_results, visualization_path, device)
135
 
136
- else:
137
- run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, visualize_results, visualization_path, device)
 
1
  import requests
 
2
  import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
 
5
  from tqdm import tqdm
6
  import os
7
  import pandas as pd
8
+ import shutil
9
+
10
+ # --- CONFIGURATION FOR KAGGLE DEBUGGING ---
11
+ # We force visualization to TRUE so you can see the images
12
+ VISUALIZE_RESULTS = True
13
 
14
+ # Use the path to your TRAINING images (from your sidebar)
15
+ # Update this path if yours is different!
16
+ PATH_TO_IMAGES = "/kaggle/input/phase2anewdata/new_data/images/Train"
17
+ # Output folder for marked images
18
+ OUTPUT_DIR = "/kaggle/working/debug_images"
19
 
20
  def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold,
21
  visualize_results, visualization_path, device):
22
 
23
+ # Get first 5 images only for debugging
24
+ test_images = sorted(os.listdir(image_path))[:5]
25
 
26
  bboxes = []
27
  category_ids = []
28
  test_images_names = []
29
 
30
+ print(f"🕵️‍♂️ Debugging on {len(test_images)} images...")
31
+
32
  for image_name in tqdm(test_images):
 
33
  test_images_names.append(image_name)
34
  bbox = []
35
  category_id = []
36
 
37
+ full_img_path = os.path.join(image_path, image_name)
38
+ if not os.path.exists(full_img_path): continue
39
+
40
+ img = Image.open(full_img_path)
41
 
42
  inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
43
 
 
52
  target_sizes=[img.size[::-1]]
53
  )
54
 
55
+ # --- VISUALIZATION BLOCK ---
56
  if visualize_results:
57
  draw = ImageDraw.Draw(img)
58
+ # Try to load a font, fallback to default if fails
59
+ try:
60
+ font = ImageFont.truetype("arial.ttf", 20)
61
+ except:
62
+ font = None
63
+
64
  for result in results:
65
  boxes = result["boxes"]
66
+ labels = result["labels"]
67
+ scores = result["scores"]
68
+
69
+ for i, box in enumerate(boxes):
70
+ # Draw Box
71
+ b = box.tolist()
72
+ draw.rectangle(b, outline="red", width=4)
73
+
74
+ # Draw Label
75
+ label_text = f"{labels[i]} ({scores[i]:.2f})"
76
+ draw.text((b[0], b[1]), label_text, fill="yellow", font=font)
77
+
78
  img.save(os.path.join(visualization_path, image_name))
79
+ # ---------------------------
80
+
81
+ # Simple Saving Logic (All ID=0 for now)
82
  for result in results:
83
  boxes = result["boxes"]
84
+ for box in boxes:
 
 
85
  xmin, ymin, xmax, ymax = box.tolist()
86
+ bbox.append([xmin, ymin, xmax-xmin, ymax-ymin])
87
+ category_id.append(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  bboxes.append(bbox)
90
  category_ids.append(category_id)
91
 
92
+ print(f" Debug images saved to: {visualization_path}")
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  if __name__ == "__main__":
96
+ current_directory = os.path.dirname(os.path.abspath("__file__"))
 
 
 
 
 
 
 
 
97
 
98
+ # Model Setup
 
99
  model_id = "IDEA-Research/grounding-dino-tiny"
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
101
 
102
+ # Load model from local folders (Kaggle specific)
103
+ # Ensure these folders exist in your working dir!
104
+ processor = AutoProcessor.from_pretrained("/kaggle/working/processor")
105
+ model = AutoModelForZeroShotObjectDetection.from_pretrained("/kaggle/working/model")
106
  model.to(device)
107
 
108
+ # --- TUNING PARAMETERS ---
109
+ BOX_THRESHOLD = 0.25
110
+ TEXT_THRESHOLD = 0.20
111
+ # New Prompt Attempt: Using "robotic" to catch da Vinci tools
112
+ PROMPT = "robotic needle driver . robotic graspers . curved scissors ."
113
 
114
+ if os.path.exists(OUTPUT_DIR):
115
+ shutil.rmtree(OUTPUT_DIR)
116
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
 
 
 
 
 
 
 
117
 
118
+ run_inference(PATH_TO_IMAGES, model, "dummy.csv", PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, VISUALIZE_RESULTS, OUTPUT_DIR, device)