File size: 4,300 Bytes
e98e09e
e14c3a3
e98e09e
8a5aefe
e14c3a3
 
 
e98e09e
e14c3a3
c4ac108
e98e09e
 
 
0d3b6b7
 
e98e09e
e14c3a3
 
 
 
594e990
c4ac108
594e990
 
e14c3a3
 
 
 
 
0aa63c6
e98e09e
c4ac108
594e990
 
 
0aa63c6
 
 
e14c3a3
0d3b6b7
e14c3a3
 
 
 
 
 
 
 
 
 
 
 
c4ac108
e14c3a3
 
594e990
4da07fc
e14c3a3
0d3b6b7
 
 
 
e14c3a3
 
 
 
c4ac108
0aa63c6
 
 
0d3b6b7
 
 
 
 
 
0aa63c6
 
e98e09e
e14c3a3
 
 
0d3b6b7
594e990
0aa63c6
 
 
 
 
0d3b6b7
0aa63c6
e14c3a3
594e990
4da07fc
0d3b6b7
c4ac108
 
0d3b6b7
e14c3a3
0d3b6b7
c4ac108
0d3b6b7
594e990
c4ac108
 
 
ce0ade6
c4ac108
 
 
ce0ade6
8a5aefe
c4ac108
 
 
 
2ed995f
594e990
0d3b6b7
0aa63c6
0d3b6b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import torch
import pandas as pd
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from tqdm import tqdm

def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold, device):
    
    # 1. Get list of images
    try:
        test_images = sorted(os.listdir(image_path))
    except FileNotFoundError:
        print(f"โš ๏ธ Warning: Path {image_path} not found. Creating dummy submission.")
        test_images = []

    bboxes = []
    category_ids = []
    test_images_names = []
    
    print(f"๐Ÿš€ Running inference on {len(test_images)} images...")
    print(f"๐Ÿ“ Prompt: {prompt}")
    
    # 2. Loop through all test images
    for image_name in tqdm(test_images):
        test_images_names.append(image_name)
        bbox = []
        category_id = []
        
        try:
            full_img_path = os.path.join(image_path, image_name)
            # Load image and ensure RGB
            img = Image.open(full_img_path).convert("RGB") 
        except Exception as e:
            print(f"Error loading {image_name}: {e}")
            bboxes.append([])
            category_ids.append([])
            continue
        
        inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            
        results = processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            threshold=box_threshold,
            text_threshold=text_threshold,
            target_sizes=[img.size[::-1]]
        )
        
        # 3. Process Results (SAFE MODE: Map all to Class ID 0)
        for result in results:
            boxes = result["boxes"]
            
            for box in boxes:
                xmin, ymin, xmax, ymax = box.tolist()
                width = xmax - xmin
                height = ymax - ymin
                bbox.append([xmin, ymin, width, height])
                category_id.append(0) 
        
        bboxes.append(bbox)
        category_ids.append(category_id)
    
    # 4. Create Submission DataFrame
    df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
    
    for i in range(len(test_images_names)):
        new_row = pd.DataFrame({
            "file_name": test_images_names[i],
            "bbox": str(bboxes[i]),
            "category_id": str(category_ids[i]),
        }, index=[0])
        df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
        
    df_predictions.to_csv(save_path, index=False)
    print("โœ… Submission file generated.")


if __name__ == "__main__":

    # --- HUGGING FACE SERVER CONFIGURATION ---
    os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
    os.environ["HF_HUB_OFFLINE"] = "1"
    os.environ["HF_DATASETS_OFFLINE"] = "1"
    
    current_directory = os.path.dirname(os.path.abspath(__file__))
    TEST_IMAGE_PATH = "/tmp/data/test_images" 
    SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
    
    # --- MODEL LOADING ---
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    processor = AutoProcessor.from_pretrained(os.path.join(current_directory, "processor"))
    model = AutoModelForZeroShotObjectDetection.from_pretrained(os.path.join(current_directory, "model"))
    model.to(device)
    
    # ==========================================
    # ๐Ÿ† REVERTED WINNING CONFIGURATION
    # ==========================================
    
    # 1. Prompt Strategy: "Medical Names + Synonyms"
    # We are bringing back the specific names because the model recognizes them better
    # than generic "silver metal".
    PROMPT = (
        "Monopolar Curved Scissors . surgical scissors . "
        "Prograsp Forceps . grasper jaws . "
        "Large Needle Driver . needle holder ."
    )
    
    # 2. Threshold Strategy: "The Sweet Spot"
    # 0.40 was too high (low recall). 0.25 was too low (high noise).
    # 0.30 balances finding the tool vs ignoring the background.
    BOX_THRESHOLD = 0.30
    TEXT_THRESHOLD = 0.25
    
    # ==========================================
    
    run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device)