--- license: mit language: - en library_name: pytorch base_model: - microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224 - torchvision/resnet101 datasets: - ms-flair-mri tags: - medical-imaging - brain-mri - multiple-sclerosis - binary-classification - pytorch --- # Multiple Sclerosis Binary Classifier PyTorch checkpoint artifacts for the MultiAgentMedClassifier MS task. Contains a ResNet101 CNN checkpoint and a BiomedCLIP linear-probe checkpoint for classifying brain FLAIR MRI images as normal or multiple sclerosis. These are checkpoint files for the accompanying project loaders, not standalone Transformers models. ## Model Description - Task: binary MS brain FLAIR MRI classification - CNN architecture: ResNet101 - Vision-language backbone for probe: `microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224` - Framework: PyTorch ## Classes - `normal` - `ms` The project-level BiomedCLIP labels are: - `normal brain FLAIR MRI` - `multiple sclerosis brain FLAIR MRI` ## Files - `ms/cnn/resnet101_MRI_ms_norm_final.pt`: ResNet101 CNN checkpoint for binary MS brain FLAIR MRI classification. - `ms/biomedclip/linear_probe_BiomedCLIP_MRI_ms_norm_best.pt`: BiomedCLIP linear-probe checkpoint for binary MS brain FLAIR MRI classification. ## Training Details - Input size: 224 x 224 RGB - Normalization: ImageNet mean/std - CNN checkpoint: ResNet101 fine-tuned for the `ms` task - BiomedCLIP probe: linear/MLP probe over frozen BiomedCLIP image features (layer 6) ## Metrics | Model | Accuracy | |-------|----------| | ResNet101 CNN | 59.7% | Note: MS classification from FLAIR MRI is a challenging task; the relatively lower accuracy reflects the difficulty of distinguishing subtle white matter lesion patterns. Recompute metrics on your own held-out test set. ## Inference Example ```python from huggingface_hub import hf_hub_download from agents.cnn_tool import CNNClassifier from config import DEFAULT_CONFIG checkpoint_path = hf_hub_download( repo_id="tamara-kostova/multiagentmed-ms", filename="ms/cnn/resnet101_MRI_ms_norm_final.pt", ) DEFAULT_CONFIG.model.cnn_checkpoints["ms"] = checkpoint_path classifier = CNNClassifier(DEFAULT_CONFIG.model, DEFAULT_CONFIG.preprocess) result = classifier.classify("path/to/brain_flair.png", task="ms") print(result) ``` ## Intended Use Research and experimentation only. Not a medical device. Always validate on your own held-out test set before using in any pipeline.