import os import torch import pandas as pd from PIL import Image from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection from tqdm import tqdm def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold, device): # 1. Get list of images try: test_images = sorted(os.listdir(image_path)) except FileNotFoundError: print(f"⚠️ Warning: Path {image_path} not found. Creating dummy submission.") test_images = [] bboxes = [] category_ids = [] test_images_names = [] print(f"🚀 Running inference on {len(test_images)} images...") print(f"📝 Prompt: {prompt}") # 2. Loop through all test images for image_name in tqdm(test_images): test_images_names.append(image_name) bbox = [] category_id = [] try: full_img_path = os.path.join(image_path, image_name) # Load image and ensure RGB img = Image.open(full_img_path).convert("RGB") except Exception as e: print(f"Error loading {image_name}: {e}") bboxes.append([]) category_ids.append([]) continue inputs = processor(images=img, text=prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, threshold=box_threshold, text_threshold=text_threshold, target_sizes=[img.size[::-1]] ) # 3. Process Results (SAFE MODE: Map all to Class ID 0) for result in results: boxes = result["boxes"] for box in boxes: xmin, ymin, xmax, ymax = box.tolist() width = xmax - xmin height = ymax - ymin bbox.append([xmin, ymin, width, height]) category_id.append(0) bboxes.append(bbox) category_ids.append(category_id) # 4. Create Submission DataFrame df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"]) for i in range(len(test_images_names)): new_row = pd.DataFrame({ "file_name": test_images_names[i], "bbox": str(bboxes[i]), "category_id": str(category_ids[i]), }, index=[0]) df_predictions = pd.concat([df_predictions, new_row], ignore_index=True) df_predictions.to_csv(save_path, index=False) print("✅ Submission file generated.") if __name__ == "__main__": # --- HUGGING FACE SERVER CONFIGURATION --- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" os.environ["HF_HUB_OFFLINE"] = "1" os.environ["HF_DATASETS_OFFLINE"] = "1" current_directory = os.path.dirname(os.path.abspath(__file__)) TEST_IMAGE_PATH = "/tmp/data/test_images" SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv") # --- MODEL LOADING --- device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(os.path.join(current_directory, "processor")) model = AutoModelForZeroShotObjectDetection.from_pretrained(os.path.join(current_directory, "model")) model.to(device) # ========================================== # 🏆 REVERTED WINNING CONFIGURATION # ========================================== # 1. Prompt Strategy: "Medical Names + Synonyms" # We are bringing back the specific names because the model recognizes them better # than generic "silver metal". PROMPT = ( "Monopolar Curved Scissors . surgical scissors . " "Prograsp Forceps . grasper jaws . " "Large Needle Driver . needle holder ." ) # 2. Threshold Strategy: "The Sweet Spot" # 0.40 was too high (low recall). 0.25 was too low (high noise). # 0.30 balances finding the tool vs ignoring the background. BOX_THRESHOLD = 0.30 TEXT_THRESHOLD = 0.25 # ========================================== run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device)