chmcbs commited on
Commit
88b9d97
·
1 Parent(s): 01e5e97

Add Gradio application and inference code

Browse files
Files changed (2) hide show
  1. app.py +81 -0
  2. inference.py +173 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from inference import GWFSSModel
4
+ from PIL import Image
5
+ import numpy as np
6
+ from scipy import ndimage
7
+ from skimage.feature import peak_local_max
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Download model from Hugging Face
11
+ print("Downloading model from Hugging Face...")
12
+ MODEL_PATH = hf_hub_download(repo_id="chmcbs/HeadCount", filename="model.pth")
13
+ print("✓ Model downloaded successfully")
14
+
15
+ # Load model
16
+ print("Loading model...")
17
+ model = GWFSSModel(MODEL_PATH)
18
+ print("✓ Model loaded successfully")
19
+
20
+ def process_image(image):
21
+ if image is None:
22
+ return "", None
23
+
24
+ try:
25
+ predictions = model.predict(image)
26
+ num_heads = model.count_heads(predictions)
27
+
28
+ # Visualise detected peaks
29
+ head_mask = (predictions == 3).astype(np.uint8)
30
+ distance = ndimage.distance_transform_edt(head_mask)
31
+ coords = peak_local_max(distance, min_distance=15, labels=head_mask)
32
+
33
+ # Create overlay with peak markers
34
+ overlay = model.overlay_mask(image, predictions, alpha=0.5, heads_only=True)
35
+ overlay_np = np.array(overlay)
36
+ for y, x in coords:
37
+ # Draw a small red circle at each detected peak
38
+ overlay_np[max(0,y-3):y+4, max(0,x-3):x+4] = [255, 0, 0]
39
+
40
+ overlay = Image.fromarray(overlay_np)
41
+ count_message = f"### 🌾 {num_heads} heads detected!"
42
+ return count_message, overlay
43
+ except Exception as e:
44
+ return f"Error: {str(e)}", None
45
+
46
+ # Get example images
47
+ example_images = []
48
+ if os.path.exists("examples"):
49
+ example_files = sorted([f for f in os.listdir("examples")
50
+ if f.endswith(('.jpg', '.jpeg', '.png'))])[:5]
51
+ example_images = [os.path.join("examples", f) for f in example_files]
52
+
53
+ # Create Gradio interface
54
+ with gr.Blocks(title="HeadCount") as demo:
55
+ gr.Markdown("# 🌾 HeadCount: Automated Wheat Head Counter")
56
+ gr.Markdown("Upload an image to automatically detect and count wheat heads.")
57
+
58
+ with gr.Row():
59
+ image_input = gr.Image(type="pil", label="Upload Image")
60
+
61
+ with gr.Column():
62
+ overlay_output = gr.Image(label="Segmentation Overlay")
63
+ generate_btn = gr.Button("Generate", variant="primary")
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ if example_images:
68
+ gr.Markdown("### Example Images")
69
+ gr.Examples(examples=example_images, inputs=image_input)
70
+
71
+ with gr.Column():
72
+ head_count_output = gr.Markdown(value="")
73
+
74
+ generate_btn.click(
75
+ fn=process_image,
76
+ inputs=image_input,
77
+ outputs=[head_count_output, overlay_output]
78
+ )
79
+
80
+ if __name__ == "__main__":
81
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
inference.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference module for counting wheat heads in field images using a DeepLabV3+ semantic
3
+ segmentation model trained on the GWFSS dataset.
4
+
5
+ The model performs multi-class segmentation (Background, Leaf, Stem, Head) to accurately
6
+ distinguish wheat heads from other plant organs, then uses connected component analysis
7
+ to count individual heads.
8
+ """
9
+
10
+ import torch
11
+ import torchvision.transforms as transforms
12
+ from PIL import Image
13
+ import numpy as np
14
+ import segmentation_models_pytorch as smp
15
+ from scipy import ndimage
16
+ from skimage.feature import peak_local_max
17
+
18
+ # ImageNet normalisation constants
19
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
20
+ IMAGENET_STD = [0.229, 0.224, 0.225]
21
+
22
+ # Mask colours for visualization
23
+ MASK_COLORS = [
24
+ (0, 0, 0), # Background: black
25
+ (214, 255, 50), # Leaf: yellow-green
26
+ (50, 132, 255), # Stem: blue
27
+ (50, 255, 132), # Head: cyan-green
28
+ ]
29
+
30
+ class GWFSSModel:
31
+ def __init__(self, model_path, device=None):
32
+ if device is None:
33
+ if torch.cuda.is_available():
34
+ self.device = torch.device("cuda")
35
+ elif torch.backends.mps.is_available():
36
+ self.device = torch.device("mps")
37
+ else:
38
+ self.device = torch.device("cpu")
39
+ else:
40
+ self.device = device
41
+
42
+ # Load model architecture
43
+ self.model = smp.DeepLabV3Plus(
44
+ encoder_name="resnet50",
45
+ encoder_weights=None,
46
+ in_channels=3,
47
+ classes=4,
48
+ )
49
+
50
+ # Load trained weights
51
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
52
+ self.model.load_state_dict(checkpoint['model_state_dict'])
53
+ self.model = self.model.to(self.device)
54
+ self.model.eval()
55
+
56
+ # Image preprocessing
57
+ self.transform = transforms.Compose([
58
+ transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
61
+ ])
62
+
63
+ def preprocess_image(self, image):
64
+ if isinstance(image, np.ndarray):
65
+ image = Image.fromarray(image)
66
+
67
+ if image.mode != 'RGB':
68
+ image = image.convert('RGB')
69
+
70
+ image_tensor = self.transform(image).unsqueeze(0)
71
+ return image_tensor.to(self.device)
72
+
73
+ def predict(self, image):
74
+ if isinstance(image, str):
75
+ image = Image.open(image)
76
+
77
+ image_tensor = self.preprocess_image(image)
78
+
79
+ with torch.no_grad():
80
+ logits = self.model(image_tensor)
81
+
82
+ predictions = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
83
+ return predictions
84
+
85
+ def count_heads(self, predictions, min_distance=15):
86
+ head_mask = (predictions == 3).astype(np.uint8)
87
+
88
+ if head_mask.sum() == 0:
89
+ return 0
90
+
91
+ # Compute distance transform
92
+ distance = ndimage.distance_transform_edt(head_mask)
93
+
94
+ # Find local peaks (head centers)
95
+ coords = peak_local_max(distance, min_distance=min_distance, labels=head_mask)
96
+
97
+ # Count the peaks
98
+ num_heads = len(coords)
99
+
100
+ return num_heads
101
+
102
+ def create_colored_mask(self, predictions):
103
+ h, w = predictions.shape
104
+ mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
105
+
106
+ for class_id, color in enumerate(MASK_COLORS):
107
+ mask_rgb[predictions == class_id] = color
108
+
109
+ return Image.fromarray(mask_rgb)
110
+
111
+ def overlay_mask(self, image, predictions, alpha=0.5, heads_only=True):
112
+ if isinstance(image, np.ndarray):
113
+ image = Image.fromarray(image)
114
+
115
+ if image.size != (512, 512):
116
+ image = image.resize((512, 512), Image.Resampling.BILINEAR)
117
+
118
+ # Create mask
119
+ h, w = predictions.shape
120
+ mask_rgb = np.zeros((h, w, 3), dtype=np.uint8)
121
+
122
+ if heads_only:
123
+ # Only highlight heads
124
+ mask_rgb[predictions == 3] = (50, 255, 132)
125
+ else:
126
+ # Show all classes
127
+ for class_id, color in enumerate(MASK_COLORS):
128
+ mask_rgb[predictions == class_id] = color
129
+
130
+ mask_img = Image.fromarray(mask_rgb)
131
+ overlay = Image.blend(image.convert('RGB'), mask_img, alpha)
132
+ return overlay
133
+
134
+ def predict_and_overlay(self, image, alpha=0.5, heads_only=True):
135
+ predictions = self.predict(image)
136
+ overlay = self.overlay_mask(image, predictions, alpha=alpha, heads_only=heads_only)
137
+ return overlay
138
+
139
+ if __name__ == "__main__":
140
+ import sys
141
+
142
+ if len(sys.argv) < 2:
143
+ print("Usage: python inference.py <image_path> [model_path]")
144
+ sys.exit(1)
145
+
146
+ image_path = sys.argv[1]
147
+ model_path = sys.argv[2] if len(sys.argv) > 2 else "cache/02_dice_stem.pth"
148
+
149
+ print(f"Loading model from {model_path}...")
150
+ model = GWFSSModel(model_path)
151
+
152
+ print(f"Processing image: {image_path}")
153
+ image = Image.open(image_path)
154
+ predictions = model.predict(image)
155
+
156
+ # Count heads
157
+ num_heads = model.count_heads(predictions)
158
+ print(f"\n🌾 {num_heads} heads detected!")
159
+
160
+ # Create visualisations
161
+ print("\nGenerating visualisations...")
162
+ overlay_heads = model.overlay_mask(image, predictions, alpha=0.5, heads_only=True)
163
+ overlay_all = model.overlay_mask(image, predictions, alpha=0.5, heads_only=False)
164
+
165
+ # Save outputs
166
+ output_heads = image_path.rsplit('.', 1)[0] + '_heads_only.png'
167
+ output_all = image_path.rsplit('.', 1)[0] + '_all_classes.png'
168
+
169
+ overlay_heads.save(output_heads)
170
+ overlay_all.save(output_all)
171
+
172
+ print(f"✓ Saved head overlay to: {output_heads}")
173
+ print(f"✓ Saved full segmentation to: {output_all}")