File size: 3,088 Bytes
505fc99
 
 
 
 
 
 
 
27a811d
505fc99
39b825d
505fc99
8bd7feb
 
505fc99
8bd7feb
 
505fc99
8bd7feb
27a811d
 
 
 
 
e346658
 
 
8bd7feb
 
 
505fc99
 
 
 
 
 
 
 
8bd7feb
505fc99
8bd7feb
505fc99
 
 
8bd7feb
505fc99
 
e346658
505fc99
8bd7feb
505fc99
e346658
 
 
505fc99
 
 
 
 
 
 
 
 
8bd7feb
505fc99
8bd7feb
505fc99
8bd7feb
505fc99
 
 
 
8bd7feb
505fc99
 
 
 
 
 
 
 
8bd7feb
 
505fc99
 
 
 
8bd7feb
505fc99
 
 
8bd7feb
505fc99
 
 
 
 
 
 
 
 
8bd7feb
505fc99
 
 
8bd7feb
 
 
 
505fc99
 
 
e346658
 
bcc26c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Utility functions for the Plant Disease Detection UI
"""

import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import os

IMAGE_SIZE = (256, 256)

NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]

TOP_K_PREDICTIONS = 5
CONFIDENCE_THRESHOLD = 0.01


BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# Path to classNames.txt relative to this file
CLASS_NAMES_FILE = os.path.join(BASE_DIR, "classNames.txt")

with open(CLASS_NAMES_FILE, "r") as f:
    CLASS_NAMES = [line.strip() for line in f.readlines() if line.strip()]

def preprocess_image(image):
    """
    Preprocess image for model input
    """
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image.astype('uint8'))

    if image.mode == 'RGBA':
        image = image.convert('RGB')

    transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD)
    ])

    tensor = transform(image)
    return tensor.unsqueeze(0)


def postprocess_predictions(logits, class_names=None, top_k=TOP_K_PREDICTIONS):
    """
    Convert logits to formatted predictions
    """
    if class_names is None:
        class_names = CLASS_NAMES

    probs = torch.nn.functional.softmax(logits, dim=1)
    probs = probs.cpu().detach().numpy()[0]

    predictions = {name: float(prob) for name, prob in zip(class_names, probs)}
    top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]

    return dict(top_predictions), predictions


def format_prediction_for_display(predictions, confidence_threshold=CONFIDENCE_THRESHOLD):
    """
    Filter predictions for Gradio display
    """
    return {k: v for k, v in predictions.items() if v >= confidence_threshold}


def format_class_name(class_name):
    """
    Format class name into readable form
    """
    parts = class_name.split("___")

    if len(parts) == 2:
        plant, disease = parts
        plant = plant.replace("_", " ")
        disease = disease.replace("_", " ")
        return f"{plant} - {disease}"

    return class_name.replace("_", " ")


def get_disease_info(class_name):
    """
    Extract structured disease info from class name
    """
    parts = class_name.split("___")

    return {
        "plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown",
        "disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown",
        "is_healthy": "healthy" in class_name.lower(),
        "formatted_name": format_class_name(class_name)
    }


def create_confidence_label(predictions, top_k=5):
    """
    Render a formatted multiline prediction list
    """
    top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]

    lines = [
        f"{i}. {format_class_name(name)}: {prob*100:.2f}%"
        for i, (name, prob) in enumerate(top_preds, 1)
    ]
    return "\n".join(lines)


def get_class_names():
    """Return the loaded class names from the txt file."""
    return CLASS_NAMES