Spaces:
Sleeping
Sleeping
| """ | |
| Memory-optimized explanation module for crop disease detection | |
| Lightweight implementation that uses minimal RAM | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import os | |
| import gc | |
| class CropDiseaseExplainerLite: | |
| """Memory-optimized explainer for crop disease detection""" | |
| def __init__(self, model, class_names, device='cpu'): | |
| """ | |
| Initialize lite explainer | |
| Args: | |
| model: Trained model | |
| class_names: List of class names | |
| device: Device to run on | |
| """ | |
| self.model = model | |
| self.class_names = class_names | |
| self.device = device | |
| self.enabled = True | |
| # Disable explanation if memory is critically low | |
| try: | |
| import psutil | |
| memory_percent = psutil.virtual_memory().percent | |
| if memory_percent > 80: # If system memory usage > 80% | |
| self.enabled = False | |
| print("⚠️ Explanations disabled due to low system memory") | |
| except ImportError: | |
| pass | |
| def generate_explanation_lite(self, image_bytes, predicted_class, max_size=112): | |
| """ | |
| Generate lightweight explanation with minimal memory usage | |
| Args: | |
| image_bytes: Raw image bytes | |
| predicted_class: Predicted disease class | |
| max_size: Maximum image size for processing (smaller = less memory) | |
| Returns: | |
| Dictionary with explanation data | |
| """ | |
| if not self.enabled: | |
| return {"error": "Explanations disabled due to memory constraints"} | |
| try: | |
| # Process image with aggressive memory optimization | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| # Convert and resize aggressively to save memory | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Use very small size for explanation to save memory | |
| image = image.resize((max_size, max_size), Image.Resampling.LANCZOS) | |
| # Convert to numpy for simple processing | |
| img_array = np.array(image) | |
| # Generate simple attention map using basic image processing | |
| attention_map = self._generate_simple_attention(img_array) | |
| # Create explanation visualization | |
| explanation_image = self._create_explanation_visualization( | |
| img_array, attention_map, predicted_class | |
| ) | |
| # Convert to base64 with compression | |
| explanation_base64 = self._image_to_base64(explanation_image, quality=60) | |
| # Clear memory | |
| del image, img_array, attention_map, explanation_image | |
| gc.collect() | |
| return { | |
| "explanation_image": explanation_base64, | |
| "method": "simple_attention", | |
| "confidence": "High confidence regions highlighted", | |
| "size": f"{max_size}x{max_size}", | |
| "memory_optimized": True | |
| } | |
| except Exception as e: | |
| print(f"Lite explanation generation failed: {e}") | |
| return {"error": f"Explanation generation failed: {str(e)}"} | |
| def _generate_simple_attention(self, img_array): | |
| """ | |
| Generate simple attention map using image processing techniques | |
| This is much more memory-efficient than Grad-CAM | |
| """ | |
| try: | |
| # Convert to grayscale for processing | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| # Apply Gaussian blur to reduce noise | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) | |
| # Detect edges (diseased areas often have different textures) | |
| edges = cv2.Canny(blurred, 50, 150) | |
| # Dilate edges to create attention regions | |
| kernel = np.ones((3, 3), np.uint8) | |
| attention = cv2.dilate(edges, kernel, iterations=1) | |
| # Apply Gaussian blur to smooth the attention map | |
| attention = cv2.GaussianBlur(attention.astype(np.float32), (11, 11), 0) | |
| # Normalize to 0-1 range | |
| if attention.max() > 0: | |
| attention = attention / attention.max() | |
| return attention | |
| except Exception as e: | |
| print(f"Simple attention generation failed: {e}") | |
| # Return uniform attention map as fallback | |
| return np.ones((img_array.shape[0], img_array.shape[1]), dtype=np.float32) * 0.5 | |
| def _create_explanation_visualization(self, img_array, attention_map, predicted_class): | |
| """Create explanation visualization with minimal memory usage""" | |
| try: | |
| # Create colored attention map | |
| colored_attention = cv2.applyColorMap( | |
| (attention_map * 255).astype(np.uint8), | |
| cv2.COLORMAP_JET | |
| ) | |
| # Convert colormap from BGR to RGB | |
| colored_attention = cv2.cvtColor(colored_attention, cv2.COLOR_BGR2RGB) | |
| # Blend with original image | |
| alpha = 0.4 | |
| blended = cv2.addWeighted( | |
| img_array, 1 - alpha, | |
| colored_attention, alpha, | |
| 0 | |
| ) | |
| # Add simple text overlay (predicted class) | |
| try: | |
| # Use PIL for text overlay (more memory efficient than cv2) | |
| pil_image = Image.fromarray(blended.astype(np.uint8)) | |
| # Note: For production, you might want to add text overlay here | |
| # For now, return the blended image to save memory | |
| return pil_image | |
| except Exception: | |
| # If text overlay fails, return blended image | |
| return Image.fromarray(blended.astype(np.uint8)) | |
| except Exception as e: | |
| print(f"Visualization creation failed: {e}") | |
| # Return original image as fallback | |
| return Image.fromarray(img_array) | |
| def _image_to_base64(self, pil_image, quality=60): | |
| """Convert PIL image to base64 with compression""" | |
| try: | |
| buffer = io.BytesIO() | |
| # Save as JPEG with compression to reduce size | |
| pil_image.save(buffer, format='JPEG', quality=quality, optimize=True) | |
| img_str = base64.b64encode(buffer.getvalue()).decode() | |
| # Clear buffer | |
| buffer.close() | |
| return f"data:image/jpeg;base64,{img_str}" | |
| except Exception as e: | |
| print(f"Base64 conversion failed: {e}") | |
| return "" | |
| def get_simple_explanation(self, predicted_class): | |
| """Get simple text explanation without image processing""" | |
| explanations = { | |
| # Tomato diseases | |
| "Tomato_healthy": "Healthy tomato leaves detected. No disease symptoms visible.", | |
| "Tomato_Late_blight": "Late blight detected. Look for dark spots with white fungal growth.", | |
| "Tomato_Early_blight": "Early blight detected. Characteristic dark spots with concentric rings.", | |
| "Tomato_Leaf_Mold": "Leaf mold detected. Yellow spots on top, grayish mold underneath.", | |
| "Tomato_Bacterial_spot": "Bacterial spot detected. Small dark spots with yellow halos.", | |
| "Tomato_Target_Spot": "Target spot detected. Circular spots with concentric rings.", | |
| "Tomato_mosaic_virus": "Mosaic virus detected. Mottled light and dark green patterns.", | |
| "Tomato_Yellow_Leaf_Curl": "Yellow leaf curl virus detected. Yellowing and curling leaves.", | |
| "Tomato_Septoria_leaf_spot": "Septoria leaf spot detected. Small circular spots with dark borders.", | |
| "Tomato_Spider_mites": "Spider mite damage detected. Fine webbing and stippled leaves.", | |
| # Potato diseases | |
| "Potato_healthy": "Healthy potato leaves detected. No disease symptoms visible.", | |
| "Potato_Late_blight": "Late blight detected. Dark water-soaked spots spreading rapidly.", | |
| "Potato_Early_blight": "Early blight detected. Dark brown spots with concentric rings.", | |
| # Pepper diseases | |
| "Pepper_healthy": "Healthy pepper leaves detected. No disease symptoms visible.", | |
| "Pepper_Bacterial_spot": "Bacterial spot detected. Small raised spots on leaves.", | |
| } | |
| return explanations.get(predicted_class, "Disease detected. Consult agricultural expert for detailed diagnosis.") | |
| def test_memory_usage(): | |
| """Test memory usage of lite explainer""" | |
| import psutil | |
| import os | |
| process = psutil.Process(os.getpid()) | |
| initial_memory = process.memory_info().rss / 1024 / 1024 | |
| print(f"Initial memory: {initial_memory:.2f} MB") | |
| # Create dummy model and explainer | |
| class DummyModel: | |
| pass | |
| model = DummyModel() | |
| class_names = ["healthy", "disease1", "disease2"] | |
| explainer = CropDiseaseExplainerLite(model, class_names, 'cpu') | |
| final_memory = process.memory_info().rss / 1024 / 1024 | |
| print(f"Memory after explainer creation: {final_memory:.2f} MB") | |
| print(f"Memory increase: {final_memory - initial_memory:.2f} MB") | |
| if __name__ == "__main__": | |
| test_memory_usage() |