tamara-kostova's picture
Upload README.md with huggingface_hub
6b6c244 verified
metadata
license: mit
language:
  - en
library_name: pytorch
base_model:
  - microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224
  - torchvision/densenet169
datasets:
  - kaggle-brain-stroke-ct
tags:
  - medical-imaging
  - brain-ct
  - stroke-classification
  - binary-classification
  - pytorch
# Ischemic Stroke Binary Classifier

PyTorch checkpoint artifacts for the MultiAgentMedClassifier stroke task.
Contains a DenseNet169 CNN checkpoint and a BiomedCLIP linear-probe checkpoint
for classifying brain CT images as normal or ischemic stroke.

These are checkpoint files for the accompanying project loaders, not standalone
Transformers models.

## Model Description

- Task: binary stroke CT classification
- CNN architecture: DenseNet169
- Vision-language backbone for probe: `microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224`
- Framework: PyTorch

## Classes

- `normal`
- `stroke`

The project-level BiomedCLIP labels are:

- `normal brain CT`
- `ischemic stroke brain CT`

## Files

- `stroke/cnn/densenet169_CT_stroke_binary_norm_final.pt`: DenseNet169 CNN checkpoint for binary stroke CT classification.
  • stroke/biomedclip/linear_probe_BiomedCLIP_CT_stroke_binary_norm_best.pt: BiomedCLIP linear-probe checkpoint for binary stroke CT classification.

    Training Details

    • Input size: 224 x 224 RGB
    • Normalization: ImageNet mean/std
    • CNN checkpoint: DenseNet169 fine-tuned for the stroke task
    • BiomedCLIP probe: linear/MLP probe over frozen BiomedCLIP image features (layer 6)

    Metrics

    Model Accuracy
    DenseNet169 CNN 97.7%

    Inference Example

    from huggingface_hub import hf_hub_download
    from agents.cnn_tool import CNNClassifier
    from config import DEFAULT_CONFIG
    
    checkpoint_path = hf_hub_download(
        repo_id="tamara-kostova/multiagentmed-stroke",
        filename="stroke/cnn/densenet169_CT_stroke_binary_norm_final.pt",
    )
    DEFAULT_CONFIG.model.cnn_checkpoints["stroke"] = checkpoint_path
    classifier = CNNClassifier(DEFAULT_CONFIG.model, DEFAULT_CONFIG.preprocess)
    result = classifier.classify("path/to/brain_ct.png", task="stroke")
    print(result)
    

    Intended Use

    Research and experimentation only. Not a medical device. Always validate on your own held-out test set before using in any pipeline.