File size: 5,726 Bytes
505fc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
Utility functions for the Plant Disease Detection UI
"""

import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import config


def preprocess_image(image, image_size=config.IMAGE_SIZE):
    """
    Preprocess image for model input

    Args:
        image: PIL Image or numpy array
        image_size: Target size (height, width)

    Returns:
        Preprocessed tensor ready for model
    """
    # Convert to PIL Image if numpy array
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image.astype('uint8'))

    # Convert RGBA to RGB if necessary
    if image.mode == 'RGBA':
        image = image.convert('RGB')

    # Define preprocessing transforms
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=config.NORMALIZE_MEAN, std=config.NORMALIZE_STD)
    ])

    # Apply transforms
    tensor = transform(image)

    # Add batch dimension
    tensor = tensor.unsqueeze(0)

    return tensor


def postprocess_predictions(logits, class_names=config.CLASS_NAMES, top_k=config.TOP_K_PREDICTIONS):
    """
    Convert model logits to human-readable predictions

    Args:
        logits: Raw model output
        class_names: List of class names
        top_k: Number of top predictions to return

    Returns:
        Dictionary of predictions with confidences
    """
    # Convert logits to probabilities using softmax
    probs = torch.nn.functional.softmax(logits, dim=1)

    # Convert to numpy
    probs = probs.cpu().detach().numpy()[0]

    # Create predictions dictionary
    predictions = {name: float(prob) for name, prob in zip(class_names, probs)}

    # Get top-k predictions
    top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]

    return dict(top_predictions), predictions


def format_prediction_for_display(predictions):
    """
    Format predictions for Gradio display

    Args:
        predictions: Dictionary of class names and probabilities

    Returns:
        Dictionary formatted for Gradio Label component
    """
    # Filter out very low confidence predictions
    filtered = {k: v for k, v in predictions.items() if v >= config.CONFIDENCE_THRESHOLD}

    return filtered


def format_class_name(class_name):
    """
    Format class name for better readability

    Args:
        class_name: Original class name (e.g., "Tomato___Late_blight")

    Returns:
        Formatted class name (e.g., "Tomato - Late blight")
    """
    # Replace underscores with spaces and split on ___
    parts = class_name.split("___")

    if len(parts) == 2:
        plant, disease = parts
        plant = plant.replace("_", " ")
        disease = disease.replace("_", " ")
        return f"{plant} - {disease}"
    else:
        return class_name.replace("_", " ")


def get_disease_info(class_name):
    """
    Get information about a disease (for future enhancement)

    Args:
        class_name: Disease class name

    Returns:
        Dictionary with disease information
    """
    # This is a placeholder - you could expand this with actual disease information
    parts = class_name.split("___")

    info = {
        "plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown",
        "disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown",
        "is_healthy": "healthy" in class_name.lower(),
        "formatted_name": format_class_name(class_name)
    }

    return info


def batch_preprocess_images(images):
    """
    Preprocess multiple images for batch prediction

    Args:
        images: List of PIL Images or numpy arrays

    Returns:
        Batched tensor ready for model
    """
    tensors = [preprocess_image(img) for img in images]
    batch = torch.cat(tensors, dim=0)
    return batch


def create_confidence_label(predictions, top_k=5):
    """
    Create a formatted string showing top predictions

    Args:
        predictions: Dictionary of predictions
        top_k: Number of top predictions to show

    Returns:
        Formatted string
    """
    top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]

    lines = []
    for i, (class_name, prob) in enumerate(top_preds, 1):
        formatted_name = format_class_name(class_name)
        lines.append(f"{i}. {formatted_name}: {prob*100:.2f}%")

    return "\n".join(lines)


if __name__ == "__main__":
    # Test utilities
    print("Testing utility functions...")

    # Test class name formatting
    test_names = [
        "Tomato___Late_blight",
        "Apple___healthy",
        "Corn_(maize)___Common_rust_"
    ]

    print("\nClass name formatting:")
    for name in test_names:
        print(f"  {name} -> {format_class_name(name)}")

    # Test disease info
    print("\nDisease info:")
    for name in test_names:
        info = get_disease_info(name)
        print(f"  {name}:")
        print(f"    Plant: {info['plant']}")
        print(f"    Disease: {info['disease']}")
        print(f"    Healthy: {info['is_healthy']}")

    # Test image preprocessing
    print("\nImage preprocessing:")
    dummy_image = Image.new('RGB', (512, 512), color='red')
    tensor = preprocess_image(dummy_image)
    print(f"  Input size: {dummy_image.size}")
    print(f"  Output tensor shape: {tensor.shape}")

    # Test mock predictions
    print("\nMock predictions:")
    from models.mock_model import create_mock_predictions
    preds = create_mock_predictions(config.CLASS_NAMES)
    top_preds, all_preds = postprocess_predictions(
        torch.tensor([list(preds.values())]),
        config.CLASS_NAMES
    )
    print(create_confidence_label(top_preds))