crop_ai_diseases / src /explain_lite.py
vivek12coder's picture
Upload 20960 files
c8df794 verified
"""
Memory-optimized explanation module for crop disease detection
Lightweight implementation that uses minimal RAM
"""
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import io
import base64
import os
import gc
class CropDiseaseExplainerLite:
"""Memory-optimized explainer for crop disease detection"""
def __init__(self, model, class_names, device='cpu'):
"""
Initialize lite explainer
Args:
model: Trained model
class_names: List of class names
device: Device to run on
"""
self.model = model
self.class_names = class_names
self.device = device
self.enabled = True
# Disable explanation if memory is critically low
try:
import psutil
memory_percent = psutil.virtual_memory().percent
if memory_percent > 80: # If system memory usage > 80%
self.enabled = False
print("⚠️ Explanations disabled due to low system memory")
except ImportError:
pass
def generate_explanation_lite(self, image_bytes, predicted_class, max_size=112):
"""
Generate lightweight explanation with minimal memory usage
Args:
image_bytes: Raw image bytes
predicted_class: Predicted disease class
max_size: Maximum image size for processing (smaller = less memory)
Returns:
Dictionary with explanation data
"""
if not self.enabled:
return {"error": "Explanations disabled due to memory constraints"}
try:
# Process image with aggressive memory optimization
image = Image.open(io.BytesIO(image_bytes))
# Convert and resize aggressively to save memory
if image.mode != 'RGB':
image = image.convert('RGB')
# Use very small size for explanation to save memory
image = image.resize((max_size, max_size), Image.Resampling.LANCZOS)
# Convert to numpy for simple processing
img_array = np.array(image)
# Generate simple attention map using basic image processing
attention_map = self._generate_simple_attention(img_array)
# Create explanation visualization
explanation_image = self._create_explanation_visualization(
img_array, attention_map, predicted_class
)
# Convert to base64 with compression
explanation_base64 = self._image_to_base64(explanation_image, quality=60)
# Clear memory
del image, img_array, attention_map, explanation_image
gc.collect()
return {
"explanation_image": explanation_base64,
"method": "simple_attention",
"confidence": "High confidence regions highlighted",
"size": f"{max_size}x{max_size}",
"memory_optimized": True
}
except Exception as e:
print(f"Lite explanation generation failed: {e}")
return {"error": f"Explanation generation failed: {str(e)}"}
def _generate_simple_attention(self, img_array):
"""
Generate simple attention map using image processing techniques
This is much more memory-efficient than Grad-CAM
"""
try:
# Convert to grayscale for processing
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# Apply Gaussian blur to reduce noise
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Detect edges (diseased areas often have different textures)
edges = cv2.Canny(blurred, 50, 150)
# Dilate edges to create attention regions
kernel = np.ones((3, 3), np.uint8)
attention = cv2.dilate(edges, kernel, iterations=1)
# Apply Gaussian blur to smooth the attention map
attention = cv2.GaussianBlur(attention.astype(np.float32), (11, 11), 0)
# Normalize to 0-1 range
if attention.max() > 0:
attention = attention / attention.max()
return attention
except Exception as e:
print(f"Simple attention generation failed: {e}")
# Return uniform attention map as fallback
return np.ones((img_array.shape[0], img_array.shape[1]), dtype=np.float32) * 0.5
def _create_explanation_visualization(self, img_array, attention_map, predicted_class):
"""Create explanation visualization with minimal memory usage"""
try:
# Create colored attention map
colored_attention = cv2.applyColorMap(
(attention_map * 255).astype(np.uint8),
cv2.COLORMAP_JET
)
# Convert colormap from BGR to RGB
colored_attention = cv2.cvtColor(colored_attention, cv2.COLOR_BGR2RGB)
# Blend with original image
alpha = 0.4
blended = cv2.addWeighted(
img_array, 1 - alpha,
colored_attention, alpha,
0
)
# Add simple text overlay (predicted class)
try:
# Use PIL for text overlay (more memory efficient than cv2)
pil_image = Image.fromarray(blended.astype(np.uint8))
# Note: For production, you might want to add text overlay here
# For now, return the blended image to save memory
return pil_image
except Exception:
# If text overlay fails, return blended image
return Image.fromarray(blended.astype(np.uint8))
except Exception as e:
print(f"Visualization creation failed: {e}")
# Return original image as fallback
return Image.fromarray(img_array)
def _image_to_base64(self, pil_image, quality=60):
"""Convert PIL image to base64 with compression"""
try:
buffer = io.BytesIO()
# Save as JPEG with compression to reduce size
pil_image.save(buffer, format='JPEG', quality=quality, optimize=True)
img_str = base64.b64encode(buffer.getvalue()).decode()
# Clear buffer
buffer.close()
return f"data:image/jpeg;base64,{img_str}"
except Exception as e:
print(f"Base64 conversion failed: {e}")
return ""
def get_simple_explanation(self, predicted_class):
"""Get simple text explanation without image processing"""
explanations = {
# Tomato diseases
"Tomato_healthy": "Healthy tomato leaves detected. No disease symptoms visible.",
"Tomato_Late_blight": "Late blight detected. Look for dark spots with white fungal growth.",
"Tomato_Early_blight": "Early blight detected. Characteristic dark spots with concentric rings.",
"Tomato_Leaf_Mold": "Leaf mold detected. Yellow spots on top, grayish mold underneath.",
"Tomato_Bacterial_spot": "Bacterial spot detected. Small dark spots with yellow halos.",
"Tomato_Target_Spot": "Target spot detected. Circular spots with concentric rings.",
"Tomato_mosaic_virus": "Mosaic virus detected. Mottled light and dark green patterns.",
"Tomato_Yellow_Leaf_Curl": "Yellow leaf curl virus detected. Yellowing and curling leaves.",
"Tomato_Septoria_leaf_spot": "Septoria leaf spot detected. Small circular spots with dark borders.",
"Tomato_Spider_mites": "Spider mite damage detected. Fine webbing and stippled leaves.",
# Potato diseases
"Potato_healthy": "Healthy potato leaves detected. No disease symptoms visible.",
"Potato_Late_blight": "Late blight detected. Dark water-soaked spots spreading rapidly.",
"Potato_Early_blight": "Early blight detected. Dark brown spots with concentric rings.",
# Pepper diseases
"Pepper_healthy": "Healthy pepper leaves detected. No disease symptoms visible.",
"Pepper_Bacterial_spot": "Bacterial spot detected. Small raised spots on leaves.",
}
return explanations.get(predicted_class, "Disease detected. Consult agricultural expert for detailed diagnosis.")
def test_memory_usage():
"""Test memory usage of lite explainer"""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024
print(f"Initial memory: {initial_memory:.2f} MB")
# Create dummy model and explainer
class DummyModel:
pass
model = DummyModel()
class_names = ["healthy", "disease1", "disease2"]
explainer = CropDiseaseExplainerLite(model, class_names, 'cpu')
final_memory = process.memory_info().rss / 1024 / 1024
print(f"Memory after explainer creation: {final_memory:.2f} MB")
print(f"Memory increase: {final_memory - initial_memory:.2f} MB")
if __name__ == "__main__":
test_memory_usage()