import gradio as gr import torch from torchvision import transforms from PIL import Image import numpy as np from unet import ImprovedUNet from huggingface_hub import hf_hub_download import cv2 # Load trained model weights from Hugging Face Hub try: weights_path = hf_hub_download( repo_id="faranbutt789/my-model", # Updated to match your repo filename="unet_weights.pth" # Updated filename as uploaded ) except Exception as e: print(f"Error downloading weights: {e}") # Fallback to local file if available weights_path = "unet_weights_v2.pth" # Initialize and load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ImprovedUNet() try: # Load the state dict state_dict = torch.load(weights_path, map_location=device) model.load_state_dict(state_dict) print("Model weights loaded successfully!") except Exception as e: print(f"Error loading model weights: {e}") print("Using randomly initialized model (for testing)") model.to(device) model.eval() # Preprocessing: same as training IMG_HEIGHT, IMG_WIDTH = 128, 128 transform = transforms.Compose([ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), transforms.ToTensor(), # Add normalization if you used it during training # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(image): if image is None: return None try: # Store original size orig_w, orig_h = image.size # Convert to RGB if not already if image.mode != 'RGB': image = image.convert('RGB') # Apply preprocessing img_tensor = transform(image).unsqueeze(0).to(device) # (1,3,128,128) # Inference with torch.no_grad(): pred = model(img_tensor) # Post-process the prediction mask = pred.squeeze(0).squeeze(0).cpu().numpy() # Remove batch and channel dims # Convert to 0-255 range mask = (mask * 255).astype(np.uint8) # Resize back to original size mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) # Convert to PIL Image mask_img = Image.fromarray(mask_resized, mode='L') # Create a colored overlay for better visualization # Convert grayscale mask to RGB mask_rgb = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2RGB) # Create colored mask (red for cracks) colored_mask = np.zeros_like(mask_rgb) colored_mask[:, :, 0] = mask_resized # Red channel for cracks # Convert original image to numpy for overlay orig_np = np.array(image.resize((orig_w, orig_h))) # Create overlay alpha = 0.4 # Transparency overlay = cv2.addWeighted(orig_np, 1-alpha, colored_mask, alpha, 0) overlay_img = Image.fromarray(overlay) return mask_img, overlay_img except Exception as e: print(f"Error in prediction: {e}") # Return a blank image in case of error blank = Image.new('L', (orig_w, orig_h), 0) return blank, blank def predict_with_threshold(image, threshold): if image is None: return None, None try: orig_w, orig_h = image.size if image.mode != 'RGB': image = image.convert('RGB') img_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): pred = model(img_tensor) mask = pred.squeeze(0).squeeze(0).cpu().numpy() # Apply threshold mask_binary = (mask > threshold).astype(np.uint8) * 255 # Resize back to original size mask_resized = cv2.resize(mask_binary, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) mask_img = Image.fromarray(mask_resized, mode='L') # Create colored overlay mask_rgb = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2RGB) colored_mask = np.zeros_like(mask_rgb) colored_mask[:, :, 0] = mask_resized orig_np = np.array(image.resize((orig_w, orig_h))) alpha = 0.4 overlay = cv2.addWeighted(orig_np, 1-alpha, colored_mask, alpha, 0) overlay_img = Image.fromarray(overlay) return mask_img, overlay_img except Exception as e: print(f"Error in prediction with threshold: {e}") blank = Image.new('L', (orig_w, orig_h), 0) return blank, blank # Create Gradio interface with multiple tabs with gr.Blocks(title="UNet Crack Segmentation", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🔍 Concrete Crack Segmentation with UNet Upload an image of a concrete surface to detect and segment cracks using a trained UNet model. **Features:** - Advanced UNet architecture with batch normalization and dropout - Optimized for highly imbalanced crack detection - Interactive threshold adjustment - Colored overlay visualization """ ) with gr.Tabs(): with gr.TabItem("Basic Prediction"): with gr.Row(): with gr.Column(): input_image1 = gr.Image( type="pil", label="Upload Concrete Image", height=400 ) predict_btn1 = gr.Button("🔍 Detect Cracks", variant="primary", size="lg") with gr.Column(): output_mask1 = gr.Image( label="Crack Mask", height=400 ) output_overlay1 = gr.Image( label="Overlay Visualization", height=400 ) predict_btn1.click( predict, inputs=[input_image1], outputs=[output_mask1, output_overlay1] ) with gr.TabItem("Advanced Prediction"): with gr.Row(): with gr.Column(): input_image2 = gr.Image( type="pil", label="Upload Concrete Image", height=400 ) threshold_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.5, step=0.1, label="Detection Threshold" ) predict_btn2 = gr.Button("🔍 Detect Cracks", variant="primary", size="lg") with gr.Column(): output_mask2 = gr.Image( label="Crack Mask", height=400 ) output_overlay2 = gr.Image( label="Overlay Visualization", height=400 ) predict_btn2.click( predict_with_threshold, inputs=[input_image2, threshold_slider], outputs=[output_mask2, output_overlay2] ) gr.Markdown( """ ### How to use: 1. **Upload** a concrete surface image 2. **Click** "Detect Cracks" to run the segmentation 3. **View** the results: white areas in the mask indicate detected cracks 4. **Adjust** the threshold in Advanced mode for fine-tuning sensitivity ### Model Information: - **Architecture**: Improved UNet with BatchNorm and Dropout - **Input Size**: Images are resized to 128x128 for processing - **Output**: Binary segmentation mask highlighting crack regions - **Training**: Optimized for imbalanced crack detection using advanced loss functions ### Tips for better results: - Use high-contrast images where cracks are visible - Ensure good lighting conditions - Try adjusting the threshold if results seem too sensitive or not sensitive enough """ ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True )