File size: 4,300 Bytes
e98e09e e14c3a3 e98e09e 8a5aefe e14c3a3 e98e09e e14c3a3 c4ac108 e98e09e 0d3b6b7 e98e09e e14c3a3 594e990 c4ac108 594e990 e14c3a3 0aa63c6 e98e09e c4ac108 594e990 0aa63c6 e14c3a3 0d3b6b7 e14c3a3 c4ac108 e14c3a3 594e990 4da07fc e14c3a3 0d3b6b7 e14c3a3 c4ac108 0aa63c6 0d3b6b7 0aa63c6 e98e09e e14c3a3 0d3b6b7 594e990 0aa63c6 0d3b6b7 0aa63c6 e14c3a3 594e990 4da07fc 0d3b6b7 c4ac108 0d3b6b7 e14c3a3 0d3b6b7 c4ac108 0d3b6b7 594e990 c4ac108 ce0ade6 c4ac108 ce0ade6 8a5aefe c4ac108 2ed995f 594e990 0d3b6b7 0aa63c6 0d3b6b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import os
import torch
import pandas as pd
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from tqdm import tqdm
def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold, device):
# 1. Get list of images
try:
test_images = sorted(os.listdir(image_path))
except FileNotFoundError:
print(f"โ ๏ธ Warning: Path {image_path} not found. Creating dummy submission.")
test_images = []
bboxes = []
category_ids = []
test_images_names = []
print(f"๐ Running inference on {len(test_images)} images...")
print(f"๐ Prompt: {prompt}")
# 2. Loop through all test images
for image_name in tqdm(test_images):
test_images_names.append(image_name)
bbox = []
category_id = []
try:
full_img_path = os.path.join(image_path, image_name)
# Load image and ensure RGB
img = Image.open(full_img_path).convert("RGB")
except Exception as e:
print(f"Error loading {image_name}: {e}")
bboxes.append([])
category_ids.append([])
continue
inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[img.size[::-1]]
)
# 3. Process Results (SAFE MODE: Map all to Class ID 0)
for result in results:
boxes = result["boxes"]
for box in boxes:
xmin, ymin, xmax, ymax = box.tolist()
width = xmax - xmin
height = ymax - ymin
bbox.append([xmin, ymin, width, height])
category_id.append(0)
bboxes.append(bbox)
category_ids.append(category_id)
# 4. Create Submission DataFrame
df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
for i in range(len(test_images_names)):
new_row = pd.DataFrame({
"file_name": test_images_names[i],
"bbox": str(bboxes[i]),
"category_id": str(category_ids[i]),
}, index=[0])
df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
df_predictions.to_csv(save_path, index=False)
print("โ
Submission file generated.")
if __name__ == "__main__":
# --- HUGGING FACE SERVER CONFIGURATION ---
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
current_directory = os.path.dirname(os.path.abspath(__file__))
TEST_IMAGE_PATH = "/tmp/data/test_images"
SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
# --- MODEL LOADING ---
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(os.path.join(current_directory, "processor"))
model = AutoModelForZeroShotObjectDetection.from_pretrained(os.path.join(current_directory, "model"))
model.to(device)
# ==========================================
# ๐ REVERTED WINNING CONFIGURATION
# ==========================================
# 1. Prompt Strategy: "Medical Names + Synonyms"
# We are bringing back the specific names because the model recognizes them better
# than generic "silver metal".
PROMPT = (
"Monopolar Curved Scissors . surgical scissors . "
"Prograsp Forceps . grasper jaws . "
"Large Needle Driver . needle holder ."
)
# 2. Threshold Strategy: "The Sweet Spot"
# 0.40 was too high (low recall). 0.25 was too low (high noise).
# 0.30 balances finding the tool vs ignoring the background.
BOX_THRESHOLD = 0.30
TEXT_THRESHOLD = 0.25
# ==========================================
run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device) |