coffee-bean-maskrcnn / predict_beans.py
Kunitomi's picture
Upload folder using huggingface_hub
196c526 verified
#!/usr/bin/env python3
"""
Bean detection prediction script
"""
# Standard library imports
import argparse
import json
import sys
import time
from pathlib import Path
from typing import Any, Dict, List
# Add src to path
sys.path.insert(0, str(Path(__file__).parent / "src"))
# Local imports
from bean_vision.export import COCOExporter, LabelMeExporter
from bean_vision.inference import BeanPredictor
from bean_vision.visualization.detection_viz import DetectionVisualizer
# ANSI escape codes for terminal formatting
BOLD = '\033[1m'
RESET = '\033[0m'
def main():
parser = argparse.ArgumentParser(description='Bean detection using trained MaskR-CNN')
# Model and input arguments
parser.add_argument('--model', type=str, required=True,
help='Path to trained model checkpoint')
parser.add_argument('--images', nargs='+', required=True,
help='Input image paths (can use wildcards)')
# Detection parameters
parser.add_argument('--confidence', '--threshold', type=float, default=0.5,
dest='confidence',
help='Confidence threshold for detections')
parser.add_argument('--max_detections', type=int, default=500,
help='Maximum detections per image')
parser.add_argument('--mask_threshold', type=float, default=0.5,
help='Threshold for mask binarization')
parser.add_argument('--device', type=str, default='cpu',
help='Device to use (cpu or cuda)')
# NMS parameters
parser.add_argument('--apply_nms', action='store_true', default=True,
help='Apply Non-Maximum Suppression to remove overlapping detections (default: True)')
parser.add_argument('--no_nms', dest='apply_nms', action='store_false',
help='Disable NMS')
parser.add_argument('--nms_type', choices=['box', 'mask'], default='box',
help='Type of NMS to apply (default: box - faster, mask - more accurate)')
parser.add_argument('--nms_threshold', type=float, default=0.3,
help='IoU threshold for NMS (lower = more aggressive)')
# Edge and size filtering
parser.add_argument('--filter_edge_beans', action='store_true', default=True,
help='Filter out partial beans at image edges (default: True)')
parser.add_argument('--no_edge_filter', dest='filter_edge_beans', action='store_false',
help='Disable edge bean filtering')
parser.add_argument('--edge_threshold', type=int, default=10,
help='Pixel distance from edge to consider for filtering')
parser.add_argument('--min_bean_area', type=float, default=500,
help='Minimum bean area in pixels')
parser.add_argument('--max_bean_area', type=float, default=30000,
help='Maximum bean area in pixels')
# Output options
parser.add_argument('--output_dir', type=str, default='results',
help='Directory to save outputs (default: results)')
parser.add_argument('--visualize', action='store_true', default=True,
help='Create visualization images (default: True)')
parser.add_argument('--no_visualize', dest='visualize', action='store_false',
help='Disable visualization')
parser.add_argument('--vis_type', choices=['masks', 'polygons', 'both'],
default='both', help='Visualization type (default: both)')
parser.add_argument('--export_format', choices=['json', 'coco', 'labelme', 'all'],
default='json', help='Export format for predictions (default: json)')
parser.add_argument('--include_polygons', action='store_true', default=True,
help='Convert masks to polygons (default: True)')
# Polygon smoothing options
parser.add_argument('--smooth_polygons', action='store_true', default=True,
help='Apply smoothing to polygons to reduce jaggedness (default: True)')
parser.add_argument('--no_smooth', dest='smooth_polygons', action='store_false',
help='Disable polygon smoothing')
parser.add_argument('--smoothing_factor', type=float, default=0.1,
help='Smoothing factor (0.0-1.0, 0=no smoothing, 1=maximum smoothing, default: 0.1)')
# Legacy compatibility
parser.add_argument('--save_json', action='store_true',
help='Save predictions as JSON (legacy, use --export_format json)')
args = parser.parse_args()
# Handle legacy save_json flag
if args.save_json and not args.export_format:
args.export_format = 'json'
# Create output directory if needed
if args.output_dir:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = None
# Print header
print('\n' + '='*80)
print(f'{BOLD}BEAN DETECTION{RESET}')
print('='*80)
# Initialize predictor
print(f'\n{BOLD}Model:{RESET} {Path(args.model).name}')
predictor = BeanPredictor(
model_path=Path(args.model),
device=args.device,
max_detections=args.max_detections,
confidence_threshold=args.confidence,
mask_threshold=args.mask_threshold,
nms_threshold=args.nms_threshold,
smooth_polygons=args.smooth_polygons or args.smoothing_factor > 0,
smoothing_factor=args.smoothing_factor,
apply_nms=args.apply_nms,
nms_type=args.nms_type,
filter_edge_beans=args.filter_edge_beans,
edge_threshold=args.edge_threshold,
min_bean_area=args.min_bean_area,
max_bean_area=args.max_bean_area
)
# Initialize visualizer if needed
if args.visualize:
visualizer = DetectionVisualizer(confidence_threshold=args.confidence)
# Initialize exporters if needed
coco_exporter = COCOExporter("bean_predictions") if args.export_format in ['coco', 'all'] else None
labelme_exporter = LabelMeExporter() if args.export_format in ['labelme', 'all'] else None
# Process images
all_results = []
total_beans = 0
total_time = 0
# Print processing header
if len(args.images) > 1:
print(f'\n{BOLD}Processing {len(args.images)} images...{RESET}')
else:
print(f'\n{BOLD}Processing image...{RESET}')
for image_path in args.images:
image_path = Path(image_path)
if not image_path.exists():
print(f" [!] {image_path.name}: not found")
continue
# Run prediction - always include polygons for better analysis
result = predictor.predict(
image_path,
return_polygons=True, # Always return polygons
return_masks=True
)
# Print results
if len(args.images) == 1:
print(f'\n{BOLD}Results:{RESET}')
print(f' Image: {image_path.name}')
print(f' Beans detected: {result["bean_count"]}')
print(f' Inference time: {result["inference_time"]:.2f}s')
else:
print(f' {image_path.name}: {result["bean_count"]} beans ({result["inference_time"]:.1f}s)')
total_beans += result['bean_count']
total_time += result['inference_time']
# Visualize if requested (silent)
if args.visualize and output_dir:
if args.vis_type in ['masks', 'both']:
# Use legacy naming for backward compatibility
mask_vis_path = output_dir / f"{image_path.stem}_prediction.png"
visualizer.visualize_masks_with_confidence(
image_path,
result,
mask_vis_path,
mask_threshold=args.mask_threshold
)
if args.vis_type in ['polygons', 'both'] and 'polygons' in result:
poly_vis_path = output_dir / f"{image_path.stem}_poly_vis.png"
visualizer.visualize_polygons(
image_path,
result,
poly_vis_path
)
# Add to exporters
if coco_exporter:
img_id = coco_exporter.add_image(
image_path,
result['image_size'][0],
result['image_size'][1]
)
coco_exporter.add_predictions(result, img_id)
if labelme_exporter and output_dir:
labelme_path = output_dir / f"{image_path.stem}_labelme.json"
labelme_exporter.save(image_path, result, labelme_path)
# Silent save
# Store result (without tensor data for JSON export)
json_result = {
'image_path': result['image_path'],
'image_size': result['image_size'],
'inference_time': result['inference_time'],
'bean_count': result['bean_count'],
'confidence_threshold': result['confidence_threshold'],
'total_detections': result['total_detections'],
'filtered_detections': result['filtered_detections'],
'predictions': {
'boxes': result['boxes'],
'scores': result['scores'],
'labels': result['labels']
}
}
# Rename for backward compatibility
json_result['inference_time_seconds'] = json_result.pop('inference_time')
if 'polygons' in result:
# Keep polygons in their original format for proper COCO export
# The format is: List[List[List[Tuple[float, float]]]]
# Each detection has a list of polygons (usually just one)
json_result['predictions']['polygons'] = result['polygons']
all_results.append(json_result)
# Save exports (silent)
if output_dir:
if coco_exporter:
coco_path = output_dir / "predictions_coco.json"
coco_exporter.save(coco_path)
if args.export_format in ['json', 'all']:
json_path = output_dir / "predictions.json"
with open(json_path, 'w') as f:
json.dump(all_results, f, indent=2)
# Print summary
if len(all_results) > 0:
if len(all_results) > 1:
print(f'\n{BOLD}Summary:{RESET}')
avg_beans = total_beans / len(all_results)
print(f' Total images: {len(all_results)}')
print(f' Total beans: {total_beans}')
print(f' Average per image: {avg_beans:.0f}')
print(f' Total time: {total_time:.1f}s')
# Show output directory
if output_dir:
print(f'\n{BOLD}Output directory:{RESET} {output_dir}/')
print('\n' + '='*80)
print() # Add final newline
if __name__ == "__main__":
main()