File size: 2,922 Bytes
144afae
03bafc0
 
144afae
 
 
 
a8246e3
144afae
9de67ae
03bafc0
144afae
 
a8246e3
 
144afae
9de67ae
a8246e3
 
 
144afae
9de67ae
144afae
 
 
a8246e3
03bafc0
 
a8246e3
144afae
03bafc0
144afae
9de67ae
03bafc0
 
 
9de67ae
03bafc0
 
 
 
 
 
 
 
 
 
144afae
 
a8246e3
144afae
9de67ae
03bafc0
 
 
 
 
144afae
03bafc0
 
 
 
 
 
 
 
a8246e3
9de67ae
a8246e3
 
 
03bafc0
 
 
 
 
a8246e3
03bafc0
144afae
a8246e3
03bafc0
a8246e3
 
03bafc0
 
a8246e3
144afae
a8246e3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import numpy as np
import gc
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

class CLIPMatcher:
    def __init__(self, model_name='openai/clip-vit-large-patch14'):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.model = CLIPModel.from_pretrained(model_name).to("cpu")
        self.processor = CLIPProcessor.from_pretrained(model_name)
    
    def get_top_k_segments(self, image, segments, text_query, k=5):
        if not segments: return []
        

        ignore = ['remove', 'delete', 'erase', 'the', 'a', 'an']
        words = [w for w in text_query.lower().split() if w not in ignore]
        clean_text = " ".join(words) if words else text_query
        

        pil_image = Image.fromarray(image)
        crops = []
        valid_segments = []
        
        h_img, w_img = image.shape[:2]
        total_img_area = h_img * w_img
        
        for seg in segments:
            if 'bbox' not in seg: continue
            

            bbox = np.array(seg['bbox']).astype(int)
            x1, y1, x2, y2 = bbox
            

            w_box, h_box = x2 - x1, y2 - y1
            pad_x = int(w_box * 0.3)
            pad_y = int(h_box * 0.3)
            
            crop_x1 = max(0, x1 - pad_x)
            crop_y1 = max(0, y1 - pad_y)
            crop_x2 = min(w_img, x2 + pad_x)
            crop_y2 = min(h_img, y2 + pad_y)
            
            crops.append(pil_image.crop((crop_x1, crop_y1, crop_x2, crop_y2)))
            valid_segments.append(seg)
            
        if not crops: return []


        try:
            self.model.to(self.device)
            inputs = self.processor(
                text=[clean_text], images=crops, return_tensors="pt", padding=True
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model(**inputs)
                probs = outputs.logits_per_image.cpu().numpy().flatten()
        except Exception as e:
            print(f"CLIP Error: {e}")
            return []
        finally:
            self.model.to("cpu") 


        final_results = []
        for i, score in enumerate(probs):
            seg = valid_segments[i]
            if 'area' in seg:
                area_ratio = seg['area'] / total_img_area
            else:
                w, h = seg['bbox'][2]-seg['bbox'][0], seg['bbox'][3]-seg['bbox'][1]
                area_ratio = (w*h) / total_img_area
            
            weighted_score = float(score) + (area_ratio * 2.0)
            
            final_results.append({
                'mask': seg.get('mask', None),
                'bbox': seg['bbox'],
                'original_score': float(score),
                'weighted_score': weighted_score,
                'label': seg.get('label', 'object')
            })

        final_results.sort(key=lambda x: x['weighted_score'], reverse=True)
        return final_results[:k]