Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import tensorflow as tf | |
| import cv2 | |
| import smtplib | |
| import tempfile | |
| import threading | |
| from tensorflow.keras.utils import img_to_array | |
| from tensorflow.keras import layers, regularizers | |
| from tensorflow.keras.applications.efficientnet_v2 import preprocess_input | |
| from email.mime.multipart import MIMEMultipart | |
| from email.mime.text import MIMEText | |
| from PIL import Image | |
| import gradio as gr | |
| # Create base model | |
| input_shape = (224, 224, 3) | |
| base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False) | |
| base_model.trainable = False | |
| # Create Functional model | |
| inputs = layers.Input(shape=input_shape, name="input_layer") | |
| x = base_model(inputs, training=False) | |
| x = layers.GlobalAveragePooling2D(name="pooling_layer")(x) | |
| x = layers.Dense(12, kernel_regularizer=regularizers.l2(0.001))(x) | |
| outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x) | |
| model = tf.keras.Model(inputs, outputs) | |
| # Compile the model | |
| model.compile( | |
| loss="categorical_crossentropy", | |
| optimizer=tf.keras.optimizers.Adam(), | |
| metrics=["accuracy"] | |
| ) | |
| # Load weights | |
| model.load_weights('pest_classif.weights.h5') | |
| # Class labels | |
| CLASS_LABELS = ['Fourmis', 'Abeilles', 'Scarabe', 'Chenille', 'Verre de terre', | |
| 'Perce-oreille', 'Criquet', 'Papillon de nuit', 'Limace', 'Escargot', 'Guêpes', 'Charançon'] | |
| def preprocess_image(frame): | |
| if isinstance(frame, Image.Image): | |
| frame = np.array(frame) | |
| if len(frame.shape) == 3 and frame.shape[2] == 3: | |
| if not isinstance(frame, np.ndarray) or frame.dtype != np.uint8: | |
| frame = frame.astype(np.uint8) | |
| frame_rgb = frame | |
| else: | |
| frame_rgb = frame | |
| frame_resized = cv2.resize(frame_rgb, (224, 224)) | |
| img_array = img_to_array(frame_resized) | |
| img_array = preprocess_input(img_array) | |
| return np.expand_dims(img_array, axis=0) | |
| def send_email(predicted_class, confidence): | |
| from_address = "# Your email address" | |
| to_address = "# Recipient's email address" | |
| subject = "Alerte peste Detectée!" | |
| msg = MIMEMultipart() | |
| msg['From'] = from_address | |
| msg['To'] = to_address | |
| msg['Subject'] = subject | |
| body = f""" | |
| Vous recevez ce message car nous sommes à {confidence*100:.2f}% certain que votre serre est infestée de {predicted_class} | |
| Veuillez à prendre les mesures qu'il faut pour pallier à cette situation | |
| """ | |
| msg.attach(MIMEText(body, 'plain')) | |
| try: | |
| server = smtplib.SMTP('smtp.gmail.com', 587) | |
| server.starttls() | |
| server.login(from_address, 'klcwvzxzvmzxbhfq') | |
| text = msg.as_string() | |
| server.sendmail(from_address, to_address, text) | |
| server.quit() | |
| except Exception as e: | |
| print(f"Erreur d'envoi d'email: {e}") | |
| def classify_image(frame): | |
| annotated_frame = frame.copy() if isinstance(frame, np.ndarray) else np.array(frame).copy() | |
| processed_frame = preprocess_image(frame) | |
| predictions = model.predict(processed_frame, verbose=0) | |
| predicted_class_idx = np.argmax(predictions, axis=1)[0] | |
| confidence = predictions[0][predicted_class_idx] | |
| predicted_label = CLASS_LABELS[predicted_class_idx] | |
| label_text = f"Classe: {predicted_label} ({confidence*100:.1f}%)" | |
| if not isinstance(annotated_frame, np.ndarray): | |
| annotated_frame = np.array(annotated_frame) | |
| cv2.putText(annotated_frame, label_text, (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) | |
| if confidence > 0.85: | |
| send_email(predicted_label, confidence) | |
| return annotated_frame, predicted_label, f"{confidence*100:.2f}%" | |
| def predict_image(image): | |
| try: | |
| print(f"[DEBUG] predict_image called with image type: {type(image)}") | |
| if image is None: | |
| print("[ERROR] No image provided") | |
| return None, "No image provided", "N/A" | |
| # Avec type="pil", Gradio nous donne directement une PIL Image | |
| # On vérifie quand même au cas où | |
| if not isinstance(image, Image.Image): | |
| print(f"[DEBUG] Converting {type(image)} to PIL Image...") | |
| from PIL import Image as PILImage | |
| if isinstance(image, np.ndarray): | |
| image = PILImage.fromarray(image) | |
| else: | |
| print(f"[ERROR] Unexpected image type: {type(image)}") | |
| return None, f"Unexpected type: {type(image)}", "N/A" | |
| print(f"[DEBUG] Image size: {image.size}, mode: {image.mode}") | |
| annotated_img, label, conf = classify_image(image) | |
| print(f"[DEBUG] Classification successful: {label} ({conf})") | |
| return annotated_img, label, conf | |
| except Exception as e: | |
| print(f"[ERROR] Exception in predict_image: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"Error: {str(e)}", "N/A" | |
| def predict_video(video_path): | |
| if video_path is None: | |
| return None | |
| cap = cv2.VideoCapture(video_path) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| output_path = tempfile.mktemp(suffix='.mp4') | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % 5 == 0: | |
| annotated_frame, _, _ = classify_image(frame) | |
| out.write(annotated_frame) | |
| else: | |
| out.write(frame) | |
| frame_count += 1 | |
| cap.release() | |
| out.release() | |
| return output_path | |
| def predict_webcam(image): | |
| return predict_image(image) | |
| # Create Gradio Interface | |
| with gr.Blocks(title="SerraSafe - Détection de Pestes") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌱 SerraSafe Guardian - Système de Détection de Pestes | |
| Ce système utilise l'intelligence artificielle pour détecter et classifier les pestes dans votre serre. | |
| **Classes détectées:** Fourmis, Abeilles, Scarabe, Chenille, Verre de terre, Perce-oreille, Criquet, Papillon de nuit, Limace, Escargot, Guêpes, Charançon | |
| """ | |
| ) | |
| with gr.Tab("📷 Télécharger Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Télécharger une image") | |
| image_button = gr.Button("Analyser l'image", variant="primary") | |
| with gr.Column(): | |
| image_output = gr.Image(label="Résultat") | |
| image_label = gr.Textbox(label="Classe détectée") | |
| image_confidence = gr.Textbox(label="Confiance") | |
| image_button.click( | |
| fn=predict_image, | |
| inputs=image_input, | |
| outputs=[image_output, image_label, image_confidence] | |
| ) | |
| with gr.Tab("📹 Télécharger Vidéo"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Télécharger une vidéo") | |
| video_button = gr.Button("Analyser la vidéo", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Vidéo annotée") | |
| video_button.click( | |
| fn=predict_video, | |
| inputs=video_input, | |
| outputs=video_output | |
| ) | |
| with gr.Tab("🎥 Webcam en Direct"): | |
| gr.Markdown("### Capturez une image avec votre webcam pour l'analyser") | |
| with gr.Row(): | |
| with gr.Column(): | |
| webcam_input = gr.Image(sources=["webcam"], type="pil", label="Webcam", streaming=False) | |
| webcam_button = gr.Button("Analyser", variant="primary") | |
| with gr.Column(): | |
| webcam_output = gr.Image(label="Résultat") | |
| webcam_label = gr.Textbox(label="Classe détectée") | |
| webcam_confidence = gr.Textbox(label="Confiance") | |
| webcam_button.click( | |
| fn=predict_webcam, | |
| inputs=webcam_input, | |
| outputs=[webcam_output, webcam_label, webcam_confidence] | |
| ) | |
| if __name__ == "__main__": | |
| print("Lancement de l'interface SerraSafe...") | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |