yusufbardolia commited on
Commit
e98e09e
·
verified ·
1 Parent(s): 0aa63c6

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +45 -32
script.py CHANGED
@@ -1,34 +1,42 @@
1
- import requests
2
  import torch
 
3
  from PIL import Image, ImageDraw
4
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
5
  from tqdm import tqdm
6
- import os
7
- import pandas as pd
8
 
9
- def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold,
10
- visualize_results, visualization_path, device):
11
-
12
- test_images = os.listdir(image_path)
13
- test_images.sort()
14
 
 
 
 
 
 
 
 
 
15
  bboxes = []
16
  category_ids = []
17
  test_images_names = []
18
 
 
 
19
  for image_name in tqdm(test_images):
20
  test_images_names.append(image_name)
21
  bbox = []
22
  category_id = []
23
 
 
24
  try:
25
- img = Image.open(os.path.join(image_path, image_name))
26
- except:
27
- # Fallback if image fails to load
 
28
  bboxes.append([])
29
  category_ids.append([])
30
  continue
31
 
 
32
  inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
33
 
34
  with torch.no_grad():
@@ -42,11 +50,12 @@ def run_inference(image_path, model, save_path, prompt, box_threshold, text_thre
42
  target_sizes=[img.size[::-1]]
43
  )
44
 
45
- # --- SAFE MODE: SAVE EVERYTHING AS ID 0 ---
46
- # This ensures we don't lose points by guessing the wrong class ID.
47
- # We focus purely on finding the objects first.
48
  for result in results:
49
  boxes = result["boxes"]
 
 
50
  for box in boxes:
51
  xmin, ymin, xmax, ymax = box.tolist()
52
  width = xmax - xmin
@@ -57,46 +66,50 @@ def run_inference(image_path, model, save_path, prompt, box_threshold, text_thre
57
  bboxes.append(bbox)
58
  category_ids.append(category_id)
59
 
 
60
  df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
61
 
62
  for i in range(len(test_images_names)):
63
- file_name = test_images_names[i]
64
- new_row = pd.DataFrame({"file_name": file_name,
65
- "bbox": str(bboxes[i]),
66
- "category_id": str(category_ids[i]),
67
- }, index=[0])
 
68
  df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
69
 
70
  df_predictions.to_csv(save_path, index=False)
 
71
 
72
 
73
  if __name__ == "__main__":
 
74
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
75
  os.environ["HF_HUB_OFFLINE"] = "1"
76
  os.environ["HF_DATASETS_OFFLINE"] = "1"
77
 
 
78
  current_directory = os.path.dirname(os.path.abspath(__file__))
79
  TEST_IMAGE_PATH = "/tmp/data/test_images"
80
  SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
81
 
82
- model_id = "IDEA-Research/grounding-dino-tiny"
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
84
 
85
- processor = AutoProcessor.from_pretrained(os.path.join(current_directory, "processor"))
86
- model = AutoModelForZeroShotObjectDetection.from_pretrained(os.path.join(current_directory, "model"))
87
  model.to(device)
88
 
89
- # --- TUNING SETTINGS ---
90
- # 1. Lower Threshold: Catches faint objects
91
  BOX_THRESHOLD = 0.20
92
  TEXT_THRESHOLD = 0.20
93
-
94
- # 2. Visual Prompt: Describes SHAPE rather than just name
95
- # "robotic" helps because these are da Vinci tools
96
- # "wristed" describes the joint
97
  PROMPT = "robotic surgical tool . metal curved scissors . wristed forceps grasper . needle driver ."
98
 
99
- visualize_results = False
100
- visualization_path = os.path.join(os.path.dirname(current_directory), "outputs")
101
-
102
- run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, visualize_results, visualization_path, device)
 
1
+ import os
2
  import torch
3
+ import pandas as pd
4
  from PIL import Image, ImageDraw
5
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
6
  from tqdm import tqdm
 
 
7
 
8
+ def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold, device):
 
 
 
 
9
 
10
+ # 1. Get list of images
11
+ try:
12
+ test_images = sorted(os.listdir(image_path))
13
+ except FileNotFoundError:
14
+ # Fallback for debugging if path is wrong
15
+ print(f"Error: Path {image_path} not found.")
16
+ return
17
+
18
  bboxes = []
19
  category_ids = []
20
  test_images_names = []
21
 
22
+ print(f"🚀 Running inference on {len(test_images)} images...")
23
+
24
  for image_name in tqdm(test_images):
25
  test_images_names.append(image_name)
26
  bbox = []
27
  category_id = []
28
 
29
+ # 2. Load Image safely
30
  try:
31
+ full_img_path = os.path.join(image_path, image_name)
32
+ img = Image.open(full_img_path).convert("RGB") # Ensure RGB
33
+ except Exception as e:
34
+ print(f"⚠️ Failed to load {image_name}: {e}")
35
  bboxes.append([])
36
  category_ids.append([])
37
  continue
38
 
39
+ # 3. Run Model
40
  inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
41
 
42
  with torch.no_grad():
 
50
  target_sizes=[img.size[::-1]]
51
  )
52
 
53
+ # 4. Save Results (SAFE MODE: All ID=0)
54
+ # We stick to ID 0 to ensure we get points for detection first.
 
55
  for result in results:
56
  boxes = result["boxes"]
57
+ # labels = result["labels"] # Not using labels for ID yet
58
+
59
  for box in boxes:
60
  xmin, ymin, xmax, ymax = box.tolist()
61
  width = xmax - xmin
 
66
  bboxes.append(bbox)
67
  category_ids.append(category_id)
68
 
69
+ # 5. Create Submission DataFrame
70
  df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
71
 
72
  for i in range(len(test_images_names)):
73
+ # Format explicitly as string for the CSV
74
+ new_row = pd.DataFrame({
75
+ "file_name": test_images_names[i],
76
+ "bbox": str(bboxes[i]),
77
+ "category_id": str(category_ids[i]),
78
+ }, index=[0])
79
  df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
80
 
81
  df_predictions.to_csv(save_path, index=False)
82
+ print("✅ Submission file generated.")
83
 
84
 
85
  if __name__ == "__main__":
86
+ # --- ENVIRONMENT SETUP ---
87
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
88
  os.environ["HF_HUB_OFFLINE"] = "1"
89
  os.environ["HF_DATASETS_OFFLINE"] = "1"
90
 
91
+ # Correct pathing for Hugging Face Repo
92
  current_directory = os.path.dirname(os.path.abspath(__file__))
93
  TEST_IMAGE_PATH = "/tmp/data/test_images"
94
  SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
95
 
96
+ # Detect Device
97
  device = "cuda" if torch.cuda.is_available() else "cpu"
98
+ print(f"🔧 Using device: {device}")
99
+
100
+ # --- MODEL LOADING (RELATIVELY) ---
101
+ # FIX: Point to folders relative to this script, NOT /kaggle/working/
102
+ processor_path = os.path.join(current_directory, "processor")
103
+ model_path = os.path.join(current_directory, "model")
104
 
105
+ processor = AutoProcessor.from_pretrained(processor_path)
106
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(model_path)
107
  model.to(device)
108
 
109
+ # --- TUNING ---
 
110
  BOX_THRESHOLD = 0.20
111
  TEXT_THRESHOLD = 0.20
 
 
 
 
112
  PROMPT = "robotic surgical tool . metal curved scissors . wristed forceps grasper . needle driver ."
113
 
114
+ # Run!
115
+ run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device)