""" Script test và đánh giá mô hình """ import os import argparse from pathlib import Path import numpy as np from PIL import Image import json from tqdm import tqdm import torch import torch.nn.functional as F from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor from sklearn.metrics import confusion_matrix, jaccard_score, precision_score, recall_score class MedicalImageSegmentationTester: def __init__(self, model_path, device="auto"): self.device = torch.device("cuda" if device == "auto" and torch.cuda.is_available() else "cpu") print(f"🖥️ Device: {self.device}") print(f"📁 Loading model from: {model_path}") # Load model self.model = SegformerForSemanticSegmentation.from_pretrained(model_path) self.model.to(self.device) self.model.eval() # Load processor self.processor = SegformerImageProcessor.from_pretrained(model_path) print("✓ Model loaded successfully") def predict_single(self, image_path, return_probs=False): """Dự đoán trên một ảnh""" # Load image image = Image.open(image_path).convert("RGB") original_size = image.size[::-1] # (H, W) # Process image inputs = self.processor(images=image, return_tensors="pt") # Inference with torch.no_grad(): outputs = self.model(pixel_values=inputs["pixel_values"].to(self.device)) logits = outputs.logits # Interpolate to original size upsampled_logits = F.interpolate( logits, size=original_size, mode="bilinear", align_corners=False ) pred_mask = upsampled_logits.argmax(dim=1)[0].cpu().numpy() if return_probs: probs = torch.softmax(upsampled_logits, dim=1)[0].cpu().numpy() return pred_mask, probs return pred_mask def evaluate_dataset(self, image_dir, mask_dir, output_dir=None): """Đánh giá trên toàn bộ dataset""" image_dir = Path(image_dir) mask_dir = Path(mask_dir) image_paths = sorted(list(image_dir.glob("*.png"))) print(f"\n📊 Evaluating {len(image_paths)} images...") if output_dir: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) metrics_list = [] all_true = [] all_pred = [] for img_path in tqdm(image_paths): img_id = img_path.stem mask_path = mask_dir / f"{img_id}_mask.png" if not mask_path.exists(): continue # Predict pred_mask = self.predict_single(img_path) # Load ground truth true_mask = np.array(Image.open(mask_path)) # Calculate metrics metrics = self.calculate_metrics(true_mask, pred_mask) metrics['image_id'] = img_id metrics_list.append(metrics) all_true.extend(true_mask.flatten()) all_pred.extend(pred_mask.flatten()) # Save prediction if output_dir provided if output_dir: pred_img = Image.fromarray((pred_mask * 50).astype(np.uint8)) pred_img.save(output_dir / f"{img_id}_pred.png") # Overall metrics overall_metrics = { 'mIoU': jaccard_score(all_true, all_pred, average='weighted'), 'precision': precision_score(all_true, all_pred, average='weighted', zero_division=0), 'recall': recall_score(all_true, all_pred, average='weighted', zero_division=0), } # Per-class metrics for class_id in range(1, 4): # 1=large_bowel, 2=small_bowel, 3=stomach class_true = (np.array(all_true) == class_id).astype(int) class_pred = (np.array(all_pred) == class_id).astype(int) if class_true.sum() > 0: overall_metrics[f'class_{class_id}_IoU'] = jaccard_score(class_true, class_pred) print("\n" + "="*60) print("📈 Evaluation Results") print("="*60) print("\nOverall Metrics:") for metric, value in overall_metrics.items(): print(f" {metric:20}: {value:.4f}") print(f"\nPer-image Statistics ({len(metrics_list)} images):") if metrics_list: for key in metrics_list[0].keys(): if key != 'image_id': values = [m[key] for m in metrics_list] print(f" {key:20}: mean={np.mean(values):.4f}, std={np.std(values):.4f}") # Save results results = { 'overall_metrics': overall_metrics, 'per_image_metrics': metrics_list } if output_dir: with open(output_dir / "evaluation_results.json", 'w') as f: json.dump(results, f, indent=2) print(f"\n✓ Results saved to {output_dir / 'evaluation_results.json'}") return results @staticmethod def calculate_metrics(true_mask, pred_mask): """Tính toán metrics cho một ảnh""" iou = jaccard_score(true_mask.flatten(), pred_mask.flatten(), average='weighted') precision = precision_score(true_mask.flatten(), pred_mask.flatten(), average='weighted', zero_division=0) recall = recall_score(true_mask.flatten(), pred_mask.flatten(), average='weighted', zero_division=0) return { 'iou': iou, 'precision': precision, 'recall': recall } def visualize_predictions(self, image_dir, mask_dir, output_dir, num_samples=5): """Tạo visualizations của predictions""" image_dir = Path(image_dir) mask_dir = Path(mask_dir) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) image_paths = sorted(list(image_dir.glob("*.png")))[:num_samples] print(f"\n🎨 Visualizing {len(image_paths)} predictions...") for img_path in tqdm(image_paths): img_id = img_path.stem # Load original image image = Image.open(img_path).convert("RGB") # Predict pred_mask, probs = self.predict_single(img_path, return_probs=True) # Create visualization # - Original image # - Prediction mask # - Confidence map fig_width = 15 import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 3, figsize=(fig_width, 5)) # Original axes[0].imshow(image) axes[0].set_title("Original Image") axes[0].axis('off') # Prediction axes[1].imshow(pred_mask, cmap='viridis') axes[1].set_title("Prediction") axes[1].axis('off') # Confidence confidence = np.max(probs, axis=0) axes[2].imshow(confidence, cmap='hot') axes[2].set_title("Confidence") axes[2].axis('off') plt.tight_layout() plt.savefig(output_dir / f"{img_id}_visualization.png", dpi=100, bbox_inches='tight') plt.close() print(f"✓ Visualizations saved to {output_dir}") def main(): parser = argparse.ArgumentParser(description="Test and evaluate medical image segmentation model") parser.add_argument("--model", type=str, required=True, help="Path to trained model") parser.add_argument("--test-images", type=str, help="Path to test images directory") parser.add_argument("--test-masks", type=str, help="Path to test masks directory") parser.add_argument("--output-dir", type=str, default="./test_results", help="Output directory for results") parser.add_argument("--visualize", action="store_true", help="Create visualizations") parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to visualize") args = parser.parse_args() # Initialize tester tester = MedicalImageSegmentationTester(args.model) # Evaluate if args.test_images and args.test_masks: results = tester.evaluate_dataset( args.test_images, args.test_masks, args.output_dir ) # Visualize if args.visualize: tester.visualize_predictions( args.test_images, args.test_masks, Path(args.output_dir) / "visualizations", args.num_samples ) else: print("Please provide --test-images and --test-masks directories") return False return True if __name__ == "__main__": import matplotlib matplotlib.use('Agg') # Use non-interactive backend success = main() exit(0 if success else 1)