|
|
|
|
|
""" |
|
|
Example inference script for Heartformer model |
|
|
|
|
|
This script demonstrates how to use the Heartformer model to detect |
|
|
heart anatomy types in images. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from rfdetr import RFDETRNano |
|
|
except ImportError: |
|
|
print("โ Error: RF-DETR not installed") |
|
|
print("Please install: pip install git+https://github.com/roboflow/rf-detr.git") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
CLASS_NAMES = [ |
|
|
"heart-anatomy-images", |
|
|
"heart_cadaver", |
|
|
"heart_cell", |
|
|
"heart_ct_scan", |
|
|
"heart_drawing", |
|
|
"heart_textbook", |
|
|
"heart_wall", |
|
|
"heart_xray" |
|
|
] |
|
|
|
|
|
|
|
|
CLASS_DESCRIPTIONS = { |
|
|
"heart_cadaver": "Real anatomical specimen from dissection", |
|
|
"heart_cell": "Microscopic/cellular view of cardiac tissue", |
|
|
"heart_ct_scan": "CT imaging of the heart", |
|
|
"heart_drawing": "Hand-drawn or digital medical illustration", |
|
|
"heart_textbook": "Educational anatomy image from textbooks", |
|
|
"heart_wall": "Cross-sectional view showing heart wall layers", |
|
|
"heart_xray": "Radiographic chest/heart image" |
|
|
} |
|
|
|
|
|
|
|
|
COLORS = [ |
|
|
(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), |
|
|
(255, 0, 255), (0, 255, 255), (128, 0, 128), (255, 128, 0) |
|
|
] |
|
|
|
|
|
|
|
|
def load_model(checkpoint_path): |
|
|
"""Load the Heartformer model""" |
|
|
print("๐ค Loading Heartformer model...") |
|
|
model = RFDETRNano( |
|
|
pretrain_weights=checkpoint_path, |
|
|
num_classes=len(CLASS_NAMES) |
|
|
) |
|
|
print("โ
Model loaded successfully") |
|
|
return model |
|
|
|
|
|
|
|
|
def run_inference(model, image_path, threshold=0.3): |
|
|
"""Run inference on an image""" |
|
|
print(f"\n๐ Running inference on: {image_path}") |
|
|
print(f" Confidence threshold: {threshold}") |
|
|
|
|
|
|
|
|
detections = model.predict(str(image_path), threshold=threshold) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for bbox, conf, class_id in zip( |
|
|
detections.xyxy, |
|
|
detections.confidence, |
|
|
detections.class_id |
|
|
): |
|
|
class_id = int(class_id) |
|
|
|
|
|
|
|
|
if class_id == 0: |
|
|
continue |
|
|
|
|
|
class_name = CLASS_NAMES[class_id] |
|
|
results.append({ |
|
|
"class_id": class_id, |
|
|
"class_name": class_name, |
|
|
"confidence": float(conf), |
|
|
"bbox": [float(x) for x in bbox], |
|
|
"description": CLASS_DESCRIPTIONS.get(class_name, "") |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def visualize_results(image_path, results, output_path=None): |
|
|
"""Draw bounding boxes on image""" |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
|
|
|
for detection in results: |
|
|
x1, y1, x2, y2 = detection['bbox'] |
|
|
color = COLORS[detection['class_id'] % len(COLORS)] |
|
|
|
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) |
|
|
|
|
|
|
|
|
label = f"{detection['class_name']}: {detection['confidence']:.2f}" |
|
|
text_bbox = draw.textbbox((x1, y1), label) |
|
|
draw.rectangle(text_bbox, fill=color) |
|
|
draw.text((x1, y1 - 20), label, fill=(255, 255, 255)) |
|
|
|
|
|
|
|
|
if output_path: |
|
|
image.save(output_path) |
|
|
print(f"๐พ Saved visualization to: {output_path}") |
|
|
else: |
|
|
image.show() |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main entry point""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run Heartformer inference") |
|
|
parser.add_argument("image", help="Path to input image") |
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
default="checkpoint_best_ema.pth", |
|
|
help="Path to model checkpoint" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--threshold", |
|
|
type=float, |
|
|
default=0.3, |
|
|
help="Confidence threshold (default: 0.3)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
help="Path to save visualization (default: show in window)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--json", |
|
|
help="Path to save detection results as JSON" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not Path(args.image).exists(): |
|
|
print(f"โ Error: Image not found: {args.image}") |
|
|
return 1 |
|
|
|
|
|
if not Path(args.checkpoint).exists(): |
|
|
print(f"โ Error: Checkpoint not found: {args.checkpoint}") |
|
|
print("\n๐ก Download the checkpoint from:") |
|
|
print(" https://huggingface.co/giannisan/heartformer") |
|
|
return 1 |
|
|
|
|
|
|
|
|
model = load_model(args.checkpoint) |
|
|
|
|
|
|
|
|
results = run_inference(model, args.image, args.threshold) |
|
|
|
|
|
|
|
|
print(f"\n๐ฏ Found {len(results)} detection(s):") |
|
|
print("-" * 60) |
|
|
for i, det in enumerate(results, 1): |
|
|
print(f"\n{i}. {det['class_name']}") |
|
|
print(f" Confidence: {det['confidence']:.1%}") |
|
|
print(f" BBox: [{det['bbox'][0]:.0f}, {det['bbox'][1]:.0f}, " |
|
|
f"{det['bbox'][2]:.0f}, {det['bbox'][3]:.0f}]") |
|
|
print(f" {det['description']}") |
|
|
|
|
|
|
|
|
if args.json: |
|
|
with open(args.json, 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
print(f"\n๐พ Saved results to: {args.json}") |
|
|
|
|
|
|
|
|
if len(results) > 0: |
|
|
print("\n๐ Creating visualization...") |
|
|
visualize_results(args.image, results, args.output) |
|
|
else: |
|
|
print("\nโ ๏ธ No detections found. Try lowering the threshold.") |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|