Spaces:
Running
Running
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import segmentation_models_pytorch as smp | |
| from albumentations import Normalize | |
| from albumentations.pytorch import ToTensorV2 | |
| # ================================ | |
| # CONFIG | |
| # ================================ | |
| MODEL_PATH = "s2ds_deeplabv3plus.pth" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| NUM_CLASSES = 7 | |
| INFER_SIZE = 512 # π₯ reduce for speed (important for live feed) | |
| CLASS_NAMES = { | |
| 0: "Background", | |
| 1: "Crack", | |
| 2: "Spalling", | |
| 3: "Corrosion", | |
| 4: "Efflorescence", | |
| 5: "Vegetation", | |
| 6: "Control Point" | |
| } | |
| ID_TO_COLOR = { | |
| 0: (0, 0, 0), | |
| 1: (255, 255, 255), | |
| 2: (255, 0, 0), | |
| 3: (255, 255, 0), | |
| 4: (0, 255, 255), | |
| 5: (0, 255, 0), | |
| 6: (0, 0, 255) | |
| } | |
| # ================================ | |
| # LOAD MODEL | |
| # ================================ | |
| model = smp.DeepLabV3Plus( | |
| encoder_name="resnet50", | |
| encoder_weights=None, | |
| in_channels=3, | |
| classes=NUM_CLASSES | |
| ) | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict( | |
| checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint | |
| ) | |
| model.to(DEVICE) | |
| model.eval() | |
| normalize = Normalize() | |
| to_tensor = ToTensorV2() | |
| # ================================ | |
| # HELPERS | |
| # ================================ | |
| def pad_to_16(img): | |
| h, w = img.shape[:2] | |
| new_h = (h + 15) // 16 * 16 | |
| new_w = (w + 15) // 16 * 16 | |
| pad_h = new_h - h | |
| pad_w = new_w - w | |
| padded = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT) | |
| return padded, h, w | |
| def colorize_mask(mask): | |
| h, w = mask.shape | |
| color_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
| for cls, color in ID_TO_COLOR.items(): | |
| color_mask[mask == cls] = color | |
| return color_mask | |
| # ================================ | |
| # FAST INFERENCE FUNCTION | |
| # ================================ | |
| def segment_image(image): | |
| if image is None: | |
| return None, "" | |
| # π₯ Downscale for speed | |
| original = image.copy() | |
| image = cv2.resize(image, (INFER_SIZE, INFER_SIZE)) | |
| padded, orig_h, orig_w = pad_to_16(image) | |
| img = normalize(image=padded)["image"] | |
| img = to_tensor(image=img)["image"] | |
| img = img.unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| if DEVICE == "cuda": | |
| with torch.cuda.amp.autocast(): | |
| pred = model(img) | |
| else: | |
| pred = model(img) | |
| pred_mask = torch.argmax(pred, dim=1)[0].cpu().numpy() | |
| pred_mask = pred_mask[:orig_h, :orig_w] | |
| color_mask = colorize_mask(pred_mask) | |
| overlay_small = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0) | |
| # π₯ Resize back to original size | |
| overlay = cv2.resize(overlay_small, (original.shape[1], original.shape[0])) | |
| # Image-level classification | |
| vals, counts = np.unique(pred_mask, return_counts=True) | |
| vals = vals[vals > 0] | |
| if len(vals) > 0: | |
| img_class = int(vals[np.argmax(counts[1:])]) | |
| label = CLASS_NAMES[img_class] | |
| else: | |
| label = "Background" | |
| return overlay, f"Detected: {label}" | |
| # ================================ | |
| # GRADIO UI | |
| # ================================ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π Structural Defect Segmentation") | |
| with gr.Tab("Image Upload"): | |
| input_img = gr.Image(type="numpy") | |
| output_img = gr.Image() | |
| output_text = gr.Textbox() | |
| btn = gr.Button("Run Segmentation") | |
| btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text]) | |
| with gr.Tab("Live Camera (Fast Mode)"): | |
| cam = gr.Image( | |
| sources=["webcam"], | |
| streaming=True, | |
| type="numpy", | |
| webcam_options={"facingMode": "environment"} # π₯ force back camera | |
| ) | |
| cam_out = gr.Image() | |
| cam.stream(lambda x: segment_image(x)[0], inputs=cam, outputs=cam_out) | |
| demo.launch() | |