chmcbs commited on
Commit
e2895d0
·
verified ·
1 Parent(s): 78f1163

Add model weights, inference code, and dependencies

Browse files
Files changed (3) hide show
  1. inference.py +173 -0
  2. model.pth +3 -0
  3. requirements.txt +7 -0
inference.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference module for counting wheat heads in field images using a DeepLabV3+ semantic
3
+ segmentation model trained on the GWFSS dataset.
4
+
5
+ The model performs multi-class segmentation (Background, Leaf, Stem, Head) to accurately
6
+ distinguish wheat heads from other plant organs, then uses connected component analysis
7
+ to count individual heads.
8
+ """
9
+
10
+ import torch
11
+ import torchvision.transforms as transforms
12
+ from PIL import Image
13
+ import numpy as np
14
+ import segmentation_models_pytorch as smp
15
+ from scipy import ndimage
16
+ from skimage.feature import peak_local_max
17
+
18
+ # ImageNet normalisation constants
19
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
20
+ IMAGENET_STD = [0.229, 0.224, 0.225]
21
+
22
+ # Mask colours for visualization
23
+ MASK_COLORS = [
24
+ (0, 0, 0), # Background: black
25
+ (214, 255, 50), # Leaf: yellow-green
26
+ (50, 132, 255), # Stem: blue
27
+ (50, 255, 132), # Head: cyan-green
28
+ ]
29
+
30
+ class GWFSSModel:
31
+ def __init__(self, model_path, device=None):
32
+ if device is None:
33
+ if torch.cuda.is_available():
34
+ self.device = torch.device("cuda")
35
+ elif torch.backends.mps.is_available():
36
+ self.device = torch.device("mps")
37
+ else:
38
+ self.device = torch.device("cpu")
39
+ else:
40
+ self.device = device
41
+
42
+ # Load model architecture
43
+ self.model = smp.DeepLabV3Plus(
44
+ encoder_name="resnet50",
45
+ encoder_weights=None,
46
+ in_channels=3,
47
+ classes=4,
48
+ )
49
+
50
+ # Load trained weights
51
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
52
+ self.model.load_state_dict(checkpoint['model_state_dict'])
53
+ self.model = self.model.to(self.device)
54
+ self.model.eval()
55
+
56
+ # Image preprocessing
57
+ self.transform = transforms.Compose([
58
+ transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
61
+ ])
62
+
63
+ def preprocess_image(self, image):
64
+ if isinstance(image, np.ndarray):
65
+ image = Image.fromarray(image)
66
+
67
+ if image.mode != 'RGB':
68
+ image = image.convert('RGB')
69
+
70
+ image_tensor = self.transform(image).unsqueeze(0)
71
+ return image_tensor.to(self.device)
72
+
73
+ def predict(self, image):
74
+ if isinstance(image, str):
75
+ image = Image.open(image)
76
+
77
+ image_tensor = self.preprocess_image(image)
78
+
79
+ with torch.no_grad():
80
+ logits = self.model(image_tensor)
81
+
82
+ predictions = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
83
+ return predictions
84
+
85
+ def count_heads(self, predictions, min_distance=15):
86
+ head_mask = (predictions == 3).astype(np.uint8)
87
+
88
+ if head_mask.sum() == 0:
89
+ return 0
90
+
91
+ # Compute distance transform
92
+ distance = ndimage.distance_transform_edt(head_mask)
93
+
94
+ # Find local peaks (head centers)
95
+ coords = peak_local_max(distance, min_distance=min_distance, labels=head_mask)
96
+
97
+ # Count the peaks
98
+ num_heads = len(coords)
99
+
100
+ return num_heads
101
+
102
+ def create_colored_mask(self, predictions):
103
+ h, w = predictions.shape
104
+ mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
105
+
106
+ for class_id, color in enumerate(MASK_COLORS):
107
+ mask_rgb[predictions == class_id] = color
108
+
109
+ return Image.fromarray(mask_rgb)
110
+
111
+ def overlay_mask(self, image, predictions, alpha=0.5, heads_only=True):
112
+ if isinstance(image, np.ndarray):
113
+ image = Image.fromarray(image)
114
+
115
+ if image.size != (512, 512):
116
+ image = image.resize((512, 512), Image.Resampling.BILINEAR)
117
+
118
+ # Create mask
119
+ h, w = predictions.shape
120
+ mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
121
+
122
+ if heads_only:
123
+ # Only highlight heads
124
+ mask_rgb[predictions == 3] = (50, 255, 132)
125
+ else:
126
+ # Show all classes
127
+ for class_id, color in enumerate(MASK_COLORS):
128
+ mask_rgb[predictions == class_id] = color
129
+
130
+ mask_img = Image.fromarray(mask_rgb)
131
+ overlay = Image.blend(image.convert('RGB'), mask_img, alpha)
132
+ return overlay
133
+
134
+ def predict_and_overlay(self, image, alpha=0.5, heads_only=True):
135
+ predictions = self.predict(image)
136
+ overlay = self.overlay_mask(image, predictions, alpha=alpha, heads_only=heads_only)
137
+ return overlay
138
+
139
+ if __name__ == "__main__":
140
+ import sys
141
+
142
+ if len(sys.argv) < 2:
143
+ print("Usage: python inference.py <image_path> [model_path]")
144
+ sys.exit(1)
145
+
146
+ image_path = sys.argv[1]
147
+ model_path = sys.argv[2] if len(sys.argv) > 2 else "cache/02_dice_stem.pth"
148
+
149
+ print(f"Loading model from {model_path}...")
150
+ model = GWFSSModel(model_path)
151
+
152
+ print(f"Processing image: {image_path}")
153
+ image = Image.open(image_path)
154
+ predictions = model.predict(image)
155
+
156
+ # Count heads
157
+ num_heads = model.count_heads(predictions)
158
+ print(f"\n🌾 {num_heads} heads detected!")
159
+
160
+ # Create visualisations
161
+ print("\nGenerating visualisations...")
162
+ overlay_heads = model.overlay_mask(image, predictions, alpha=0.5, heads_only=True)
163
+ overlay_all = model.overlay_mask(image, predictions, alpha=0.5, heads_only=False)
164
+
165
+ # Save outputs
166
+ output_heads = image_path.rsplit('.', 1)[0] + '_heads_only.png'
167
+ output_all = image_path.rsplit('.', 1)[0] + '_all_classes.png'
168
+
169
+ overlay_heads.save(output_heads)
170
+ overlay_all.save(output_all)
171
+
172
+ print(f"✓ Saved head overlay to: {output_heads}")
173
+ print(f"✓ Saved full segmentation to: {output_all}")
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a44b1504d0ce10a601cda4adf11bbc967d87f42bb2b2622fa922b5667bcaf17
3
+ size 320679895
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ segmentation-models-pytorch>=0.3.3
4
+ Pillow>=9.0.0
5
+ numpy>=1.24.0
6
+ scipy>=1.10.0
7
+ scikit-image>=0.20.0