Jazz1508 commited on
Commit
bec0074
·
verified ·
1 Parent(s): a2716a6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ import segmentation_models_pytorch as smp
6
+ from albumentations import Normalize
7
+ from albumentations.pytorch import ToTensorV2
8
+
9
+ # ================================
10
+ # CONFIG
11
+ # ================================
12
+ MODEL_PATH = "s2ds_deeplabv3plus.pth"
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ NUM_CLASSES = 7
15
+
16
+ CLASS_NAMES = {
17
+ 0: "Background",
18
+ 1: "Crack",
19
+ 2: "Spalling",
20
+ 3: "Corrosion",
21
+ 4: "Efflorescence",
22
+ 5: "Vegetation",
23
+ 6: "Control Point"
24
+ }
25
+
26
+ ID_TO_COLOR = {
27
+ 0: (0, 0, 0),
28
+ 1: (255, 255, 255),
29
+ 2: (255, 0, 0),
30
+ 3: (255, 255, 0),
31
+ 4: (0, 255, 255),
32
+ 5: (0, 255, 0),
33
+ 6: (0, 0, 255)
34
+ }
35
+
36
+ # ================================
37
+ # LOAD MODEL
38
+ # ================================
39
+ model = smp.DeepLabV3Plus(
40
+ encoder_name="resnet50",
41
+ encoder_weights=None,
42
+ in_channels=3,
43
+ classes=NUM_CLASSES
44
+ )
45
+
46
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
47
+
48
+ if "model_state_dict" in checkpoint:
49
+ model.load_state_dict(checkpoint["model_state_dict"])
50
+ else:
51
+ model.load_state_dict(checkpoint)
52
+
53
+ model.to(DEVICE)
54
+ model.eval()
55
+
56
+ # ================================
57
+ # HELPERS
58
+ # ================================
59
+ normalize = Normalize()
60
+ to_tensor = ToTensorV2()
61
+
62
+ def pad_to_16(img):
63
+ h, w = img.shape[:2]
64
+ new_h = (h + 15) // 16 * 16
65
+ new_w = (w + 15) // 16 * 16
66
+ pad_h = new_h - h
67
+ pad_w = new_w - w
68
+ padded = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT)
69
+ return padded, h, w
70
+
71
+ def colorize_mask(mask):
72
+ h, w = mask.shape
73
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
74
+ for cls, color in ID_TO_COLOR.items():
75
+ color_mask[mask == cls] = color
76
+ return color_mask
77
+
78
+ # ================================
79
+ # INFERENCE FUNCTION
80
+ # ================================
81
+ def segment_image(image):
82
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
83
+
84
+ padded, orig_h, orig_w = pad_to_16(image)
85
+
86
+ img = normalize(image=padded)["image"]
87
+ img = to_tensor(image=img)["image"]
88
+ img = img.unsqueeze(0).to(DEVICE)
89
+
90
+ with torch.no_grad():
91
+ pred = model(img)
92
+ pred_mask = torch.argmax(pred, dim=1)[0].cpu().numpy()
93
+
94
+ pred_mask = pred_mask[:orig_h, :orig_w]
95
+
96
+ color_mask = colorize_mask(pred_mask)
97
+ overlay = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)
98
+
99
+ vals, counts = np.unique(pred_mask, return_counts=True)
100
+ vals = vals[vals > 0]
101
+
102
+ if len(vals) > 0:
103
+ img_class = int(vals[np.argmax(counts[1:])])
104
+ label = CLASS_NAMES[img_class]
105
+ else:
106
+ label = "Background"
107
+
108
+ return overlay, f"Detected: {label}"
109
+
110
+ # ================================
111
+ # GRADIO UI
112
+ # ================================
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown("# 🏗 Structural Defect Segmentation")
115
+
116
+ with gr.Tab("Image Upload"):
117
+ input_img = gr.Image(type="numpy")
118
+ output_img = gr.Image()
119
+ output_text = gr.Textbox()
120
+ btn = gr.Button("Run Segmentation")
121
+ btn.click(segment_image, inputs=input_img, outputs=[output_img, output_text])
122
+
123
+ with gr.Tab("Live Camera"):
124
+ cam = gr.Image(source="webcam", streaming=True)
125
+ cam_out = gr.Image()
126
+ cam.stream(segment_image, inputs=cam, outputs=[cam_out])
127
+
128
+ demo.launch()