360TechEnv commited on
Commit
c1803de
·
verified ·
1 Parent(s): 6ef8e1f

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +370 -0
app.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Interface Streamlit pour la classification de déchets - Version Hugging Face Spaces
4
+ Déployé sur Hugging Face Spaces avec téléchargement automatique des modèles
5
+ """
6
+
7
+ import streamlit as st
8
+ import numpy as np
9
+ import pandas as pd
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ from PIL import Image
13
+ import tensorflow as tf
14
+ from tensorflow.keras.models import load_model
15
+ from tensorflow.keras.preprocessing import image
16
+ import os
17
+ from pathlib import Path
18
+ import logging
19
+ import requests
20
+ import zipfile
21
+ import tempfile
22
+
23
+ # Configuration de la page
24
+ st.set_page_config(
25
+ page_title="Classificateur de Déchets IA",
26
+ page_icon="♻️",
27
+ layout="wide",
28
+ initial_sidebar_state="expanded"
29
+ )
30
+
31
+ # Configuration du logging
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ class WasteClassifierUI:
36
+ """Classe principale pour l'interface de classification de déchets."""
37
+
38
+ def __init__(self):
39
+ self.model_v1 = None
40
+ self.model_v2 = None
41
+ self.class_names = ["Papier", "Plastique"]
42
+ self.target_size = (96, 96)
43
+
44
+ # Chemins des modèles pour Hugging Face Spaces
45
+ self.models_dir = Path("models")
46
+ self.models_dir.mkdir(exist_ok=True)
47
+
48
+ self.model_v1_path = self.models_dir / "waste_classifier_v1.h5"
49
+ self.model_v2_path = self.models_dir / "waste_classifier_v2.h5"
50
+
51
+ # URLs des modèles (à remplacer par vos URLs Hugging Face)
52
+ # Pour Docker, vous pouvez aussi utiliser des modèles locaux
53
+ self.model_v1_url = os.getenv('MODEL_V1_URL', "https://huggingface.co/your-username/waste-classifier/resolve/main/models/waste_classifier_v1.h5")
54
+ self.model_v2_url = os.getenv('MODEL_V2_URL', "https://huggingface.co/your-username/waste-classifier/resolve/main/models/waste_classifier_v2.h5")
55
+
56
+ # Vérifier si des modèles locaux existent (pour Docker)
57
+ local_v1 = Path("models/waste_classifier_v1.h5")
58
+ local_v2 = Path("models/waste_classifier_v2.h5")
59
+
60
+ if local_v1.exists():
61
+ self.model_v1_path = local_v1
62
+ if local_v2.exists():
63
+ self.model_v2_path = local_v2
64
+
65
+ def download_model(self, url, local_path):
66
+ """Télécharge un modèle depuis une URL."""
67
+ try:
68
+ if local_path.exists():
69
+ logger.info(f"Modèle déjà présent: {local_path}")
70
+ return True
71
+
72
+ logger.info(f"Téléchargement du modèle depuis: {url}")
73
+ response = requests.get(url, stream=True)
74
+ response.raise_for_status()
75
+
76
+ with open(local_path, 'wb') as f:
77
+ for chunk in response.iter_content(chunk_size=8192):
78
+ f.write(chunk)
79
+
80
+ logger.info(f"Modèle téléchargé avec succès: {local_path}")
81
+ return True
82
+
83
+ except Exception as e:
84
+ logger.error(f"Erreur lors du téléchargement: {e}")
85
+ return False
86
+
87
+ def load_models(self):
88
+ """Charge les modèles v1 et v2."""
89
+ try:
90
+ # Télécharger le modèle v1 si nécessaire
91
+ if not self.model_v1_path.exists():
92
+ st.info("Téléchargement du modèle v1...")
93
+ if not self.download_model(self.model_v1_url, self.model_v1_path):
94
+ st.warning("Impossible de télécharger le modèle v1")
95
+ else:
96
+ st.success("Modèle v1 téléchargé avec succès!")
97
+
98
+ # Charger le modèle v1
99
+ if self.model_v1_path.exists():
100
+ self.model_v1 = load_model(self.model_v1_path)
101
+ logger.info("Modèle v1 chargé avec succès")
102
+ else:
103
+ logger.warning("Modèle v1 non disponible")
104
+
105
+ # Télécharger le modèle v2 si nécessaire
106
+ if not self.model_v2_path.exists():
107
+ st.info("Téléchargement du modèle v2...")
108
+ if not self.download_model(self.model_v2_url, self.model_v2_path):
109
+ st.warning("Impossible de télécharger le modèle v2")
110
+ else:
111
+ st.success("Modèle v2 téléchargé avec succès!")
112
+
113
+ # Charger le modèle v2
114
+ if self.model_v2_path.exists():
115
+ self.model_v2 = load_model(self.model_v2_path)
116
+ logger.info("Modèle v2 chargé avec succès")
117
+ else:
118
+ logger.warning("Modèle v2 non disponible")
119
+
120
+ except Exception as e:
121
+ logger.error(f"Erreur lors du chargement des modèles: {e}")
122
+ st.error(f"Erreur lors du chargement des modèles: {e}")
123
+
124
+ def preprocess_image(self, img, target_size=(96, 96)):
125
+ """Préprocesse une image pour la prédiction."""
126
+ try:
127
+ # Redimensionner l'image
128
+ img_resized = img.resize(target_size)
129
+
130
+ # Convertir en array numpy
131
+ img_array = image.img_to_array(img_resized)
132
+
133
+ # Normaliser les pixels (0-255 -> 0-1)
134
+ img_array = img_array / 255.0
135
+
136
+ # Ajouter une dimension de batch
137
+ img_array = np.expand_dims(img_array, axis=0)
138
+
139
+ return img_array
140
+
141
+ except Exception as e:
142
+ logger.error(f"Erreur lors du preprocessing: {e}")
143
+ st.error(f"Erreur lors du preprocessing: {e}")
144
+ return None
145
+
146
+ def predict_image(self, img_array, model, model_name):
147
+ """Prédit la classe d'une image avec un modèle donné."""
148
+ try:
149
+ if model is None:
150
+ return None
151
+
152
+ # Faire la prédiction
153
+ predictions = model.predict(img_array, verbose=0)
154
+
155
+ # Obtenir la classe prédite et la confiance
156
+ predicted_class_idx = np.argmax(predictions[0])
157
+ confidence = predictions[0][predicted_class_idx]
158
+ predicted_class = self.class_names[predicted_class_idx]
159
+
160
+ # Obtenir les probabilités pour toutes les classes
161
+ class_probabilities = {}
162
+ for i, class_name in enumerate(self.class_names):
163
+ class_probabilities[class_name] = float(predictions[0][i])
164
+
165
+ result = {
166
+ 'model_name': model_name,
167
+ 'predicted_class': predicted_class,
168
+ 'confidence': float(confidence),
169
+ 'class_probabilities': class_probabilities
170
+ }
171
+
172
+ return result
173
+
174
+ except Exception as e:
175
+ logger.error(f"Erreur lors de la prédiction avec {model_name}: {e}")
176
+ st.error(f"Erreur lors de la prédiction avec {model_name}: {e}")
177
+ return None
178
+
179
+ def create_confidence_chart(self, results):
180
+ """Crée un graphique en barres des probabilités de confiance."""
181
+ if not results:
182
+ return None
183
+
184
+ fig, axes = plt.subplots(1, len(results), figsize=(6 * len(results), 5))
185
+ if len(results) == 1:
186
+ axes = [axes]
187
+
188
+ for i, result in enumerate(results):
189
+ if result is None:
190
+ continue
191
+
192
+ classes = list(result['class_probabilities'].keys())
193
+ probabilities = list(result['class_probabilities'].values())
194
+
195
+ # Créer le graphique en barres
196
+ bars = axes[i].bar(classes, probabilities,
197
+ color=['#2E8B57' if c == result['predicted_class'] else '#4682B4'
198
+ for c in classes])
199
+
200
+ axes[i].set_title(f"{result['model_name']}\nPrédiction: {result['predicted_class']}\nConfiance: {result['confidence']:.3f}")
201
+ axes[i].set_ylabel("Probabilité")
202
+ axes[i].set_ylim(0, 1)
203
+
204
+ # Ajouter les valeurs sur les barres
205
+ for bar, prob in zip(bars, probabilities):
206
+ height = bar.get_height()
207
+ axes[i].text(bar.get_x() + bar.get_width()/2., height + 0.01,
208
+ f'{prob:.3f}', ha='center', va='bottom', fontweight='bold')
209
+
210
+ plt.tight_layout()
211
+ return fig
212
+
213
+ def run(self):
214
+ """Lance l'interface Streamlit."""
215
+ # Titre principal
216
+ st.title("♻️ Classificateur de Déchets IA")
217
+ st.markdown("---")
218
+
219
+ # Charger les modèles
220
+ if self.model_v1 is None or self.model_v2 is None:
221
+ with st.spinner("Chargement des modèles..."):
222
+ self.load_models()
223
+
224
+ # Sidebar pour la configuration
225
+ st.sidebar.header("Configuration")
226
+
227
+ # Sélection du modèle
228
+ available_models = []
229
+ if self.model_v1 is not None:
230
+ available_models.append("Modèle v1")
231
+ if self.model_v2 is not None:
232
+ available_models.append("Modèle v2")
233
+
234
+ if not available_models:
235
+ st.error("Aucun modèle disponible. Vérifiez la connexion internet et réessayez.")
236
+ return
237
+
238
+ selected_models = st.sidebar.multiselect(
239
+ "Sélectionnez les modèles à utiliser:",
240
+ available_models,
241
+ default=available_models
242
+ )
243
+
244
+ # Upload d'image
245
+ st.sidebar.header("Upload d'image")
246
+ uploaded_file = st.sidebar.file_uploader(
247
+ "Choisissez une image de déchet:",
248
+ type=['jpg', 'jpeg', 'png', 'bmp', 'tiff'],
249
+ help="Formats supportés: JPG, JPEG, PNG, BMP, TIFF"
250
+ )
251
+
252
+ # Zone principale
253
+ col1, col2 = st.columns([1, 1])
254
+
255
+ with col1:
256
+ st.header("Image d'entrée")
257
+ if uploaded_file is not None:
258
+ # Afficher l'image uploadée
259
+ image_pil = Image.open(uploaded_file)
260
+ st.image(image_pil, caption="Image uploadée", use_column_width=True)
261
+
262
+ # Informations sur l'image
263
+ st.info(f"**Dimensions originales:** {image_pil.size[0]} x {image_pil.size[1]} pixels")
264
+
265
+ # Bouton de prédiction
266
+ if st.button("🔍 Classifier l'image", type="primary"):
267
+ if not selected_models:
268
+ st.warning("Veuillez sélectionner au moins un modèle.")
269
+ else:
270
+ with st.spinner("Classification en cours..."):
271
+ # Préprocesser l'image
272
+ img_array = self.preprocess_image(image_pil, self.target_size)
273
+
274
+ if img_array is not None:
275
+ # Faire les prédictions
276
+ results = []
277
+ for model_name in selected_models:
278
+ if model_name == "Modèle v1" and self.model_v1 is not None:
279
+ result = self.predict_image(img_array, self.model_v1, "Modèle v1")
280
+ results.append(result)
281
+ elif model_name == "Modèle v2" and self.model_v2 is not None:
282
+ result = self.predict_image(img_array, self.model_v2, "Modèle v2")
283
+ results.append(result)
284
+
285
+ # Stocker les résultats dans la session
286
+ st.session_state['prediction_results'] = results
287
+ st.session_state['uploaded_image'] = image_pil
288
+ else:
289
+ st.info("Veuillez uploader une image pour commencer la classification.")
290
+
291
+ with col2:
292
+ st.header("Résultats de classification")
293
+
294
+ # Afficher les résultats
295
+ if 'prediction_results' in st.session_state and st.session_state['prediction_results']:
296
+ results = st.session_state['prediction_results']
297
+
298
+ # Résumé des prédictions
299
+ st.subheader("📊 Résumé des prédictions")
300
+
301
+ for result in results:
302
+ if result is not None:
303
+ col_pred, col_conf = st.columns([2, 1])
304
+ with col_pred:
305
+ st.write(f"**{result['model_name']}:**")
306
+ with col_conf:
307
+ confidence_pct = result['confidence'] * 100
308
+ st.metric("Confiance", f"{confidence_pct:.1f}%")
309
+
310
+ # Barre de progression pour la confiance
311
+ st.progress(result['confidence'])
312
+
313
+ # Détails des probabilités
314
+ with st.expander(f"Détails - {result['model_name']}"):
315
+ for class_name, prob in result['class_probabilities'].items():
316
+ prob_pct = prob * 100
317
+ st.write(f"**{class_name}:** {prob_pct:.2f}%")
318
+
319
+ # Graphique de comparaison
320
+ if len(results) > 1:
321
+ st.subheader("📈 Comparaison des modèles")
322
+ fig = self.create_confidence_chart(results)
323
+ if fig is not None:
324
+ st.pyplot(fig)
325
+
326
+ # Recommandation
327
+ st.subheader("💡 Recommandation")
328
+ if len(results) == 1:
329
+ result = results[0]
330
+ if result is not None:
331
+ confidence_pct = result['confidence'] * 100
332
+ if confidence_pct >= 80:
333
+ st.success(f"Classification très fiable: {result['predicted_class']} ({confidence_pct:.1f}%)")
334
+ elif confidence_pct >= 60:
335
+ st.warning(f"Classification modérée: {result['predicted_class']} ({confidence_pct:.1f}%)")
336
+ else:
337
+ st.error(f"Classification incertaine: {result['predicted_class']} ({confidence_pct:.1f}%)")
338
+ else:
339
+ # Comparer les résultats des différents modèles
340
+ predictions = [r['predicted_class'] for r in results if r is not None]
341
+ confidences = [r['confidence'] for r in results if r is not None]
342
+
343
+ if len(set(predictions)) == 1:
344
+ st.success(f"✅ Consensus: Tous les modèles prédisent '{predictions[0]}'")
345
+ else:
346
+ st.warning("⚠️ Divergence: Les modèles donnent des prédictions différentes")
347
+ for i, (pred, conf) in enumerate(zip(predictions, confidences)):
348
+ st.write(f"- {results[i]['model_name']}: {pred} ({conf*100:.1f}%)")
349
+ else:
350
+ st.info("Les résultats de classification apparaîtront ici après l'analyse.")
351
+
352
+ # Footer
353
+ st.markdown("---")
354
+ st.markdown(
355
+ """
356
+ <div style='text-align: center; color: #666;'>
357
+ <p>Classificateur de Déchets IA - Modèles v1 et v2</p>
358
+ <p>Déployé sur Hugging Face Spaces</p>
359
+ </div>
360
+ """,
361
+ unsafe_allow_html=True
362
+ )
363
+
364
+ def main():
365
+ """Fonction principale."""
366
+ classifier_ui = WasteClassifierUI()
367
+ classifier_ui.run()
368
+
369
+ if __name__ == "__main__":
370
+ main()