Spaces:
Running
Running
| """ | |
| Simple CLI for local image prediction (non-Docker) | |
| Usage (PowerShell): | |
| python -m src.predict_cli -i path\to\image.jpg -m models\crop_disease_v3_model.pth | |
| """ | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torchvision import transforms | |
| from .model import CropDiseaseResNet50 | |
| DEFAULT_CLASSES = [ | |
| 'Pepper__bell___Bacterial_spot', | |
| 'Pepper__bell___healthy', | |
| 'Potato___Early_blight', | |
| 'Potato___healthy', | |
| 'Potato___Late_blight', | |
| 'Tomato__Target_Spot', | |
| 'Tomato__Tomato_mosaic_virus', | |
| 'Tomato__Tomato_YellowLeaf__Curl_Virus', | |
| 'Tomato_Bacterial_spot', | |
| 'Tomato_Early_blight', | |
| 'Tomato_healthy', | |
| 'Tomato_Late_blight', | |
| 'Tomato_Leaf_Mold', | |
| 'Tomato_Septoria_leaf_spot', | |
| 'Tomato_Spider_mites_Two_spotted_spider_mite' | |
| ] | |
| def load_model(model_path: Path, class_names: list[str]) -> tuple[torch.nn.Module, torch.device]: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = CropDiseaseResNet50(num_classes=len(class_names), pretrained=False) | |
| if model_path.exists(): | |
| checkpoint = torch.load(str(model_path), map_location=device) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| if 'class_names' in checkpoint: | |
| # Prefer class names bundled in checkpoint, if present | |
| ckpt_classes = checkpoint['class_names'] | |
| if isinstance(ckpt_classes, list) and len(ckpt_classes) == len(class_names): | |
| class_names = ckpt_classes | |
| else: | |
| state_dict = checkpoint | |
| model.load_state_dict(state_dict, strict=True) | |
| else: | |
| print(f"Warning: model file not found at {model_path}, using untrained weights.") | |
| model.to(device) | |
| model.eval() | |
| return model, device, class_names | |
| def preprocess(image_path: Path) -> torch.Tensor: | |
| image = Image.open(str(image_path)).convert('RGB') | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Local image prediction for crop disease detection') | |
| parser.add_argument('-i', '--image', required=True, type=Path, help='Path to input image') | |
| parser.add_argument('-m', '--model', default=Path('models/crop_disease_v3_model.pth'), type=Path, help='Path to model checkpoint (.pth)') | |
| parser.add_argument('--classes', type=Path, help='Optional JSON file containing class names array') | |
| args = parser.parse_args() | |
| # Resolve class names | |
| class_names = DEFAULT_CLASSES | |
| if args.classes and args.classes.exists(): | |
| try: | |
| class_names = json.loads(Path(args.classes).read_text(encoding='utf-8')) | |
| except Exception: | |
| print('Warning: Failed to read classes file, falling back to default classes.') | |
| model, device, class_names = load_model(args.model, class_names) | |
| input_tensor = preprocess(args.image).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| confidence, predicted_idx = torch.max(probabilities, 1) | |
| result = { | |
| 'image': str(args.image), | |
| 'predicted_class': class_names[predicted_idx.item()], | |
| 'confidence': float(confidence.item()) | |
| } | |
| print(json.dumps(result, indent=2)) | |
| if __name__ == '__main__': | |
| main() | |