English

ALPaCA: Adapting Llama for Pathology Context Analysis

Welcome to ALPaCA, a multimodal training framework tailored for slide-level question answering in computational pathology. ALPaCA integrates Llama3.1-8B-Instruct as the language backbone and CONCH as the patch-level vision encoder. The trained model is named Llama-slideQA.

Setup

  1. Download the base model Llama-3.1-8B-Instruct (Hugging Face: meta-llama/Llama-3.1-8B-Instruct).
  2. Install dependencies:
    pip install -r requirements.txt
    
  3. Prepare WSI patch features. Tile each WSI at three scales (0_1024, 1_512, 1_1024) and run CONCH inference to produce per-patch features. The data-processing code lives at https://github.com/ZeyuGaoAi/SMMILe.
  4. Prepare GMM prototype features (required for the gmm aggregation branch). Run GMM_feature_extraction/src/wsi_single_demo.py (or its wrapper run_wsi_single.sh) using the provided CONCH prototypes under GMM_feature_extraction/GTEx-TCGA-prototypes/. The prototypes are split into GTEx_TCGA_merged_output_less/ (32/16/8 clusters) and _more/ (256/128/64 clusters); pick the one matching your --version setting.
  5. Make sure you have access to the Hugging Face datasets listed in the frontmatter above.

Aggregation strategies

--agg_strategy selects how patch features are aggregated to slide-level tokens:

Value Description
longnet LongNet + question-conditioned cross-attention over multi-scale patch features
gmm Cross-attention over GMM prototype tokens (prob | mean | cov per prototype)
abmil Gated attention pooling (ABMIL) over patches β€” pure-vision baseline, no text interaction
kmeans Pass pre-computed k-means cluster centroids straight to the shared resampler β€” lightweight prototype baseline
random Replace the aggregated slide embedding with Gaussian noise β€” sanity-check / random-image baseline that ignores the WSI entirely
longnet,gmm Hybrid β€” both branches are computed and the slide-level embeddings are summed. This is the configuration reported in the paper.

Strategies are also composable beyond the hybrid above (e.g. abmil,gmm or random,gmm). When using a composite, --fea_root accepts a comma-separated list whose order matches --agg_strategy: the N-th path is consumed by the N-th strategy. The random branch still reads features (they are discarded) so any valid path works for that slot.

Note on Q-Former. An earlier prototype shipped a qformer strategy that mirrored longnet but replaced the LongNet backbone with a vanilla BERT-style Transformer encoder for the question-conditioned cross-attention. If you want to reproduce it, swap the slide_encoder.create_model(...) call inside the longnet branch with a standard BertModel (or any nn.TransformerEncoder) operating over the same query / patch / instruction inputs.

Other key flags

Flag Meaning
--venc_name Patch-feature encoder name. Conch (512-dim, default), MUSK / UNI (1024-dim), PathGen (512-dim). Must match the directory layout under --fea_root.
--embed_dim Patch-feature dimensionality. Set to 512 for Conch/PathGen, 1024 for MUSK/UNI.
--n_heads Comma-separated number of query/attention heads per pyramid level (e.g. 32,16,8).
--hierachical_token True: insert <|High|><|Mid|><|Low|> magnification tokens before each level.
--hierachical_adaptor True: separate adaptor per level; False: shared adaptor across all levels.
--gmm_need_query True: GMM branch uses learnable queries + cross-attention with cos/TV auxiliary losses; False: GMM features go through a single resampler.
--att_loss_weight Weight on the LongNet+GMM auxiliary cos/TV losses. Default 1.0.
--llm_requires_grad False freezes the LLM (used for Stage 1); True fine-tunes it (Stage 2/3).
--ckpt_path Path to a .bin checkpoint produced by a previous stage.
--dataset_multiplier_list Optional per-dataset upsampling factors aligned with --dataset_name_list (e.g. "1,1,5,5" to upsample the last two datasets 5x).

Training

Three stages are provided as ready-to-run scripts. Each .sh shows multiple encoder variants β€” comment/uncomment the block matching your --venc_name.

Stage 1 β€” slide-level description pretraining (LLM frozen)

accelerate launch --config_file=./accelerate_configs/deepspeed_zero2.yaml run_wsi.py \
    --max_steps 20000 --warmup_steps 1000 \
    --gpu 2 --train_batch_size 4 --eval_batch_size 2 --max_seq_length 512 \
    --agg_strategy longnet,gmm --embed_dim 512 --att_loss_weight 1 \
    --n_heads 32,16,8 --hierachical_token True --hierachical_adaptor True --gmm_need_query True \
    --llm_name meta-llama/Llama-3.1-8B-Instruct --venc_name Conch \
    --dataset_name_list CNX-PathLLM/TCGA-WSI-Description-4o,CNX-PathLLM/TCGA-WSI-Description-4omini,CNX-PathLLM/GTEx-WSI-Description \
    --data_cache_dir ~/.cache \
    --fea_root /path/to/Conch,/path/to/GTEx-TCGA-GMM_Conch/more \
    --output_dir /path/to/output/stage1 \
    --llm_requires_grad False --resume_from_checkpoint False

Also see run_wsi_stage1.sh.

Stage 2 β€” instruction QA fine-tuning (LLM unfrozen)

accelerate launch --config_file=./accelerate_configs/deepspeed_zero2.yaml run_wsi.py \
    --max_steps 20000 --warmup_steps 100 --save_steps 300 \
    --gpu 2 --train_batch_size 8 --eval_batch_size 2 --max_seq_length 256 \
    --agg_strategy longnet,gmm --embed_dim 512 --att_loss_weight 1 \
    --n_heads 32,16,8 --hierachical_token True --hierachical_adaptor True --gmm_need_query True \
    --llm_name meta-llama/Llama-3.1-8B-Instruct --venc_name Conch \
    --dataset_name_list CNX-PathLLM/TCGA-WSI-CloseQA-Balanced,CNX-PathLLM/GTEx-WSI-CloseQA-Balanced,CNX-PathLLM/TCGA-WSI-OpenQA,CNX-PathLLM/GTEx-WSI-OpenQA \
    --data_cache_dir ~/.cache \
    --fea_root /path/to/Conch,/path/to/GTEx-TCGA-GMM_Conch/more \
    --output_dir /path/to/output/stage2 \
    --ckpt_path /path/to/output/stage1/ckpt20000.bin \
    --llm_requires_grad True --resume_from_checkpoint False

Also see run_wsi_stage2.sh.

Stage 3 β€” domain-specific fine-tuning

Continue training from the Stage 2 checkpoint with a specialized dataset:

  • TCGA-BRCA detailed QA: CNX-PathLLM/TCGA-BRCA-Details-CloseQA,CNX-PathLLM/TCGA-BRCA-Details-OpenQA
  • PathChat morphological QA (TCGA-STAD / -KIRC / -OV): CNX-PathLLM/PathChat_CloseQA_Balanced,CNX-PathLLM/PathChat_OpenQA. See PathChat.

Use the same command as Stage 2, but set --dataset_name_list, --ckpt_path (pointing at the Stage 2 checkpoint), and --output_dir accordingly. See run_wsi_stage3.sh.

Tip β€” rephrase QA each epoch to prevent memorisation (Stage 2 & 3). Both Stage 2 and Stage 3 unfreeze the LLM and train on a fixed QA pool, so the model can start to memorise question wording and recover the answer from the prompt alone. We recommend regenerating the QA each epoch with an open-source LLM that paraphrases the question (and, for open-QA, the answer) while preserving meaning, so the model has to ground its prediction in the slide rather than the surface form. For the released checkpoint we used DeepSeek-V3 to rewrite our QA pool once per epoch.

Testing

General QA (after Stage 2)

python test_wsi.py \
    --max_seq_length 128 --batch_size 16 --select_data_num -1 --eval_sample_size -1 \
    --n_heads 32,16,8 --hierachical_token True --hierachical_adaptor True --gmm_need_query True \
    --llm_name meta-llama/Llama-3.1-8B-Instruct --venc_name Conch \
    --shuffle False --data_cache_dir ~/.cache \
    --dataset_name_list CNX-PathLLM/TCGA-WSI-CloseQA-Balanced,CNX-PathLLM/GTEx-WSI-CloseQA-Balanced,CNX-PathLLM/TCGA-WSI-OpenQA,CNX-PathLLM/GTEx-WSI-OpenQA \
    --agg_strategy longnet,gmm --embed_dim 512 \
    --fea_root /path/to/Conch,/path/to/GTEx-TCGA-GMM_Conch/more \
    --ckpt_path /path/to/output/stage2/ckpt.bin \
    --results_save_path /path/to/output.csv

Also see test_wsi_stage2.sh.

Specific QA (after Stage 3)

Same command, swap --dataset_name_list for the domain-specific dataset and --ckpt_path for the Stage 3 checkpoint. See test_wsi_stage3.sh.

Single-WSI demo (from .svs to an answer)

For a quick end-to-end smoke test on one slide, the recommended workflow is:

Step A β€” patch + CONCH feature extraction. Use CLAM_PreProcessing (a fork of CLAM tailored to this codebase). Run it three times so that you end up with the three magnification scales used by ALPaCA:

<conch-dir>/
    <slide_id>_0_1024.npy        # level 0 (40x), 1024 px tiles
    <slide_id>_1_512.npy         # level 1 (20x),  512 px tiles
    <slide_id>_1_1024.npy        # level 1 (20x), 1024 px tiles

Each file is a 0-d np.ndarray wrapping {"feature": (N, 512) float32, "index": ["row_col_*", ...]}.

Skip Step A β€” grab the sample feature set. We ship pre-extracted CONCH features for 20 TCGA + 20 GTEx slides drawn from the test set at CNX-PathLLM/Llama-slideQA-Sample-Features. The exact slide-ids, questions, answers, and per-slide mode (close-mc / close-tf / open) are listed in demo/slides_sample.csv, so you can reproduce our demo predictions out of the box.

Step B β€” GMM tokenisation + inference. Drive both with demo/single_wsi_demo.py:

python demo/single_wsi_demo.py \
    --conch-dir  /path/to/conch_features \
    --slide-id   TCGA-XX-XXXX-01Z-00-DX1.<uuid> \
    --out-dir    ./demo_out \
    --proto-root GMM_feature_extraction/GTEx-TCGA-prototypes/GTEx_TCGA_merged_output_more \
    --vlm-ckpt   /path/to/Llama-slideQA.bin \
    --llm-name   meta-llama/Llama-3.1-8B-Instruct \
    --question   "Describe the morphology of this slide."

The script does:

  1. tokenise the three CONCH feature scales against the CONCH prototypes under --proto-root β†’ one <slide_id>.npy GMM feature file,
  2. instantiate WPathVLM with the same flags as Stage 2 training (--agg_strategy longnet,gmm, --gmm_need_query True, hierarchical tokens on), load --vlm-ckpt, build a single-sample batch, and call model.generate().

Pass --skip-stage1 to reuse an existing GMM feature file in --out-dir.

--mode selects the prompt template (defaults to open, equivalent to wsi_formatting_qa_open_test):

Mode When to use Prompt template
open (default) Free-text questions β€” descriptions, "what is...", "where is..." <|Question|>{q}<|Answer|>
close-mc Multiple-choice; forces a single letter <|Question|>{q}<|Prompt|> Please provide only the answer (for example, A, B, etc.) ... <|Answer|>
close-tf True/False; forces a Yes / No <|Question|>{q}<|Prompt|> Please provide only the answer (either Yes or No) ... <|Answer|>

These match how test_wsi.py auto-routes samples in the closed-QA pipeline (see utils/formatting_funcs.py).

Checkpoint

Llama-slideQA.bin β€” trained with general QA following Stage 2 β€” is released alongside this model card on the Hugging Face hub.

Load it into the same model definition by passing --ckpt_path /path/to/Llama-slideQA.bin to test_wsi.py or demo/single_wsi_demo.py.

Disclaimer

This repository and all associated models are intended solely for academic research and non-commercial use. The model involves medical data (e.g. TCGA, GTEx) and pathology-related tasks, but is not approved for clinical diagnosis or medical decision-making. The developers are not responsible for any misuse of this code or model in medical or commercial contexts.

License

This model is developed using Meta's LLaMA 3 model as part of its architecture and follows the LLaMA 3.1 License.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for CNX-PathLLM/Llama-slideQA

Finetuned
(2833)
this model

Datasets used to train CNX-PathLLM/Llama-slideQA