File size: 11,049 Bytes
196c526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
#!/usr/bin/env python3
"""
Bean detection prediction script
"""

# Standard library imports
import argparse
import json
import sys
import time
from pathlib import Path
from typing import Any, Dict, List

# Add src to path
sys.path.insert(0, str(Path(__file__).parent / "src"))

# Local imports
from bean_vision.export import COCOExporter, LabelMeExporter
from bean_vision.inference import BeanPredictor
from bean_vision.visualization.detection_viz import DetectionVisualizer

# ANSI escape codes for terminal formatting
BOLD = '\033[1m'
RESET = '\033[0m'


def main():
    parser = argparse.ArgumentParser(description='Bean detection using trained MaskR-CNN')
    
    # Model and input arguments
    parser.add_argument('--model', type=str, required=True,
                       help='Path to trained model checkpoint')
    parser.add_argument('--images', nargs='+', required=True,
                       help='Input image paths (can use wildcards)')
    
    # Detection parameters
    parser.add_argument('--confidence', '--threshold', type=float, default=0.5, 
                       dest='confidence',
                       help='Confidence threshold for detections')
    parser.add_argument('--max_detections', type=int, default=500,
                       help='Maximum detections per image')
    parser.add_argument('--mask_threshold', type=float, default=0.5,
                       help='Threshold for mask binarization')
    parser.add_argument('--device', type=str, default='cpu',
                       help='Device to use (cpu or cuda)')
    
    # NMS parameters
    parser.add_argument('--apply_nms', action='store_true', default=True,
                       help='Apply Non-Maximum Suppression to remove overlapping detections (default: True)')
    parser.add_argument('--no_nms', dest='apply_nms', action='store_false',
                       help='Disable NMS')
    parser.add_argument('--nms_type', choices=['box', 'mask'], default='box',
                       help='Type of NMS to apply (default: box - faster, mask - more accurate)')
    parser.add_argument('--nms_threshold', type=float, default=0.3,
                       help='IoU threshold for NMS (lower = more aggressive)')
    
    # Edge and size filtering
    parser.add_argument('--filter_edge_beans', action='store_true', default=True,
                       help='Filter out partial beans at image edges (default: True)')
    parser.add_argument('--no_edge_filter', dest='filter_edge_beans', action='store_false',
                       help='Disable edge bean filtering')
    parser.add_argument('--edge_threshold', type=int, default=10,
                       help='Pixel distance from edge to consider for filtering')
    parser.add_argument('--min_bean_area', type=float, default=500,
                       help='Minimum bean area in pixels')
    parser.add_argument('--max_bean_area', type=float, default=30000,
                       help='Maximum bean area in pixels')
    
    # Output options
    parser.add_argument('--output_dir', type=str, default='results',
                       help='Directory to save outputs (default: results)')
    parser.add_argument('--visualize', action='store_true', default=True,
                       help='Create visualization images (default: True)')
    parser.add_argument('--no_visualize', dest='visualize', action='store_false',
                       help='Disable visualization')
    parser.add_argument('--vis_type', choices=['masks', 'polygons', 'both'], 
                       default='both', help='Visualization type (default: both)')
    parser.add_argument('--export_format', choices=['json', 'coco', 'labelme', 'all'],
                       default='json', help='Export format for predictions (default: json)')
    parser.add_argument('--include_polygons', action='store_true', default=True,
                       help='Convert masks to polygons (default: True)')
    
    # Polygon smoothing options
    parser.add_argument('--smooth_polygons', action='store_true', default=True,
                       help='Apply smoothing to polygons to reduce jaggedness (default: True)')
    parser.add_argument('--no_smooth', dest='smooth_polygons', action='store_false',
                       help='Disable polygon smoothing')
    parser.add_argument('--smoothing_factor', type=float, default=0.1,
                       help='Smoothing factor (0.0-1.0, 0=no smoothing, 1=maximum smoothing, default: 0.1)')
    
    # Legacy compatibility
    parser.add_argument('--save_json', action='store_true',
                       help='Save predictions as JSON (legacy, use --export_format json)')
    
    args = parser.parse_args()
    
    # Handle legacy save_json flag
    if args.save_json and not args.export_format:
        args.export_format = 'json'
    
    # Create output directory if needed
    if args.output_dir:
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
    else:
        output_dir = None
    
    # Print header
    print('\n' + '='*80)
    print(f'{BOLD}BEAN DETECTION{RESET}')
    print('='*80)
    
    # Initialize predictor
    print(f'\n{BOLD}Model:{RESET} {Path(args.model).name}')
    
    predictor = BeanPredictor(
        model_path=Path(args.model),
        device=args.device,
        max_detections=args.max_detections,
        confidence_threshold=args.confidence,
        mask_threshold=args.mask_threshold,
        nms_threshold=args.nms_threshold,
        smooth_polygons=args.smooth_polygons or args.smoothing_factor > 0,
        smoothing_factor=args.smoothing_factor,
        apply_nms=args.apply_nms,
        nms_type=args.nms_type,
        filter_edge_beans=args.filter_edge_beans,
        edge_threshold=args.edge_threshold,
        min_bean_area=args.min_bean_area,
        max_bean_area=args.max_bean_area
    )
    
    # Initialize visualizer if needed
    if args.visualize:
        visualizer = DetectionVisualizer(confidence_threshold=args.confidence)
    
    # Initialize exporters if needed
    coco_exporter = COCOExporter("bean_predictions") if args.export_format in ['coco', 'all'] else None
    labelme_exporter = LabelMeExporter() if args.export_format in ['labelme', 'all'] else None
    
    # Process images
    all_results = []
    total_beans = 0
    total_time = 0
    
    # Print processing header
    if len(args.images) > 1:
        print(f'\n{BOLD}Processing {len(args.images)} images...{RESET}')
    else:
        print(f'\n{BOLD}Processing image...{RESET}')
    
    for image_path in args.images:
        image_path = Path(image_path)
        
        if not image_path.exists():
            print(f"  [!] {image_path.name}: not found")
            continue
        
        # Run prediction - always include polygons for better analysis
        result = predictor.predict(
            image_path,
            return_polygons=True,  # Always return polygons
            return_masks=True
        )
        
        # Print results
        if len(args.images) == 1:
            print(f'\n{BOLD}Results:{RESET}')
            print(f'  Image: {image_path.name}')
            print(f'  Beans detected: {result["bean_count"]}')
            print(f'  Inference time: {result["inference_time"]:.2f}s')
        else:
            print(f'  {image_path.name}: {result["bean_count"]} beans ({result["inference_time"]:.1f}s)')
        
        total_beans += result['bean_count']
        total_time += result['inference_time']
        
        # Visualize if requested (silent)
        if args.visualize and output_dir:
            if args.vis_type in ['masks', 'both']:
                # Use legacy naming for backward compatibility
                mask_vis_path = output_dir / f"{image_path.stem}_prediction.png"
                visualizer.visualize_masks_with_confidence(
                    image_path,
                    result,
                    mask_vis_path,
                    mask_threshold=args.mask_threshold
                )
            
            if args.vis_type in ['polygons', 'both'] and 'polygons' in result:
                poly_vis_path = output_dir / f"{image_path.stem}_poly_vis.png"
                visualizer.visualize_polygons(
                    image_path,
                    result,
                    poly_vis_path
                )
        
        # Add to exporters
        if coco_exporter:
            img_id = coco_exporter.add_image(
                image_path,
                result['image_size'][0],
                result['image_size'][1]
            )
            coco_exporter.add_predictions(result, img_id)
        
        if labelme_exporter and output_dir:
            labelme_path = output_dir / f"{image_path.stem}_labelme.json"
            labelme_exporter.save(image_path, result, labelme_path)
            # Silent save
        
        # Store result (without tensor data for JSON export)
        json_result = {
            'image_path': result['image_path'],
            'image_size': result['image_size'],
            'inference_time': result['inference_time'],
            'bean_count': result['bean_count'],
            'confidence_threshold': result['confidence_threshold'],
            'total_detections': result['total_detections'],
            'filtered_detections': result['filtered_detections'],
            'predictions': {
                'boxes': result['boxes'],
                'scores': result['scores'],
                'labels': result['labels']
            }
        }
        
        # Rename for backward compatibility
        json_result['inference_time_seconds'] = json_result.pop('inference_time')
        
        if 'polygons' in result:
            # Keep polygons in their original format for proper COCO export
            # The format is: List[List[List[Tuple[float, float]]]]
            # Each detection has a list of polygons (usually just one)
            json_result['predictions']['polygons'] = result['polygons']
        
        all_results.append(json_result)
    
    # Save exports (silent)
    if output_dir:
        if coco_exporter:
            coco_path = output_dir / "predictions_coco.json"
            coco_exporter.save(coco_path)
        
        if args.export_format in ['json', 'all']:
            json_path = output_dir / "predictions.json"
            with open(json_path, 'w') as f:
                json.dump(all_results, f, indent=2)
    
    # Print summary
    if len(all_results) > 0:
        if len(all_results) > 1:
            print(f'\n{BOLD}Summary:{RESET}')
            avg_beans = total_beans / len(all_results)
            print(f'  Total images: {len(all_results)}')
            print(f'  Total beans: {total_beans}')
            print(f'  Average per image: {avg_beans:.0f}')
            print(f'  Total time: {total_time:.1f}s')
        
        # Show output directory
        if output_dir:
            print(f'\n{BOLD}Output directory:{RESET} {output_dir}/')
    
    print('\n' + '='*80)
    print()  # Add final newline


if __name__ == "__main__":
    main()