File size: 3,876 Bytes
bec0074
 
 
 
 
 
 
 
 
 
 
 
 
 
e451a47
bec0074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc909e1
 
 
bec0074
 
 
 
 
 
 
dc909e1
 
 
bec0074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e451a47
bec0074
 
dc909e1
 
bec0074
e451a47
 
 
 
bec0074
 
 
 
 
 
 
e451a47
 
 
 
 
 
bec0074
 
 
 
 
e451a47
 
 
 
bec0074
dc909e1
bec0074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc909e1
bec0074
 
 
 
 
 
dc909e1
e451a47
 
 
 
 
 
 
bec0074
dc909e1
bec0074
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import cv2
import torch
import numpy as np
import gradio as gr
import segmentation_models_pytorch as smp
from albumentations import Normalize
from albumentations.pytorch import ToTensorV2

# ================================
# CONFIG
# ================================
MODEL_PATH = "s2ds_deeplabv3plus.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 7
INFER_SIZE = 512  # 🔥 reduce for speed (important for live feed)

CLASS_NAMES = {
    0: "Background",
    1: "Crack",
    2: "Spalling",
    3: "Corrosion",
    4: "Efflorescence",
    5: "Vegetation",
    6: "Control Point"
}

ID_TO_COLOR = {
    0: (0, 0, 0),
    1: (255, 255, 255),
    2: (255, 0, 0),
    3: (255, 255, 0),
    4: (0, 255, 255),
    5: (0, 255, 0),
    6: (0, 0, 255)
}

# ================================
# LOAD MODEL
# ================================
model = smp.DeepLabV3Plus(
    encoder_name="resnet50",
    encoder_weights=None,
    in_channels=3,
    classes=NUM_CLASSES
)

checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(
    checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint
)

model.to(DEVICE)
model.eval()

normalize = Normalize()
to_tensor = ToTensorV2()

# ================================
# HELPERS
# ================================
def pad_to_16(img):
    h, w = img.shape[:2]
    new_h = (h + 15) // 16 * 16
    new_w = (w + 15) // 16 * 16
    pad_h = new_h - h
    pad_w = new_w - w
    padded = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT)
    return padded, h, w

def colorize_mask(mask):
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for cls, color in ID_TO_COLOR.items():
        color_mask[mask == cls] = color
    return color_mask

# ================================
# FAST INFERENCE FUNCTION
# ================================
def segment_image(image):
    if image is None:
        return None, ""

    # 🔥 Downscale for speed
    original = image.copy()
    image = cv2.resize(image, (INFER_SIZE, INFER_SIZE))

    padded, orig_h, orig_w = pad_to_16(image)

    img = normalize(image=padded)["image"]
    img = to_tensor(image=img)["image"]
    img = img.unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        if DEVICE == "cuda":
            with torch.cuda.amp.autocast():
                pred = model(img)
        else:
            pred = model(img)

        pred_mask = torch.argmax(pred, dim=1)[0].cpu().numpy()

    pred_mask = pred_mask[:orig_h, :orig_w]

    color_mask = colorize_mask(pred_mask)
    overlay_small = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)

    # 🔥 Resize back to original size
    overlay = cv2.resize(overlay_small, (original.shape[1], original.shape[0]))

    # Image-level classification
    vals, counts = np.unique(pred_mask, return_counts=True)
    vals = vals[vals > 0]

    if len(vals) > 0:
        img_class = int(vals[np.argmax(counts[1:])])
        label = CLASS_NAMES[img_class]
    else:
        label = "Background"

    return overlay, f"Detected: {label}"

# ================================
# GRADIO UI
# ================================
with gr.Blocks() as demo:
    gr.Markdown("# 🏗 Structural Defect Segmentation")

    with gr.Tab("Image Upload"):
        input_img = gr.Image(type="numpy")
        output_img = gr.Image()
        output_text = gr.Textbox()
        btn = gr.Button("Run Segmentation")
        btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text])

    with gr.Tab("Live Camera (Fast Mode)"):
        cam = gr.Image(
            sources=["webcam"],
            streaming=True,
            type="numpy",
            webcam_options={"facingMode": "environment"}  # 🔥 force back camera
        )
        cam_out = gr.Image()
        cam.stream(lambda x: segment_image(x)[0], inputs=cam, outputs=cam_out)

demo.launch()