S-4-G-4-R commited on
Commit
b5122c5
Β·
verified Β·
1 Parent(s): 437b12c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +215 -0
  2. requirements .txt +7 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py β€” Gradio demo for Prompted Segmentation for Drywall QA
3
+ Model : CLIPSeg (CIDAS/clipseg-rd64-refined), fine-tuned on drywall datasets
4
+ Weights: best_model.pt (upload this file to your HuggingFace Space)
5
+ """
6
+
7
+ import os
8
+ import time
9
+ import numpy as np
10
+ import torch
11
+ import gradio as gr
12
+ from PIL import Image
13
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
14
+
15
+ # ── Config ────────────────────────────────────────────────────────────────────
16
+ MODEL_NAME = "CIDAS/clipseg-rd64-refined"
17
+ CKPT_PATH = "best_model.pt" # must be in the Space root directory
18
+ IMG_SIZE = 352
19
+ THRESHOLD = 0.5
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # Supported prompts (trained)
23
+ PROMPT_CHOICES = [
24
+ "segment crack",
25
+ "segment taping area",
26
+ ]
27
+
28
+ # ── Load model (once at startup) ──────────────────────────────────────────────
29
+ print(f"Loading CLIPSeg processor from {MODEL_NAME} ...")
30
+ processor = CLIPSegProcessor.from_pretrained(MODEL_NAME)
31
+
32
+ print(f"Loading CLIPSeg model from {MODEL_NAME} ...")
33
+ model = CLIPSegForImageSegmentation.from_pretrained(MODEL_NAME)
34
+
35
+ if os.path.exists(CKPT_PATH):
36
+ print(f"Loading fine-tuned weights from {CKPT_PATH} ...")
37
+ state_dict = torch.load(CKPT_PATH, map_location=DEVICE)
38
+ model.load_state_dict(state_dict)
39
+ print("Fine-tuned weights loaded successfully.")
40
+ else:
41
+ print(f"WARNING: {CKPT_PATH} not found β€” running with base CLIPSeg weights.")
42
+
43
+ model = model.to(DEVICE)
44
+ model.eval()
45
+ print(f"Model ready on {DEVICE}.")
46
+
47
+
48
+ # ── Inference function ────────────────────────────────────────────────────────
49
+ def predict(image: Image.Image, prompt: str, threshold: float) -> tuple:
50
+ """
51
+ Runs CLIPSeg inference and returns:
52
+ - overlay : original image blended with coloured mask
53
+ - mask_img : pure binary mask (grayscale)
54
+ - info_str : prompt used + inference time
55
+ """
56
+ if image is None:
57
+ return None, None, "Please upload an image."
58
+
59
+ original_size = image.size # (W, H) β€” to resize mask back
60
+ image_rgb = image.convert("RGB")
61
+
62
+ # Preprocess
63
+ encoding = processor(
64
+ text = [prompt],
65
+ images = [image_rgb],
66
+ return_tensors = "pt",
67
+ padding = "max_length",
68
+ truncation = True,
69
+ )
70
+ pixel_values = encoding["pixel_values"].to(DEVICE)
71
+ input_ids = encoding["input_ids"].to(DEVICE)
72
+ attention_mask = encoding["attention_mask"].to(DEVICE)
73
+
74
+ # Inference
75
+ t0 = time.time()
76
+ with torch.no_grad():
77
+ outputs = model(
78
+ pixel_values = pixel_values,
79
+ input_ids = input_ids,
80
+ attention_mask = attention_mask,
81
+ )
82
+ inf_ms = (time.time() - t0) * 1000
83
+
84
+ # Post-process logits β†’ binary mask
85
+ prob = torch.sigmoid(outputs.logits[0]).cpu().numpy() # (H, W) at 352Γ—352
86
+ pred_bin = (prob > threshold).astype(np.uint8) # 0 or 1
87
+
88
+ # Resize mask back to original image size
89
+ mask_pil = Image.fromarray((pred_bin * 255).astype(np.uint8), mode="L")
90
+ mask_pil = mask_pil.resize(original_size, Image.NEAREST)
91
+ mask_arr = np.array(mask_pil) # 0 or 255
92
+
93
+ # ── Build overlay (original + coloured mask) ──────────────────────────────
94
+ img_arr = np.array(image_rgb).astype(np.float32) # (H, W, 3)
95
+ overlay = img_arr.copy()
96
+
97
+ # Colour: teal for crack, orange for taping area
98
+ if "crack" in prompt.lower():
99
+ colour = np.array([0, 200, 220], dtype=np.float32) # teal
100
+ else:
101
+ colour = np.array([255, 160, 50], dtype=np.float32) # orange
102
+
103
+ fg = mask_arr > 0
104
+ overlay[fg] = overlay[fg] * 0.45 + colour * 0.55
105
+ overlay = np.clip(overlay, 0, 255).astype(np.uint8)
106
+
107
+ # Coverage stat
108
+ coverage = fg.sum() / fg.size * 100
109
+
110
+ info = (
111
+ f"Prompt : \"{prompt}\"\n"
112
+ f"Threshold : {threshold:.2f}\n"
113
+ f"Inference : {inf_ms:.1f} ms\n"
114
+ f"Coverage : {coverage:.2f} % of image\n"
115
+ f"Device : {DEVICE}"
116
+ )
117
+
118
+ return Image.fromarray(overlay), mask_pil, info
119
+
120
+
121
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
122
+ TITLE = "🧱 Drywall QA β€” Prompted Segmentation"
123
+
124
+ DESCRIPTION = """
125
+ Fine-tuned **CLIPSeg** for text-conditioned binary segmentation of drywall defects.
126
+
127
+ Upload a drywall image, pick a prompt, and the model highlights the defective region.
128
+
129
+ | Prompt | Target | Val mIoU | Val Dice |
130
+ |---|---|---|---|
131
+ | `segment crack` | Wall cracks | **0.735** | **0.834** |
132
+ | `segment taping area` | Joint / tape seam | **0.499** | **0.626** |
133
+
134
+ *Model: CIDAS/clipseg-rd64-refined fine-tuned for 20 epochs Β· Seed 42*
135
+ """
136
+
137
+ ARTICLE = """
138
+ ### How it works
139
+ CLIPSeg extends CLIP with a lightweight decoder that turns any text prompt into a segmentation mask.
140
+ The model was fine-tuned end-to-end on two Roboflow drywall datasets using a combined BCE + Dice loss.
141
+
142
+ **Datasets:** [Drywall-Join-Detect](https://universe.roboflow.com/objectdetect-pu6rn/drywall-join-detect) Β· [Cracks](https://universe.roboflow.com/fyp-ny1jt/cracks-3ii36)
143
+ """
144
+
145
+ with gr.Blocks(title=TITLE, theme=gr.themes.Soft()) as demo:
146
+
147
+ gr.Markdown(f"# {TITLE}")
148
+ gr.Markdown(DESCRIPTION)
149
+
150
+ with gr.Row():
151
+
152
+ # ── Left column: inputs ───────────────────────────────────────────────
153
+ with gr.Column(scale=1):
154
+ image_input = gr.Image(
155
+ type = "pil",
156
+ label = "Upload Drywall Image",
157
+ height = 320,
158
+ )
159
+ prompt_input = gr.Radio(
160
+ choices = PROMPT_CHOICES,
161
+ value = PROMPT_CHOICES[0],
162
+ label = "Segmentation Prompt",
163
+ )
164
+ threshold_slider = gr.Slider(
165
+ minimum = 0.1,
166
+ maximum = 0.9,
167
+ value = THRESHOLD,
168
+ step = 0.05,
169
+ label = "Threshold (lower β†’ more detections, higher β†’ stricter)",
170
+ )
171
+ run_btn = gr.Button("πŸ” Run Segmentation", variant="primary")
172
+
173
+ # ── Right column: outputs ─────────────────────────────────────────────
174
+ with gr.Column(scale=1):
175
+ overlay_out = gr.Image(
176
+ type = "pil",
177
+ label = "Overlay (original + mask)",
178
+ height= 320,
179
+ )
180
+ mask_out = gr.Image(
181
+ type = "pil",
182
+ label = "Binary Mask (white = detected region)",
183
+ height= 160,
184
+ )
185
+ info_out = gr.Textbox(
186
+ label = "Run Info",
187
+ lines = 5,
188
+ )
189
+
190
+ run_btn.click(
191
+ fn = predict,
192
+ inputs = [image_input, prompt_input, threshold_slider],
193
+ outputs = [overlay_out, mask_out, info_out],
194
+ )
195
+
196
+ # Also run on image upload (convenience)
197
+ image_input.change(
198
+ fn = predict,
199
+ inputs = [image_input, prompt_input, threshold_slider],
200
+ outputs = [overlay_out, mask_out, info_out],
201
+ )
202
+
203
+ gr.Markdown(ARTICLE)
204
+
205
+ gr.Examples(
206
+ examples = [], # add example image paths here if you have them
207
+ inputs = [image_input, prompt_input, threshold_slider],
208
+ outputs = [overlay_out, mask_out, info_out],
209
+ fn = predict,
210
+ cache_examples = False,
211
+ )
212
+
213
+
214
+ if __name__ == "__main__":
215
+ demo.launch()
requirements .txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch==2.3.1
3
+ torchvision==0.18.1
4
+ transformers==4.44.2
5
+ Pillow==10.4.0
6
+ numpy==1.26.4
7
+ matplotlib==3.9.2