Update script.py
Browse files
script.py
CHANGED
|
@@ -1,31 +1,43 @@
|
|
| 1 |
import requests
|
| 2 |
-
|
| 3 |
import torch
|
| 4 |
-
from PIL import Image, ImageDraw
|
| 5 |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
| 6 |
-
|
| 7 |
from tqdm import tqdm
|
| 8 |
import os
|
| 9 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold,
|
| 13 |
visualize_results, visualization_path, device):
|
| 14 |
|
| 15 |
-
|
| 16 |
-
test_images.
|
| 17 |
|
| 18 |
bboxes = []
|
| 19 |
category_ids = []
|
| 20 |
test_images_names = []
|
| 21 |
|
|
|
|
|
|
|
| 22 |
for image_name in tqdm(test_images):
|
| 23 |
-
|
| 24 |
test_images_names.append(image_name)
|
| 25 |
bbox = []
|
| 26 |
category_id = []
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
|
| 31 |
|
|
@@ -40,98 +52,67 @@ def run_inference(image_path, model, save_path, prompt, box_threshold, text_thre
|
|
| 40 |
target_sizes=[img.size[::-1]]
|
| 41 |
)
|
| 42 |
|
| 43 |
-
#
|
| 44 |
if visualize_results:
|
| 45 |
draw = ImageDraw.Draw(img)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
for result in results:
|
| 50 |
boxes = result["boxes"]
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
img.save(os.path.join(visualization_path, image_name))
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
for result in results:
|
| 59 |
boxes = result["boxes"]
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
for i, box in enumerate(boxes):
|
| 63 |
xmin, ymin, xmax, ymax = box.tolist()
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
bbox.append([xmin, ymin, width, height])
|
| 67 |
-
|
| 68 |
-
# Get the text label found by the model
|
| 69 |
-
label_text = labels[i].lower()
|
| 70 |
-
|
| 71 |
-
# Assign ID based on the text description
|
| 72 |
-
# WE USE 0, 1, 2 based on your successful test
|
| 73 |
-
if "needle" in label_text or "driver" in label_text:
|
| 74 |
-
final_id = 0 # Large Needle Driver
|
| 75 |
-
elif "forceps" in label_text or "grasper" in label_text:
|
| 76 |
-
final_id = 1 # Prograsp Forceps
|
| 77 |
-
elif "scissors" in label_text or "curved" in label_text:
|
| 78 |
-
final_id = 2 # Monopolar Curved Scissors
|
| 79 |
-
else:
|
| 80 |
-
final_id = 0 # Default fallback
|
| 81 |
-
|
| 82 |
-
category_id.append(final_id)
|
| 83 |
|
| 84 |
bboxes.append(bbox)
|
| 85 |
category_ids.append(category_id)
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
for i in range(len(test_images_names)):
|
| 90 |
-
file_name = test_images_names[i]
|
| 91 |
-
new_row = pd.DataFrame({"file_name": file_name,
|
| 92 |
-
"bbox": str(bboxes[i]),
|
| 93 |
-
"category_id": str(category_ids[i]),
|
| 94 |
-
}, index=[0])
|
| 95 |
-
df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
|
| 96 |
-
|
| 97 |
-
df_predictions.to_csv(save_path, index=False)
|
| 98 |
|
| 99 |
|
| 100 |
if __name__ == "__main__":
|
| 101 |
-
|
| 102 |
-
# The following environment variables are required for offline mode during HuggingFace Submission
|
| 103 |
-
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 104 |
-
os.environ["HF_HUB_OFFLINE"] = "1"
|
| 105 |
-
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
| 106 |
-
|
| 107 |
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
| 108 |
-
TEST_IMAGE_PATH = "/tmp/data/test_images"
|
| 109 |
-
SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
# If you want to use another model - you need to make it avaible for offline usage. More information here: https://huggingface.co/docs/transformers/installation#offline-mode
|
| 113 |
model_id = "IDEA-Research/grounding-dino-tiny"
|
| 114 |
-
device = "cuda"
|
| 115 |
-
processor = AutoProcessor.from_pretrained(os.path.join(current_directory, "processor"))
|
| 116 |
-
model = AutoModelForZeroShotObjectDetection.from_pretrained(os.path.join(current_directory, "model"))
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
model.to(device)
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
visualization_path = os.path.join(parent_directory, "outputs")
|
| 129 |
-
visualize_results = False
|
| 130 |
-
if visualize_results:
|
| 131 |
-
if os.path.exists(visualization_path):
|
| 132 |
-
os.system("rm -rf " + visualization_path)
|
| 133 |
-
os.makedirs(visualization_path, exist_ok=True)
|
| 134 |
-
run_inference(PATH_TO_TRAINING_IMAGES_FOR_FOR_VISUALIZATION, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, visualize_results, visualization_path, device)
|
| 135 |
|
| 136 |
-
|
| 137 |
-
run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, visualize_results, visualization_path, device)
|
|
|
|
| 1 |
import requests
|
|
|
|
| 2 |
import torch
|
| 3 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
|
|
|
| 5 |
from tqdm import tqdm
|
| 6 |
import os
|
| 7 |
import pandas as pd
|
| 8 |
+
import shutil
|
| 9 |
+
|
| 10 |
+
# --- CONFIGURATION FOR KAGGLE DEBUGGING ---
|
| 11 |
+
# We force visualization to TRUE so you can see the images
|
| 12 |
+
VISUALIZE_RESULTS = True
|
| 13 |
|
| 14 |
+
# Use the path to your TRAINING images (from your sidebar)
|
| 15 |
+
# Update this path if yours is different!
|
| 16 |
+
PATH_TO_IMAGES = "/kaggle/input/phase2anewdata/new_data/images/Train"
|
| 17 |
+
# Output folder for marked images
|
| 18 |
+
OUTPUT_DIR = "/kaggle/working/debug_images"
|
| 19 |
|
| 20 |
def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold,
|
| 21 |
visualize_results, visualization_path, device):
|
| 22 |
|
| 23 |
+
# Get first 5 images only for debugging
|
| 24 |
+
test_images = sorted(os.listdir(image_path))[:5]
|
| 25 |
|
| 26 |
bboxes = []
|
| 27 |
category_ids = []
|
| 28 |
test_images_names = []
|
| 29 |
|
| 30 |
+
print(f"🕵️♂️ Debugging on {len(test_images)} images...")
|
| 31 |
+
|
| 32 |
for image_name in tqdm(test_images):
|
|
|
|
| 33 |
test_images_names.append(image_name)
|
| 34 |
bbox = []
|
| 35 |
category_id = []
|
| 36 |
|
| 37 |
+
full_img_path = os.path.join(image_path, image_name)
|
| 38 |
+
if not os.path.exists(full_img_path): continue
|
| 39 |
+
|
| 40 |
+
img = Image.open(full_img_path)
|
| 41 |
|
| 42 |
inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
|
| 43 |
|
|
|
|
| 52 |
target_sizes=[img.size[::-1]]
|
| 53 |
)
|
| 54 |
|
| 55 |
+
# --- VISUALIZATION BLOCK ---
|
| 56 |
if visualize_results:
|
| 57 |
draw = ImageDraw.Draw(img)
|
| 58 |
+
# Try to load a font, fallback to default if fails
|
| 59 |
+
try:
|
| 60 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
| 61 |
+
except:
|
| 62 |
+
font = None
|
| 63 |
+
|
| 64 |
for result in results:
|
| 65 |
boxes = result["boxes"]
|
| 66 |
+
labels = result["labels"]
|
| 67 |
+
scores = result["scores"]
|
| 68 |
+
|
| 69 |
+
for i, box in enumerate(boxes):
|
| 70 |
+
# Draw Box
|
| 71 |
+
b = box.tolist()
|
| 72 |
+
draw.rectangle(b, outline="red", width=4)
|
| 73 |
+
|
| 74 |
+
# Draw Label
|
| 75 |
+
label_text = f"{labels[i]} ({scores[i]:.2f})"
|
| 76 |
+
draw.text((b[0], b[1]), label_text, fill="yellow", font=font)
|
| 77 |
+
|
| 78 |
img.save(os.path.join(visualization_path, image_name))
|
| 79 |
+
# ---------------------------
|
| 80 |
+
|
| 81 |
+
# Simple Saving Logic (All ID=0 for now)
|
| 82 |
for result in results:
|
| 83 |
boxes = result["boxes"]
|
| 84 |
+
for box in boxes:
|
|
|
|
|
|
|
| 85 |
xmin, ymin, xmax, ymax = box.tolist()
|
| 86 |
+
bbox.append([xmin, ymin, xmax-xmin, ymax-ymin])
|
| 87 |
+
category_id.append(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
bboxes.append(bbox)
|
| 90 |
category_ids.append(category_id)
|
| 91 |
|
| 92 |
+
print(f"✅ Debug images saved to: {visualization_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
if __name__ == "__main__":
|
| 96 |
+
current_directory = os.path.dirname(os.path.abspath("__file__"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
# Model Setup
|
|
|
|
| 99 |
model_id = "IDEA-Research/grounding-dino-tiny"
|
| 100 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
# Load model from local folders (Kaggle specific)
|
| 103 |
+
# Ensure these folders exist in your working dir!
|
| 104 |
+
processor = AutoProcessor.from_pretrained("/kaggle/working/processor")
|
| 105 |
+
model = AutoModelForZeroShotObjectDetection.from_pretrained("/kaggle/working/model")
|
| 106 |
model.to(device)
|
| 107 |
|
| 108 |
+
# --- TUNING PARAMETERS ---
|
| 109 |
+
BOX_THRESHOLD = 0.25
|
| 110 |
+
TEXT_THRESHOLD = 0.20
|
| 111 |
+
# New Prompt Attempt: Using "robotic" to catch da Vinci tools
|
| 112 |
+
PROMPT = "robotic needle driver . robotic graspers . curved scissors ."
|
| 113 |
|
| 114 |
+
if os.path.exists(OUTPUT_DIR):
|
| 115 |
+
shutil.rmtree(OUTPUT_DIR)
|
| 116 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
run_inference(PATH_TO_IMAGES, model, "dummy.csv", PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, VISUALIZE_RESULTS, OUTPUT_DIR, device)
|
|
|