File size: 3,640 Bytes
36dd4e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Simple CLI for local image prediction (non-Docker)

Usage (PowerShell):
  python -m src.predict_cli -i path\to\image.jpg -m models\crop_disease_v3_model.pth
"""

import argparse
import json
from pathlib import Path

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from .model import CropDiseaseResNet50

DEFAULT_CLASSES = [
    'Pepper__bell___Bacterial_spot',
    'Pepper__bell___healthy',
    'Potato___Early_blight',
    'Potato___healthy',
    'Potato___Late_blight',
    'Tomato__Target_Spot',
    'Tomato__Tomato_mosaic_virus',
    'Tomato__Tomato_YellowLeaf__Curl_Virus',
    'Tomato_Bacterial_spot',
    'Tomato_Early_blight',
    'Tomato_healthy',
    'Tomato_Late_blight',
    'Tomato_Leaf_Mold',
    'Tomato_Septoria_leaf_spot',
    'Tomato_Spider_mites_Two_spotted_spider_mite'
]


def load_model(model_path: Path, class_names: list[str]) -> tuple[torch.nn.Module, torch.device]:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CropDiseaseResNet50(num_classes=len(class_names), pretrained=False)

    if model_path.exists():
        checkpoint = torch.load(str(model_path), map_location=device)
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
            if 'class_names' in checkpoint:
                # Prefer class names bundled in checkpoint, if present
                ckpt_classes = checkpoint['class_names']
                if isinstance(ckpt_classes, list) and len(ckpt_classes) == len(class_names):
                    class_names = ckpt_classes
        else:
            state_dict = checkpoint
        model.load_state_dict(state_dict, strict=True)
    else:
        print(f"Warning: model file not found at {model_path}, using untrained weights.")

    model.to(device)
    model.eval()
    return model, device, class_names


def preprocess(image_path: Path) -> torch.Tensor:
    image = Image.open(str(image_path)).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)


def main():
    parser = argparse.ArgumentParser(description='Local image prediction for crop disease detection')
    parser.add_argument('-i', '--image', required=True, type=Path, help='Path to input image')
    parser.add_argument('-m', '--model', default=Path('models/crop_disease_v3_model.pth'), type=Path, help='Path to model checkpoint (.pth)')
    parser.add_argument('--classes', type=Path, help='Optional JSON file containing class names array')

    args = parser.parse_args()

    # Resolve class names
    class_names = DEFAULT_CLASSES
    if args.classes and args.classes.exists():
        try:
            class_names = json.loads(Path(args.classes).read_text(encoding='utf-8'))
        except Exception:
            print('Warning: Failed to read classes file, falling back to default classes.')

    model, device, class_names = load_model(args.model, class_names)

    input_tensor = preprocess(args.image).to(device)

    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = F.softmax(outputs, dim=1)
        confidence, predicted_idx = torch.max(probabilities, 1)

    result = {
        'image': str(args.image),
        'predicted_class': class_names[predicted_idx.item()],
        'confidence': float(confidence.item())
    }

    print(json.dumps(result, indent=2))


if __name__ == '__main__':
    main()