""" Simple CLI for local image prediction (non-Docker) Usage (PowerShell): python -m src.predict_cli -i path\to\image.jpg -m models\crop_disease_v3_model.pth """ import argparse import json from pathlib import Path import torch import torch.nn.functional as F from PIL import Image from torchvision import transforms from .model import CropDiseaseResNet50 DEFAULT_CLASSES = [ 'Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___healthy', 'Potato___Late_blight', 'Tomato__Target_Spot', 'Tomato__Tomato_mosaic_virus', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_healthy', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite' ] def load_model(model_path: Path, class_names: list[str]) -> tuple[torch.nn.Module, torch.device]: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CropDiseaseResNet50(num_classes=len(class_names), pretrained=False) if model_path.exists(): checkpoint = torch.load(str(model_path), map_location=device) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] if 'class_names' in checkpoint: # Prefer class names bundled in checkpoint, if present ckpt_classes = checkpoint['class_names'] if isinstance(ckpt_classes, list) and len(ckpt_classes) == len(class_names): class_names = ckpt_classes else: state_dict = checkpoint model.load_state_dict(state_dict, strict=True) else: print(f"Warning: model file not found at {model_path}, using untrained weights.") model.to(device) model.eval() return model, device, class_names def preprocess(image_path: Path) -> torch.Tensor: image = Image.open(str(image_path)).convert('RGB') transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0) def main(): parser = argparse.ArgumentParser(description='Local image prediction for crop disease detection') parser.add_argument('-i', '--image', required=True, type=Path, help='Path to input image') parser.add_argument('-m', '--model', default=Path('models/crop_disease_v3_model.pth'), type=Path, help='Path to model checkpoint (.pth)') parser.add_argument('--classes', type=Path, help='Optional JSON file containing class names array') args = parser.parse_args() # Resolve class names class_names = DEFAULT_CLASSES if args.classes and args.classes.exists(): try: class_names = json.loads(Path(args.classes).read_text(encoding='utf-8')) except Exception: print('Warning: Failed to read classes file, falling back to default classes.') model, device, class_names = load_model(args.model, class_names) input_tensor = preprocess(args.image).to(device) with torch.no_grad(): outputs = model(input_tensor) probabilities = F.softmax(outputs, dim=1) confidence, predicted_idx = torch.max(probabilities, 1) result = { 'image': str(args.image), 'predicted_class': class_names[predicted_idx.item()], 'confidence': float(confidence.item()) } print(json.dumps(result, indent=2)) if __name__ == '__main__': main()