pepitolechevalier commited on
Commit
bdddba7
·
verified ·
1 Parent(s): 6cae383

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -18
app.py CHANGED
@@ -130,15 +130,24 @@ class ImgLoader:
130
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
131
  ])
132
 
133
- def load(self, path):
134
- ori_img = cv2.imread(path)
135
- img = copy.deepcopy(ori_img[:, :, ::-1])
136
- img = Image.fromarray(img)
 
 
 
 
 
 
 
 
137
  return self.transform(img).unsqueeze(0)
138
 
139
- def cal_backward(out):
140
  target_layer_names = ['layer1', 'layer2', 'layer3', 'layer4',
141
  'FPN1_layer1', 'FPN1_layer2', 'FPN1_layer3', 'FPN1_layer4', 'comb_outs']
 
142
  sum_out = None
143
  for name in target_layer_names:
144
  tmp_out = out[name].mean(1) if name != "comb_outs" else out[name]
@@ -156,9 +165,12 @@ def cal_backward(out):
156
  V = V - min(V)
157
  V = V / sum(V)
158
 
159
- top5 = np.argsort(-V)[:5]
160
- accs = -np.sort(-V)[:5]
161
- return [f"{classes_list[int(cls)]}: {acc*100:.2f}%" for cls, acc in zip(top5, accs)]
 
 
 
162
 
163
  # === Chargement du modèle
164
  model = build_model("weights.pt")
@@ -167,20 +179,33 @@ img_loader = ImgLoader(data_size)
167
  def predict_image(image: Image.Image) -> List[str]:
168
  global features, grads, module_id_mapper
169
  features, grads, module_id_mapper = {}, {}, {}
170
- image_path = "temp_image.jpg"
171
- image.save(image_path)
172
- img_tensor = img_loader.load(image_path)
 
 
 
 
 
 
173
  out = model(img_tensor)
174
  return cal_backward(out)
175
 
176
  # === Interface Gradio
177
- demo = gr.Interface(
178
- fn=predict_image,
179
- inputs=gr.Image(type="pil"),
180
- outputs=gr.Textbox(label="Prédictions top 5"),
181
- title="Classification OGSO",
182
- description="Détection de la catégorie de quincaillerie via Swin Transformer avec architecture personnalisée"
183
- )
 
 
 
 
 
 
 
184
 
185
  if __name__ == "__main__":
186
  demo.launch()
 
130
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
131
  ])
132
 
133
+ def load(self, input_img):
134
+ if isinstance(input_img, str):
135
+ ori_img = cv2.imread(input_img)
136
+ img = Image.fromarray(cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB))
137
+ elif isinstance(input_img, Image.Image):
138
+ img = input_img
139
+ else:
140
+ raise ValueError("Image invalide")
141
+
142
+ if img.mode != "RGB":
143
+ img = img.convert("RGB")
144
+
145
  return self.transform(img).unsqueeze(0)
146
 
147
+ def cal_backward(out) -> dict:
148
  target_layer_names = ['layer1', 'layer2', 'layer3', 'layer4',
149
  'FPN1_layer1', 'FPN1_layer2', 'FPN1_layer3', 'FPN1_layer4', 'comb_outs']
150
+
151
  sum_out = None
152
  for name in target_layer_names:
153
  tmp_out = out[name].mean(1) if name != "comb_outs" else out[name]
 
165
  V = V - min(V)
166
  V = V / sum(V)
167
 
168
+ top5_indices = np.argsort(-V)[:5]
169
+ top5_scores = -np.sort(-V)[:5]
170
+
171
+ # Construction du dictionnaire pour gr.Label
172
+ top5_dict = {classes_list[int(idx)]: float(f"{score:.4f}") for idx, score in zip(top5_indices, top5_scores)}
173
+ return top5_dict
174
 
175
  # === Chargement du modèle
176
  model = build_model("weights.pt")
 
179
  def predict_image(image: Image.Image) -> List[str]:
180
  global features, grads, module_id_mapper
181
  features, grads, module_id_mapper = {}, {}, {}
182
+
183
+ if image is None:
184
+ return {}
185
+ # raise ValueError("Aucune image reçue. Vérifie l'entrée.")
186
+
187
+ if image.mode != "RGB":
188
+ image = image.convert("RGB")
189
+
190
+ img_tensor = img_loader.load(image)
191
  out = model(img_tensor)
192
  return cal_backward(out)
193
 
194
  # === Interface Gradio
195
+ with gr.Blocks() as demo:
196
+ with gr.Row():
197
+ with gr.Column():
198
+ with gr.Tab("Téléversement"):
199
+ upload_input = gr.Image(type="pil", label="Image téléchargée", sources=["upload"], show_label=True)
200
+ with gr.Tab("Webcam"):
201
+ webcam_input = gr.Image(type="pil", label="Webcam", sources=["webcam"], show_label=True)
202
+ with gr.Column():
203
+ output = gr.Label(num_top_classes=5, label="Prédiction")
204
+
205
+ # Connexion des callbacks
206
+ upload_input.change(fn=predict_image, inputs=upload_input, outputs=output)
207
+ webcam_input.change(fn=predict_image, inputs=webcam_input, outputs=output)
208
+
209
 
210
  if __name__ == "__main__":
211
  demo.launch()