Spaces:
Running
Running
File size: 3,640 Bytes
36dd4e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
"""
Simple CLI for local image prediction (non-Docker)
Usage (PowerShell):
python -m src.predict_cli -i path\to\image.jpg -m models\crop_disease_v3_model.pth
"""
import argparse
import json
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from .model import CropDiseaseResNet50
DEFAULT_CLASSES = [
'Pepper__bell___Bacterial_spot',
'Pepper__bell___healthy',
'Potato___Early_blight',
'Potato___healthy',
'Potato___Late_blight',
'Tomato__Target_Spot',
'Tomato__Tomato_mosaic_virus',
'Tomato__Tomato_YellowLeaf__Curl_Virus',
'Tomato_Bacterial_spot',
'Tomato_Early_blight',
'Tomato_healthy',
'Tomato_Late_blight',
'Tomato_Leaf_Mold',
'Tomato_Septoria_leaf_spot',
'Tomato_Spider_mites_Two_spotted_spider_mite'
]
def load_model(model_path: Path, class_names: list[str]) -> tuple[torch.nn.Module, torch.device]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CropDiseaseResNet50(num_classes=len(class_names), pretrained=False)
if model_path.exists():
checkpoint = torch.load(str(model_path), map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
if 'class_names' in checkpoint:
# Prefer class names bundled in checkpoint, if present
ckpt_classes = checkpoint['class_names']
if isinstance(ckpt_classes, list) and len(ckpt_classes) == len(class_names):
class_names = ckpt_classes
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=True)
else:
print(f"Warning: model file not found at {model_path}, using untrained weights.")
model.to(device)
model.eval()
return model, device, class_names
def preprocess(image_path: Path) -> torch.Tensor:
image = Image.open(str(image_path)).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
def main():
parser = argparse.ArgumentParser(description='Local image prediction for crop disease detection')
parser.add_argument('-i', '--image', required=True, type=Path, help='Path to input image')
parser.add_argument('-m', '--model', default=Path('models/crop_disease_v3_model.pth'), type=Path, help='Path to model checkpoint (.pth)')
parser.add_argument('--classes', type=Path, help='Optional JSON file containing class names array')
args = parser.parse_args()
# Resolve class names
class_names = DEFAULT_CLASSES
if args.classes and args.classes.exists():
try:
class_names = json.loads(Path(args.classes).read_text(encoding='utf-8'))
except Exception:
print('Warning: Failed to read classes file, falling back to default classes.')
model, device, class_names = load_model(args.model, class_names)
input_tensor = preprocess(args.image).to(device)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
result = {
'image': str(args.image),
'predicted_class': class_names[predicted_idx.item()],
'confidence': float(confidence.item())
}
print(json.dumps(result, indent=2))
if __name__ == '__main__':
main()
|