File size: 11,809 Bytes
63cd310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c0344
 
 
 
 
63cd310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Patch to avoid additional_chat_templates 404 error
# We need to patch the function in the module where it is USED, not just where it's defined
print("Patching transformers to avoid additional_chat_templates 404 error...")

import transformers.tokenization_utils_base
import transformers.utils.hub
try:
    from huggingface_hub.errors import RemoteEntryNotFoundError
except ImportError:
    # Fallback for older versions of huggingface_hub
    from huggingface_hub.utils import EntryNotFoundError as RemoteEntryNotFoundError

# Capture the original function carefully to avoid recursion
# We use a unique attribute to track if we've already patched it
if not hasattr(transformers.utils.hub.list_repo_templates, "_patched"):
    _original_list_repo_templates = transformers.utils.hub.list_repo_templates
else:
    # If already patched, use the stored original
    _original_list_repo_templates = transformers.utils.hub.list_repo_templates._original

def patched_list_repo_templates(repo_id, *args, **kwargs):
    """Patch to catch and ignore additional_chat_templates 404 errors"""
    try:
        results = []
        # Use the captured original function
        for template in _original_list_repo_templates(repo_id, *args, **kwargs):
            results.append(template)
        return results
    except (RemoteEntryNotFoundError, Exception) as e:
        # Check if this is the additional_chat_templates error
        error_str = str(e).lower()
        if "additional_chat_templates" in error_str or "404" in error_str:
            print(f"Suppressing additional_chat_templates 404 error for {repo_id}")
            return []
        raise

# Mark as patched and store original
patched_list_repo_templates._patched = True
patched_list_repo_templates._original = _original_list_repo_templates

# Apply the patch to BOTH locations
transformers.utils.hub.list_repo_templates = patched_list_repo_templates
transformers.tokenization_utils_base.list_repo_templates = patched_list_repo_templates
print("Patch applied to transformers.tokenization_utils_base.list_repo_templates")

# Load processor from original model
print("Loading processor from original model...")
try:
    from transformers import CLIPTokenizer, CLIPImageProcessor
    # Load components separately
    tokenizer = CLIPTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
    image_processor = CLIPImageProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
    processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
    print("Processor loaded successfully from original model components")
except Exception as e:
    print(f"Error loading processor components: {e}")
    # Fallback: try loading processor directly (should work with patch)
    processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
    print("Processor loaded directly with patched template check")

# Load models
print("Loading pretrained model...")
model_pretrained = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
model_pretrained.eval()

print("Loading fine-tuned model...")
try:
    model_trained = CLIPSegForImageSegmentation.from_pretrained("smcs/clipseg_drywall").to(device)
    model_trained.eval()
    model_trained_available = True
    print("Fine-tuned model loaded successfully from smcs/clipseg_drywall")
except Exception as e:
    print(f"Warning: Could not load fine-tuned model from smcs/clipseg_drywall: {e}")
    model_trained = None
    model_trained_available = False

# Define prompts
PROMPTS = {
    "segment crack": "segment crack",
    "segment taping area": "segment taping area"
}

# Example images
example_images = [
    ["examples/crack_1.jpg"],
    ["examples/crack_2.jpg"],
    ["examples/drywall_1.jpg"],
    ["examples/drywall_2.jpg"]
]


def overlay_mask(image, mask, alpha=0.5, color=(255, 0, 0)):
    """Overlay mask on image with transparency and colored mask"""
    if mask is None:
        return image
    
    # Ensure same size
    if mask.size != image.size:
        mask = mask.resize(image.size, Image.NEAREST)
    
    # Convert mask to numpy array
    mask_array = np.array(mask.convert('L'))
    mask_binary = (mask_array > 127).astype(np.float32)
    
    # Create colored mask
    colored_mask = np.zeros((*mask_array.shape, 3), dtype=np.uint8)
    colored_mask[:, :, 0] = color[0]  # Red channel
    colored_mask[:, :, 1] = color[1]  # Green channel
    colored_mask[:, :, 2] = color[2]  # Blue channel
    
    # Convert image to numpy array
    img_array = np.array(image.convert('RGB'))
    
    # Create overlay
    overlay = img_array.copy().astype(np.float32)
    for c in range(3):
        overlay[:, :, c] = overlay[:, :, c] * (1 - alpha * mask_binary) + colored_mask[:, :, c] * (alpha * mask_binary)
    
    overlay = overlay.astype(np.uint8)
    return Image.fromarray(overlay)


def process_image(image, prompt_option):
    """
    Process an image with both pretrained and fine-tuned models.
    
    Args:
        image: PIL Image or numpy array
        prompt_option: Selected prompt option ("segment crack" or "segment taping area")
    
    Returns:
        Tuple of (pretrained_mask, trained_mask) or error message
    """
    if image is None:
        return None, None
    
    try:
        # Convert to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        elif not isinstance(image, Image.Image):
            image = Image.open(image).convert('RGB')
        else:
            image = image.convert('RGB')
        
        # Get the prompt
        prompt = PROMPTS.get(prompt_option, prompt_option)
        
        # Resize image for processing
        img_orig = image.copy()
        img = img_orig.resize((352, 352), Image.BILINEAR)
        
        # Prepare inputs
        pixel_values = processor(images=[img], return_tensors="pt")['pixel_values'].to(device)
        text_inputs = processor.tokenizer(
            prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
        ).to(device)
        
        # Process with pretrained model
        with torch.no_grad():
            outputs_pretrained = model_pretrained(
                pixel_values=pixel_values,
                input_ids=text_inputs['input_ids'],
                attention_mask=text_inputs['attention_mask']
            )
            logits_pretrained = outputs_pretrained.logits[0].cpu().numpy()
        
        pred_mask_pretrained = torch.sigmoid(torch.from_numpy(logits_pretrained)).numpy()
        pred_mask_pretrained = (pred_mask_pretrained > 0.5).astype(np.uint8)
        
        # Resize mask back to original image size
        pred_mask_pretrained_img = Image.fromarray(pred_mask_pretrained * 255, mode='L')
        if img_orig.size != (352, 352):
            pred_mask_pretrained_img = pred_mask_pretrained_img.resize(
                (img_orig.size[0], img_orig.size[1]), Image.NEAREST
            )
        
        # Create overlay for pretrained result (blue color)
        pred_mask_pretrained_overlay = overlay_mask(img_orig.copy(), pred_mask_pretrained_img, alpha=0.5, color=(0, 100, 255))
        
        # Process with fine-tuned model if available
        if model_trained_available and model_trained is not None:
            with torch.no_grad():
                outputs_trained = model_trained(
                    pixel_values=pixel_values,
                    input_ids=text_inputs['input_ids'],
                    attention_mask=text_inputs['attention_mask']
                )
                logits_trained = outputs_trained.logits[0].cpu().numpy()
            
            pred_mask_trained = torch.sigmoid(torch.from_numpy(logits_trained)).numpy()
            pred_mask_trained = (pred_mask_trained > 0.5).astype(np.uint8)
            
            # Resize mask back to original image size
            pred_mask_trained_img = Image.fromarray(pred_mask_trained * 255, mode='L')
            if img_orig.size != (352, 352):
                pred_mask_trained_img = pred_mask_trained_img.resize(
                    (img_orig.size[0], img_orig.size[1]), Image.NEAREST
                )
            
            # Create overlay for fine-tuned result (green color)
            pred_mask_trained_overlay = overlay_mask(img_orig.copy(), pred_mask_trained_img, alpha=0.5, color=(0, 255, 0))
        else:
            # Create a placeholder image with message
            placeholder = Image.new('RGB', img_orig.size, color=(240, 240, 240))
            pred_mask_trained_overlay = placeholder
        
        return pred_mask_pretrained_overlay, pred_mask_trained_overlay
    
    except Exception as e:
        error_msg = f"Error processing image: {str(e)}"
        print(error_msg)
        return None, None


def create_interface():
    """Create the Gradio interface"""
    
    with gr.Blocks(title="CLIPSeg Image Segmentation") as demo:
        gr.Markdown(
            """
            # CLIPSeg Image Segmentation Demo
            
            This demo compares zero-shot pretrained CLIPSeg results with fine-tuned model results.
            Select an example image or upload your own, then choose a prompt to see the segmentation results.
            """
        )
        
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(
                    label="Input Image",
                    type="pil",
                    height=400
                )
                
                prompt_dropdown = gr.Dropdown(
                    choices=list(PROMPTS.keys()),
                    value=list(PROMPTS.keys())[0],
                    label="Select Prompt",
                    info="Choose the segmentation prompt"
                )
                
                submit_btn = gr.Button("Segment", variant="primary")
        
        with gr.Row():
            with gr.Column():
                pretrained_output = gr.Image(
                    label="Pretrained (Zero-shot) Result",
                    type="pil",
                    height=400
                )
            
            with gr.Column():
                trained_output = gr.Image(
                    label="Fine-tuned Result" + (" (Not Available)" if not model_trained_available else ""),
                    type="pil",
                    height=400
                )
        
        if not model_trained_available:
            gr.Markdown(
                "⚠️ **Note:** Fine-tuned model could not be loaded from `smcs/clipseg_drywall`. "
                "Only pretrained results will be shown."
            )
        
        gr.Examples(
            examples=example_images,
            inputs=image_input,
            label="Example Images"
        )
        
        # Connect the function
        submit_btn.click(
            fn=process_image,
            inputs=[image_input, prompt_dropdown],
            outputs=[pretrained_output, trained_output]
        )
        
        # Also process when example is selected
        image_input.change(
            fn=process_image,
            inputs=[image_input, prompt_dropdown],
            outputs=[pretrained_output, trained_output]
        )
        
        # Process when prompt changes
        prompt_dropdown.change(
            fn=process_image,
            inputs=[image_input, prompt_dropdown],
            outputs=[pretrained_output, trained_output]
        )
    
    return demo


if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=False)