#!/usr/bin/env python3 """ Bean detection prediction script """ # Standard library imports import argparse import json import sys import time from pathlib import Path from typing import Any, Dict, List # Add src to path sys.path.insert(0, str(Path(__file__).parent / "src")) # Local imports from bean_vision.export import COCOExporter, LabelMeExporter from bean_vision.inference import BeanPredictor from bean_vision.visualization.detection_viz import DetectionVisualizer # ANSI escape codes for terminal formatting BOLD = '\033[1m' RESET = '\033[0m' def main(): parser = argparse.ArgumentParser(description='Bean detection using trained MaskR-CNN') # Model and input arguments parser.add_argument('--model', type=str, required=True, help='Path to trained model checkpoint') parser.add_argument('--images', nargs='+', required=True, help='Input image paths (can use wildcards)') # Detection parameters parser.add_argument('--confidence', '--threshold', type=float, default=0.5, dest='confidence', help='Confidence threshold for detections') parser.add_argument('--max_detections', type=int, default=500, help='Maximum detections per image') parser.add_argument('--mask_threshold', type=float, default=0.5, help='Threshold for mask binarization') parser.add_argument('--device', type=str, default='cpu', help='Device to use (cpu or cuda)') # NMS parameters parser.add_argument('--apply_nms', action='store_true', default=True, help='Apply Non-Maximum Suppression to remove overlapping detections (default: True)') parser.add_argument('--no_nms', dest='apply_nms', action='store_false', help='Disable NMS') parser.add_argument('--nms_type', choices=['box', 'mask'], default='box', help='Type of NMS to apply (default: box - faster, mask - more accurate)') parser.add_argument('--nms_threshold', type=float, default=0.3, help='IoU threshold for NMS (lower = more aggressive)') # Edge and size filtering parser.add_argument('--filter_edge_beans', action='store_true', default=True, help='Filter out partial beans at image edges (default: True)') parser.add_argument('--no_edge_filter', dest='filter_edge_beans', action='store_false', help='Disable edge bean filtering') parser.add_argument('--edge_threshold', type=int, default=10, help='Pixel distance from edge to consider for filtering') parser.add_argument('--min_bean_area', type=float, default=500, help='Minimum bean area in pixels') parser.add_argument('--max_bean_area', type=float, default=30000, help='Maximum bean area in pixels') # Output options parser.add_argument('--output_dir', type=str, default='results', help='Directory to save outputs (default: results)') parser.add_argument('--visualize', action='store_true', default=True, help='Create visualization images (default: True)') parser.add_argument('--no_visualize', dest='visualize', action='store_false', help='Disable visualization') parser.add_argument('--vis_type', choices=['masks', 'polygons', 'both'], default='both', help='Visualization type (default: both)') parser.add_argument('--export_format', choices=['json', 'coco', 'labelme', 'all'], default='json', help='Export format for predictions (default: json)') parser.add_argument('--include_polygons', action='store_true', default=True, help='Convert masks to polygons (default: True)') # Polygon smoothing options parser.add_argument('--smooth_polygons', action='store_true', default=True, help='Apply smoothing to polygons to reduce jaggedness (default: True)') parser.add_argument('--no_smooth', dest='smooth_polygons', action='store_false', help='Disable polygon smoothing') parser.add_argument('--smoothing_factor', type=float, default=0.1, help='Smoothing factor (0.0-1.0, 0=no smoothing, 1=maximum smoothing, default: 0.1)') # Legacy compatibility parser.add_argument('--save_json', action='store_true', help='Save predictions as JSON (legacy, use --export_format json)') args = parser.parse_args() # Handle legacy save_json flag if args.save_json and not args.export_format: args.export_format = 'json' # Create output directory if needed if args.output_dir: output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) else: output_dir = None # Print header print('\n' + '='*80) print(f'{BOLD}BEAN DETECTION{RESET}') print('='*80) # Initialize predictor print(f'\n{BOLD}Model:{RESET} {Path(args.model).name}') predictor = BeanPredictor( model_path=Path(args.model), device=args.device, max_detections=args.max_detections, confidence_threshold=args.confidence, mask_threshold=args.mask_threshold, nms_threshold=args.nms_threshold, smooth_polygons=args.smooth_polygons or args.smoothing_factor > 0, smoothing_factor=args.smoothing_factor, apply_nms=args.apply_nms, nms_type=args.nms_type, filter_edge_beans=args.filter_edge_beans, edge_threshold=args.edge_threshold, min_bean_area=args.min_bean_area, max_bean_area=args.max_bean_area ) # Initialize visualizer if needed if args.visualize: visualizer = DetectionVisualizer(confidence_threshold=args.confidence) # Initialize exporters if needed coco_exporter = COCOExporter("bean_predictions") if args.export_format in ['coco', 'all'] else None labelme_exporter = LabelMeExporter() if args.export_format in ['labelme', 'all'] else None # Process images all_results = [] total_beans = 0 total_time = 0 # Print processing header if len(args.images) > 1: print(f'\n{BOLD}Processing {len(args.images)} images...{RESET}') else: print(f'\n{BOLD}Processing image...{RESET}') for image_path in args.images: image_path = Path(image_path) if not image_path.exists(): print(f" [!] {image_path.name}: not found") continue # Run prediction - always include polygons for better analysis result = predictor.predict( image_path, return_polygons=True, # Always return polygons return_masks=True ) # Print results if len(args.images) == 1: print(f'\n{BOLD}Results:{RESET}') print(f' Image: {image_path.name}') print(f' Beans detected: {result["bean_count"]}') print(f' Inference time: {result["inference_time"]:.2f}s') else: print(f' {image_path.name}: {result["bean_count"]} beans ({result["inference_time"]:.1f}s)') total_beans += result['bean_count'] total_time += result['inference_time'] # Visualize if requested (silent) if args.visualize and output_dir: if args.vis_type in ['masks', 'both']: # Use legacy naming for backward compatibility mask_vis_path = output_dir / f"{image_path.stem}_prediction.png" visualizer.visualize_masks_with_confidence( image_path, result, mask_vis_path, mask_threshold=args.mask_threshold ) if args.vis_type in ['polygons', 'both'] and 'polygons' in result: poly_vis_path = output_dir / f"{image_path.stem}_poly_vis.png" visualizer.visualize_polygons( image_path, result, poly_vis_path ) # Add to exporters if coco_exporter: img_id = coco_exporter.add_image( image_path, result['image_size'][0], result['image_size'][1] ) coco_exporter.add_predictions(result, img_id) if labelme_exporter and output_dir: labelme_path = output_dir / f"{image_path.stem}_labelme.json" labelme_exporter.save(image_path, result, labelme_path) # Silent save # Store result (without tensor data for JSON export) json_result = { 'image_path': result['image_path'], 'image_size': result['image_size'], 'inference_time': result['inference_time'], 'bean_count': result['bean_count'], 'confidence_threshold': result['confidence_threshold'], 'total_detections': result['total_detections'], 'filtered_detections': result['filtered_detections'], 'predictions': { 'boxes': result['boxes'], 'scores': result['scores'], 'labels': result['labels'] } } # Rename for backward compatibility json_result['inference_time_seconds'] = json_result.pop('inference_time') if 'polygons' in result: # Keep polygons in their original format for proper COCO export # The format is: List[List[List[Tuple[float, float]]]] # Each detection has a list of polygons (usually just one) json_result['predictions']['polygons'] = result['polygons'] all_results.append(json_result) # Save exports (silent) if output_dir: if coco_exporter: coco_path = output_dir / "predictions_coco.json" coco_exporter.save(coco_path) if args.export_format in ['json', 'all']: json_path = output_dir / "predictions.json" with open(json_path, 'w') as f: json.dump(all_results, f, indent=2) # Print summary if len(all_results) > 0: if len(all_results) > 1: print(f'\n{BOLD}Summary:{RESET}') avg_beans = total_beans / len(all_results) print(f' Total images: {len(all_results)}') print(f' Total beans: {total_beans}') print(f' Average per image: {avg_beans:.0f}') print(f' Total time: {total_time:.1f}s') # Show output directory if output_dir: print(f'\n{BOLD}Output directory:{RESET} {output_dir}/') print('\n' + '='*80) print() # Add final newline if __name__ == "__main__": main()