Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Interface Streamlit pour la classification de déchets - Version Hugging Face Spaces | |
| Déployé sur Hugging Face Spaces avec téléchargement automatique des modèles | |
| """ | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from PIL import Image | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing import image | |
| import os | |
| from pathlib import Path | |
| import logging | |
| import requests | |
| import zipfile | |
| import tempfile | |
| # Configuration de la page | |
| st.set_page_config( | |
| page_title="Classificateur de Déchets IA", | |
| page_icon="♻️", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Configuration du logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class WasteClassifierUI: | |
| """Classe principale pour l'interface de classification de déchets.""" | |
| def __init__(self): | |
| self.model_v1 = None | |
| self.model_v2 = None | |
| self.class_names = ["Papier", "Plastique"] | |
| self.target_size = (96, 96) | |
| # Chemins des modèles pour Hugging Face Spaces | |
| self.models_dir = Path("models") | |
| self.models_dir.mkdir(exist_ok=True) | |
| self.model_v1_path = self.models_dir / "waste_classifier_v1.h5" | |
| self.model_v2_path = self.models_dir / "waste_classifier_v2.h5" | |
| # URLs des modèles (à remplacer par vos URLs Hugging Face) | |
| # Pour Docker, vous pouvez aussi utiliser des modèles locaux | |
| self.model_v1_url = os.getenv('MODEL_V1_URL', "https://huggingface.co/360TechEnv/waste-classifier/resolve/main/models/waste_classifier_v1.h5") | |
| self.model_v2_url = os.getenv('MODEL_V2_URL', "https://huggingface.co/360TechEnv/waste-classifier/resolve/main/models/waste_classifier_v2.h5") | |
| # Vérifier si des modèles locaux existent (pour Docker) | |
| local_v1 = Path("models/waste_classifier_v1.h5") | |
| local_v2 = Path("models/waste_classifier_v2.h5") | |
| if local_v1.exists(): | |
| self.model_v1_path = local_v1 | |
| if local_v2.exists(): | |
| self.model_v2_path = local_v2 | |
| def download_model(self, url, local_path): | |
| """Télécharge un modèle depuis une URL.""" | |
| try: | |
| if local_path.exists(): | |
| logger.info(f"Modèle déjà présent: {local_path}") | |
| return True | |
| logger.info(f"Téléchargement du modèle depuis: {url}") | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info(f"Modèle téléchargé avec succès: {local_path}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Erreur lors du téléchargement: {e}") | |
| return False | |
| def load_models(self): | |
| """Charge les modèles v1 et v2.""" | |
| try: | |
| # Télécharger le modèle v1 si nécessaire | |
| if not self.model_v1_path.exists(): | |
| st.info("Téléchargement du modèle v1...") | |
| if not self.download_model(self.model_v1_url, self.model_v1_path): | |
| st.warning("Impossible de télécharger le modèle v1") | |
| else: | |
| st.success("Modèle v1 téléchargé avec succès!") | |
| # Charger le modèle v1 | |
| if self.model_v1_path.exists(): | |
| self.model_v1 = load_model(self.model_v1_path) | |
| logger.info("Modèle v1 chargé avec succès") | |
| else: | |
| logger.warning("Modèle v1 non disponible") | |
| # Télécharger le modèle v2 si nécessaire | |
| if not self.model_v2_path.exists(): | |
| st.info("Téléchargement du modèle v2...") | |
| if not self.download_model(self.model_v2_url, self.model_v2_path): | |
| st.warning("Impossible de télécharger le modèle v2") | |
| else: | |
| st.success("Modèle v2 téléchargé avec succès!") | |
| # Charger le modèle v2 | |
| if self.model_v2_path.exists(): | |
| self.model_v2 = load_model(self.model_v2_path) | |
| logger.info("Modèle v2 chargé avec succès") | |
| else: | |
| logger.warning("Modèle v2 non disponible") | |
| except Exception as e: | |
| logger.error(f"Erreur lors du chargement des modèles: {e}") | |
| st.error(f"Erreur lors du chargement des modèles: {e}") | |
| def preprocess_image(self, img, target_size=(96, 96)): | |
| """Préprocesse une image pour la prédiction.""" | |
| try: | |
| # Redimensionner l'image | |
| img_resized = img.resize(target_size) | |
| # Convertir en array numpy | |
| img_array = image.img_to_array(img_resized) | |
| # Normaliser les pixels (0-255 -> 0-1) | |
| img_array = img_array / 255.0 | |
| # Ajouter une dimension de batch | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| except Exception as e: | |
| logger.error(f"Erreur lors du preprocessing: {e}") | |
| st.error(f"Erreur lors du preprocessing: {e}") | |
| return None | |
| def predict_image(self, img_array, model, model_name): | |
| """Prédit la classe d'une image avec un modèle donné.""" | |
| try: | |
| if model is None: | |
| return None | |
| # Faire la prédiction | |
| predictions = model.predict(img_array, verbose=0) | |
| # Obtenir la classe prédite et la confiance | |
| predicted_class_idx = np.argmax(predictions[0]) | |
| confidence = predictions[0][predicted_class_idx] | |
| predicted_class = self.class_names[predicted_class_idx] | |
| # Obtenir les probabilités pour toutes les classes | |
| class_probabilities = {} | |
| for i, class_name in enumerate(self.class_names): | |
| class_probabilities[class_name] = float(predictions[0][i]) | |
| result = { | |
| 'model_name': model_name, | |
| 'predicted_class': predicted_class, | |
| 'confidence': float(confidence), | |
| 'class_probabilities': class_probabilities | |
| } | |
| return result | |
| except Exception as e: | |
| logger.error(f"Erreur lors de la prédiction avec {model_name}: {e}") | |
| st.error(f"Erreur lors de la prédiction avec {model_name}: {e}") | |
| return None | |
| def create_confidence_chart(self, results): | |
| """Crée un graphique en barres des probabilités de confiance.""" | |
| if not results: | |
| return None | |
| fig, axes = plt.subplots(1, len(results), figsize=(6 * len(results), 5)) | |
| if len(results) == 1: | |
| axes = [axes] | |
| for i, result in enumerate(results): | |
| if result is None: | |
| continue | |
| classes = list(result['class_probabilities'].keys()) | |
| probabilities = list(result['class_probabilities'].values()) | |
| # Créer le graphique en barres | |
| bars = axes[i].bar(classes, probabilities, | |
| color=['#2E8B57' if c == result['predicted_class'] else '#4682B4' | |
| for c in classes]) | |
| axes[i].set_title(f"{result['model_name']}\nPrédiction: {result['predicted_class']}\nConfiance: {result['confidence']:.3f}") | |
| axes[i].set_ylabel("Probabilité") | |
| axes[i].set_ylim(0, 1) | |
| # Ajouter les valeurs sur les barres | |
| for bar, prob in zip(bars, probabilities): | |
| height = bar.get_height() | |
| axes[i].text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
| f'{prob:.3f}', ha='center', va='bottom', fontweight='bold') | |
| plt.tight_layout() | |
| return fig | |
| def run(self): | |
| """Lance l'interface Streamlit.""" | |
| # Titre principal | |
| st.title("♻️ Classificateur de Déchets IA") | |
| st.markdown("---") | |
| # Charger les modèles | |
| if self.model_v1 is None or self.model_v2 is None: | |
| with st.spinner("Chargement des modèles..."): | |
| self.load_models() | |
| # Sidebar pour la configuration | |
| st.sidebar.header("Configuration") | |
| # Sélection du modèle | |
| available_models = [] | |
| if self.model_v1 is not None: | |
| available_models.append("Modèle v1") | |
| if self.model_v2 is not None: | |
| available_models.append("Modèle v2") | |
| if not available_models: | |
| st.error("Aucun modèle disponible. Vérifiez la connexion internet et réessayez.") | |
| return | |
| selected_models = st.sidebar.multiselect( | |
| "Sélectionnez les modèles à utiliser:", | |
| available_models, | |
| default=available_models | |
| ) | |
| # Upload d'image | |
| st.sidebar.header("Upload d'image") | |
| uploaded_file = st.sidebar.file_uploader( | |
| "Choisissez une image de déchet:", | |
| type=['jpg', 'jpeg', 'png', 'bmp', 'tiff'], | |
| help="Formats supportés: JPG, JPEG, PNG, BMP, TIFF" | |
| ) | |
| # Zone principale | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.header("Image d'entrée") | |
| if uploaded_file is not None: | |
| # Afficher l'image uploadée | |
| image_pil = Image.open(uploaded_file) | |
| st.image(image_pil, caption="Image uploadée", use_column_width=True) | |
| # Informations sur l'image | |
| st.info(f"**Dimensions originales:** {image_pil.size[0]} x {image_pil.size[1]} pixels") | |
| # Bouton de prédiction | |
| if st.button("🔍 Classifier l'image", type="primary"): | |
| if not selected_models: | |
| st.warning("Veuillez sélectionner au moins un modèle.") | |
| else: | |
| with st.spinner("Classification en cours..."): | |
| # Préprocesser l'image | |
| img_array = self.preprocess_image(image_pil, self.target_size) | |
| if img_array is not None: | |
| # Faire les prédictions | |
| results = [] | |
| for model_name in selected_models: | |
| if model_name == "Modèle v1" and self.model_v1 is not None: | |
| result = self.predict_image(img_array, self.model_v1, "Modèle v1") | |
| results.append(result) | |
| elif model_name == "Modèle v2" and self.model_v2 is not None: | |
| result = self.predict_image(img_array, self.model_v2, "Modèle v2") | |
| results.append(result) | |
| # Stocker les résultats dans la session | |
| st.session_state['prediction_results'] = results | |
| st.session_state['uploaded_image'] = image_pil | |
| else: | |
| st.info("Veuillez uploader une image pour commencer la classification.") | |
| with col2: | |
| st.header("Résultats de classification") | |
| # Afficher les résultats | |
| if 'prediction_results' in st.session_state and st.session_state['prediction_results']: | |
| results = st.session_state['prediction_results'] | |
| # Résumé des prédictions | |
| st.subheader("📊 Résumé des prédictions") | |
| for result in results: | |
| if result is not None: | |
| col_pred, col_conf = st.columns([2, 1]) | |
| with col_pred: | |
| st.write(f"**{result['model_name']}:**") | |
| with col_conf: | |
| confidence_pct = result['confidence'] * 100 | |
| st.metric("Confiance", f"{confidence_pct:.1f}%") | |
| # Barre de progression pour la confiance | |
| st.progress(result['confidence']) | |
| # Détails des probabilités | |
| with st.expander(f"Détails - {result['model_name']}"): | |
| for class_name, prob in result['class_probabilities'].items(): | |
| prob_pct = prob * 100 | |
| st.write(f"**{class_name}:** {prob_pct:.2f}%") | |
| # Graphique de comparaison | |
| if len(results) > 1: | |
| st.subheader("📈 Comparaison des modèles") | |
| fig = self.create_confidence_chart(results) | |
| if fig is not None: | |
| st.pyplot(fig) | |
| # Recommandation | |
| st.subheader("💡 Recommandation") | |
| if len(results) == 1: | |
| result = results[0] | |
| if result is not None: | |
| confidence_pct = result['confidence'] * 100 | |
| if confidence_pct >= 80: | |
| st.success(f"Classification très fiable: {result['predicted_class']} ({confidence_pct:.1f}%)") | |
| elif confidence_pct >= 60: | |
| st.warning(f"Classification modérée: {result['predicted_class']} ({confidence_pct:.1f}%)") | |
| else: | |
| st.error(f"Classification incertaine: {result['predicted_class']} ({confidence_pct:.1f}%)") | |
| else: | |
| # Comparer les résultats des différents modèles | |
| predictions = [r['predicted_class'] for r in results if r is not None] | |
| confidences = [r['confidence'] for r in results if r is not None] | |
| if len(set(predictions)) == 1: | |
| st.success(f"✅ Consensus: Tous les modèles prédisent '{predictions[0]}'") | |
| else: | |
| st.warning("⚠️ Divergence: Les modèles donnent des prédictions différentes") | |
| for i, (pred, conf) in enumerate(zip(predictions, confidences)): | |
| st.write(f"- {results[i]['model_name']}: {pred} ({conf*100:.1f}%)") | |
| else: | |
| st.info("Les résultats de classification apparaîtront ici après l'analyse.") | |
| # Footer | |
| st.markdown("---") | |
| st.markdown( | |
| """ | |
| <div style='text-align: center; color: #666;'> | |
| <p>Classificateur de Déchets IA - Modèles v1 et v2</p> | |
| <p>Déployé sur Hugging Face Spaces</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| def main(): | |
| """Fonction principale.""" | |
| classifier_ui = WasteClassifierUI() | |
| classifier_ui.run() | |
| if __name__ == "__main__": | |
| main() | |