|
|
import gradio as gr |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from unet import ImprovedUNet |
|
|
from huggingface_hub import hf_hub_download |
|
|
import cv2 |
|
|
|
|
|
|
|
|
try: |
|
|
weights_path = hf_hub_download( |
|
|
repo_id="faranbutt789/my-model", |
|
|
filename="unet_weights.pth" |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error downloading weights: {e}") |
|
|
|
|
|
weights_path = "unet_weights_v2.pth" |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = ImprovedUNet() |
|
|
|
|
|
try: |
|
|
|
|
|
state_dict = torch.load(weights_path, map_location=device) |
|
|
model.load_state_dict(state_dict) |
|
|
print("Model weights loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading model weights: {e}") |
|
|
print("Using randomly initialized model (for testing)") |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
IMG_HEIGHT, IMG_WIDTH = 128, 128 |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), |
|
|
transforms.ToTensor(), |
|
|
|
|
|
|
|
|
]) |
|
|
|
|
|
def predict(image): |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
orig_w, orig_h = image.size |
|
|
|
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
|
|
|
img_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
pred = model(img_tensor) |
|
|
|
|
|
|
|
|
mask = pred.squeeze(0).squeeze(0).cpu().numpy() |
|
|
|
|
|
|
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
|
|
|
mask_img = Image.fromarray(mask_resized, mode='L') |
|
|
|
|
|
|
|
|
|
|
|
mask_rgb = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2RGB) |
|
|
|
|
|
|
|
|
colored_mask = np.zeros_like(mask_rgb) |
|
|
colored_mask[:, :, 0] = mask_resized |
|
|
|
|
|
|
|
|
orig_np = np.array(image.resize((orig_w, orig_h))) |
|
|
|
|
|
|
|
|
alpha = 0.4 |
|
|
overlay = cv2.addWeighted(orig_np, 1-alpha, colored_mask, alpha, 0) |
|
|
overlay_img = Image.fromarray(overlay) |
|
|
|
|
|
return mask_img, overlay_img |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in prediction: {e}") |
|
|
|
|
|
blank = Image.new('L', (orig_w, orig_h), 0) |
|
|
return blank, blank |
|
|
|
|
|
def predict_with_threshold(image, threshold): |
|
|
if image is None: |
|
|
return None, None |
|
|
|
|
|
try: |
|
|
orig_w, orig_h = image.size |
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
img_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred = model(img_tensor) |
|
|
|
|
|
mask = pred.squeeze(0).squeeze(0).cpu().numpy() |
|
|
|
|
|
|
|
|
mask_binary = (mask > threshold).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
mask_resized = cv2.resize(mask_binary, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) |
|
|
mask_img = Image.fromarray(mask_resized, mode='L') |
|
|
|
|
|
|
|
|
mask_rgb = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2RGB) |
|
|
colored_mask = np.zeros_like(mask_rgb) |
|
|
colored_mask[:, :, 0] = mask_resized |
|
|
|
|
|
orig_np = np.array(image.resize((orig_w, orig_h))) |
|
|
alpha = 0.4 |
|
|
overlay = cv2.addWeighted(orig_np, 1-alpha, colored_mask, alpha, 0) |
|
|
overlay_img = Image.fromarray(overlay) |
|
|
|
|
|
return mask_img, overlay_img |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in prediction with threshold: {e}") |
|
|
blank = Image.new('L', (orig_w, orig_h), 0) |
|
|
return blank, blank |
|
|
|
|
|
|
|
|
with gr.Blocks(title="UNet Crack Segmentation", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π Concrete Crack Segmentation with UNet |
|
|
|
|
|
Upload an image of a concrete surface to detect and segment cracks using a trained UNet model. |
|
|
|
|
|
**Features:** |
|
|
- Advanced UNet architecture with batch normalization and dropout |
|
|
- Optimized for highly imbalanced crack detection |
|
|
- Interactive threshold adjustment |
|
|
- Colored overlay visualization |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Basic Prediction"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image1 = gr.Image( |
|
|
type="pil", |
|
|
label="Upload Concrete Image", |
|
|
height=400 |
|
|
) |
|
|
predict_btn1 = gr.Button("π Detect Cracks", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
output_mask1 = gr.Image( |
|
|
label="Crack Mask", |
|
|
height=400 |
|
|
) |
|
|
output_overlay1 = gr.Image( |
|
|
label="Overlay Visualization", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
predict_btn1.click( |
|
|
predict, |
|
|
inputs=[input_image1], |
|
|
outputs=[output_mask1, output_overlay1] |
|
|
) |
|
|
|
|
|
with gr.TabItem("Advanced Prediction"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image2 = gr.Image( |
|
|
type="pil", |
|
|
label="Upload Concrete Image", |
|
|
height=400 |
|
|
) |
|
|
threshold_slider = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=0.9, |
|
|
value=0.5, |
|
|
step=0.1, |
|
|
label="Detection Threshold" |
|
|
) |
|
|
predict_btn2 = gr.Button("π Detect Cracks", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
output_mask2 = gr.Image( |
|
|
label="Crack Mask", |
|
|
height=400 |
|
|
) |
|
|
output_overlay2 = gr.Image( |
|
|
label="Overlay Visualization", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
predict_btn2.click( |
|
|
predict_with_threshold, |
|
|
inputs=[input_image2, threshold_slider], |
|
|
outputs=[output_mask2, output_overlay2] |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### How to use: |
|
|
1. **Upload** a concrete surface image |
|
|
2. **Click** "Detect Cracks" to run the segmentation |
|
|
3. **View** the results: white areas in the mask indicate detected cracks |
|
|
4. **Adjust** the threshold in Advanced mode for fine-tuning sensitivity |
|
|
|
|
|
### Model Information: |
|
|
- **Architecture**: Improved UNet with BatchNorm and Dropout |
|
|
- **Input Size**: Images are resized to 128x128 for processing |
|
|
- **Output**: Binary segmentation mask highlighting crack regions |
|
|
- **Training**: Optimized for imbalanced crack detection using advanced loss functions |
|
|
|
|
|
### Tips for better results: |
|
|
- Use high-contrast images where cracks are visible |
|
|
- Ensure good lighting conditions |
|
|
- Try adjusting the threshold if results seem too sensitive or not sensitive enough |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=True |
|
|
) |