| | --- |
| | library_name: transformers |
| | tags: [] |
| | --- |
| | |
| | # Fine-tuned ViT image classifier |
| |
|
| | This repository provides a fine-tuned Vision Transformer model for classifying leukemia patient peripheral blood mononuclear cells. |
| |
|
| | ## Model Overview |
| |
|
| | - **Base Model**: google/vit-large-patch16-224-in21k |
| | - **Task**: 5-class classification of leukemia cells |
| | - **Input**: 224x224 pixel dual-channel fluorescence microscopy images (R: ch1, G: ch6) |
| | - **Output**: Probability distribution over 5 classes |
| |
|
| | ## Performance |
| |
|
| | - **Architecture**: ViT-Large/16 (patch size 16x16) |
| | - **Parameters**: ~307M |
| | - **Accuracy**: 94.67% (evaluation dataset) |
| |
|
| | ## Data Preparation |
| |
|
| | ### Prerequisites for Data Processing |
| |
|
| | ```bash |
| | # Required libraries for image processing |
| | pip install numpy pillow tifffile |
| | ``` |
| |
|
| | ### Data Processing Tool |
| |
|
| | `tools/prepare_data.py` is a lightweight script for preprocessing dual-channel (ch1, ch6) cell images. |
| | Implemented primarily using standard libraries, it performs the following operations: |
| |
|
| | 1. Detects ch1 and ch6 image pairs |
| | 2. Normalizes each channel (0-255 scaling) |
| | 3. Converts to RGB format (R: ch1, G: ch6, B: empty channel) |
| | 4. Saves to specified output directory |
| |
|
| | ```bash |
| | # Basic usage |
| | python prepare_data.py input_dir output_dir |
| | |
| | # Example with options |
| | python prepare_data.py \ |
| | /path/to/raw_images \ |
| | /path/to/processed_images \ |
| | --workers 8 \ |
| | --recursive |
| | ``` |
| |
|
| | #### Options |
| | - `--workers`: Number of parallel workers (default: 4) |
| | - `--recursive`: Process subdirectories recursively |
| |
|
| | #### Input Directory Structure |
| | ``` |
| | input_dir/ |
| | ├── class1/ |
| | │ ├── ch1_1.tif |
| | │ ├── ch6_1.tif |
| | │ ├── ch1_2.tif |
| | │ └── ch6_2.tif |
| | └── class2/ |
| | ├── ch1_1.tif |
| | ├── ch6_1.tif |
| | ... |
| | ``` |
| |
|
| | #### Output Directory Structure |
| | ``` |
| | output_dir/ |
| | ├── class1/ |
| | │ ├── merged_1.tif |
| | │ └── merged_2.tif |
| | └── class2/ |
| | ├── merged_1.tif |
| | ... |
| | ``` |
| |
|
| | ## Model Usage |
| |
|
| | ### Prerequisites for Model |
| |
|
| | ```bash |
| | # Required libraries for model inference |
| | pip install torch torchvision transformers |
| | ``` |
| |
|
| | ### Usage Example |
| |
|
| | #### Single Image Inference |
| |
|
| | ```python |
| | from transformers import ViTForImageClassification, ViTImageProcessor |
| | import torch |
| | from PIL import Image |
| | |
| | # Load model and processor |
| | model = ViTForImageClassification.from_pretrained("poprap/vit16L-FT-cellclassification") |
| | processor = ViTImageProcessor.from_pretrained("poprap/vit16L-FT-cellclassification") |
| | |
| | # Preprocess image |
| | image = Image.open("cell_image.tif") |
| | inputs = processor(images=image, return_tensors="pt") |
| | |
| | # Inference |
| | outputs = model(**inputs) |
| | probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| | predicted_class = torch.argmax(probabilities, dim=-1).item() |
| | ``` |
| |
|
| | #### Batch Processing and Evaluation |
| |
|
| | For batch processing and comprehensive evaluation metrics calculation: |
| |
|
| | ```python |
| | import torch |
| | import numpy as np |
| | import time |
| | from pathlib import Path |
| | from tqdm import tqdm |
| | from torchvision import transforms, datasets |
| | from torch.utils.data import DataLoader |
| | from transformers import ViTForImageClassification, ViTImageProcessor |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from sklearn.metrics import ( |
| | confusion_matrix, accuracy_score, recall_score, |
| | precision_score, f1_score, roc_auc_score, |
| | classification_report |
| | ) |
| | from sklearn.preprocessing import label_binarize |
| | |
| | # --- 1. データセット準備用関数 --- |
| | def transform_function(feature_extractor, img): |
| | resized = transforms.Resize((224, 224))(img) |
| | encoded = feature_extractor(images=resized, return_tensors="pt") |
| | return encoded["pixel_values"][0] |
| | |
| | def collate_fn(batch): |
| | pixel_values = torch.stack([item[0] for item in batch]) |
| | labels = torch.tensor([item[1] for item in batch]) |
| | return {"pixel_values": pixel_values, "labels": labels} |
| | |
| | # --- 2. モデルとデータセットの準備 --- |
| | # モデルの準備 |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = ViTForImageClassification.from_pretrained("poprap/vit16L-FT-cellclassification") |
| | feature_extractor = ViTImageProcessor.from_pretrained("poprap/vit16L-FT-cellclassification") |
| | model.to(device) |
| | |
| | # データセットとデータローダーの準備 |
| | eval_dir = Path("path/to/eval/data") # 評価データのパス |
| | dataset = datasets.ImageFolder( |
| | root=str(eval_dir), |
| | transform=lambda img: transform_function(feature_extractor, img) |
| | ) |
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=32, |
| | shuffle=False, |
| | collate_fn=collate_fn |
| | ) |
| | |
| | # --- 3. バッチ推論の実行 --- |
| | model.eval() |
| | all_preds = [] |
| | all_labels = [] |
| | all_probs = [] |
| | |
| | start_time = time.time() |
| | |
| | with torch.no_grad(): |
| | for batch in tqdm(dataloader, desc="Evaluating"): |
| | inputs = batch["pixel_values"].to(device) |
| | labels = batch["labels"].to(device) |
| | |
| | outputs = model(inputs) |
| | logits = outputs.logits |
| | probs = torch.softmax(logits, dim=1) |
| | preds = torch.argmax(probs, dim=1) |
| | |
| | all_preds.extend(preds.cpu().numpy()) |
| | all_labels.extend(labels.cpu().numpy()) |
| | all_probs.extend(probs.cpu().numpy()) |
| | |
| | end_time = time.time() |
| | |
| | # --- 4. 性能指標の計算 --- |
| | # 処理時間の計算 |
| | total_images = len(all_labels) |
| | total_time = end_time - start_time |
| | time_per_image = total_time / total_images |
| | |
| | # 基本的な指標 |
| | cm = confusion_matrix(all_labels, all_preds) |
| | accuracy = accuracy_score(all_labels, all_preds) |
| | recall_weighted = recall_score(all_labels, all_preds, average="weighted") |
| | precision_weighted = precision_score(all_labels, all_preds, average="weighted") |
| | f1_weighted = f1_score(all_labels, all_preds, average="weighted") |
| | |
| | # クラスごとのAUC計算 |
| | num_classes = len(dataset.classes) |
| | all_labels_onehot = label_binarize(all_labels, classes=range(num_classes)) |
| | all_probs = np.array(all_probs) |
| | |
| | auc_scores = {} |
| | for class_idx in range(num_classes): |
| | try: |
| | auc = roc_auc_score(all_labels_onehot[:, class_idx], all_probs[:, class_idx]) |
| | auc_scores[dataset.classes[class_idx]] = auc |
| | except ValueError: |
| | auc_scores[dataset.classes[class_idx]] = None |
| | |
| | # --- 5. 結果の可視化 --- |
| | # Confusion Matrixの可視化 |
| | plt.figure(figsize=(10, 8)) |
| | sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", |
| | xticklabels=dataset.classes, |
| | yticklabels=dataset.classes) |
| | plt.xlabel("Predicted Label") |
| | plt.ylabel("True Label") |
| | plt.title("Confusion Matrix") |
| | plt.tight_layout() |
| | plt.show() |
| | |
| | # 結果の出力 |
| | print(f"\nEvaluation Results:") |
| | print(f"Accuracy: {accuracy:.4f}") |
| | print(f"Weighted Recall: {recall_weighted:.4f}") |
| | print(f"Weighted Precision: {precision_weighted:.4f}") |
| | print(f"Weighted F1: {f1_weighted:.4f}") |
| | print(f"\nAUC Scores per Class:") |
| | for class_name, auc in auc_scores.items(): |
| | print(f"{class_name}: {auc:.4f}" if auc is not None else f"{class_name}: N/A") |
| | |
| | print(f"\nDetailed Classification Report:") |
| | print(classification_report(all_labels, all_preds, target_names=dataset.classes)) |
| | |
| | print(f"\nPerformance Metrics:") |
| | print(f"Total images evaluated: {total_images}") |
| | print(f"Total time: {total_time:.2f} seconds") |
| | print(f"Average time per image: {time_per_image:.4f} seconds") |
| | ``` |
| |
|
| | This example demonstrates how to: |
| | 1. Process multiple images in batches |
| | 2. Calculate comprehensive evaluation metrics |
| | 3. Generate confusion matrix visualization |
| | 4. Measure inference time performance |
| |
|
| | Key metrics calculated: |
| | - Accuracy, Precision, Recall, F1-score |
| | - Class-wise AUC scores |
| | - Confusion matrix |
| | - Detailed classification report |
| | - Processing time statistics |
| |
|
| | ## Training Configuration |
| |
|
| | The model was fine-tuned with the following settings: |
| |
|
| | ### Hyperparameters |
| | - Batch size: 56 |
| | - Learning rate: 1e-5 |
| | - Number of epochs: 20 |
| | - Mixed precision training (FP16) |
| | - Label smoothing: 0.1 |
| | - Cosine scheduling with warmup (warmup steps: 100) |
| |
|
| | ### Data Augmentation |
| | - RandomResizedCrop (224x224, scale=(0.8, 1.0)) |
| | - RandomHorizontalFlip |
| | - RandomRotation (±10 degrees) |
| | - ColorJitter (brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1) |
| |
|
| | ### Implementation Details |
| | - Utilized HuggingFace Transformers' `Trainer` class |
| | - Checkpoint saving: every 100 steps |
| | - Evaluation: every 100 steps |
| | - Logging: every 10 steps |
| |
|
| | ## Data Source |
| |
|
| | This project uses data from the following research paper: |
| |
|
| | Phillip Eulenberg, Niklas Köhler, Thomas Blasi, Andrew Filby, Anne E. Carpenter, Paul Rees, Fabian J. Theis & F. Alexander Wolf. "Reconstructing cell cycle and disease progression using deep learning." Nature Communications volume 8, Article number: 463 (2017). |
| |
|
| | ## License |
| |
|
| | This project is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0), inheriting the same license as the base Google Vision Transformer model. |
| |
|
| | ## Citations |
| |
|
| | ```bibtex |
| | @misc{dosovitskiy2021vit, |
| | title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale}, |
| | author={Alexey Dosovitskiy and others}, |
| | year={2021}, |
| | eprint={2010.11929}, |
| | archivePrefix={arXiv} |
| | } |