Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image | |
| import torch | |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
| import gradio as gr | |
| # Initialize device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Patch to avoid additional_chat_templates 404 error | |
| # We need to patch the function in the module where it is USED, not just where it's defined | |
| print("Patching transformers to avoid additional_chat_templates 404 error...") | |
| import transformers.tokenization_utils_base | |
| import transformers.utils.hub | |
| try: | |
| from huggingface_hub.errors import RemoteEntryNotFoundError | |
| except ImportError: | |
| # Fallback for older versions of huggingface_hub | |
| from huggingface_hub.utils import EntryNotFoundError as RemoteEntryNotFoundError | |
| # Capture the original function carefully to avoid recursion | |
| # We use a unique attribute to track if we've already patched it | |
| if not hasattr(transformers.utils.hub.list_repo_templates, "_patched"): | |
| _original_list_repo_templates = transformers.utils.hub.list_repo_templates | |
| else: | |
| # If already patched, use the stored original | |
| _original_list_repo_templates = transformers.utils.hub.list_repo_templates._original | |
| def patched_list_repo_templates(repo_id, *args, **kwargs): | |
| """Patch to catch and ignore additional_chat_templates 404 errors""" | |
| try: | |
| results = [] | |
| # Use the captured original function | |
| for template in _original_list_repo_templates(repo_id, *args, **kwargs): | |
| results.append(template) | |
| return results | |
| except (RemoteEntryNotFoundError, Exception) as e: | |
| # Check if this is the additional_chat_templates error | |
| error_str = str(e).lower() | |
| if "additional_chat_templates" in error_str or "404" in error_str: | |
| print(f"Suppressing additional_chat_templates 404 error for {repo_id}") | |
| return [] | |
| raise | |
| # Mark as patched and store original | |
| patched_list_repo_templates._patched = True | |
| patched_list_repo_templates._original = _original_list_repo_templates | |
| # Apply the patch to BOTH locations | |
| transformers.utils.hub.list_repo_templates = patched_list_repo_templates | |
| transformers.tokenization_utils_base.list_repo_templates = patched_list_repo_templates | |
| print("Patch applied to transformers.tokenization_utils_base.list_repo_templates") | |
| # Load processor from original model | |
| print("Loading processor from original model...") | |
| try: | |
| from transformers import CLIPTokenizer, CLIPImageProcessor | |
| # Load components separately | |
| tokenizer = CLIPTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| image_processor = CLIPImageProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer) | |
| print("Processor loaded successfully from original model components") | |
| except Exception as e: | |
| print(f"Error loading processor components: {e}") | |
| # Fallback: try loading processor directly (should work with patch) | |
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| print("Processor loaded directly with patched template check") | |
| # Load models | |
| print("Loading pretrained model...") | |
| model_pretrained = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) | |
| model_pretrained.eval() | |
| print("Loading fine-tuned model...") | |
| try: | |
| model_trained = CLIPSegForImageSegmentation.from_pretrained("smcs/clipseg_drywall").to(device) | |
| model_trained.eval() | |
| model_trained_available = True | |
| print("Fine-tuned model loaded successfully from smcs/clipseg_drywall") | |
| except Exception as e: | |
| print(f"Warning: Could not load fine-tuned model from smcs/clipseg_drywall: {e}") | |
| model_trained = None | |
| model_trained_available = False | |
| # Define prompts | |
| PROMPTS = { | |
| "segment crack": "segment crack", | |
| "segment taping area": "segment taping area" | |
| } | |
| # Example images | |
| example_images = [ | |
| ["examples/crack_1.jpg"], | |
| ["examples/crack_2.jpg"], | |
| ["examples/drywall_1.jpg"], | |
| ["examples/drywall_2.jpg"] | |
| ] | |
| def overlay_mask(image, mask, alpha=0.5, color=(255, 0, 0)): | |
| """Overlay mask on image with transparency and colored mask""" | |
| if mask is None: | |
| return image | |
| # Ensure same size | |
| if mask.size != image.size: | |
| mask = mask.resize(image.size, Image.NEAREST) | |
| # Convert mask to numpy array | |
| mask_array = np.array(mask.convert('L')) | |
| mask_binary = (mask_array > 127).astype(np.float32) | |
| # Create colored mask | |
| colored_mask = np.zeros((*mask_array.shape, 3), dtype=np.uint8) | |
| colored_mask[:, :, 0] = color[0] # Red channel | |
| colored_mask[:, :, 1] = color[1] # Green channel | |
| colored_mask[:, :, 2] = color[2] # Blue channel | |
| # Convert image to numpy array | |
| img_array = np.array(image.convert('RGB')) | |
| # Create overlay | |
| overlay = img_array.copy().astype(np.float32) | |
| for c in range(3): | |
| overlay[:, :, c] = overlay[:, :, c] * (1 - alpha * mask_binary) + colored_mask[:, :, c] * (alpha * mask_binary) | |
| overlay = overlay.astype(np.uint8) | |
| return Image.fromarray(overlay) | |
| def process_image(image, prompt_option): | |
| """ | |
| Process an image with both pretrained and fine-tuned models. | |
| Args: | |
| image: PIL Image or numpy array | |
| prompt_option: Selected prompt option ("segment crack" or "segment taping area") | |
| Returns: | |
| Tuple of (pretrained_mask, trained_mask) or error message | |
| """ | |
| if image is None: | |
| return None, None | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| elif not isinstance(image, Image.Image): | |
| image = Image.open(image).convert('RGB') | |
| else: | |
| image = image.convert('RGB') | |
| # Get the prompt | |
| prompt = PROMPTS.get(prompt_option, prompt_option) | |
| # Resize image for processing | |
| img_orig = image.copy() | |
| img = img_orig.resize((352, 352), Image.BILINEAR) | |
| # Prepare inputs | |
| pixel_values = processor(images=[img], return_tensors="pt")['pixel_values'].to(device) | |
| text_inputs = processor.tokenizer( | |
| prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" | |
| ).to(device) | |
| # Process with pretrained model | |
| with torch.no_grad(): | |
| outputs_pretrained = model_pretrained( | |
| pixel_values=pixel_values, | |
| input_ids=text_inputs['input_ids'], | |
| attention_mask=text_inputs['attention_mask'] | |
| ) | |
| logits_pretrained = outputs_pretrained.logits[0].cpu().numpy() | |
| pred_mask_pretrained = torch.sigmoid(torch.from_numpy(logits_pretrained)).numpy() | |
| pred_mask_pretrained = (pred_mask_pretrained > 0.5).astype(np.uint8) | |
| # Resize mask back to original image size | |
| pred_mask_pretrained_img = Image.fromarray(pred_mask_pretrained * 255, mode='L') | |
| if img_orig.size != (352, 352): | |
| pred_mask_pretrained_img = pred_mask_pretrained_img.resize( | |
| (img_orig.size[0], img_orig.size[1]), Image.NEAREST | |
| ) | |
| # Create overlay for pretrained result (blue color) | |
| pred_mask_pretrained_overlay = overlay_mask(img_orig.copy(), pred_mask_pretrained_img, alpha=0.5, color=(0, 100, 255)) | |
| # Process with fine-tuned model if available | |
| if model_trained_available and model_trained is not None: | |
| with torch.no_grad(): | |
| outputs_trained = model_trained( | |
| pixel_values=pixel_values, | |
| input_ids=text_inputs['input_ids'], | |
| attention_mask=text_inputs['attention_mask'] | |
| ) | |
| logits_trained = outputs_trained.logits[0].cpu().numpy() | |
| pred_mask_trained = torch.sigmoid(torch.from_numpy(logits_trained)).numpy() | |
| pred_mask_trained = (pred_mask_trained > 0.5).astype(np.uint8) | |
| # Resize mask back to original image size | |
| pred_mask_trained_img = Image.fromarray(pred_mask_trained * 255, mode='L') | |
| if img_orig.size != (352, 352): | |
| pred_mask_trained_img = pred_mask_trained_img.resize( | |
| (img_orig.size[0], img_orig.size[1]), Image.NEAREST | |
| ) | |
| # Create overlay for fine-tuned result (green color) | |
| pred_mask_trained_overlay = overlay_mask(img_orig.copy(), pred_mask_trained_img, alpha=0.5, color=(0, 255, 0)) | |
| else: | |
| # Create a placeholder image with message | |
| placeholder = Image.new('RGB', img_orig.size, color=(240, 240, 240)) | |
| pred_mask_trained_overlay = placeholder | |
| return pred_mask_pretrained_overlay, pred_mask_trained_overlay | |
| except Exception as e: | |
| error_msg = f"Error processing image: {str(e)}" | |
| print(error_msg) | |
| return None, None | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| with gr.Blocks(title="CLIPSeg Image Segmentation") as demo: | |
| gr.Markdown( | |
| """ | |
| # CLIPSeg Image Segmentation Demo | |
| This demo compares zero-shot pretrained CLIPSeg results with fine-tuned model results. | |
| Select an example image or upload your own, then choose a prompt to see the segmentation results. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| prompt_dropdown = gr.Dropdown( | |
| choices=list(PROMPTS.keys()), | |
| value=list(PROMPTS.keys())[0], | |
| label="Select Prompt", | |
| info="Choose the segmentation prompt" | |
| ) | |
| submit_btn = gr.Button("Segment", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pretrained_output = gr.Image( | |
| label="Pretrained (Zero-shot) Result", | |
| type="pil", | |
| height=400 | |
| ) | |
| with gr.Column(): | |
| trained_output = gr.Image( | |
| label="Fine-tuned Result" + (" (Not Available)" if not model_trained_available else ""), | |
| type="pil", | |
| height=400 | |
| ) | |
| if not model_trained_available: | |
| gr.Markdown( | |
| "⚠️ **Note:** Fine-tuned model could not be loaded from `smcs/clipseg_drywall`. " | |
| "Only pretrained results will be shown." | |
| ) | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=image_input, | |
| label="Example Images" | |
| ) | |
| # Connect the function | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[image_input, prompt_dropdown], | |
| outputs=[pretrained_output, trained_output] | |
| ) | |
| # Also process when example is selected | |
| image_input.change( | |
| fn=process_image, | |
| inputs=[image_input, prompt_dropdown], | |
| outputs=[pretrained_output, trained_output] | |
| ) | |
| # Process when prompt changes | |
| prompt_dropdown.change( | |
| fn=process_image, | |
| inputs=[image_input, prompt_dropdown], | |
| outputs=[pretrained_output, trained_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=False) | |