import numpy as np from pathlib import Path from PIL import Image import torch from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import gradio as gr # Initialize device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Patch to avoid additional_chat_templates 404 error # We need to patch the function in the module where it is USED, not just where it's defined print("Patching transformers to avoid additional_chat_templates 404 error...") import transformers.tokenization_utils_base import transformers.utils.hub try: from huggingface_hub.errors import RemoteEntryNotFoundError except ImportError: # Fallback for older versions of huggingface_hub from huggingface_hub.utils import EntryNotFoundError as RemoteEntryNotFoundError # Capture the original function carefully to avoid recursion # We use a unique attribute to track if we've already patched it if not hasattr(transformers.utils.hub.list_repo_templates, "_patched"): _original_list_repo_templates = transformers.utils.hub.list_repo_templates else: # If already patched, use the stored original _original_list_repo_templates = transformers.utils.hub.list_repo_templates._original def patched_list_repo_templates(repo_id, *args, **kwargs): """Patch to catch and ignore additional_chat_templates 404 errors""" try: results = [] # Use the captured original function for template in _original_list_repo_templates(repo_id, *args, **kwargs): results.append(template) return results except (RemoteEntryNotFoundError, Exception) as e: # Check if this is the additional_chat_templates error error_str = str(e).lower() if "additional_chat_templates" in error_str or "404" in error_str: print(f"Suppressing additional_chat_templates 404 error for {repo_id}") return [] raise # Mark as patched and store original patched_list_repo_templates._patched = True patched_list_repo_templates._original = _original_list_repo_templates # Apply the patch to BOTH locations transformers.utils.hub.list_repo_templates = patched_list_repo_templates transformers.tokenization_utils_base.list_repo_templates = patched_list_repo_templates print("Patch applied to transformers.tokenization_utils_base.list_repo_templates") # Load processor from original model print("Loading processor from original model...") try: from transformers import CLIPTokenizer, CLIPImageProcessor # Load components separately tokenizer = CLIPTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined") image_processor = CLIPImageProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer) print("Processor loaded successfully from original model components") except Exception as e: print(f"Error loading processor components: {e}") # Fallback: try loading processor directly (should work with patch) processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") print("Processor loaded directly with patched template check") # Load models print("Loading pretrained model...") model_pretrained = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) model_pretrained.eval() print("Loading fine-tuned model...") try: model_trained = CLIPSegForImageSegmentation.from_pretrained("smcs/clipseg_drywall").to(device) model_trained.eval() model_trained_available = True print("Fine-tuned model loaded successfully from smcs/clipseg_drywall") except Exception as e: print(f"Warning: Could not load fine-tuned model from smcs/clipseg_drywall: {e}") model_trained = None model_trained_available = False # Define prompts PROMPTS = { "segment crack": "segment crack", "segment taping area": "segment taping area" } # Example images example_images = [ ["examples/crack_1.jpg"], ["examples/crack_2.jpg"], ["examples/drywall_1.jpg"], ["examples/drywall_2.jpg"] ] def overlay_mask(image, mask, alpha=0.5, color=(255, 0, 0)): """Overlay mask on image with transparency and colored mask""" if mask is None: return image # Ensure same size if mask.size != image.size: mask = mask.resize(image.size, Image.NEAREST) # Convert mask to numpy array mask_array = np.array(mask.convert('L')) mask_binary = (mask_array > 127).astype(np.float32) # Create colored mask colored_mask = np.zeros((*mask_array.shape, 3), dtype=np.uint8) colored_mask[:, :, 0] = color[0] # Red channel colored_mask[:, :, 1] = color[1] # Green channel colored_mask[:, :, 2] = color[2] # Blue channel # Convert image to numpy array img_array = np.array(image.convert('RGB')) # Create overlay overlay = img_array.copy().astype(np.float32) for c in range(3): overlay[:, :, c] = overlay[:, :, c] * (1 - alpha * mask_binary) + colored_mask[:, :, c] * (alpha * mask_binary) overlay = overlay.astype(np.uint8) return Image.fromarray(overlay) def process_image(image, prompt_option): """ Process an image with both pretrained and fine-tuned models. Args: image: PIL Image or numpy array prompt_option: Selected prompt option ("segment crack" or "segment taping area") Returns: Tuple of (pretrained_mask, trained_mask) or error message """ if image is None: return None, None try: # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) elif not isinstance(image, Image.Image): image = Image.open(image).convert('RGB') else: image = image.convert('RGB') # Get the prompt prompt = PROMPTS.get(prompt_option, prompt_option) # Resize image for processing img_orig = image.copy() img = img_orig.resize((352, 352), Image.BILINEAR) # Prepare inputs pixel_values = processor(images=[img], return_tensors="pt")['pixel_values'].to(device) text_inputs = processor.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).to(device) # Process with pretrained model with torch.no_grad(): outputs_pretrained = model_pretrained( pixel_values=pixel_values, input_ids=text_inputs['input_ids'], attention_mask=text_inputs['attention_mask'] ) logits_pretrained = outputs_pretrained.logits[0].cpu().numpy() pred_mask_pretrained = torch.sigmoid(torch.from_numpy(logits_pretrained)).numpy() pred_mask_pretrained = (pred_mask_pretrained > 0.5).astype(np.uint8) # Resize mask back to original image size pred_mask_pretrained_img = Image.fromarray(pred_mask_pretrained * 255, mode='L') if img_orig.size != (352, 352): pred_mask_pretrained_img = pred_mask_pretrained_img.resize( (img_orig.size[0], img_orig.size[1]), Image.NEAREST ) # Create overlay for pretrained result (blue color) pred_mask_pretrained_overlay = overlay_mask(img_orig.copy(), pred_mask_pretrained_img, alpha=0.5, color=(0, 100, 255)) # Process with fine-tuned model if available if model_trained_available and model_trained is not None: with torch.no_grad(): outputs_trained = model_trained( pixel_values=pixel_values, input_ids=text_inputs['input_ids'], attention_mask=text_inputs['attention_mask'] ) logits_trained = outputs_trained.logits[0].cpu().numpy() pred_mask_trained = torch.sigmoid(torch.from_numpy(logits_trained)).numpy() pred_mask_trained = (pred_mask_trained > 0.5).astype(np.uint8) # Resize mask back to original image size pred_mask_trained_img = Image.fromarray(pred_mask_trained * 255, mode='L') if img_orig.size != (352, 352): pred_mask_trained_img = pred_mask_trained_img.resize( (img_orig.size[0], img_orig.size[1]), Image.NEAREST ) # Create overlay for fine-tuned result (green color) pred_mask_trained_overlay = overlay_mask(img_orig.copy(), pred_mask_trained_img, alpha=0.5, color=(0, 255, 0)) else: # Create a placeholder image with message placeholder = Image.new('RGB', img_orig.size, color=(240, 240, 240)) pred_mask_trained_overlay = placeholder return pred_mask_pretrained_overlay, pred_mask_trained_overlay except Exception as e: error_msg = f"Error processing image: {str(e)}" print(error_msg) return None, None def create_interface(): """Create the Gradio interface""" with gr.Blocks(title="CLIPSeg Image Segmentation") as demo: gr.Markdown( """ # CLIPSeg Image Segmentation Demo This demo compares zero-shot pretrained CLIPSeg results with fine-tuned model results. Select an example image or upload your own, then choose a prompt to see the segmentation results. """ ) with gr.Row(): with gr.Column(): image_input = gr.Image( label="Input Image", type="pil", height=400 ) prompt_dropdown = gr.Dropdown( choices=list(PROMPTS.keys()), value=list(PROMPTS.keys())[0], label="Select Prompt", info="Choose the segmentation prompt" ) submit_btn = gr.Button("Segment", variant="primary") with gr.Row(): with gr.Column(): pretrained_output = gr.Image( label="Pretrained (Zero-shot) Result", type="pil", height=400 ) with gr.Column(): trained_output = gr.Image( label="Fine-tuned Result" + (" (Not Available)" if not model_trained_available else ""), type="pil", height=400 ) if not model_trained_available: gr.Markdown( "⚠️ **Note:** Fine-tuned model could not be loaded from `smcs/clipseg_drywall`. " "Only pretrained results will be shown." ) gr.Examples( examples=example_images, inputs=image_input, label="Example Images" ) # Connect the function submit_btn.click( fn=process_image, inputs=[image_input, prompt_dropdown], outputs=[pretrained_output, trained_output] ) # Also process when example is selected image_input.change( fn=process_image, inputs=[image_input, prompt_dropdown], outputs=[pretrained_output, trained_output] ) # Process when prompt changes prompt_dropdown.change( fn=process_image, inputs=[image_input, prompt_dropdown], outputs=[pretrained_output, trained_output] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch(share=False)