Spaces:
Sleeping
Sleeping
AutoDeploy
Fix: Python 3.8 compatibility (use Tuple from typing) + Gradio 4.48.1 security update
8f59aab
| """ | |
| Script test và đánh giá mô hình | |
| """ | |
| import os | |
| import argparse | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
| from sklearn.metrics import confusion_matrix, jaccard_score, precision_score, recall_score | |
| class MedicalImageSegmentationTester: | |
| def __init__(self, model_path, device="auto"): | |
| self.device = torch.device("cuda" if device == "auto" and torch.cuda.is_available() else "cpu") | |
| print(f"🖥️ Device: {self.device}") | |
| print(f"📁 Loading model from: {model_path}") | |
| # Load model | |
| self.model = SegformerForSemanticSegmentation.from_pretrained(model_path) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Load processor | |
| self.processor = SegformerImageProcessor.from_pretrained(model_path) | |
| print("✓ Model loaded successfully") | |
| def predict_single(self, image_path, return_probs=False): | |
| """Dự đoán trên một ảnh""" | |
| # Load image | |
| image = Image.open(image_path).convert("RGB") | |
| original_size = image.size[::-1] # (H, W) | |
| # Process image | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = self.model(pixel_values=inputs["pixel_values"].to(self.device)) | |
| logits = outputs.logits | |
| # Interpolate to original size | |
| upsampled_logits = F.interpolate( | |
| logits, | |
| size=original_size, | |
| mode="bilinear", | |
| align_corners=False | |
| ) | |
| pred_mask = upsampled_logits.argmax(dim=1)[0].cpu().numpy() | |
| if return_probs: | |
| probs = torch.softmax(upsampled_logits, dim=1)[0].cpu().numpy() | |
| return pred_mask, probs | |
| return pred_mask | |
| def evaluate_dataset(self, image_dir, mask_dir, output_dir=None): | |
| """Đánh giá trên toàn bộ dataset""" | |
| image_dir = Path(image_dir) | |
| mask_dir = Path(mask_dir) | |
| image_paths = sorted(list(image_dir.glob("*.png"))) | |
| print(f"\n📊 Evaluating {len(image_paths)} images...") | |
| if output_dir: | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| metrics_list = [] | |
| all_true = [] | |
| all_pred = [] | |
| for img_path in tqdm(image_paths): | |
| img_id = img_path.stem | |
| mask_path = mask_dir / f"{img_id}_mask.png" | |
| if not mask_path.exists(): | |
| continue | |
| # Predict | |
| pred_mask = self.predict_single(img_path) | |
| # Load ground truth | |
| true_mask = np.array(Image.open(mask_path)) | |
| # Calculate metrics | |
| metrics = self.calculate_metrics(true_mask, pred_mask) | |
| metrics['image_id'] = img_id | |
| metrics_list.append(metrics) | |
| all_true.extend(true_mask.flatten()) | |
| all_pred.extend(pred_mask.flatten()) | |
| # Save prediction if output_dir provided | |
| if output_dir: | |
| pred_img = Image.fromarray((pred_mask * 50).astype(np.uint8)) | |
| pred_img.save(output_dir / f"{img_id}_pred.png") | |
| # Overall metrics | |
| overall_metrics = { | |
| 'mIoU': jaccard_score(all_true, all_pred, average='weighted'), | |
| 'precision': precision_score(all_true, all_pred, average='weighted', zero_division=0), | |
| 'recall': recall_score(all_true, all_pred, average='weighted', zero_division=0), | |
| } | |
| # Per-class metrics | |
| for class_id in range(1, 4): # 1=large_bowel, 2=small_bowel, 3=stomach | |
| class_true = (np.array(all_true) == class_id).astype(int) | |
| class_pred = (np.array(all_pred) == class_id).astype(int) | |
| if class_true.sum() > 0: | |
| overall_metrics[f'class_{class_id}_IoU'] = jaccard_score(class_true, class_pred) | |
| print("\n" + "="*60) | |
| print("📈 Evaluation Results") | |
| print("="*60) | |
| print("\nOverall Metrics:") | |
| for metric, value in overall_metrics.items(): | |
| print(f" {metric:20}: {value:.4f}") | |
| print(f"\nPer-image Statistics ({len(metrics_list)} images):") | |
| if metrics_list: | |
| for key in metrics_list[0].keys(): | |
| if key != 'image_id': | |
| values = [m[key] for m in metrics_list] | |
| print(f" {key:20}: mean={np.mean(values):.4f}, std={np.std(values):.4f}") | |
| # Save results | |
| results = { | |
| 'overall_metrics': overall_metrics, | |
| 'per_image_metrics': metrics_list | |
| } | |
| if output_dir: | |
| with open(output_dir / "evaluation_results.json", 'w') as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\n✓ Results saved to {output_dir / 'evaluation_results.json'}") | |
| return results | |
| def calculate_metrics(true_mask, pred_mask): | |
| """Tính toán metrics cho một ảnh""" | |
| iou = jaccard_score(true_mask.flatten(), pred_mask.flatten(), average='weighted') | |
| precision = precision_score(true_mask.flatten(), pred_mask.flatten(), | |
| average='weighted', zero_division=0) | |
| recall = recall_score(true_mask.flatten(), pred_mask.flatten(), | |
| average='weighted', zero_division=0) | |
| return { | |
| 'iou': iou, | |
| 'precision': precision, | |
| 'recall': recall | |
| } | |
| def visualize_predictions(self, image_dir, mask_dir, output_dir, num_samples=5): | |
| """Tạo visualizations của predictions""" | |
| image_dir = Path(image_dir) | |
| mask_dir = Path(mask_dir) | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| image_paths = sorted(list(image_dir.glob("*.png")))[:num_samples] | |
| print(f"\n🎨 Visualizing {len(image_paths)} predictions...") | |
| for img_path in tqdm(image_paths): | |
| img_id = img_path.stem | |
| # Load original image | |
| image = Image.open(img_path).convert("RGB") | |
| # Predict | |
| pred_mask, probs = self.predict_single(img_path, return_probs=True) | |
| # Create visualization | |
| # - Original image | |
| # - Prediction mask | |
| # - Confidence map | |
| fig_width = 15 | |
| import matplotlib.pyplot as plt | |
| fig, axes = plt.subplots(1, 3, figsize=(fig_width, 5)) | |
| # Original | |
| axes[0].imshow(image) | |
| axes[0].set_title("Original Image") | |
| axes[0].axis('off') | |
| # Prediction | |
| axes[1].imshow(pred_mask, cmap='viridis') | |
| axes[1].set_title("Prediction") | |
| axes[1].axis('off') | |
| # Confidence | |
| confidence = np.max(probs, axis=0) | |
| axes[2].imshow(confidence, cmap='hot') | |
| axes[2].set_title("Confidence") | |
| axes[2].axis('off') | |
| plt.tight_layout() | |
| plt.savefig(output_dir / f"{img_id}_visualization.png", dpi=100, bbox_inches='tight') | |
| plt.close() | |
| print(f"✓ Visualizations saved to {output_dir}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Test and evaluate medical image segmentation model") | |
| parser.add_argument("--model", type=str, required=True, | |
| help="Path to trained model") | |
| parser.add_argument("--test-images", type=str, | |
| help="Path to test images directory") | |
| parser.add_argument("--test-masks", type=str, | |
| help="Path to test masks directory") | |
| parser.add_argument("--output-dir", type=str, default="./test_results", | |
| help="Output directory for results") | |
| parser.add_argument("--visualize", action="store_true", | |
| help="Create visualizations") | |
| parser.add_argument("--num-samples", type=int, default=5, | |
| help="Number of samples to visualize") | |
| args = parser.parse_args() | |
| # Initialize tester | |
| tester = MedicalImageSegmentationTester(args.model) | |
| # Evaluate | |
| if args.test_images and args.test_masks: | |
| results = tester.evaluate_dataset( | |
| args.test_images, | |
| args.test_masks, | |
| args.output_dir | |
| ) | |
| # Visualize | |
| if args.visualize: | |
| tester.visualize_predictions( | |
| args.test_images, | |
| args.test_masks, | |
| Path(args.output_dir) / "visualizations", | |
| args.num_samples | |
| ) | |
| else: | |
| print("Please provide --test-images and --test-masks directories") | |
| return False | |
| return True | |
| if __name__ == "__main__": | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| success = main() | |
| exit(0 if success else 1) | |