Medical_Image_Segmentation / prepare_dataset.py
AutoDeploy
Fix: Python 3.8 compatibility (use Tuple from typing) + Gradio 4.48.1 security update
8f59aab
"""
Script chuẩn bị dataset: chia train/val/test, tạo masks từ RLE encoding
"""
import os
import json
import numpy as np
from pathlib import Path
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')
class DatasetPreparator:
def __init__(self, data_dir="./data", output_dir="./prepared_data"):
self.data_dir = Path(data_dir)
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# Tạo subdirectories
self.train_images_dir = self.output_dir / "train_images"
self.train_masks_dir = self.output_dir / "train_masks"
self.val_images_dir = self.output_dir / "val_images"
self.val_masks_dir = self.output_dir / "val_masks"
self.test_images_dir = self.output_dir / "test_images"
self.test_masks_dir = self.output_dir / "test_masks"
for dir_path in [self.train_images_dir, self.train_masks_dir,
self.val_images_dir, self.val_masks_dir,
self.test_images_dir, self.test_masks_dir]:
dir_path.mkdir(parents=True, exist_ok=True)
@staticmethod
def rle_decode(mask_rle, shape=(137, 236)):
"""Giải mã RLE encoding thành mask"""
if pd.isna(mask_rle):
return np.zeros(shape[0] * shape[1], dtype=np.uint8)
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for (x, y) in
zip(s[0:None:2], s[1:None:2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape[::-1]).T
def create_segmentation_mask(self, image_id, df_masks):
"""Tạo mask phân đoạn từ dữ liệu RLE"""
height, width = 137, 236
mask = np.zeros((height, width), dtype=np.uint8)
# Các class: 1=large_bowel, 2=small_bowel, 3=stomach
class_mapping = {'large_bowel': 1, 'small_bowel': 2, 'stomach': 3}
for idx, row in df_masks[df_masks['id'] == image_id].iterrows():
organ_class = class_mapping.get(row['organ'], 0)
if organ_class > 0:
rle_mask = self.rle_decode(row['segmentation'], shape=(height, width))
mask[rle_mask == 1] = organ_class
return mask
def process_dataset(self, train_size=0.8, val_size=0.1):
"""Xử lý toàn bộ dataset"""
print("\n📊 Đang chuẩn bị dataset...")
# 1. Tìm các ảnh huấn luyện
if (self.data_dir / "train_images").exists():
train_images = sorted(list((self.data_dir / "train_images").glob("*.png")))
print(f"✓ Tìm thấy {len(train_images)} ảnh huấn luyện")
else:
print("✗ Không tìm thấy thư mục train_images")
return False
# 2. Load RLE masks nếu có
train_masks_csv = self.data_dir / "train_masks.csv"
if train_masks_csv.exists():
df_masks = pd.read_csv(train_masks_csv)
print(f"✓ Load {len(df_masks)} mask annotations")
has_masks = True
else:
print("⚠️ Không tìm thấy train_masks.csv, bỏ qua giải mã RLE")
has_masks = False
# 3. Chia train/val/test
image_ids = [img.stem for img in train_images]
train_ids, test_ids = train_test_split(
image_ids, test_size=(1-train_size), random_state=42
)
train_ids, val_ids = train_test_split(
train_ids, test_size=val_size/(train_size), random_state=42
)
print(f" Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")
# 4. Copy ảnh và tạo masks
dataset_splits = {
'train': (train_ids, self.train_images_dir, self.train_masks_dir),
'val': (val_ids, self.val_images_dir, self.val_masks_dir),
'test': (test_ids, self.test_images_dir, self.test_masks_dir)
}
for split_name, (ids, images_dir, masks_dir) in dataset_splits.items():
print(f"\n 📁 Xử lý {split_name} set ({len(ids)} ảnh)...")
for i, img_id in enumerate(ids):
# Copy ảnh
src_img = self.data_dir / "train_images" / f"{img_id}.png"
if src_img.exists():
dst_img = images_dir / f"{img_id}.png"
Image.open(src_img).save(dst_img)
# Tạo mask
if has_masks:
mask = self.create_segmentation_mask(img_id, df_masks)
mask_img = Image.fromarray(mask)
mask_img.save(masks_dir / f"{img_id}_mask.png")
if (i + 1) % max(1, len(ids) // 5) == 0 or i == 0:
print(f" → {i+1}/{len(ids)} hoàn thành")
# 5. Lưu split info
split_info = {
'train': train_ids,
'val': val_ids,
'test': test_ids
}
with open(self.output_dir / "split.json", 'w') as f:
json.dump(split_info, f, indent=2)
print(f"\n✓ Split info lưu tại: {self.output_dir / 'split.json'}")
return True
def get_dataset_statistics(self):
"""Thống kê dataset"""
print("\n📈 Thống kê dataset:")
for split_dir in [self.train_images_dir, self.val_images_dir, self.test_images_dir]:
split_name = split_dir.parent.name.replace('_images', '')
num_images = len(list(split_dir.glob("*.png")))
total_size_mb = sum(f.stat().st_size for f in split_dir.glob("*.png")) / (1024*1024)
print(f" {split_name:8} - {num_images:5} ảnh ({total_size_mb:8.2f} MB)")
def main():
print("=" * 60)
print("🎯 Dataset Preparation Tool")
print("=" * 60)
preparator = DatasetPreparator(
data_dir="./data",
output_dir="./prepared_data"
)
if preparator.process_dataset():
preparator.get_dataset_statistics()
print("\n" + "=" * 60)
print("✅ Dataset đã được chuẩn bị! Tiếp theo:")
print(" python train.py --data ./prepared_data")
print("=" * 60)
return True
return False
if __name__ == "__main__":
success = main()
exit(0 if success else 1)