""" Script training SegFormer model cho medical image segmentation """ import os import argparse from pathlib import Path import json import numpy as np from PIL import Image from tqdm import tqdm import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor from torch.optim import AdamW import torch.nn.functional as F class MedicalSegmentationDataset(Dataset): def __init__(self, image_dir, mask_dir, image_size=(288, 288)): self.image_dir = Path(image_dir) self.mask_dir = Path(mask_dir) self.image_size = image_size self.image_paths = sorted(list(self.image_dir.glob("*.png"))) self.processor = SegformerImageProcessor(do_reduce_labels=False) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] img_id = img_path.stem mask_path = self.mask_dir / f"{img_id}_mask.png" # Load image image = Image.open(img_path).convert("RGB") # Load mask if mask_path.exists(): mask = Image.open(mask_path) segmentation_maps = np.array(mask) else: segmentation_maps = np.zeros((image.height, image.width), dtype=np.uint8) # Resize image = image.resize(self.image_size[::-1]) mask_tensor = torch.from_numpy(segmentation_maps).long() mask_tensor = F.interpolate( mask_tensor.unsqueeze(0).unsqueeze(0).float(), size=self.image_size[::-1], mode="nearest" ).squeeze(0).squeeze(0).long() # Process with SegformerImageProcessor encoded_inputs = self.processor(images=image, return_tensors="pt") for k, v in encoded_inputs.items(): encoded_inputs[k].squeeze_(0) encoded_inputs["labels"] = mask_tensor return encoded_inputs class MedicalImageSegmentationTrainer: def __init__(self, args): self.args = args self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.output_dir = Path(args.output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) print(f"🖥️ Device: {self.device}") print(f"📁 Output directory: {self.output_dir}") def create_datasets(self): """Tạo training và validation datasets""" print("\n📊 Loading datasets...") train_dataset = MedicalSegmentationDataset( self.args.train_images_dir, self.args.train_masks_dir, image_size=(288, 288) ) val_dataset = MedicalSegmentationDataset( self.args.val_images_dir, self.args.val_masks_dir, image_size=(288, 288) ) print(f" Train dataset: {len(train_dataset)} samples") print(f" Val dataset: {len(val_dataset)} samples") return train_dataset, val_dataset def create_dataloaders(self, train_dataset, val_dataset): """Tạo data loaders""" train_loader = DataLoader( train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers ) val_loader = DataLoader( val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers ) return train_loader, val_loader def create_model(self): """Tạo SegFormer model""" print("\n🧠 Loading SegFormer model...") model = SegformerForSemanticSegmentation.from_pretrained( "nvidia/mit-b0", num_labels=4, # background + 3 organs id2label={0: "background", 1: "large_bowel", 2: "small_bowel", 3: "stomach"}, label2id={"background": 0, "large_bowel": 1, "small_bowel": 2, "stomach": 3}, ignore_mismatched_sizes=True ) model.to(self.device) print(f"✓ Model loaded ({sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters)") return model def train_epoch(self, model, train_loader, optimizer, epoch): """Huấn luyện một epoch""" model.train() total_loss = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.args.epochs}") for batch in pbar: pixel_values = batch["pixel_values"].to(self.device) labels = batch["labels"].to(self.device) optimizer.zero_grad() outputs = model(pixel_values=pixel_values, labels=labels) loss = outputs.loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() pbar.set_postfix({'loss': f'{loss.item():.4f}'}) return total_loss / len(train_loader) def validate(self, model, val_loader): """Đánh giá trên validation set""" model.eval() total_loss = 0 with torch.no_grad(): for batch in tqdm(val_loader, desc="Validating"): pixel_values = batch["pixel_values"].to(self.device) labels = batch["labels"].to(self.device) outputs = model(pixel_values=pixel_values, labels=labels) loss = outputs.loss total_loss += loss.item() return total_loss / len(val_loader) def train(self): """Huấn luyện mô hình""" print("\n" + "="*60) print("🚀 Starting Training") print("="*60) # Tạo datasets train_dataset, val_dataset = self.create_datasets() train_loader, val_loader = self.create_dataloaders(train_dataset, val_dataset) # Tạo model model = self.create_model() # Optimizer optimizer = AdamW(model.parameters(), lr=self.args.learning_rate) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.args.epochs ) # Training loop best_val_loss = float('inf') history = {'train_loss': [], 'val_loss': []} for epoch in range(self.args.epochs): print(f"\n📌 Epoch {epoch+1}/{self.args.epochs}") # Train train_loss = self.train_epoch(model, train_loader, optimizer, epoch) history['train_loss'].append(train_loss) print(f" Train Loss: {train_loss:.4f}") # Validate val_loss = self.validate(model, val_loader) history['val_loss'].append(val_loss) print(f" Val Loss: {val_loss:.4f}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss model_path = self.output_dir / "best_model" model.save_pretrained(model_path) print(f" ✓ Best model saved to {model_path}") # Learning rate scheduler scheduler.step() # Save final model final_model_path = self.output_dir / "final_model" model.save_pretrained(final_model_path) # Save training history with open(self.output_dir / "training_history.json", 'w') as f: json.dump(history, f, indent=2) print("\n" + "="*60) print("✅ Training Complete!") print(f" Best Model: {self.output_dir / 'best_model'}") print(f" Final Model: {final_model_path}") print(f" History: {self.output_dir / 'training_history.json'}") print("="*60) def main(): parser = argparse.ArgumentParser(description="Train medical image segmentation model") parser.add_argument("--data", type=str, default="./prepared_data", help="Path to prepared dataset") parser.add_argument("--output-dir", type=str, default="./models", help="Output directory for models") parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs") parser.add_argument("--batch-size", type=int, default=8, help="Batch size") parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--num-workers", type=int, default=4, help="Number of workers for dataloader") args = parser.parse_args() # Thêm các đường dẫn dataset vào args args.train_images_dir = os.path.join(args.data, "train_images") args.train_masks_dir = os.path.join(args.data, "train_masks") args.val_images_dir = os.path.join(args.data, "val_images") args.val_masks_dir = os.path.join(args.data, "val_masks") # Kiểm tra dataset tồn tại for dir_path in [args.train_images_dir, args.train_masks_dir, args.val_images_dir, args.val_masks_dir]: if not os.path.exists(dir_path): print(f"❌ Directory not found: {dir_path}") print("Please run prepare_dataset.py first") return False # Khởi tạo trainer trainer = MedicalImageSegmentationTrainer(args) # Train trainer.train() return True if __name__ == "__main__": success = main() exit(0 if success else 1)