eho69 commited on
Commit
2c08aa7
Β·
verified Β·
1 Parent(s): 31dcba3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models, transforms
6
+ from PIL import Image
7
+ import numpy as np
8
+ import pickle
9
+ import os
10
+
11
+ # class EnginePartDetector:
12
+ # def __init__(self):
13
+ # self.model = models.resnet50(weights='IMAGENET1K_V1')
14
+ # self.model = nn.Sequential(*list(self.model.children())[:-1])
15
+ # self.model.eval()
16
+
17
+ # self.transform = transforms.Compose([
18
+ # transforms.Resize((224, 224)),
19
+ # transforms.ToTensor(),
20
+ # transforms.Normalize(
21
+ # mean=[0.485, 0.456, 0.406],
22
+ # std=[0.229, 0.224, 0.225]
23
+ # )
24
+ # ])
25
+
26
+ # self.templates = {}
27
+ # self.load_templates()
28
+
29
+ # def extract_features(self, image):
30
+ # if isinstance(image, np.ndarray):
31
+ # image = Image.fromarray(image)
32
+
33
+ # img_tensor = self.transform(image).unsqueeze(0)
34
+
35
+ # with torch.no_grad():
36
+ # features = self.model(img_tensor)
37
+ # features = features.squeeze().numpy()
38
+
39
+ # return features
40
+
41
+ class EnginePartDetector:
42
+ def __init__(
43
+ self,
44
+ clahe_clip_limit: float = 9.9,
45
+ clahe_tile_grid: tuple = (8, 8),
46
+ ):
47
+ # ── ResNet-50 backbone (feature extractor only) ──────────────────
48
+ self.model = models.resnet50(weights='IMAGENET1K_V1')
49
+ self.model = nn.Sequential(*list(self.model.children())[:-1])
50
+ self.model.eval()
51
+
52
+ # ── CLAHE (OpenCV) β€” applied BEFORE the torch transform ──────────
53
+ # Operates on grayscale to recover shadow-suppressed edges
54
+ # (e.g. missing bearing saddle arcs), then merged back to RGB
55
+ # so the 3-channel ResNet pipeline is unaffected.
56
+ self.clahe = cv2.createCLAHE(
57
+ clipLimit=clahe_clip_limit,
58
+ tileGridSize=clahe_tile_grid,
59
+ )
60
+
61
+ # ── ResNet normalisation transform (unchanged) ───────────────────
62
+ self.transform = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(
66
+ mean=[0.485, 0.456, 0.406],
67
+ std=[0.229, 0.224, 0.225],
68
+ )
69
+ ])
70
+
71
+ self.templates = {}
72
+ self.load_templates()
73
+
74
+ # ── CLAHE preprocessing ───────────────────────────────────────────────
75
+
76
+ def apply_clahe(self, image: np.ndarray) -> np.ndarray:
77
+
78
+ # Convert RGB (PIL/numpy) β†’ BGR for OpenCV
79
+ bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
80
+
81
+ # BGR β†’ LAB
82
+ lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
83
+
84
+ # Split channels; apply CLAHE only to L (luminance)
85
+ l_channel, a_channel, b_channel = cv2.split(lab)
86
+ l_enhanced = self.clahe.apply(l_channel)
87
+
88
+ # Merge enhanced L back with untouched A and B
89
+ lab_enhanced = cv2.merge([l_enhanced, a_channel, b_channel])
90
+
91
+ # LAB β†’ BGR β†’ RGB
92
+ bgr_enhanced = cv2.cvtColor(lab_enhanced, cv2.COLOR_LAB2BGR)
93
+ rgb_enhanced = cv2.cvtColor(bgr_enhanced, cv2.COLOR_BGR2RGB)
94
+
95
+ return rgb_enhanced # uint8 numpy array, same shape as input
96
+
97
+ # ── Feature extraction ────────────────────────────────────────────────
98
+
99
+ def extract_features(self, image) -> np.ndarray:
100
+
101
+ # 1. Normalise input to numpy uint8 RGB
102
+ if isinstance(image, Image.Image):
103
+ image = np.array(image.convert("RGB"))
104
+ elif isinstance(image, np.ndarray) and image.dtype != np.uint8:
105
+ image = image.astype(np.uint8)
106
+
107
+ # 2. CLAHE β€” recover shadow-suppressed structural edges
108
+ image = self.apply_clahe(image)
109
+
110
+ # 3. Mild Gaussian blur β€” reduces high-freq metallic sheen noise
111
+ # that CLAHE can amplify; kernel (3,3) is intentionally light
112
+ # so real surface-defect texture is preserved
113
+ image = cv2.GaussianBlur(image, (3, 3), 0)
114
+
115
+ # 4. Convert back to PIL for torchvision transforms
116
+ image_pil = Image.fromarray(image)
117
+
118
+ # 5. ResNet transform β†’ tensor
119
+ img_tensor = self.transform(image_pil).unsqueeze(0)
120
+
121
+ # 6. Forward pass (no grad needed β€” inference only)
122
+ with torch.no_grad():
123
+ features = self.model(img_tensor)
124
+ features = features.squeeze().numpy()
125
+
126
+ return features
127
+
128
+ def cosine_similarity(self, feat1, feat2):
129
+ return np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2))
130
+
131
+ def save_template(self, image, part_name):
132
+ if image is None or not part_name:
133
+ return "Please provide both image and part name"
134
+
135
+ features = self.extract_features(image)
136
+ self.templates[part_name] = features
137
+
138
+ with open('templates.pkl', 'wb') as f:
139
+ pickle.dump(self.templates, f)
140
+
141
+ return f"βœ… Template '{part_name}' saved successfully!"
142
+
143
+ def load_templates(self):
144
+ if os.path.exists('templates.pkl'):
145
+ try:
146
+ with open('templates.pkl', 'rb') as f:
147
+ self.templates = pickle.load(f)
148
+ except:
149
+ self.templates = {}
150
+
151
+ def match_part(self, image, threshold=0.7):
152
+ if image is None:
153
+ return "Please provide an image", None
154
+
155
+ if not self.templates:
156
+ return "⚠️ No templates available. Please add templates first.", None
157
+
158
+ query_features = self.extract_features(image)
159
+
160
+ results = []
161
+ for part_name, template_features in self.templates.items():
162
+ similarity = self.cosine_similarity(query_features, template_features)
163
+ results.append((part_name, similarity))
164
+
165
+ results.sort(key=lambda x: x[1], reverse=True)
166
+
167
+ best_match = results[0]
168
+ output_text = f"πŸ” **Best Match**: {best_match[0]}\n"
169
+ output_text += f"πŸ“Š **Confidence**: {best_match[1]:.2%}\n\n"
170
+
171
+ if best_match[1] >= threshold:
172
+ output_text += "βœ… **Status**: MATCHED\n\n"
173
+ else:
174
+ output_text += "❌ **Status**: NO MATCH (below threshold)\n\n"
175
+
176
+ output_text += "**All Results:**\n"
177
+ for part, sim in results:
178
+ output_text += f"- {part}: {sim:.2%}\n"
179
+
180
+ matched_label = best_match[0] if best_match[1] >= threshold else None
181
+ return output_text, matched_label
182
+
183
+ detector = EnginePartDetector()
184
+
185
+ def add_template(image, part_name):
186
+ return detector.save_template(image, part_name)
187
+
188
+ def detect_part(image, threshold):
189
+ return detector.match_part(image, threshold)
190
+
191
+ def list_templates():
192
+ if not detector.templates:
193
+ return "No templates saved yet"
194
+ return "\n".join([f"- {name}" for name in detector.templates.keys()])
195
+
196
+ with gr.Blocks(title="Engine Part Detection System") as demo:
197
+ gr.Markdown("""
198
+ # πŸ”§ Engine Part Detection System
199
+ ### Using ResNet50 Feature Extraction & Template Matching
200
+
201
+ **How to use:**
202
+ 1. **Add Templates**: Upload reference images of engine parts
203
+ 2. **Detect Parts**: Upload/capture images to identify parts
204
+ """)
205
+
206
+ with gr.Tab("πŸ” Detect Part"):
207
+ with gr.Row():
208
+ with gr.Column():
209
+ detect_input = gr.Image(sources=["upload", "webcam"], type="numpy")
210
+ threshold_slider = gr.Slider(0.5, 0.95, value=0.7, label="Similarity Threshold")
211
+ detect_btn = gr.Button("Detect Part", variant="primary")
212
+ with gr.Column():
213
+ detect_output = gr.Textbox(label="Detection Results", lines=10)
214
+ match_label = gr.Label(label="Matched Part")
215
+
216
+ detect_btn.click(
217
+ fn=detect_part,
218
+ inputs=[detect_input, threshold_slider],
219
+ outputs=[detect_output, match_label],
220
+ api_name="detect"
221
+ )
222
+
223
+ with gr.Tab("βž• Add Template"):
224
+ with gr.Row():
225
+ with gr.Column():
226
+ template_input = gr.Image(sources=["upload"], type="numpy")
227
+ part_name_input = gr.Textbox(label="Part Name (e.g., 'spark_plug', 'piston')")
228
+ add_btn = gr.Button("Save Template", variant="primary")
229
+ with gr.Column():
230
+ add_output = gr.Textbox(label="Status")
231
+
232
+ add_btn.click(
233
+ fn=add_template,
234
+ inputs=[template_input, part_name_input],
235
+ outputs=add_output,
236
+ api_name="add_template"
237
+ )
238
+
239
+ with gr.Tab("πŸ“‹ View Templates"):
240
+ template_list = gr.Textbox(label="Saved Templates", lines=10)
241
+ refresh_btn = gr.Button("Refresh List")
242
+ refresh_btn.click(
243
+ fn=list_templates,
244
+ outputs=template_list,
245
+ api_name="list_templates"
246
+ )
247
+ demo.load(fn=list_templates, outputs=template_list)
248
+
249
+ if __name__ == "__main__":
250
+ demo.launch()
251
+
252
+
253
+ app.py