smcs's picture
Fix ImportError for RemoteEntryNotFoundError
b2c0344
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Patch to avoid additional_chat_templates 404 error
# We need to patch the function in the module where it is USED, not just where it's defined
print("Patching transformers to avoid additional_chat_templates 404 error...")
import transformers.tokenization_utils_base
import transformers.utils.hub
try:
from huggingface_hub.errors import RemoteEntryNotFoundError
except ImportError:
# Fallback for older versions of huggingface_hub
from huggingface_hub.utils import EntryNotFoundError as RemoteEntryNotFoundError
# Capture the original function carefully to avoid recursion
# We use a unique attribute to track if we've already patched it
if not hasattr(transformers.utils.hub.list_repo_templates, "_patched"):
_original_list_repo_templates = transformers.utils.hub.list_repo_templates
else:
# If already patched, use the stored original
_original_list_repo_templates = transformers.utils.hub.list_repo_templates._original
def patched_list_repo_templates(repo_id, *args, **kwargs):
"""Patch to catch and ignore additional_chat_templates 404 errors"""
try:
results = []
# Use the captured original function
for template in _original_list_repo_templates(repo_id, *args, **kwargs):
results.append(template)
return results
except (RemoteEntryNotFoundError, Exception) as e:
# Check if this is the additional_chat_templates error
error_str = str(e).lower()
if "additional_chat_templates" in error_str or "404" in error_str:
print(f"Suppressing additional_chat_templates 404 error for {repo_id}")
return []
raise
# Mark as patched and store original
patched_list_repo_templates._patched = True
patched_list_repo_templates._original = _original_list_repo_templates
# Apply the patch to BOTH locations
transformers.utils.hub.list_repo_templates = patched_list_repo_templates
transformers.tokenization_utils_base.list_repo_templates = patched_list_repo_templates
print("Patch applied to transformers.tokenization_utils_base.list_repo_templates")
# Load processor from original model
print("Loading processor from original model...")
try:
from transformers import CLIPTokenizer, CLIPImageProcessor
# Load components separately
tokenizer = CLIPTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
image_processor = CLIPImageProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
print("Processor loaded successfully from original model components")
except Exception as e:
print(f"Error loading processor components: {e}")
# Fallback: try loading processor directly (should work with patch)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
print("Processor loaded directly with patched template check")
# Load models
print("Loading pretrained model...")
model_pretrained = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
model_pretrained.eval()
print("Loading fine-tuned model...")
try:
model_trained = CLIPSegForImageSegmentation.from_pretrained("smcs/clipseg_drywall").to(device)
model_trained.eval()
model_trained_available = True
print("Fine-tuned model loaded successfully from smcs/clipseg_drywall")
except Exception as e:
print(f"Warning: Could not load fine-tuned model from smcs/clipseg_drywall: {e}")
model_trained = None
model_trained_available = False
# Define prompts
PROMPTS = {
"segment crack": "segment crack",
"segment taping area": "segment taping area"
}
# Example images
example_images = [
["examples/crack_1.jpg"],
["examples/crack_2.jpg"],
["examples/drywall_1.jpg"],
["examples/drywall_2.jpg"]
]
def overlay_mask(image, mask, alpha=0.5, color=(255, 0, 0)):
"""Overlay mask on image with transparency and colored mask"""
if mask is None:
return image
# Ensure same size
if mask.size != image.size:
mask = mask.resize(image.size, Image.NEAREST)
# Convert mask to numpy array
mask_array = np.array(mask.convert('L'))
mask_binary = (mask_array > 127).astype(np.float32)
# Create colored mask
colored_mask = np.zeros((*mask_array.shape, 3), dtype=np.uint8)
colored_mask[:, :, 0] = color[0] # Red channel
colored_mask[:, :, 1] = color[1] # Green channel
colored_mask[:, :, 2] = color[2] # Blue channel
# Convert image to numpy array
img_array = np.array(image.convert('RGB'))
# Create overlay
overlay = img_array.copy().astype(np.float32)
for c in range(3):
overlay[:, :, c] = overlay[:, :, c] * (1 - alpha * mask_binary) + colored_mask[:, :, c] * (alpha * mask_binary)
overlay = overlay.astype(np.uint8)
return Image.fromarray(overlay)
def process_image(image, prompt_option):
"""
Process an image with both pretrained and fine-tuned models.
Args:
image: PIL Image or numpy array
prompt_option: Selected prompt option ("segment crack" or "segment taping area")
Returns:
Tuple of (pretrained_mask, trained_mask) or error message
"""
if image is None:
return None, None
try:
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
image = Image.open(image).convert('RGB')
else:
image = image.convert('RGB')
# Get the prompt
prompt = PROMPTS.get(prompt_option, prompt_option)
# Resize image for processing
img_orig = image.copy()
img = img_orig.resize((352, 352), Image.BILINEAR)
# Prepare inputs
pixel_values = processor(images=[img], return_tensors="pt")['pixel_values'].to(device)
text_inputs = processor.tokenizer(
prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
).to(device)
# Process with pretrained model
with torch.no_grad():
outputs_pretrained = model_pretrained(
pixel_values=pixel_values,
input_ids=text_inputs['input_ids'],
attention_mask=text_inputs['attention_mask']
)
logits_pretrained = outputs_pretrained.logits[0].cpu().numpy()
pred_mask_pretrained = torch.sigmoid(torch.from_numpy(logits_pretrained)).numpy()
pred_mask_pretrained = (pred_mask_pretrained > 0.5).astype(np.uint8)
# Resize mask back to original image size
pred_mask_pretrained_img = Image.fromarray(pred_mask_pretrained * 255, mode='L')
if img_orig.size != (352, 352):
pred_mask_pretrained_img = pred_mask_pretrained_img.resize(
(img_orig.size[0], img_orig.size[1]), Image.NEAREST
)
# Create overlay for pretrained result (blue color)
pred_mask_pretrained_overlay = overlay_mask(img_orig.copy(), pred_mask_pretrained_img, alpha=0.5, color=(0, 100, 255))
# Process with fine-tuned model if available
if model_trained_available and model_trained is not None:
with torch.no_grad():
outputs_trained = model_trained(
pixel_values=pixel_values,
input_ids=text_inputs['input_ids'],
attention_mask=text_inputs['attention_mask']
)
logits_trained = outputs_trained.logits[0].cpu().numpy()
pred_mask_trained = torch.sigmoid(torch.from_numpy(logits_trained)).numpy()
pred_mask_trained = (pred_mask_trained > 0.5).astype(np.uint8)
# Resize mask back to original image size
pred_mask_trained_img = Image.fromarray(pred_mask_trained * 255, mode='L')
if img_orig.size != (352, 352):
pred_mask_trained_img = pred_mask_trained_img.resize(
(img_orig.size[0], img_orig.size[1]), Image.NEAREST
)
# Create overlay for fine-tuned result (green color)
pred_mask_trained_overlay = overlay_mask(img_orig.copy(), pred_mask_trained_img, alpha=0.5, color=(0, 255, 0))
else:
# Create a placeholder image with message
placeholder = Image.new('RGB', img_orig.size, color=(240, 240, 240))
pred_mask_trained_overlay = placeholder
return pred_mask_pretrained_overlay, pred_mask_trained_overlay
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
print(error_msg)
return None, None
def create_interface():
"""Create the Gradio interface"""
with gr.Blocks(title="CLIPSeg Image Segmentation") as demo:
gr.Markdown(
"""
# CLIPSeg Image Segmentation Demo
This demo compares zero-shot pretrained CLIPSeg results with fine-tuned model results.
Select an example image or upload your own, then choose a prompt to see the segmentation results.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Input Image",
type="pil",
height=400
)
prompt_dropdown = gr.Dropdown(
choices=list(PROMPTS.keys()),
value=list(PROMPTS.keys())[0],
label="Select Prompt",
info="Choose the segmentation prompt"
)
submit_btn = gr.Button("Segment", variant="primary")
with gr.Row():
with gr.Column():
pretrained_output = gr.Image(
label="Pretrained (Zero-shot) Result",
type="pil",
height=400
)
with gr.Column():
trained_output = gr.Image(
label="Fine-tuned Result" + (" (Not Available)" if not model_trained_available else ""),
type="pil",
height=400
)
if not model_trained_available:
gr.Markdown(
"⚠️ **Note:** Fine-tuned model could not be loaded from `smcs/clipseg_drywall`. "
"Only pretrained results will be shown."
)
gr.Examples(
examples=example_images,
inputs=image_input,
label="Example Images"
)
# Connect the function
submit_btn.click(
fn=process_image,
inputs=[image_input, prompt_dropdown],
outputs=[pretrained_output, trained_output]
)
# Also process when example is selected
image_input.change(
fn=process_image,
inputs=[image_input, prompt_dropdown],
outputs=[pretrained_output, trained_output]
)
# Process when prompt changes
prompt_dropdown.change(
fn=process_image,
inputs=[image_input, prompt_dropdown],
outputs=[pretrained_output, trained_output]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=False)