File size: 8,460 Bytes
2eacc62 12b9e29 2eacc62 12b9e29 2eacc62 12b9e29 2eacc62 12b9e29 2eacc62 12b9e29 2eacc62 12b9e29 2eacc62 12b9e29 2eacc62 12b9e29 2eacc62 12b9e29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from unet import ImprovedUNet
from huggingface_hub import hf_hub_download
import cv2
# Load trained model weights from Hugging Face Hub
try:
weights_path = hf_hub_download(
repo_id="faranbutt789/my-model", # Updated to match your repo
filename="unet_weights.pth" # Updated filename as uploaded
)
except Exception as e:
print(f"Error downloading weights: {e}")
# Fallback to local file if available
weights_path = "unet_weights_v2.pth"
# Initialize and load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedUNet()
try:
# Load the state dict
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
print("Model weights loaded successfully!")
except Exception as e:
print(f"Error loading model weights: {e}")
print("Using randomly initialized model (for testing)")
model.to(device)
model.eval()
# Preprocessing: same as training
IMG_HEIGHT, IMG_WIDTH = 128, 128
transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
transforms.ToTensor(),
# Add normalization if you used it during training
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image):
if image is None:
return None
try:
# Store original size
orig_w, orig_h = image.size
# Convert to RGB if not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Apply preprocessing
img_tensor = transform(image).unsqueeze(0).to(device) # (1,3,128,128)
# Inference
with torch.no_grad():
pred = model(img_tensor)
# Post-process the prediction
mask = pred.squeeze(0).squeeze(0).cpu().numpy() # Remove batch and channel dims
# Convert to 0-255 range
mask = (mask * 255).astype(np.uint8)
# Resize back to original size
mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
# Convert to PIL Image
mask_img = Image.fromarray(mask_resized, mode='L')
# Create a colored overlay for better visualization
# Convert grayscale mask to RGB
mask_rgb = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2RGB)
# Create colored mask (red for cracks)
colored_mask = np.zeros_like(mask_rgb)
colored_mask[:, :, 0] = mask_resized # Red channel for cracks
# Convert original image to numpy for overlay
orig_np = np.array(image.resize((orig_w, orig_h)))
# Create overlay
alpha = 0.4 # Transparency
overlay = cv2.addWeighted(orig_np, 1-alpha, colored_mask, alpha, 0)
overlay_img = Image.fromarray(overlay)
return mask_img, overlay_img
except Exception as e:
print(f"Error in prediction: {e}")
# Return a blank image in case of error
blank = Image.new('L', (orig_w, orig_h), 0)
return blank, blank
def predict_with_threshold(image, threshold):
if image is None:
return None, None
try:
orig_w, orig_h = image.size
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
pred = model(img_tensor)
mask = pred.squeeze(0).squeeze(0).cpu().numpy()
# Apply threshold
mask_binary = (mask > threshold).astype(np.uint8) * 255
# Resize back to original size
mask_resized = cv2.resize(mask_binary, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
mask_img = Image.fromarray(mask_resized, mode='L')
# Create colored overlay
mask_rgb = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2RGB)
colored_mask = np.zeros_like(mask_rgb)
colored_mask[:, :, 0] = mask_resized
orig_np = np.array(image.resize((orig_w, orig_h)))
alpha = 0.4
overlay = cv2.addWeighted(orig_np, 1-alpha, colored_mask, alpha, 0)
overlay_img = Image.fromarray(overlay)
return mask_img, overlay_img
except Exception as e:
print(f"Error in prediction with threshold: {e}")
blank = Image.new('L', (orig_w, orig_h), 0)
return blank, blank
# Create Gradio interface with multiple tabs
with gr.Blocks(title="UNet Crack Segmentation", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ๐ Concrete Crack Segmentation with UNet
Upload an image of a concrete surface to detect and segment cracks using a trained UNet model.
**Features:**
- Advanced UNet architecture with batch normalization and dropout
- Optimized for highly imbalanced crack detection
- Interactive threshold adjustment
- Colored overlay visualization
"""
)
with gr.Tabs():
with gr.TabItem("Basic Prediction"):
with gr.Row():
with gr.Column():
input_image1 = gr.Image(
type="pil",
label="Upload Concrete Image",
height=400
)
predict_btn1 = gr.Button("๐ Detect Cracks", variant="primary", size="lg")
with gr.Column():
output_mask1 = gr.Image(
label="Crack Mask",
height=400
)
output_overlay1 = gr.Image(
label="Overlay Visualization",
height=400
)
predict_btn1.click(
predict,
inputs=[input_image1],
outputs=[output_mask1, output_overlay1]
)
with gr.TabItem("Advanced Prediction"):
with gr.Row():
with gr.Column():
input_image2 = gr.Image(
type="pil",
label="Upload Concrete Image",
height=400
)
threshold_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.1,
label="Detection Threshold"
)
predict_btn2 = gr.Button("๐ Detect Cracks", variant="primary", size="lg")
with gr.Column():
output_mask2 = gr.Image(
label="Crack Mask",
height=400
)
output_overlay2 = gr.Image(
label="Overlay Visualization",
height=400
)
predict_btn2.click(
predict_with_threshold,
inputs=[input_image2, threshold_slider],
outputs=[output_mask2, output_overlay2]
)
gr.Markdown(
"""
### How to use:
1. **Upload** a concrete surface image
2. **Click** "Detect Cracks" to run the segmentation
3. **View** the results: white areas in the mask indicate detected cracks
4. **Adjust** the threshold in Advanced mode for fine-tuning sensitivity
### Model Information:
- **Architecture**: Improved UNet with BatchNorm and Dropout
- **Input Size**: Images are resized to 128x128 for processing
- **Output**: Binary segmentation mask highlighting crack regions
- **Training**: Optimized for imbalanced crack detection using advanced loss functions
### Tips for better results:
- Use high-contrast images where cracks are visible
- Ensure good lighting conditions
- Try adjusting the threshold if results seem too sensitive or not sensitive enough
"""
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
) |