#!/usr/bin/env python3 """ Example inference script for Heartformer model This script demonstrates how to use the Heartformer model to detect heart anatomy types in images. """ import sys from pathlib import Path from PIL import Image, ImageDraw, ImageFont import json # You'll need to install rf-detr first: # pip install git+https://github.com/roboflow/rf-detr.git try: from rfdetr import RFDETRNano except ImportError: print("āŒ Error: RF-DETR not installed") print("Please install: pip install git+https://github.com/roboflow/rf-detr.git") sys.exit(1) # Class names (matching the model training) CLASS_NAMES = [ "heart-anatomy-images", # Parent category at index 0 "heart_cadaver", "heart_cell", "heart_ct_scan", "heart_drawing", "heart_textbook", "heart_wall", "heart_xray" ] # Class descriptions CLASS_DESCRIPTIONS = { "heart_cadaver": "Real anatomical specimen from dissection", "heart_cell": "Microscopic/cellular view of cardiac tissue", "heart_ct_scan": "CT imaging of the heart", "heart_drawing": "Hand-drawn or digital medical illustration", "heart_textbook": "Educational anatomy image from textbooks", "heart_wall": "Cross-sectional view showing heart wall layers", "heart_xray": "Radiographic chest/heart image" } # Colors for bounding boxes COLORS = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), (128, 0, 128), (255, 128, 0) ] def load_model(checkpoint_path): """Load the Heartformer model""" print("šŸ¤– Loading Heartformer model...") model = RFDETRNano( pretrain_weights=checkpoint_path, num_classes=len(CLASS_NAMES) ) print("āœ… Model loaded successfully") return model def run_inference(model, image_path, threshold=0.3): """Run inference on an image""" print(f"\nšŸ” Running inference on: {image_path}") print(f" Confidence threshold: {threshold}") # Run detection detections = model.predict(str(image_path), threshold=threshold) # Parse results results = [] for bbox, conf, class_id in zip( detections.xyxy, detections.confidence, detections.class_id ): class_id = int(class_id) # Skip parent category if class_id == 0: continue class_name = CLASS_NAMES[class_id] results.append({ "class_id": class_id, "class_name": class_name, "confidence": float(conf), "bbox": [float(x) for x in bbox], "description": CLASS_DESCRIPTIONS.get(class_name, "") }) return results def visualize_results(image_path, results, output_path=None): """Draw bounding boxes on image""" # Load image image = Image.open(image_path).convert('RGB') draw = ImageDraw.Draw(image) # Draw each detection for detection in results: x1, y1, x2, y2 = detection['bbox'] color = COLORS[detection['class_id'] % len(COLORS)] # Draw bounding box draw.rectangle([x1, y1, x2, y2], outline=color, width=3) # Draw label label = f"{detection['class_name']}: {detection['confidence']:.2f}" text_bbox = draw.textbbox((x1, y1), label) draw.rectangle(text_bbox, fill=color) draw.text((x1, y1 - 20), label, fill=(255, 255, 255)) # Save or show if output_path: image.save(output_path) print(f"šŸ’¾ Saved visualization to: {output_path}") else: image.show() return image def main(): """Main entry point""" import argparse parser = argparse.ArgumentParser(description="Run Heartformer inference") parser.add_argument("image", help="Path to input image") parser.add_argument( "--checkpoint", default="checkpoint_best_ema.pth", help="Path to model checkpoint" ) parser.add_argument( "--threshold", type=float, default=0.3, help="Confidence threshold (default: 0.3)" ) parser.add_argument( "--output", help="Path to save visualization (default: show in window)" ) parser.add_argument( "--json", help="Path to save detection results as JSON" ) args = parser.parse_args() # Validate inputs if not Path(args.image).exists(): print(f"āŒ Error: Image not found: {args.image}") return 1 if not Path(args.checkpoint).exists(): print(f"āŒ Error: Checkpoint not found: {args.checkpoint}") print("\nšŸ’” Download the checkpoint from:") print(" https://huggingface.co/giannisan/heartformer") return 1 # Load model model = load_model(args.checkpoint) # Run inference results = run_inference(model, args.image, args.threshold) # Print results print(f"\nšŸŽÆ Found {len(results)} detection(s):") print("-" * 60) for i, det in enumerate(results, 1): print(f"\n{i}. {det['class_name']}") print(f" Confidence: {det['confidence']:.1%}") print(f" BBox: [{det['bbox'][0]:.0f}, {det['bbox'][1]:.0f}, " f"{det['bbox'][2]:.0f}, {det['bbox'][3]:.0f}]") print(f" {det['description']}") # Save JSON if requested if args.json: with open(args.json, 'w') as f: json.dump(results, f, indent=2) print(f"\nšŸ’¾ Saved results to: {args.json}") # Visualize if len(results) > 0: print("\nšŸ“Š Creating visualization...") visualize_results(args.image, results, args.output) else: print("\nāš ļø No detections found. Try lowering the threshold.") return 0 if __name__ == "__main__": exit(main())