Mert Yerlikaya
Add feature-rich Gradio UI with mock model
505fc99
raw
history blame
5.73 kB
"""
Utility functions for the Plant Disease Detection UI
"""
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import config
def preprocess_image(image, image_size=config.IMAGE_SIZE):
"""
Preprocess image for model input
Args:
image: PIL Image or numpy array
image_size: Target size (height, width)
Returns:
Preprocessed tensor ready for model
"""
# Convert to PIL Image if numpy array
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
image = image.convert('RGB')
# Define preprocessing transforms
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=config.NORMALIZE_MEAN, std=config.NORMALIZE_STD)
])
# Apply transforms
tensor = transform(image)
# Add batch dimension
tensor = tensor.unsqueeze(0)
return tensor
def postprocess_predictions(logits, class_names=config.CLASS_NAMES, top_k=config.TOP_K_PREDICTIONS):
"""
Convert model logits to human-readable predictions
Args:
logits: Raw model output
class_names: List of class names
top_k: Number of top predictions to return
Returns:
Dictionary of predictions with confidences
"""
# Convert logits to probabilities using softmax
probs = torch.nn.functional.softmax(logits, dim=1)
# Convert to numpy
probs = probs.cpu().detach().numpy()[0]
# Create predictions dictionary
predictions = {name: float(prob) for name, prob in zip(class_names, probs)}
# Get top-k predictions
top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
return dict(top_predictions), predictions
def format_prediction_for_display(predictions):
"""
Format predictions for Gradio display
Args:
predictions: Dictionary of class names and probabilities
Returns:
Dictionary formatted for Gradio Label component
"""
# Filter out very low confidence predictions
filtered = {k: v for k, v in predictions.items() if v >= config.CONFIDENCE_THRESHOLD}
return filtered
def format_class_name(class_name):
"""
Format class name for better readability
Args:
class_name: Original class name (e.g., "Tomato___Late_blight")
Returns:
Formatted class name (e.g., "Tomato - Late blight")
"""
# Replace underscores with spaces and split on ___
parts = class_name.split("___")
if len(parts) == 2:
plant, disease = parts
plant = plant.replace("_", " ")
disease = disease.replace("_", " ")
return f"{plant} - {disease}"
else:
return class_name.replace("_", " ")
def get_disease_info(class_name):
"""
Get information about a disease (for future enhancement)
Args:
class_name: Disease class name
Returns:
Dictionary with disease information
"""
# This is a placeholder - you could expand this with actual disease information
parts = class_name.split("___")
info = {
"plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown",
"disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown",
"is_healthy": "healthy" in class_name.lower(),
"formatted_name": format_class_name(class_name)
}
return info
def batch_preprocess_images(images):
"""
Preprocess multiple images for batch prediction
Args:
images: List of PIL Images or numpy arrays
Returns:
Batched tensor ready for model
"""
tensors = [preprocess_image(img) for img in images]
batch = torch.cat(tensors, dim=0)
return batch
def create_confidence_label(predictions, top_k=5):
"""
Create a formatted string showing top predictions
Args:
predictions: Dictionary of predictions
top_k: Number of top predictions to show
Returns:
Formatted string
"""
top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
lines = []
for i, (class_name, prob) in enumerate(top_preds, 1):
formatted_name = format_class_name(class_name)
lines.append(f"{i}. {formatted_name}: {prob*100:.2f}%")
return "\n".join(lines)
if __name__ == "__main__":
# Test utilities
print("Testing utility functions...")
# Test class name formatting
test_names = [
"Tomato___Late_blight",
"Apple___healthy",
"Corn_(maize)___Common_rust_"
]
print("\nClass name formatting:")
for name in test_names:
print(f" {name} -> {format_class_name(name)}")
# Test disease info
print("\nDisease info:")
for name in test_names:
info = get_disease_info(name)
print(f" {name}:")
print(f" Plant: {info['plant']}")
print(f" Disease: {info['disease']}")
print(f" Healthy: {info['is_healthy']}")
# Test image preprocessing
print("\nImage preprocessing:")
dummy_image = Image.new('RGB', (512, 512), color='red')
tensor = preprocess_image(dummy_image)
print(f" Input size: {dummy_image.size}")
print(f" Output tensor shape: {tensor.shape}")
# Test mock predictions
print("\nMock predictions:")
from models.mock_model import create_mock_predictions
preds = create_mock_predictions(config.CLASS_NAMES)
top_preds, all_preds = postprocess_predictions(
torch.tensor([list(preds.values())]),
config.CLASS_NAMES
)
print(create_confidence_label(top_preds))