File size: 6,698 Bytes
8f59aab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Script chuẩn bị dataset: chia train/val/test, tạo masks từ RLE encoding
"""

import os
import json
import numpy as np
from pathlib import Path
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

class DatasetPreparator:
    def __init__(self, data_dir="./data", output_dir="./prepared_data"):
        self.data_dir = Path(data_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Tạo subdirectories
        self.train_images_dir = self.output_dir / "train_images"
        self.train_masks_dir = self.output_dir / "train_masks"
        self.val_images_dir = self.output_dir / "val_images"
        self.val_masks_dir = self.output_dir / "val_masks"
        self.test_images_dir = self.output_dir / "test_images"
        self.test_masks_dir = self.output_dir / "test_masks"
        
        for dir_path in [self.train_images_dir, self.train_masks_dir, 
                        self.val_images_dir, self.val_masks_dir,
                        self.test_images_dir, self.test_masks_dir]:
            dir_path.mkdir(parents=True, exist_ok=True)
    
    @staticmethod
    def rle_decode(mask_rle, shape=(137, 236)):
        """Giải mã RLE encoding thành mask"""
        if pd.isna(mask_rle):
            return np.zeros(shape[0] * shape[1], dtype=np.uint8)
        
        s = mask_rle.split()
        starts, lengths = [np.asarray(x, dtype=int) for (x, y) in 
                          zip(s[0:None:2], s[1:None:2])]
        starts -= 1
        ends = starts + lengths
        
        img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
        for lo, hi in zip(starts, ends):
            img[lo:hi] = 1
        return img.reshape(shape[::-1]).T
    
    def create_segmentation_mask(self, image_id, df_masks):
        """Tạo mask phân đoạn từ dữ liệu RLE"""
        height, width = 137, 236
        mask = np.zeros((height, width), dtype=np.uint8)
        
        # Các class: 1=large_bowel, 2=small_bowel, 3=stomach
        class_mapping = {'large_bowel': 1, 'small_bowel': 2, 'stomach': 3}
        
        for idx, row in df_masks[df_masks['id'] == image_id].iterrows():
            organ_class = class_mapping.get(row['organ'], 0)
            if organ_class > 0:
                rle_mask = self.rle_decode(row['segmentation'], shape=(height, width))
                mask[rle_mask == 1] = organ_class
        
        return mask
    
    def process_dataset(self, train_size=0.8, val_size=0.1):
        """Xử lý toàn bộ dataset"""
        print("\n📊 Đang chuẩn bị dataset...")
        
        # 1. Tìm các ảnh huấn luyện
        if (self.data_dir / "train_images").exists():
            train_images = sorted(list((self.data_dir / "train_images").glob("*.png")))
            print(f"✓ Tìm thấy {len(train_images)} ảnh huấn luyện")
        else:
            print("✗ Không tìm thấy thư mục train_images")
            return False
        
        # 2. Load RLE masks nếu có
        train_masks_csv = self.data_dir / "train_masks.csv"
        if train_masks_csv.exists():
            df_masks = pd.read_csv(train_masks_csv)
            print(f"✓ Load {len(df_masks)} mask annotations")
            has_masks = True
        else:
            print("⚠️  Không tìm thấy train_masks.csv, bỏ qua giải mã RLE")
            has_masks = False
        
        # 3. Chia train/val/test
        image_ids = [img.stem for img in train_images]
        train_ids, test_ids = train_test_split(
            image_ids, test_size=(1-train_size), random_state=42
        )
        train_ids, val_ids = train_test_split(
            train_ids, test_size=val_size/(train_size), random_state=42
        )
        
        print(f"  Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")
        
        # 4. Copy ảnh và tạo masks
        dataset_splits = {
            'train': (train_ids, self.train_images_dir, self.train_masks_dir),
            'val': (val_ids, self.val_images_dir, self.val_masks_dir),
            'test': (test_ids, self.test_images_dir, self.test_masks_dir)
        }
        
        for split_name, (ids, images_dir, masks_dir) in dataset_splits.items():
            print(f"\n  📁 Xử lý {split_name} set ({len(ids)} ảnh)...")
            
            for i, img_id in enumerate(ids):
                # Copy ảnh
                src_img = self.data_dir / "train_images" / f"{img_id}.png"
                if src_img.exists():
                    dst_img = images_dir / f"{img_id}.png"
                    Image.open(src_img).save(dst_img)
                
                # Tạo mask
                if has_masks:
                    mask = self.create_segmentation_mask(img_id, df_masks)
                    mask_img = Image.fromarray(mask)
                    mask_img.save(masks_dir / f"{img_id}_mask.png")
                
                if (i + 1) % max(1, len(ids) // 5) == 0 or i == 0:
                    print(f"    → {i+1}/{len(ids)} hoàn thành")
        
        # 5. Lưu split info
        split_info = {
            'train': train_ids,
            'val': val_ids,
            'test': test_ids
        }
        
        with open(self.output_dir / "split.json", 'w') as f:
            json.dump(split_info, f, indent=2)
        
        print(f"\n✓ Split info lưu tại: {self.output_dir / 'split.json'}")
        
        return True
    
    def get_dataset_statistics(self):
        """Thống kê dataset"""
        print("\n📈 Thống kê dataset:")
        
        for split_dir in [self.train_images_dir, self.val_images_dir, self.test_images_dir]:
            split_name = split_dir.parent.name.replace('_images', '')
            num_images = len(list(split_dir.glob("*.png")))
            total_size_mb = sum(f.stat().st_size for f in split_dir.glob("*.png")) / (1024*1024)
            print(f"  {split_name:8} - {num_images:5} ảnh ({total_size_mb:8.2f} MB)")

def main():
    print("=" * 60)
    print("🎯 Dataset Preparation Tool")
    print("=" * 60)
    
    preparator = DatasetPreparator(
        data_dir="./data",
        output_dir="./prepared_data"
    )
    
    if preparator.process_dataset():
        preparator.get_dataset_statistics()
        
        print("\n" + "=" * 60)
        print("✅ Dataset đã được chuẩn bị! Tiếp theo:")
        print("   python train.py --data ./prepared_data")
        print("=" * 60)
        return True
    
    return False

if __name__ == "__main__":
    success = main()
    exit(0 if success else 1)