dakesan commited on
Commit ·
e363d75
1
Parent(s): 4e3239c
initial commit
Browse files- README.md +298 -192
- prepare_data.py +134 -0
README.md
CHANGED
|
@@ -3,197 +3,303 @@ library_name: transformers
|
|
| 3 |
tags: []
|
| 4 |
---
|
| 5 |
|
| 6 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
tags: []
|
| 4 |
---
|
| 5 |
|
| 6 |
+
# Fine-tuned ViT image classifier
|
| 7 |
+
|
| 8 |
+
This repository provides a fine-tuned Vision Transformer model for classifying leukemia patient peripheral blood mononuclear cells.
|
| 9 |
+
|
| 10 |
+
## Model Overview
|
| 11 |
+
|
| 12 |
+
- **Base Model**: google/vit-large-patch16-224-in21k
|
| 13 |
+
- **Task**: 5-class classification of leukemia cells
|
| 14 |
+
- **Input**: 224x224 pixel dual-channel fluorescence microscopy images (R: ch1, G: ch6)
|
| 15 |
+
- **Output**: Probability distribution over 5 classes
|
| 16 |
|
| 17 |
+
## Performance
|
| 18 |
|
| 19 |
+
- **Architecture**: ViT-Large/16 (patch size 16x16)
|
| 20 |
+
- **Parameters**: ~307M
|
| 21 |
+
- **Accuracy**: 94.67% (evaluation dataset)
|
| 22 |
+
|
| 23 |
+
## Data Preparation
|
| 24 |
+
|
| 25 |
+
### Prerequisites for Data Processing
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# Required libraries for image processing
|
| 29 |
+
pip install numpy pillow tifffile
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
### Data Processing Tool
|
| 33 |
+
|
| 34 |
+
`tools/prepare_data.py` is a lightweight script for preprocessing dual-channel (ch1, ch6) cell images.
|
| 35 |
+
Implemented primarily using standard libraries, it performs the following operations:
|
| 36 |
+
|
| 37 |
+
1. Detects ch1 and ch6 image pairs
|
| 38 |
+
2. Normalizes each channel (0-255 scaling)
|
| 39 |
+
3. Converts to RGB format (R: ch1, G: ch6, B: empty channel)
|
| 40 |
+
4. Saves to specified output directory
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
# Basic usage
|
| 44 |
+
python prepare_data.py input_dir output_dir
|
| 45 |
+
|
| 46 |
+
# Example with options
|
| 47 |
+
python prepare_data.py \
|
| 48 |
+
/path/to/raw_images \
|
| 49 |
+
/path/to/processed_images \
|
| 50 |
+
--workers 8 \
|
| 51 |
+
--recursive
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
#### Options
|
| 55 |
+
- `--workers`: Number of parallel workers (default: 4)
|
| 56 |
+
- `--recursive`: Process subdirectories recursively
|
| 57 |
+
|
| 58 |
+
#### Input Directory Structure
|
| 59 |
+
```
|
| 60 |
+
input_dir/
|
| 61 |
+
├── class1/
|
| 62 |
+
│ ├── ch1_1.tif
|
| 63 |
+
│ ├── ch6_1.tif
|
| 64 |
+
│ ├── ch1_2.tif
|
| 65 |
+
│ └── ch6_2.tif
|
| 66 |
+
└── class2/
|
| 67 |
+
├── ch1_1.tif
|
| 68 |
+
├── ch6_1.tif
|
| 69 |
+
...
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
#### Output Directory Structure
|
| 73 |
+
```
|
| 74 |
+
output_dir/
|
| 75 |
+
├── class1/
|
| 76 |
+
│ ├── merged_1.tif
|
| 77 |
+
│ └── merged_2.tif
|
| 78 |
+
└── class2/
|
| 79 |
+
├── merged_1.tif
|
| 80 |
+
...
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Model Usage
|
| 84 |
+
|
| 85 |
+
### Prerequisites for Model
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# Required libraries for model inference
|
| 89 |
+
pip install torch torchvision transformers
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Usage Example
|
| 93 |
+
|
| 94 |
+
#### Single Image Inference
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
from transformers import ViTForImageClassification, ViTImageProcessor
|
| 98 |
+
import torch
|
| 99 |
+
from PIL import Image
|
| 100 |
+
|
| 101 |
+
# Load model and processor
|
| 102 |
+
model = ViTForImageClassification.from_pretrained("poprap/vit16L-FT-cellclassification")
|
| 103 |
+
processor = ViTImageProcessor.from_pretrained("poprap/vit16L-FT-cellclassification")
|
| 104 |
+
|
| 105 |
+
# Preprocess image
|
| 106 |
+
image = Image.open("cell_image.tif")
|
| 107 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 108 |
+
|
| 109 |
+
# Inference
|
| 110 |
+
outputs = model(**inputs)
|
| 111 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 112 |
+
predicted_class = torch.argmax(probabilities, dim=-1).item()
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
#### Batch Processing and Evaluation
|
| 116 |
+
|
| 117 |
+
For batch processing and comprehensive evaluation metrics calculation:
|
| 118 |
+
|
| 119 |
+
```python
|
| 120 |
+
import torch
|
| 121 |
+
import numpy as np
|
| 122 |
+
import time
|
| 123 |
+
from pathlib import Path
|
| 124 |
+
from tqdm import tqdm
|
| 125 |
+
from torchvision import transforms, datasets
|
| 126 |
+
from torch.utils.data import DataLoader
|
| 127 |
+
from transformers import ViTForImageClassification, ViTImageProcessor
|
| 128 |
+
import matplotlib.pyplot as plt
|
| 129 |
+
import seaborn as sns
|
| 130 |
+
from sklearn.metrics import (
|
| 131 |
+
confusion_matrix, accuracy_score, recall_score,
|
| 132 |
+
precision_score, f1_score, roc_auc_score,
|
| 133 |
+
classification_report
|
| 134 |
+
)
|
| 135 |
+
from sklearn.preprocessing import label_binarize
|
| 136 |
+
|
| 137 |
+
# --- 1. データセット準備用関数 ---
|
| 138 |
+
def transform_function(feature_extractor, img):
|
| 139 |
+
resized = transforms.Resize((224, 224))(img)
|
| 140 |
+
encoded = feature_extractor(images=resized, return_tensors="pt")
|
| 141 |
+
return encoded["pixel_values"][0]
|
| 142 |
+
|
| 143 |
+
def collate_fn(batch):
|
| 144 |
+
pixel_values = torch.stack([item[0] for item in batch])
|
| 145 |
+
labels = torch.tensor([item[1] for item in batch])
|
| 146 |
+
return {"pixel_values": pixel_values, "labels": labels}
|
| 147 |
+
|
| 148 |
+
# --- 2. モデルとデータセットの準備 ---
|
| 149 |
+
# モデルの準備
|
| 150 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 151 |
+
model = ViTForImageClassification.from_pretrained("poprap/vit16L-FT-cellclassification")
|
| 152 |
+
feature_extractor = ViTImageProcessor.from_pretrained("poprap/vit16L-FT-cellclassification")
|
| 153 |
+
model.to(device)
|
| 154 |
+
|
| 155 |
+
# データセットとデータローダーの準備
|
| 156 |
+
eval_dir = Path("path/to/eval/data") # 評価データのパス
|
| 157 |
+
dataset = datasets.ImageFolder(
|
| 158 |
+
root=str(eval_dir),
|
| 159 |
+
transform=lambda img: transform_function(feature_extractor, img)
|
| 160 |
+
)
|
| 161 |
+
dataloader = DataLoader(
|
| 162 |
+
dataset,
|
| 163 |
+
batch_size=32,
|
| 164 |
+
shuffle=False,
|
| 165 |
+
collate_fn=collate_fn
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# --- 3. バッチ推論の実行 ---
|
| 169 |
+
model.eval()
|
| 170 |
+
all_preds = []
|
| 171 |
+
all_labels = []
|
| 172 |
+
all_probs = []
|
| 173 |
+
|
| 174 |
+
start_time = time.time()
|
| 175 |
+
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
for batch in tqdm(dataloader, desc="Evaluating"):
|
| 178 |
+
inputs = batch["pixel_values"].to(device)
|
| 179 |
+
labels = batch["labels"].to(device)
|
| 180 |
+
|
| 181 |
+
outputs = model(inputs)
|
| 182 |
+
logits = outputs.logits
|
| 183 |
+
probs = torch.softmax(logits, dim=1)
|
| 184 |
+
preds = torch.argmax(probs, dim=1)
|
| 185 |
+
|
| 186 |
+
all_preds.extend(preds.cpu().numpy())
|
| 187 |
+
all_labels.extend(labels.cpu().numpy())
|
| 188 |
+
all_probs.extend(probs.cpu().numpy())
|
| 189 |
+
|
| 190 |
+
end_time = time.time()
|
| 191 |
+
|
| 192 |
+
# --- 4. 性能指標の計算 ---
|
| 193 |
+
# 処理時間の計算
|
| 194 |
+
total_images = len(all_labels)
|
| 195 |
+
total_time = end_time - start_time
|
| 196 |
+
time_per_image = total_time / total_images
|
| 197 |
+
|
| 198 |
+
# 基本的な指標
|
| 199 |
+
cm = confusion_matrix(all_labels, all_preds)
|
| 200 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 201 |
+
recall_weighted = recall_score(all_labels, all_preds, average="weighted")
|
| 202 |
+
precision_weighted = precision_score(all_labels, all_preds, average="weighted")
|
| 203 |
+
f1_weighted = f1_score(all_labels, all_preds, average="weighted")
|
| 204 |
+
|
| 205 |
+
# クラスごとのAUC計算
|
| 206 |
+
num_classes = len(dataset.classes)
|
| 207 |
+
all_labels_onehot = label_binarize(all_labels, classes=range(num_classes))
|
| 208 |
+
all_probs = np.array(all_probs)
|
| 209 |
+
|
| 210 |
+
auc_scores = {}
|
| 211 |
+
for class_idx in range(num_classes):
|
| 212 |
+
try:
|
| 213 |
+
auc = roc_auc_score(all_labels_onehot[:, class_idx], all_probs[:, class_idx])
|
| 214 |
+
auc_scores[dataset.classes[class_idx]] = auc
|
| 215 |
+
except ValueError:
|
| 216 |
+
auc_scores[dataset.classes[class_idx]] = None
|
| 217 |
+
|
| 218 |
+
# --- 5. 結果の可視化 ---
|
| 219 |
+
# Confusion Matrixの可視化
|
| 220 |
+
plt.figure(figsize=(10, 8))
|
| 221 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
|
| 222 |
+
xticklabels=dataset.classes,
|
| 223 |
+
yticklabels=dataset.classes)
|
| 224 |
+
plt.xlabel("Predicted Label")
|
| 225 |
+
plt.ylabel("True Label")
|
| 226 |
+
plt.title("Confusion Matrix")
|
| 227 |
+
plt.tight_layout()
|
| 228 |
+
plt.show()
|
| 229 |
+
|
| 230 |
+
# 結果の出力
|
| 231 |
+
print(f"\nEvaluation Results:")
|
| 232 |
+
print(f"Accuracy: {accuracy:.4f}")
|
| 233 |
+
print(f"Weighted Recall: {recall_weighted:.4f}")
|
| 234 |
+
print(f"Weighted Precision: {precision_weighted:.4f}")
|
| 235 |
+
print(f"Weighted F1: {f1_weighted:.4f}")
|
| 236 |
+
print(f"\nAUC Scores per Class:")
|
| 237 |
+
for class_name, auc in auc_scores.items():
|
| 238 |
+
print(f"{class_name}: {auc:.4f}" if auc is not None else f"{class_name}: N/A")
|
| 239 |
+
|
| 240 |
+
print(f"\nDetailed Classification Report:")
|
| 241 |
+
print(classification_report(all_labels, all_preds, target_names=dataset.classes))
|
| 242 |
+
|
| 243 |
+
print(f"\nPerformance Metrics:")
|
| 244 |
+
print(f"Total images evaluated: {total_images}")
|
| 245 |
+
print(f"Total time: {total_time:.2f} seconds")
|
| 246 |
+
print(f"Average time per image: {time_per_image:.4f} seconds")
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
This example demonstrates how to:
|
| 250 |
+
1. Process multiple images in batches
|
| 251 |
+
2. Calculate comprehensive evaluation metrics
|
| 252 |
+
3. Generate confusion matrix visualization
|
| 253 |
+
4. Measure inference time performance
|
| 254 |
+
|
| 255 |
+
Key metrics calculated:
|
| 256 |
+
- Accuracy, Precision, Recall, F1-score
|
| 257 |
+
- Class-wise AUC scores
|
| 258 |
+
- Confusion matrix
|
| 259 |
+
- Detailed classification report
|
| 260 |
+
- Processing time statistics
|
| 261 |
+
|
| 262 |
+
## Training Configuration
|
| 263 |
+
|
| 264 |
+
The model was fine-tuned with the following settings:
|
| 265 |
+
|
| 266 |
+
### Hyperparameters
|
| 267 |
+
- Batch size: 56
|
| 268 |
+
- Learning rate: 1e-5
|
| 269 |
+
- Number of epochs: 20
|
| 270 |
+
- Mixed precision training (FP16)
|
| 271 |
+
- Label smoothing: 0.1
|
| 272 |
+
- Cosine scheduling with warmup (warmup steps: 100)
|
| 273 |
+
|
| 274 |
+
### Data Augmentation
|
| 275 |
+
- RandomResizedCrop (224x224, scale=(0.8, 1.0))
|
| 276 |
+
- RandomHorizontalFlip
|
| 277 |
+
- RandomRotation (±10 degrees)
|
| 278 |
+
- ColorJitter (brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
|
| 279 |
+
|
| 280 |
+
### Implementation Details
|
| 281 |
+
- Utilized HuggingFace Transformers' `Trainer` class
|
| 282 |
+
- Checkpoint saving: every 100 steps
|
| 283 |
+
- Evaluation: every 100 steps
|
| 284 |
+
- Logging: every 10 steps
|
| 285 |
+
|
| 286 |
+
## Data Source
|
| 287 |
+
|
| 288 |
+
This project uses data from the following research paper:
|
| 289 |
+
|
| 290 |
+
Phillip Eulenberg, Niklas Köhler, Thomas Blasi, Andrew Filby, Anne E. Carpenter, Paul Rees, Fabian J. Theis & F. Alexander Wolf. "Reconstructing cell cycle and disease progression using deep learning." Nature Communications volume 8, Article number: 463 (2017).
|
| 291 |
+
|
| 292 |
+
## License
|
| 293 |
+
|
| 294 |
+
This project is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0), inheriting the same license as the base Google Vision Transformer model.
|
| 295 |
+
|
| 296 |
+
## Citations
|
| 297 |
+
|
| 298 |
+
```bibtex
|
| 299 |
+
@misc{dosovitskiy2021vit,
|
| 300 |
+
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
| 301 |
+
author={Alexey Dosovitskiy and others},
|
| 302 |
+
year={2021},
|
| 303 |
+
eprint={2010.11929},
|
| 304 |
+
archivePrefix={arXiv}
|
| 305 |
+
}
|
prepare_data.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
白血病細胞画像の前処理用スクリプト
|
| 4 |
+
ch1とch6の画像をマージし、正規化してRGB形式で保存します。
|
| 5 |
+
|
| 6 |
+
必要なライブラリ:
|
| 7 |
+
- numpy
|
| 8 |
+
- tifffile (TIFFファイルの読み込み用)
|
| 9 |
+
- PIL (画像処理用)
|
| 10 |
+
|
| 11 |
+
使用方法:
|
| 12 |
+
python prepare_data.py input_dir output_dir [--workers N] [--recursive]
|
| 13 |
+
"""
|
| 14 |
+
import argparse
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import numpy as np
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import tifffile
|
| 19 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 20 |
+
import sys
|
| 21 |
+
from typing import Tuple, List
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
def setup_logger():
|
| 25 |
+
"""ロガーの設定"""
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 29 |
+
)
|
| 30 |
+
return logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
def load_and_normalize(path: Path) -> np.ndarray:
|
| 33 |
+
"""
|
| 34 |
+
TIFF画像を読み込み、0~255の8bit画像に正規化する
|
| 35 |
+
"""
|
| 36 |
+
img = tifffile.imread(str(path))
|
| 37 |
+
img_norm = (img - np.min(img)) / (np.max(img) - np.min(img)) * 255
|
| 38 |
+
return img_norm.astype(np.uint8)
|
| 39 |
+
|
| 40 |
+
def process_image_pair(paths: Tuple[Path, Path, Path]) -> None:
|
| 41 |
+
"""
|
| 42 |
+
ch1とch6の画像ペアを処理してマージ画像を保存
|
| 43 |
+
"""
|
| 44 |
+
ch1_path, ch6_path, save_path = paths
|
| 45 |
+
try:
|
| 46 |
+
# 画像の読み込みと正規化
|
| 47 |
+
arr1 = load_and_normalize(ch1_path)
|
| 48 |
+
arr6 = load_and_normalize(ch6_path)
|
| 49 |
+
|
| 50 |
+
# 空のチャンネル作成
|
| 51 |
+
empty_channel = np.zeros_like(arr1)
|
| 52 |
+
|
| 53 |
+
# RGB形式で統合 (R: ch1, G: ch6, B: empty)
|
| 54 |
+
merged_array = np.stack((arr1, arr6, empty_channel), axis=-1)
|
| 55 |
+
merged_image = Image.fromarray(merged_array)
|
| 56 |
+
|
| 57 |
+
# 保存
|
| 58 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
merged_image.save(save_path)
|
| 60 |
+
return True
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logging.error(f"Error processing {ch1_path}: {e}")
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
def find_image_pairs(input_dir: Path) -> List[Tuple[Path, Path]]:
|
| 67 |
+
"""
|
| 68 |
+
入力ディレクトリからch1とch6のペアを見つける
|
| 69 |
+
"""
|
| 70 |
+
pairs = []
|
| 71 |
+
for ch1_file in input_dir.glob("ch1_*.tif"):
|
| 72 |
+
idx = ch1_file.stem.split('_')[1]
|
| 73 |
+
ch6_file = ch1_file.parent / f"ch6_{idx}.tif"
|
| 74 |
+
if ch6_file.exists():
|
| 75 |
+
pairs.append((ch1_file, ch6_file))
|
| 76 |
+
return pairs
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
parser = argparse.ArgumentParser(description='細胞画像の前処理スクリプト')
|
| 80 |
+
parser.add_argument('input_dir', type=str, help='入力ディレクトリのパス')
|
| 81 |
+
parser.add_argument('output_dir', type=str, help='出力ディレクトリのパス')
|
| 82 |
+
parser.add_argument('--workers', type=int, default=4, help='並列処理のワーカー数')
|
| 83 |
+
parser.add_argument('--recursive', action='store_true', help='サブディレクトリも処理する')
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
|
| 86 |
+
logger = setup_logger()
|
| 87 |
+
input_path = Path(args.input_dir)
|
| 88 |
+
output_path = Path(args.output_dir)
|
| 89 |
+
|
| 90 |
+
if not input_path.exists():
|
| 91 |
+
logger.error(f"入力ディレクトリが存在しません: {args.input_dir}")
|
| 92 |
+
sys.exit(1)
|
| 93 |
+
|
| 94 |
+
# 処理対象のディレクトリを特定
|
| 95 |
+
target_dirs = list(input_path.glob("**/*")) if args.recursive else [input_path]
|
| 96 |
+
target_dirs = [d for d in target_dirs if d.is_dir()]
|
| 97 |
+
|
| 98 |
+
total_processed = 0
|
| 99 |
+
total_failed = 0
|
| 100 |
+
|
| 101 |
+
with ProcessPoolExecutor(max_workers=args.workers) as executor:
|
| 102 |
+
for current_dir in target_dirs:
|
| 103 |
+
# 画像ペアの検索
|
| 104 |
+
pairs = find_image_pairs(current_dir)
|
| 105 |
+
if not pairs:
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
# 相対パスを保持した出力先の設定
|
| 109 |
+
rel_path = current_dir.relative_to(input_path)
|
| 110 |
+
current_output_dir = output_path / rel_path
|
| 111 |
+
|
| 112 |
+
# 処理タスクのリスト作成
|
| 113 |
+
tasks = [
|
| 114 |
+
(ch1_file, ch6_file, current_output_dir / f"merged_{ch1_file.stem.split('_')[1]}.tif")
|
| 115 |
+
for ch1_file, ch6_file in pairs
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
# 並列処理の実行
|
| 119 |
+
futures = [executor.submit(process_image_pair, task) for task in tasks]
|
| 120 |
+
|
| 121 |
+
successful = sum(1 for future in futures if future.result())
|
| 122 |
+
failed = len(futures) - successful
|
| 123 |
+
|
| 124 |
+
total_processed += successful
|
| 125 |
+
total_failed += failed
|
| 126 |
+
|
| 127 |
+
logger.info(f"{current_dir.name}: {successful}/{len(pairs)} files processed successfully")
|
| 128 |
+
|
| 129 |
+
logger.info(f"\n処理完了:")
|
| 130 |
+
logger.info(f"成功: {total_processed}")
|
| 131 |
+
logger.info(f"失敗: {total_failed}")
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|