Add model weights, inference code, and dependencies
Browse files- inference.py +173 -0
- model.pth +3 -0
- requirements.txt +7 -0
inference.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference module for counting wheat heads in field images using a DeepLabV3+ semantic
|
| 3 |
+
segmentation model trained on the GWFSS dataset.
|
| 4 |
+
|
| 5 |
+
The model performs multi-class segmentation (Background, Leaf, Stem, Head) to accurately
|
| 6 |
+
distinguish wheat heads from other plant organs, then uses connected component analysis
|
| 7 |
+
to count individual heads.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
import segmentation_models_pytorch as smp
|
| 15 |
+
from scipy import ndimage
|
| 16 |
+
from skimage.feature import peak_local_max
|
| 17 |
+
|
| 18 |
+
# ImageNet normalisation constants
|
| 19 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 20 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 21 |
+
|
| 22 |
+
# Mask colours for visualization
|
| 23 |
+
MASK_COLORS = [
|
| 24 |
+
(0, 0, 0), # Background: black
|
| 25 |
+
(214, 255, 50), # Leaf: yellow-green
|
| 26 |
+
(50, 132, 255), # Stem: blue
|
| 27 |
+
(50, 255, 132), # Head: cyan-green
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
class GWFSSModel:
|
| 31 |
+
def __init__(self, model_path, device=None):
|
| 32 |
+
if device is None:
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
self.device = torch.device("cuda")
|
| 35 |
+
elif torch.backends.mps.is_available():
|
| 36 |
+
self.device = torch.device("mps")
|
| 37 |
+
else:
|
| 38 |
+
self.device = torch.device("cpu")
|
| 39 |
+
else:
|
| 40 |
+
self.device = device
|
| 41 |
+
|
| 42 |
+
# Load model architecture
|
| 43 |
+
self.model = smp.DeepLabV3Plus(
|
| 44 |
+
encoder_name="resnet50",
|
| 45 |
+
encoder_weights=None,
|
| 46 |
+
in_channels=3,
|
| 47 |
+
classes=4,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Load trained weights
|
| 51 |
+
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
| 52 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 53 |
+
self.model = self.model.to(self.device)
|
| 54 |
+
self.model.eval()
|
| 55 |
+
|
| 56 |
+
# Image preprocessing
|
| 57 |
+
self.transform = transforms.Compose([
|
| 58 |
+
transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
|
| 59 |
+
transforms.ToTensor(),
|
| 60 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
| 61 |
+
])
|
| 62 |
+
|
| 63 |
+
def preprocess_image(self, image):
|
| 64 |
+
if isinstance(image, np.ndarray):
|
| 65 |
+
image = Image.fromarray(image)
|
| 66 |
+
|
| 67 |
+
if image.mode != 'RGB':
|
| 68 |
+
image = image.convert('RGB')
|
| 69 |
+
|
| 70 |
+
image_tensor = self.transform(image).unsqueeze(0)
|
| 71 |
+
return image_tensor.to(self.device)
|
| 72 |
+
|
| 73 |
+
def predict(self, image):
|
| 74 |
+
if isinstance(image, str):
|
| 75 |
+
image = Image.open(image)
|
| 76 |
+
|
| 77 |
+
image_tensor = self.preprocess_image(image)
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
logits = self.model(image_tensor)
|
| 81 |
+
|
| 82 |
+
predictions = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
|
| 83 |
+
return predictions
|
| 84 |
+
|
| 85 |
+
def count_heads(self, predictions, min_distance=15):
|
| 86 |
+
head_mask = (predictions == 3).astype(np.uint8)
|
| 87 |
+
|
| 88 |
+
if head_mask.sum() == 0:
|
| 89 |
+
return 0
|
| 90 |
+
|
| 91 |
+
# Compute distance transform
|
| 92 |
+
distance = ndimage.distance_transform_edt(head_mask)
|
| 93 |
+
|
| 94 |
+
# Find local peaks (head centers)
|
| 95 |
+
coords = peak_local_max(distance, min_distance=min_distance, labels=head_mask)
|
| 96 |
+
|
| 97 |
+
# Count the peaks
|
| 98 |
+
num_heads = len(coords)
|
| 99 |
+
|
| 100 |
+
return num_heads
|
| 101 |
+
|
| 102 |
+
def create_colored_mask(self, predictions):
|
| 103 |
+
h, w = predictions.shape
|
| 104 |
+
mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
|
| 105 |
+
|
| 106 |
+
for class_id, color in enumerate(MASK_COLORS):
|
| 107 |
+
mask_rgb[predictions == class_id] = color
|
| 108 |
+
|
| 109 |
+
return Image.fromarray(mask_rgb)
|
| 110 |
+
|
| 111 |
+
def overlay_mask(self, image, predictions, alpha=0.5, heads_only=True):
|
| 112 |
+
if isinstance(image, np.ndarray):
|
| 113 |
+
image = Image.fromarray(image)
|
| 114 |
+
|
| 115 |
+
if image.size != (512, 512):
|
| 116 |
+
image = image.resize((512, 512), Image.Resampling.BILINEAR)
|
| 117 |
+
|
| 118 |
+
# Create mask
|
| 119 |
+
h, w = predictions.shape
|
| 120 |
+
mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
|
| 121 |
+
|
| 122 |
+
if heads_only:
|
| 123 |
+
# Only highlight heads
|
| 124 |
+
mask_rgb[predictions == 3] = (50, 255, 132)
|
| 125 |
+
else:
|
| 126 |
+
# Show all classes
|
| 127 |
+
for class_id, color in enumerate(MASK_COLORS):
|
| 128 |
+
mask_rgb[predictions == class_id] = color
|
| 129 |
+
|
| 130 |
+
mask_img = Image.fromarray(mask_rgb)
|
| 131 |
+
overlay = Image.blend(image.convert('RGB'), mask_img, alpha)
|
| 132 |
+
return overlay
|
| 133 |
+
|
| 134 |
+
def predict_and_overlay(self, image, alpha=0.5, heads_only=True):
|
| 135 |
+
predictions = self.predict(image)
|
| 136 |
+
overlay = self.overlay_mask(image, predictions, alpha=alpha, heads_only=heads_only)
|
| 137 |
+
return overlay
|
| 138 |
+
|
| 139 |
+
if __name__ == "__main__":
|
| 140 |
+
import sys
|
| 141 |
+
|
| 142 |
+
if len(sys.argv) < 2:
|
| 143 |
+
print("Usage: python inference.py <image_path> [model_path]")
|
| 144 |
+
sys.exit(1)
|
| 145 |
+
|
| 146 |
+
image_path = sys.argv[1]
|
| 147 |
+
model_path = sys.argv[2] if len(sys.argv) > 2 else "cache/02_dice_stem.pth"
|
| 148 |
+
|
| 149 |
+
print(f"Loading model from {model_path}...")
|
| 150 |
+
model = GWFSSModel(model_path)
|
| 151 |
+
|
| 152 |
+
print(f"Processing image: {image_path}")
|
| 153 |
+
image = Image.open(image_path)
|
| 154 |
+
predictions = model.predict(image)
|
| 155 |
+
|
| 156 |
+
# Count heads
|
| 157 |
+
num_heads = model.count_heads(predictions)
|
| 158 |
+
print(f"\n🌾 {num_heads} heads detected!")
|
| 159 |
+
|
| 160 |
+
# Create visualisations
|
| 161 |
+
print("\nGenerating visualisations...")
|
| 162 |
+
overlay_heads = model.overlay_mask(image, predictions, alpha=0.5, heads_only=True)
|
| 163 |
+
overlay_all = model.overlay_mask(image, predictions, alpha=0.5, heads_only=False)
|
| 164 |
+
|
| 165 |
+
# Save outputs
|
| 166 |
+
output_heads = image_path.rsplit('.', 1)[0] + '_heads_only.png'
|
| 167 |
+
output_all = image_path.rsplit('.', 1)[0] + '_all_classes.png'
|
| 168 |
+
|
| 169 |
+
overlay_heads.save(output_heads)
|
| 170 |
+
overlay_all.save(output_all)
|
| 171 |
+
|
| 172 |
+
print(f"✓ Saved head overlay to: {output_heads}")
|
| 173 |
+
print(f"✓ Saved full segmentation to: {output_all}")
|
model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a44b1504d0ce10a601cda4adf11bbc967d87f42bb2b2622fa922b5667bcaf17
|
| 3 |
+
size 320679895
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
segmentation-models-pytorch>=0.3.3
|
| 4 |
+
Pillow>=9.0.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
scipy>=1.10.0
|
| 7 |
+
scikit-image>=0.20.0
|