Spaces:
Sleeping
Sleeping
| """ | |
| Script training SegFormer model cho medical image segmentation | |
| """ | |
| import os | |
| import argparse | |
| from pathlib import Path | |
| import json | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
| from torch.optim import AdamW | |
| import torch.nn.functional as F | |
| class MedicalSegmentationDataset(Dataset): | |
| def __init__(self, image_dir, mask_dir, image_size=(288, 288)): | |
| self.image_dir = Path(image_dir) | |
| self.mask_dir = Path(mask_dir) | |
| self.image_size = image_size | |
| self.image_paths = sorted(list(self.image_dir.glob("*.png"))) | |
| self.processor = SegformerImageProcessor(do_reduce_labels=False) | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| img_path = self.image_paths[idx] | |
| img_id = img_path.stem | |
| mask_path = self.mask_dir / f"{img_id}_mask.png" | |
| # Load image | |
| image = Image.open(img_path).convert("RGB") | |
| # Load mask | |
| if mask_path.exists(): | |
| mask = Image.open(mask_path) | |
| segmentation_maps = np.array(mask) | |
| else: | |
| segmentation_maps = np.zeros((image.height, image.width), dtype=np.uint8) | |
| # Resize | |
| image = image.resize(self.image_size[::-1]) | |
| mask_tensor = torch.from_numpy(segmentation_maps).long() | |
| mask_tensor = F.interpolate( | |
| mask_tensor.unsqueeze(0).unsqueeze(0).float(), | |
| size=self.image_size[::-1], | |
| mode="nearest" | |
| ).squeeze(0).squeeze(0).long() | |
| # Process with SegformerImageProcessor | |
| encoded_inputs = self.processor(images=image, return_tensors="pt") | |
| for k, v in encoded_inputs.items(): | |
| encoded_inputs[k].squeeze_(0) | |
| encoded_inputs["labels"] = mask_tensor | |
| return encoded_inputs | |
| class MedicalImageSegmentationTrainer: | |
| def __init__(self, args): | |
| self.args = args | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.output_dir = Path(args.output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"🖥️ Device: {self.device}") | |
| print(f"📁 Output directory: {self.output_dir}") | |
| def create_datasets(self): | |
| """Tạo training và validation datasets""" | |
| print("\n📊 Loading datasets...") | |
| train_dataset = MedicalSegmentationDataset( | |
| self.args.train_images_dir, | |
| self.args.train_masks_dir, | |
| image_size=(288, 288) | |
| ) | |
| val_dataset = MedicalSegmentationDataset( | |
| self.args.val_images_dir, | |
| self.args.val_masks_dir, | |
| image_size=(288, 288) | |
| ) | |
| print(f" Train dataset: {len(train_dataset)} samples") | |
| print(f" Val dataset: {len(val_dataset)} samples") | |
| return train_dataset, val_dataset | |
| def create_dataloaders(self, train_dataset, val_dataset): | |
| """Tạo data loaders""" | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=self.args.batch_size, | |
| shuffle=True, | |
| num_workers=self.args.num_workers | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=self.args.batch_size, | |
| num_workers=self.args.num_workers | |
| ) | |
| return train_loader, val_loader | |
| def create_model(self): | |
| """Tạo SegFormer model""" | |
| print("\n🧠 Loading SegFormer model...") | |
| model = SegformerForSemanticSegmentation.from_pretrained( | |
| "nvidia/mit-b0", | |
| num_labels=4, # background + 3 organs | |
| id2label={0: "background", 1: "large_bowel", 2: "small_bowel", 3: "stomach"}, | |
| label2id={"background": 0, "large_bowel": 1, "small_bowel": 2, "stomach": 3}, | |
| ignore_mismatched_sizes=True | |
| ) | |
| model.to(self.device) | |
| print(f"✓ Model loaded ({sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters)") | |
| return model | |
| def train_epoch(self, model, train_loader, optimizer, epoch): | |
| """Huấn luyện một epoch""" | |
| model.train() | |
| total_loss = 0 | |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.args.epochs}") | |
| for batch in pbar: | |
| pixel_values = batch["pixel_values"].to(self.device) | |
| labels = batch["labels"].to(self.device) | |
| optimizer.zero_grad() | |
| outputs = model(pixel_values=pixel_values, labels=labels) | |
| loss = outputs.loss | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| pbar.set_postfix({'loss': f'{loss.item():.4f}'}) | |
| return total_loss / len(train_loader) | |
| def validate(self, model, val_loader): | |
| """Đánh giá trên validation set""" | |
| model.eval() | |
| total_loss = 0 | |
| with torch.no_grad(): | |
| for batch in tqdm(val_loader, desc="Validating"): | |
| pixel_values = batch["pixel_values"].to(self.device) | |
| labels = batch["labels"].to(self.device) | |
| outputs = model(pixel_values=pixel_values, labels=labels) | |
| loss = outputs.loss | |
| total_loss += loss.item() | |
| return total_loss / len(val_loader) | |
| def train(self): | |
| """Huấn luyện mô hình""" | |
| print("\n" + "="*60) | |
| print("🚀 Starting Training") | |
| print("="*60) | |
| # Tạo datasets | |
| train_dataset, val_dataset = self.create_datasets() | |
| train_loader, val_loader = self.create_dataloaders(train_dataset, val_dataset) | |
| # Tạo model | |
| model = self.create_model() | |
| # Optimizer | |
| optimizer = AdamW(model.parameters(), lr=self.args.learning_rate) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=self.args.epochs | |
| ) | |
| # Training loop | |
| best_val_loss = float('inf') | |
| history = {'train_loss': [], 'val_loss': []} | |
| for epoch in range(self.args.epochs): | |
| print(f"\n📌 Epoch {epoch+1}/{self.args.epochs}") | |
| # Train | |
| train_loss = self.train_epoch(model, train_loader, optimizer, epoch) | |
| history['train_loss'].append(train_loss) | |
| print(f" Train Loss: {train_loss:.4f}") | |
| # Validate | |
| val_loss = self.validate(model, val_loader) | |
| history['val_loss'].append(val_loss) | |
| print(f" Val Loss: {val_loss:.4f}") | |
| # Save best model | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| model_path = self.output_dir / "best_model" | |
| model.save_pretrained(model_path) | |
| print(f" ✓ Best model saved to {model_path}") | |
| # Learning rate scheduler | |
| scheduler.step() | |
| # Save final model | |
| final_model_path = self.output_dir / "final_model" | |
| model.save_pretrained(final_model_path) | |
| # Save training history | |
| with open(self.output_dir / "training_history.json", 'w') as f: | |
| json.dump(history, f, indent=2) | |
| print("\n" + "="*60) | |
| print("✅ Training Complete!") | |
| print(f" Best Model: {self.output_dir / 'best_model'}") | |
| print(f" Final Model: {final_model_path}") | |
| print(f" History: {self.output_dir / 'training_history.json'}") | |
| print("="*60) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train medical image segmentation model") | |
| parser.add_argument("--data", type=str, default="./prepared_data", | |
| help="Path to prepared dataset") | |
| parser.add_argument("--output-dir", type=str, default="./models", | |
| help="Output directory for models") | |
| parser.add_argument("--epochs", type=int, default=10, | |
| help="Number of training epochs") | |
| parser.add_argument("--batch-size", type=int, default=8, | |
| help="Batch size") | |
| parser.add_argument("--learning-rate", type=float, default=1e-4, | |
| help="Learning rate") | |
| parser.add_argument("--num-workers", type=int, default=4, | |
| help="Number of workers for dataloader") | |
| args = parser.parse_args() | |
| # Thêm các đường dẫn dataset vào args | |
| args.train_images_dir = os.path.join(args.data, "train_images") | |
| args.train_masks_dir = os.path.join(args.data, "train_masks") | |
| args.val_images_dir = os.path.join(args.data, "val_images") | |
| args.val_masks_dir = os.path.join(args.data, "val_masks") | |
| # Kiểm tra dataset tồn tại | |
| for dir_path in [args.train_images_dir, args.train_masks_dir, | |
| args.val_images_dir, args.val_masks_dir]: | |
| if not os.path.exists(dir_path): | |
| print(f"❌ Directory not found: {dir_path}") | |
| print("Please run prepare_dataset.py first") | |
| return False | |
| # Khởi tạo trainer | |
| trainer = MedicalImageSegmentationTrainer(args) | |
| # Train | |
| trainer.train() | |
| return True | |
| if __name__ == "__main__": | |
| success = main() | |
| exit(0 if success else 1) | |