Commit
·
cbff41a
1
Parent(s):
958e3c5
added code for running
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +5 -0
- README.md +83 -3
- accelerate_configs/.ipynb_checkpoints/deepspeed_zero2-checkpoint.yaml +22 -0
- accelerate_configs/.ipynb_checkpoints/deepspeed_zero3-checkpoint.yaml +23 -0
- accelerate_configs/deepspeed_zero1.yaml +20 -0
- accelerate_configs/deepspeed_zero2.yaml +23 -0
- accelerate_configs/deepspeed_zero3.yaml +23 -0
- accelerate_configs/multi_gpu.yaml +16 -0
- accelerate_configs/single_gpu.yaml +16 -0
- data/huggingface_data.py +31 -0
- dataset_csv/gtex_slide_url_info.csv +0 -0
- dataset_csv/indices_and_slide_ids.csv +0 -0
- dataset_csv/indices_and_slide_ids_with_folds.csv +0 -0
- dataset_csv/tcga_slide_url_info.csv +0 -0
- demo/CONCH_clip.py +11 -0
- demo/Trainer_Mixtrial_ds_demo.py +154 -0
- demo/Trainer_bert_demo.py +76 -0
- demo/UNI_clip.py +33 -0
- demo/path_clip.py +28 -0
- demo/peft_demo.py +17 -0
- demo/trl_demo.py +175 -0
- evaluation/cider_score/cider_demo.ipynb +290 -0
- evaluation/cider_score/cidereval/__init__.py +5 -0
- evaluation/cider_score/cidereval/cider/__init__.py +1 -0
- evaluation/cider_score/cidereval/cider/cider.py +67 -0
- evaluation/cider_score/cidereval/cider/cider_scorer.py +274 -0
- evaluation/cider_score/cidereval/ciderD/__init__.py +1 -0
- evaluation/cider_score/cidereval/ciderD/ciderD.py +57 -0
- evaluation/cider_score/cidereval/ciderD/ciderD_scorer.py +265 -0
- evaluation/cider_score/cidereval/data/__init__.py +0 -0
- evaluation/cider_score/cidereval/data/coco-val.p +3 -0
- evaluation/cider_score/cidereval/eval.py +40 -0
- evaluation/cider_score/cidereval/scorers.py +76 -0
- evaluation/cider_score/cidereval/tokenizer/__init__.py +4 -0
- evaluation/cider_score/cidereval/tokenizer/ptbtokenizer.py +112 -0
- evaluation/cider_score/cidereval/tokenizer/simpletokenizer.py +106 -0
- evaluation/cider_score/output_sample.xls +0 -0
- filter_dataset.py +43 -0
- gigapath/__init__.py +0 -0
- gigapath/__pycache__/__init__.cpython-310.pyc +0 -0
- gigapath/__pycache__/pos_embed.cpython-310.pyc +0 -0
- gigapath/__pycache__/slide_encoder.cpython-310.pyc +0 -0
- gigapath/__pycache__/slide_encoder_vision.cpython-310.pyc +0 -0
- gigapath/classification_head.py +92 -0
- gigapath/pipeline.py +190 -0
- gigapath/pos_embed.py +105 -0
- gigapath/preprocessing/__init__.py +0 -0
- gigapath/preprocessing/data/__init__.py +0 -0
- gigapath/preprocessing/data/box_utils.py +145 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.p filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Conch_Llama3ins_Instruct3/
|
| 2 |
+
Conch_Mistral_Instruct3/
|
| 3 |
+
output/
|
| 4 |
+
logs/
|
| 5 |
+
wandb/
|
README.md
CHANGED
|
@@ -1,3 +1,83 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PathLLM
|
| 2 |
+
|
| 3 |
+
Welcome to ALPaCA. This repository aims to provide a straightforward reproduction of ALPaCA. To run ALPaCA, please first download **Llama3.1-8b-instruct** as the base model.
|
| 4 |
+
|
| 5 |
+
For data from TCGA and GTEx, you can visit the [GDC Data Portal Homepage](https://portal.gdc.cancer.gov/) and [GTEx Portal](https://www.gtexportal.org/) to download and extract features yourself. Alternatively, you can use the features we have already extracted based on CONCH: `CNX-PathLLM/GMM_Embeddings` and `CNX-PathLLM/GTEx-TCGA-Embeddings`. After downloading, please unzip them into the respective folders for `TCGA-Embedding` and `GMM_Embedding`.
|
| 6 |
+
|
| 7 |
+
Please ensure you have access to the pathological image description data:
|
| 8 |
+
`CNX-PathLLM/TCGA-WSI-Description-4onew`, `CNX-PathLLM/TCGA-WSI-Description-4omini`, and `CNX-PathLLM/GTEx-WSI-Description`.
|
| 9 |
+
|
| 10 |
+
Please ensure you also have access to the WSI-QA data:
|
| 11 |
+
`CNX-PathLLM/TCGA-WSI-CloseQA-Balanced`, `CNX-PathLLM/GTEx-WSI-CloseQA-Balanced`, `CNX-PathLLM/TCGA-WSI-OpenQA`, and `CNX-PathLLM/GTEx-WSI-OpenQA`.
|
| 12 |
+
|
| 13 |
+
After completing all the setups mentioned above and setting up the correct Python environment, you can start the training process using the provided shell script, e.g., `run_wsi_stage*.sh`, or follow the instructions in the [Train Step](#train-step-1) section below.
|
| 14 |
+
|
| 15 |
+
Do not forget to adjust the TCGA and GMM embedding paths to reflect your own file locations.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
## Settings
|
| 19 |
+
|
| 20 |
+
### Different Aggregate Strategies
|
| 21 |
+
You can change aggregate strategies using the `--agg_strategy` flag, such as `qformer`, `abmil`, or `longnet`. You can also reproduce the method described in our paper by setting `--agg_strategy gmm,longnet` in the `.sh` script.
|
| 22 |
+
|
| 23 |
+
### Configurable Settings
|
| 24 |
+
|
| 25 |
+
(1) `--vision_adaptor False --hierarchical_adaptor True`
|
| 26 |
+
|
| 27 |
+
(2) `--vision_adaptor False --hierarchical_adaptor False`
|
| 28 |
+
|
| 29 |
+
(3) `--vision_adaptor True --hierarchical_adaptor True`
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
--vision_adaptor False (vision-query-question interaction)
|
| 33 |
+
--vision_adaptor True (vision-query interaction)
|
| 34 |
+
|
| 35 |
+
--hierarchical_adaptor False (same adaptor for all levels)
|
| 36 |
+
--hierarchical_adaptor True (different adaptors for different levels)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Train Step 1 ##
|
| 40 |
+
```
|
| 41 |
+
accelerate launch --config_file=./accelerate_configs/deepspeed_zero2.yaml run_wsi.py --learning_rate 1e-4 --max_steps 10000 --warmup_steps 100\
|
| 42 |
+
--gpu 2 --train_batch_size 4 --eval_batch_size 2 --max_seq_length 512 \
|
| 43 |
+
--agg_strategy gmm,longnet --embed_dim 512 --vision_adaptor False --hierachical_token True --hierachical_adaptor True\
|
| 44 |
+
--n_heads 32,16,8 --llm_requires_grad False --resume_from_checkpoint False \
|
| 45 |
+
--llm_name /data_local/pxb/LLM_models/llama3/llama3.1-8b-instruct \
|
| 46 |
+
--dataset_name_list CNX-PathLLM/TCGA-WSI-Description-4onew,CNX-PathLLM/TCGA-WSI-Description-4omini,CNX-PathLLM/GTEx-WSI-Description \
|
| 47 |
+
--data_cache_dir /data_local/pxb/CNX-PathLLM/.cache \
|
| 48 |
+
--fea_root /path/to/CNX-PathLLM/GTEx-TCGA-Embeddings \
|
| 49 |
+
--gmm_root /path/to/GMM_Embeddings\
|
| 50 |
+
--output_dir path/to/output/of/step2
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Train Step 2 ##
|
| 54 |
+
```
|
| 55 |
+
accelerate launch --config_file=./accelerate_configs/deepspeed_zero2.yaml run_wsi.py --max_steps 20000 --warmup_steps 10\
|
| 56 |
+
--gpu 2 --train_batch_size 8 --eval_batch_size 2 --max_seq_length 256 \
|
| 57 |
+
--agg_strategy gmm,longnet --embed_dim 512 --vision_adaptor False --hierachical_token True --hierachical_adaptor True\
|
| 58 |
+
--n_heads 32,16,8 --llm_requires_grad True --resume_from_checkpoint False \
|
| 59 |
+
--llm_name /data_local/pxb/LLM_models/llama3/llama3.1-8b-instruct \
|
| 60 |
+
--dataset_name_list CNX-PathLLM/TCGA-WSI-CloseQA-Balanced,CNX-PathLLM/GTEx-WSI-CloseQA-Balanced,CNX-PathLLM/TCGA-WSI-OpenQA,CNX-PathLLM/GTEx-WSI-OpenQA \
|
| 61 |
+
--data_cache_dir /data_local/pxb/CNX-PathLLM/.cache \
|
| 62 |
+
--fea_root /path/to/CNX-PathLLM/GTEx-TCGA-Embeddings \
|
| 63 |
+
--gmm_root /path/to/GMM_Embeddings\
|
| 64 |
+
--output_dir path/to/output/of/step2\
|
| 65 |
+
--ckpt_path path/to/ckpt.bin/of/step1
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Train Step 3 ##
|
| 69 |
+
To continue training with the specific detailed BRCA dataset! Make sure you can access the dataset and change above command with the dataset you want.
|
| 70 |
+
|
| 71 |
+
## Test of Step2 General QA ##
|
| 72 |
+
|
| 73 |
+
```
|
| 74 |
+
python test_wsi.py --max_seq_length 128 --batch_size 1 --select_data_num -1 --eval_sample_size -1 --n_heads 32,16,8 --llm_name /data_local/pxb/LLM_models/llama3/llama3.1-8b-instruct --vision_adaptor False --hierachical_token True --hierachical_adaptor True \
|
| 75 |
+
--shuffle False --data_cache_dir /data_local/pxb/CNX-PathLLM/.cache\
|
| 76 |
+
--dataset_name_list CNX-PathLLM/TCGA-WSI-CloseQA-Balanced,CNX-PathLLM/GTEx-WSI-CloseQA-Balanced,CNX-PathLLM/TCGA-WSI-OpenQA,CNX-PathLLM/GTEx-WSI-OpenQA\
|
| 77 |
+
--agg_strategy gmm,longnet --embed_dim 512\
|
| 78 |
+
--fea_root /path/to/CNX-PathLLM/GTEx-TCGA-Embeddings \
|
| 79 |
+
--gmm_root /path/to/GMM_Embeddings\
|
| 80 |
+
--ckpt_path path/to/ckpt.bin/of/step2\
|
| 81 |
+
--results_save_path /path/to/the/output.csv\
|
| 82 |
+
--use_peft False
|
| 83 |
+
```
|
accelerate_configs/.ipynb_checkpoints/deepspeed_zero2-checkpoint.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
gradient_accumulation_steps: "auto"
|
| 6 |
+
offload_optimizer_device: "cpu"
|
| 7 |
+
offload_param_device: "cpu"
|
| 8 |
+
zero3_init_flag: false
|
| 9 |
+
zero_stage: 2
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'auto'
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_training_function: main
|
| 14 |
+
mixed_precision: 'bf16'
|
| 15 |
+
num_machines: 1
|
| 16 |
+
num_processes: 2
|
| 17 |
+
rdzv_backend: static
|
| 18 |
+
same_network: true
|
| 19 |
+
tpu_env: []
|
| 20 |
+
tpu_use_cluster: false
|
| 21 |
+
tpu_use_sudo: false
|
| 22 |
+
use_cpu: false
|
accelerate_configs/.ipynb_checkpoints/deepspeed_zero3-checkpoint.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
gradient_accumulation_steps: 8
|
| 6 |
+
offload_optimizer_device: "cpu"
|
| 7 |
+
offload_param_device: "cpu"
|
| 8 |
+
zero3_init_flag: true
|
| 9 |
+
zero3_save_16bit_model: true
|
| 10 |
+
zero_stage: 3
|
| 11 |
+
distributed_type: DEEPSPEED
|
| 12 |
+
downcast_bf16: 'auto'
|
| 13 |
+
machine_rank: 0
|
| 14 |
+
main_training_function: main
|
| 15 |
+
mixed_precision: 'bf16'
|
| 16 |
+
num_machines: 1
|
| 17 |
+
num_processes: 2
|
| 18 |
+
rdzv_backend: static
|
| 19 |
+
same_network: true
|
| 20 |
+
tpu_env: []
|
| 21 |
+
tpu_use_cluster: false
|
| 22 |
+
tpu_use_sudo: false
|
| 23 |
+
use_cpu: false
|
accelerate_configs/deepspeed_zero1.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
gradient_accumulation_steps: "auto"
|
| 6 |
+
zero3_init_flag: false
|
| 7 |
+
zero_stage: 1
|
| 8 |
+
distributed_type: DEEPSPEED
|
| 9 |
+
downcast_bf16: 'no'
|
| 10 |
+
machine_rank: 0
|
| 11 |
+
main_training_function: main
|
| 12 |
+
mixed_precision: 'bf16'
|
| 13 |
+
num_machines: 1
|
| 14 |
+
num_processes: 2
|
| 15 |
+
rdzv_backend: static
|
| 16 |
+
same_network: true
|
| 17 |
+
tpu_env: []
|
| 18 |
+
tpu_use_cluster: false
|
| 19 |
+
tpu_use_sudo: false
|
| 20 |
+
use_cpu: false
|
accelerate_configs/deepspeed_zero2.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
gradient_accumulation_steps: "auto"
|
| 6 |
+
offload_optimizer_device: "cpu"
|
| 7 |
+
offload_param_device: "cpu"
|
| 8 |
+
zero3_init_flag: false
|
| 9 |
+
zero_stage: 2
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'auto'
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_training_function: main
|
| 14 |
+
mixed_precision: 'bf16'
|
| 15 |
+
num_machines: 1
|
| 16 |
+
num_processes: 2
|
| 17 |
+
rdzv_backend: static
|
| 18 |
+
same_network: true
|
| 19 |
+
tpu_env: []
|
| 20 |
+
tpu_use_cluster: false
|
| 21 |
+
tpu_use_sudo: false
|
| 22 |
+
use_cpu: false
|
| 23 |
+
main_process_port: 29502
|
accelerate_configs/deepspeed_zero3.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
gradient_accumulation_steps: 8
|
| 6 |
+
offload_optimizer_device: "cpu"
|
| 7 |
+
offload_param_device: "cpu"
|
| 8 |
+
zero3_init_flag: true
|
| 9 |
+
zero3_save_16bit_model: true
|
| 10 |
+
zero_stage: 3
|
| 11 |
+
distributed_type: DEEPSPEED
|
| 12 |
+
downcast_bf16: 'auto'
|
| 13 |
+
machine_rank: 0
|
| 14 |
+
main_training_function: main
|
| 15 |
+
mixed_precision: 'bf16'
|
| 16 |
+
num_machines: 1
|
| 17 |
+
num_processes: 2
|
| 18 |
+
rdzv_backend: static
|
| 19 |
+
same_network: true
|
| 20 |
+
tpu_env: []
|
| 21 |
+
tpu_use_cluster: false
|
| 22 |
+
tpu_use_sudo: false
|
| 23 |
+
use_cpu: false
|
accelerate_configs/multi_gpu.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
gpu_ids: all
|
| 6 |
+
machine_rank: 0
|
| 7 |
+
main_training_function: main
|
| 8 |
+
mixed_precision: 'bf16'
|
| 9 |
+
num_machines: 1
|
| 10 |
+
num_processes: 8
|
| 11 |
+
rdzv_backend: static
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
accelerate_configs/single_gpu.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: "NO"
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
gpu_ids: all
|
| 6 |
+
machine_rank: 0
|
| 7 |
+
main_training_function: main
|
| 8 |
+
mixed_precision: 'bf16'
|
| 9 |
+
num_machines: 1
|
| 10 |
+
num_processes: 8
|
| 11 |
+
rdzv_backend: static
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
data/huggingface_data.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import datasets
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
#分割训练集、测试集、验证集
|
| 7 |
+
splits = ["test","train","val"]
|
| 8 |
+
#分数据集处理
|
| 9 |
+
for item in splits:
|
| 10 |
+
os.makedirs(f"our_clean/{item}/", exist_ok=True)
|
| 11 |
+
data = pd.read_csv(f"{item}.csv")
|
| 12 |
+
data["image_path"] = data["image_path"].map(lambda x:x.split("/")[-1])
|
| 13 |
+
#简单清洗文本内容
|
| 14 |
+
f = lambda x: re.sub(' +', ' ', str(x).lower()).replace(" ?", "?").strip()
|
| 15 |
+
#huggingface要求需要包含file_name作为键值
|
| 16 |
+
data.insert(0, "file_name", "")
|
| 17 |
+
data["question"] = data["question"].apply(f)
|
| 18 |
+
data["answer"] = data["answer"].apply(f)
|
| 19 |
+
#实现图文对应
|
| 20 |
+
for i, row in data.iterrows():
|
| 21 |
+
file_name = f"img_{i}.jpg"
|
| 22 |
+
data["file_name"].iloc[i] = file_name
|
| 23 |
+
shutil.copyfile(src=f"author-folder/pvqa/pvqa/images/{item}/{row['image']}.jpg", dst=f"our_clean/{item}/{file_name}")
|
| 24 |
+
##删除无关行
|
| 25 |
+
_ = data.pop("image")
|
| 26 |
+
data.drop(["pathology","image_path"],axis=1,inplace=True)
|
| 27 |
+
data.to_csv(f"our_clean/{item}/metadata.csv", index=False)
|
| 28 |
+
#创建imagefolder格式的数据集,data_dir为存放数据的文件夹,可以参考https://huggingface.co/docs/datasets/en/image_load
|
| 29 |
+
dataset = datasets.load_dataset("imagefolder", data_dir="our_clean/")
|
| 30 |
+
#发布数据
|
| 31 |
+
dataset.push_to_hub("CNX-PathLLM/PVQAClean")
|
dataset_csv/gtex_slide_url_info.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset_csv/indices_and_slide_ids.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset_csv/indices_and_slide_ids_with_folds.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset_csv/tcga_slide_url_info.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
demo/CONCH_clip.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from conch.open_clip_custom import create_model_from_pretrained
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
model, preprocess = create_model_from_pretrained('conch_ViT-B-16', "/raid/hpc/hekai/WorkShop/My_project/PathLLM_new/load_weights/conch/pytorch_model.bin")
|
| 6 |
+
|
| 7 |
+
image = Image.open("/bask/homes/a/asiw9691/PathVLM/source/Flamingo/med-flamingo/img/test_path5.jpg")
|
| 8 |
+
image = preprocess(image).unsqueeze(0)
|
| 9 |
+
with torch.inference_mode():
|
| 10 |
+
image_embs = model.encode_image(image, proj_contrast=False, normalize=False)
|
| 11 |
+
print(image_embs.shape)
|
demo/Trainer_Mixtrial_ds_demo.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# accelerate launch --config_file=/raid/hpc/hekai/WorkShop/My_project/PathLLM_new/accelerate_configs/deepspeed_zero2.yaml Trainer_Mixtrial_demo.py
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
os.environ["WANDB_MODE"] = "offline"
|
| 6 |
+
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
| 7 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "5,6"
|
| 8 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, Trainer, PreTrainedModel
|
| 13 |
+
from accelerate import Accelerator
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
from typing import Optional
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ScriptArguments:
|
| 21 |
+
"""
|
| 22 |
+
The name of the Casual LM model we wish to fine with SFTTrainer
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
model_name: Optional[str] = field(default="mistralai/Mistral-7B-Instruct-v0.2", metadata={"help": "the model name, meta-llama/Llama-2-7b-chat-hf "})
|
| 26 |
+
dataset_name: Optional[str] = field(default="stingning/ultrachat", metadata={"help": "the dataset name"})
|
| 27 |
+
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
|
| 28 |
+
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
|
| 29 |
+
learning_rate: Optional[float] = field(default=2.0e-5, metadata={"help": "the learning rate"})
|
| 30 |
+
batch_size: Optional[int] = field(default=1, metadata={"help": "the batch size"})
|
| 31 |
+
seq_length: Optional[int] = field(default=1024, metadata={"help": "Input sequence length"})
|
| 32 |
+
gradient_accumulation_steps: Optional[int] = field(default=8, metadata={"help": "the number of gradient accumulation steps"})
|
| 33 |
+
|
| 34 |
+
evaluation_strategy: Optional[str] = field(default="steps", metadata={"help": "epoch, step"})
|
| 35 |
+
eval_steps: Optional[int] = field(default=1000, metadata={"help": "the number of gradient accumulation steps"})
|
| 36 |
+
|
| 37 |
+
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
|
| 38 |
+
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "load the model in 4 bits precision"})
|
| 39 |
+
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
|
| 40 |
+
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
|
| 41 |
+
|
| 42 |
+
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
|
| 43 |
+
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
|
| 44 |
+
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
|
| 45 |
+
logging_steps: Optional[int] = field(default=5, metadata={"help": "the number of logging steps"})
|
| 46 |
+
token: Optional[bool] = field(default="True", metadata={"help": "Use HF auth token to access the model"})
|
| 47 |
+
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
|
| 48 |
+
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
|
| 49 |
+
save_steps: Optional[int] = field(default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"})
|
| 50 |
+
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
|
| 51 |
+
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
|
| 52 |
+
hub_model_id: Optional[str] = field(default="mistral-7b-finetuned-ultrachat", metadata={"help": "The name of the model on HF Hub"}
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 56 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MyCustomModel(nn.Module):
|
| 62 |
+
def __init__(self, script_args, num_labels):
|
| 63 |
+
super(MyCustomModel, self).__init__()
|
| 64 |
+
self.num_labels = num_labels
|
| 65 |
+
self.pretrained_model = AutoModelForCausalLM.from_pretrained(script_args.model_name,
|
| 66 |
+
quantization_config=quantization_config,
|
| 67 |
+
device_map=device_map,
|
| 68 |
+
trust_remote_code=script_args.trust_remote_code,
|
| 69 |
+
torch_dtype=torch_dtype,
|
| 70 |
+
token=script_args.token)
|
| 71 |
+
self.classifier = nn.Linear(self.pretrained_model.config.hidden_size, num_labels)
|
| 72 |
+
|
| 73 |
+
def forward(self, input_ids, attention_mask=None, labels=None):
|
| 74 |
+
outputs = self.pretrained_model(input_ids, attention_mask=attention_mask)
|
| 75 |
+
sequence_output = outputs.last_hidden_state[:,0,:]
|
| 76 |
+
logits = self.classifier(sequence_output)
|
| 77 |
+
|
| 78 |
+
loss = None
|
| 79 |
+
if labels is not None:
|
| 80 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 81 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 82 |
+
|
| 83 |
+
return {"loss": loss, "logits": logits} if loss is not None else logits
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
dataset = load_dataset("glue", "mrpc")
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
| 88 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 89 |
+
|
| 90 |
+
def preprocess_function(examples):
|
| 91 |
+
# Tokenize the inputs (pair of sentences)
|
| 92 |
+
return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding=True, max_length=10)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
from transformers import DataCollatorWithPadding
|
| 98 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 99 |
+
|
| 100 |
+
small_train_dataset = dataset["train"].shuffle(seed=42).select(range(500)) # 选择前500个样本
|
| 101 |
+
small_train_dataset = small_train_dataset.map(preprocess_function, batched=True)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if script_args.load_in_8bit and script_args.load_in_4bit:
|
| 105 |
+
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
|
| 106 |
+
elif script_args.load_in_8bit or script_args.load_in_4bit:
|
| 107 |
+
quantization_config = BitsAndBytesConfig(
|
| 108 |
+
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
|
| 109 |
+
)
|
| 110 |
+
# Copy the model to each device
|
| 111 |
+
device_map = {"": Accelerator().local_process_index}
|
| 112 |
+
torch_dtype = torch.bfloat16
|
| 113 |
+
else:
|
| 114 |
+
device_map = None
|
| 115 |
+
quantization_config = None
|
| 116 |
+
torch_dtype = None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
model = MyCustomModel(script_args, num_labels=2)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
training_args = TrainingArguments(
|
| 124 |
+
output_dir=script_args.output_dir,
|
| 125 |
+
per_device_train_batch_size=script_args.batch_size,
|
| 126 |
+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
| 127 |
+
# gradient_checkpointing=True,
|
| 128 |
+
learning_rate=script_args.learning_rate,
|
| 129 |
+
logging_steps=script_args.logging_steps,
|
| 130 |
+
num_train_epochs=script_args.num_train_epochs,
|
| 131 |
+
max_steps=script_args.max_steps,
|
| 132 |
+
report_to=script_args.log_with,
|
| 133 |
+
save_steps=script_args.save_steps,
|
| 134 |
+
save_total_limit=script_args.save_total_limit,
|
| 135 |
+
bf16=True,
|
| 136 |
+
lr_scheduler_type="cosine",
|
| 137 |
+
warmup_ratio=0.1,
|
| 138 |
+
evaluation_strategy=script_args.evaluation_strategy,
|
| 139 |
+
eval_steps=script_args.eval_steps,
|
| 140 |
+
logging_first_step=True,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
trainer = Trainer(
|
| 145 |
+
model=model,
|
| 146 |
+
args=training_args,
|
| 147 |
+
train_dataset=small_train_dataset,
|
| 148 |
+
data_collator=data_collator,
|
| 149 |
+
compute_metrics=None,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
trainer.train()
|
| 154 |
+
# model.save_pretrained("./my_custom_model")
|
demo/Trainer_bert_demo.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
| 3 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import Trainer, TrainingArguments
|
| 8 |
+
from transformers import AutoTokenizer, AutoModel
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
# 自定义模型,继承自nn.Module或者transformers提供的预训练模型类
|
| 12 |
+
class MyCustomModel(nn.Module):
|
| 13 |
+
def __init__(self, num_labels):
|
| 14 |
+
super(MyCustomModel, self).__init__()
|
| 15 |
+
self.num_labels = num_labels
|
| 16 |
+
self.pretrained_model = AutoModel.from_pretrained("bert-base-uncased")
|
| 17 |
+
self.classifier = nn.Linear(self.pretrained_model.config.hidden_size, num_labels)
|
| 18 |
+
|
| 19 |
+
def forward(self, input_ids, attention_mask=None, labels=None):
|
| 20 |
+
outputs = self.pretrained_model(input_ids, attention_mask=attention_mask)
|
| 21 |
+
sequence_output = outputs[1]
|
| 22 |
+
logits = self.classifier(sequence_output)
|
| 23 |
+
|
| 24 |
+
loss = None
|
| 25 |
+
if labels is not None:
|
| 26 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 27 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 28 |
+
|
| 29 |
+
return {"loss": loss, "logits": logits} if loss is not None else logits
|
| 30 |
+
|
| 31 |
+
# 加载数据集并预处理
|
| 32 |
+
dataset = load_dataset("glue", "mrpc")
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 34 |
+
|
| 35 |
+
def preprocess_function(examples):
|
| 36 |
+
# Tokenize the inputs (pair of sentences)
|
| 37 |
+
return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding=True)
|
| 38 |
+
|
| 39 |
+
from transformers import DataCollatorWithPadding
|
| 40 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 41 |
+
|
| 42 |
+
small_train_dataset = dataset["train"].shuffle(seed=42).select(range(500)) # 选择前500个样本
|
| 43 |
+
small_train_dataset = small_train_dataset.map(preprocess_function, batched=True)
|
| 44 |
+
|
| 45 |
+
for i in small_train_dataset:
|
| 46 |
+
print(i)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# 自定义模型实例化
|
| 50 |
+
model = MyCustomModel(num_labels=2).to("cuda")
|
| 51 |
+
|
| 52 |
+
# 定义训练参数
|
| 53 |
+
training_args = TrainingArguments(
|
| 54 |
+
output_dir="./results",
|
| 55 |
+
num_train_epochs=3,
|
| 56 |
+
per_device_train_batch_size=8,
|
| 57 |
+
warmup_steps=500,
|
| 58 |
+
weight_decay=0.01,
|
| 59 |
+
logging_dir='./logs',
|
| 60 |
+
logging_steps=10,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# 初始化Trainer
|
| 64 |
+
trainer = Trainer(
|
| 65 |
+
model=model,
|
| 66 |
+
args=training_args,
|
| 67 |
+
train_dataset=small_train_dataset,
|
| 68 |
+
data_collator=data_collator,
|
| 69 |
+
compute_metrics=None, # 如果需要可以添加计算指标的函数
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# 训练模型
|
| 73 |
+
trainer.train()
|
| 74 |
+
|
| 75 |
+
# 保存模型
|
| 76 |
+
model.save_pretrained("./my_custom_model")
|
demo/UNI_clip.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
import timm
|
| 5 |
+
from huggingface_hub import login, hf_hub_download
|
| 6 |
+
|
| 7 |
+
# login() # login with your User Access Token, found at https://huggingface.co/settings/tokens
|
| 8 |
+
|
| 9 |
+
local_dir = "/bask/homes/a/asiw9691/PathVLM/UNI"
|
| 10 |
+
# os.makedirs(local_dir, exist_ok=True) # create directory if it does not exist
|
| 11 |
+
# hf_hub_download("MahmoodLab/UNI", filename="pytorch_model.bin", local_dir=local_dir, force_download=True)
|
| 12 |
+
model = timm.create_model("vit_large_patch16_224", img_size=224, patch_size=16, init_values=1e-5, num_classes=0, dynamic_img_size=True)
|
| 13 |
+
model.load_state_dict(torch.load(os.path.join(local_dir, "pytorch_model.bin"), map_location="cpu"), strict=True)
|
| 14 |
+
transform = transforms.Compose(
|
| 15 |
+
[
|
| 16 |
+
# transforms.Resize(224),
|
| 17 |
+
transforms.Resize(256), # 先将最短边调整到256像素
|
| 18 |
+
transforms.CenterCrop(224), # 然后从中心裁剪出224x224像素的图像
|
| 19 |
+
transforms.ToTensor(),
|
| 20 |
+
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 21 |
+
transforms.Lambda(lambda x: x.unsqueeze(0))
|
| 22 |
+
]
|
| 23 |
+
)
|
| 24 |
+
model.eval()
|
| 25 |
+
|
| 26 |
+
from PIL import Image
|
| 27 |
+
image = Image.open("/bask/homes/a/asiw9691/PathVLM/source/Flamingo/med-flamingo/img/test_path5.jpg")
|
| 28 |
+
|
| 29 |
+
image = transform(image) # Image (torch.Tensor) with shape [1, 3, 224, 224] following image resizing and normalization (ImageNet parameters)
|
| 30 |
+
with torch.inference_mode():
|
| 31 |
+
feature_emb = model(image) # Extracted features (torch.Tensor) with shape [1,1024]
|
| 32 |
+
|
| 33 |
+
print(feature_emb.shape)
|
demo/path_clip.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
import open_clip
|
| 5 |
+
|
| 6 |
+
## load the model
|
| 7 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained="/bask/homes/a/asiw9691/PathVLM/PathClip/pathclip-base.pt",
|
| 8 |
+
force_quick_gelu=True)
|
| 9 |
+
tokenizer = open_clip.get_tokenizer('ViT-B-16')
|
| 10 |
+
model = model.cuda()
|
| 11 |
+
|
| 12 |
+
##load the image and prepare the text prompt
|
| 13 |
+
img_path = '/raid/hpc/hekai/WorkShop/My_project/PathLLM_new/data/test_data/test_path1.jpg'
|
| 14 |
+
label_description_list = ['apple', 'liver', 'cancer',] # specify the label descriptions
|
| 15 |
+
text_label_list = ['An image of {}'.format(i) for i in label_description_list]
|
| 16 |
+
image = Image.open(img_path)
|
| 17 |
+
image = preprocess(image).unsqueeze(0).cuda()
|
| 18 |
+
text = tokenizer(text_label_list).cuda()
|
| 19 |
+
|
| 20 |
+
## extract the img and text feature and predict the label
|
| 21 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
| 22 |
+
image_features = model.encode_image(image)
|
| 23 |
+
text_features = model.encode_text(text)
|
| 24 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 25 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 26 |
+
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
| 27 |
+
predict_label = torch.argmax(text_probs).item()
|
| 28 |
+
print(predict_label)
|
demo/peft_demo.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 5 |
+
from peft import LoraConfig, TaskType, get_peft_model
|
| 6 |
+
|
| 7 |
+
model_name_or_path = "/raid/hpc/hekai/WorkShop/My_project/LLM_models/llama2/Llama-2-7b-chat-hf"
|
| 8 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
| 9 |
+
|
| 10 |
+
peft_config = LoraConfig(
|
| 11 |
+
r=8,
|
| 12 |
+
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
| 13 |
+
bias="none",
|
| 14 |
+
task_type=TaskType.CAUSAL_LM,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
model = get_peft_model(model, peft_config)
|
demo/trl_demo.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
# accelerate launch --config_file=/raid/hpc/hekai/WorkShop/My_project/PathLLM_new/accelerate_configs/deepspeed_zero2.yaml demo/trl_demo.py
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
| 7 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from accelerate import Accelerator
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
from peft import LoraConfig
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, AutoTokenizer
|
| 18 |
+
from trl import SFTTrainer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
tqdm.pandas()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Define and parse arguments.
|
| 26 |
+
@dataclass
|
| 27 |
+
class ScriptArguments:
|
| 28 |
+
"""
|
| 29 |
+
The name of the Casual LM model we wish to fine with SFTTrainer
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
model_name: Optional[str] = field(default="mistralai/Mistral-7B-Instruct-v0.2", metadata={"help": "the model name, meta-llama/Llama-2-7b-chat-hf "})
|
| 33 |
+
dataset_name: Optional[str] = field(default="stingning/ultrachat", metadata={"help": "the dataset name"})
|
| 34 |
+
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
|
| 35 |
+
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
|
| 36 |
+
learning_rate: Optional[float] = field(default=2.0e-5, metadata={"help": "the learning rate"})
|
| 37 |
+
batch_size: Optional[int] = field(default=1, metadata={"help": "the batch size"})
|
| 38 |
+
seq_length: Optional[int] = field(default=1024, metadata={"help": "Input sequence length"})
|
| 39 |
+
gradient_accumulation_steps: Optional[int] = field(default=8, metadata={"help": "the number of gradient accumulation steps"})
|
| 40 |
+
|
| 41 |
+
evaluation_strategy: Optional[str] = field(default="steps", metadata={"help": "epoch, step"})
|
| 42 |
+
eval_steps: Optional[int] = field(default=2, metadata={"help": "the number of gradient accumulation steps"})
|
| 43 |
+
|
| 44 |
+
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
|
| 45 |
+
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
|
| 46 |
+
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
|
| 47 |
+
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
|
| 48 |
+
|
| 49 |
+
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
|
| 50 |
+
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
|
| 51 |
+
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
|
| 52 |
+
logging_steps: Optional[int] = field(default=5, metadata={"help": "the number of logging steps"})
|
| 53 |
+
token: Optional[bool] = field(default="True", metadata={"help": "Use HF auth token to access the model"})
|
| 54 |
+
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
|
| 55 |
+
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
|
| 56 |
+
save_steps: Optional[int] = field(default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"})
|
| 57 |
+
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
|
| 58 |
+
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
|
| 59 |
+
hub_model_id: Optional[str] = field(default="mistral-7b-finetuned-ultrachat", metadata={"help": "The name of the model on HF Hub"})
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 63 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 64 |
+
|
| 65 |
+
# Step 1: Load the dataset
|
| 66 |
+
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
|
| 67 |
+
tokenizer.padding_side = 'right'
|
| 68 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 69 |
+
|
| 70 |
+
dataset = load_dataset(script_args.dataset_name, split="train[:200]")
|
| 71 |
+
dataset = dataset.train_test_split(test_size=0.1)
|
| 72 |
+
|
| 73 |
+
def prepare_dialogue(example):
|
| 74 |
+
text = ""
|
| 75 |
+
for idx, msg in enumerate(example["data"]):
|
| 76 |
+
if idx % 2 == 0:
|
| 77 |
+
text += f"<|user|>\n{msg}{tokenizer.eos_token}\n"
|
| 78 |
+
else:
|
| 79 |
+
text += f"<|assistant|>\n{msg}{tokenizer.eos_token}\n"
|
| 80 |
+
example["text"] = text
|
| 81 |
+
return example
|
| 82 |
+
|
| 83 |
+
dataset = dataset.map(prepare_dialogue, num_proc=4, remove_columns=["id", "data"])
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Step 2: Load the model
|
| 87 |
+
if script_args.load_in_8bit and script_args.load_in_4bit:
|
| 88 |
+
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
|
| 89 |
+
elif script_args.load_in_8bit or script_args.load_in_4bit:
|
| 90 |
+
quantization_config = BitsAndBytesConfig(
|
| 91 |
+
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
|
| 92 |
+
)
|
| 93 |
+
# Copy the model to each device
|
| 94 |
+
device_map = {"": Accelerator().local_process_index}
|
| 95 |
+
torch_dtype = torch.bfloat16
|
| 96 |
+
else:
|
| 97 |
+
# device_map = "auto"
|
| 98 |
+
device_map = None
|
| 99 |
+
quantization_config = None
|
| 100 |
+
torch_dtype = None
|
| 101 |
+
|
| 102 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 103 |
+
script_args.model_name,
|
| 104 |
+
quantization_config=quantization_config,
|
| 105 |
+
device_map=device_map,
|
| 106 |
+
trust_remote_code=script_args.trust_remote_code,
|
| 107 |
+
torch_dtype=torch_dtype,
|
| 108 |
+
token=script_args.token,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Step 4: Define the LoraConfig
|
| 113 |
+
if script_args.use_peft:
|
| 114 |
+
peft_config = LoraConfig(
|
| 115 |
+
r=script_args.peft_lora_r,
|
| 116 |
+
lora_alpha=script_args.peft_lora_alpha,
|
| 117 |
+
bias="none",
|
| 118 |
+
task_type="CAUSAL_LM",
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
peft_config = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
training_args = TrainingArguments(
|
| 126 |
+
output_dir=script_args.output_dir,
|
| 127 |
+
per_device_train_batch_size=script_args.batch_size,
|
| 128 |
+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
| 129 |
+
gradient_checkpointing=True,
|
| 130 |
+
learning_rate=script_args.learning_rate,
|
| 131 |
+
logging_steps=script_args.logging_steps,
|
| 132 |
+
num_train_epochs=script_args.num_train_epochs,
|
| 133 |
+
max_steps=script_args.max_steps,
|
| 134 |
+
report_to=script_args.log_with,
|
| 135 |
+
save_steps=script_args.save_steps,
|
| 136 |
+
save_total_limit=script_args.save_total_limit,
|
| 137 |
+
# push_to_hub=script_args.push_to_hub,
|
| 138 |
+
# hub_model_id=script_args.hub_model_id,
|
| 139 |
+
bf16=True,
|
| 140 |
+
lr_scheduler_type="cosine",
|
| 141 |
+
warmup_ratio=0.1,
|
| 142 |
+
evaluation_strategy=script_args.evaluation_strategy,
|
| 143 |
+
eval_steps=script_args.eval_steps,
|
| 144 |
+
logging_first_step=True,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def my_compute_metrics(p):
|
| 148 |
+
predictions, labels = p
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
'precision': 1,
|
| 152 |
+
'recall': 1,
|
| 153 |
+
'f1': 1,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
trainer = SFTTrainer(
|
| 157 |
+
model=model,
|
| 158 |
+
args=training_args,
|
| 159 |
+
max_seq_length=script_args.seq_length,
|
| 160 |
+
train_dataset=dataset["train"],
|
| 161 |
+
eval_dataset=dataset["test"],
|
| 162 |
+
dataset_text_field=script_args.dataset_text_field,
|
| 163 |
+
peft_config=peft_config,
|
| 164 |
+
packing=False,
|
| 165 |
+
tokenizer=tokenizer,
|
| 166 |
+
compute_metrics=my_compute_metrics
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
trainer.train()
|
| 170 |
+
|
| 171 |
+
# Step 6: Save the model
|
| 172 |
+
trainer.save_model(script_args.output_dir)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
evaluation/cider_score/cider_demo.ipynb
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"colab_type": "text",
|
| 7 |
+
"id": "view-in-github"
|
| 8 |
+
},
|
| 9 |
+
"source": [
|
| 10 |
+
"<a href=\"https://colab.research.google.com/github/michelecafagna26/cider/blob/master/cider_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "markdown",
|
| 15 |
+
"metadata": {
|
| 16 |
+
"id": "GWTmEvA9jNbE"
|
| 17 |
+
},
|
| 18 |
+
"source": [
|
| 19 |
+
"# Install"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": 9,
|
| 25 |
+
"metadata": {
|
| 26 |
+
"colab": {
|
| 27 |
+
"base_uri": "https://localhost:8080/"
|
| 28 |
+
},
|
| 29 |
+
"id": "AeLUPV23cglP",
|
| 30 |
+
"outputId": "61e4fef5-0481-4c41-ee8a-0e8073f0d3e1"
|
| 31 |
+
},
|
| 32 |
+
"outputs": [
|
| 33 |
+
{
|
| 34 |
+
"name": "stdout",
|
| 35 |
+
"output_type": "stream",
|
| 36 |
+
"text": [
|
| 37 |
+
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
|
| 38 |
+
"Requirement already satisfied: spacy in d:\\miniconda\\envs\\llm\\lib\\site-packages (3.7.4)\n",
|
| 39 |
+
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (3.0.12)\n",
|
| 40 |
+
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (1.0.5)\n",
|
| 41 |
+
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (1.0.10)\n",
|
| 42 |
+
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.0.8)\n",
|
| 43 |
+
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (3.0.9)\n",
|
| 44 |
+
"Requirement already satisfied: thinc<8.3.0,>=8.2.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (8.2.3)\n",
|
| 45 |
+
"Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (1.1.2)\n",
|
| 46 |
+
"Requirement already satisfied: srsly<3.0.0,>=2.4.3 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.4.8)\n",
|
| 47 |
+
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.0.10)\n",
|
| 48 |
+
"Requirement already satisfied: weasel<0.4.0,>=0.1.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (0.3.4)\n",
|
| 49 |
+
"Requirement already satisfied: typer<0.10.0,>=0.3.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (0.9.4)\n",
|
| 50 |
+
"Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (6.4.0)\n",
|
| 51 |
+
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (4.66.1)\n",
|
| 52 |
+
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.31.0)\n",
|
| 53 |
+
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.7.1)\n",
|
| 54 |
+
"Requirement already satisfied: jinja2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (3.1.2)\n",
|
| 55 |
+
"Requirement already satisfied: setuptools in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (68.0.0)\n",
|
| 56 |
+
"Requirement already satisfied: packaging>=20.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (23.2)\n",
|
| 57 |
+
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (3.4.0)\n",
|
| 58 |
+
"Requirement already satisfied: numpy>=1.15.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (1.24.4)\n",
|
| 59 |
+
"Requirement already satisfied: language-data>=1.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from langcodes<4.0.0,>=3.2.0->spacy) (1.2.0)\n",
|
| 60 |
+
"Requirement already satisfied: annotated-types>=0.4.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (0.7.0)\n",
|
| 61 |
+
"Requirement already satisfied: pydantic-core==2.18.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (2.18.2)\n",
|
| 62 |
+
"Requirement already satisfied: typing-extensions>=4.6.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (4.8.0)\n",
|
| 63 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (3.3.2)\n",
|
| 64 |
+
"Requirement already satisfied: idna<4,>=2.5 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (3.4)\n",
|
| 65 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (2.0.7)\n",
|
| 66 |
+
"Requirement already satisfied: certifi>=2017.4.17 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (2023.7.22)\n",
|
| 67 |
+
"Requirement already satisfied: blis<0.8.0,>=0.7.8 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from thinc<8.3.0,>=8.2.2->spacy) (0.7.11)\n",
|
| 68 |
+
"Requirement already satisfied: confection<1.0.0,>=0.0.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from thinc<8.3.0,>=8.2.2->spacy) (0.1.4)\n",
|
| 69 |
+
"Requirement already satisfied: colorama in d:\\miniconda\\envs\\llm\\lib\\site-packages (from tqdm<5.0.0,>=4.38.0->spacy) (0.4.6)\n",
|
| 70 |
+
"Requirement already satisfied: click<9.0.0,>=7.1.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from typer<0.10.0,>=0.3.0->spacy) (8.1.7)\n",
|
| 71 |
+
"Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from weasel<0.4.0,>=0.1.0->spacy) (0.16.0)\n",
|
| 72 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from jinja2->spacy) (2.1.3)\n",
|
| 73 |
+
"Requirement already satisfied: marisa-trie>=0.7.7 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from language-data>=1.2->langcodes<4.0.0,>=3.2.0->spacy) (1.1.1)\n",
|
| 74 |
+
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
|
| 75 |
+
"Collecting en-core-web-sm==3.7.1\n",
|
| 76 |
+
" Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)\n",
|
| 77 |
+
" ---------------------------------------- 0.0/12.8 MB ? eta -:--:--\n",
|
| 78 |
+
" ---------------------------------------- 0.0/12.8 MB ? eta -:--:--\n",
|
| 79 |
+
" --------------------------------------- 0.0/12.8 MB 217.9 kB/s eta 0:00:59\n",
|
| 80 |
+
" --------------------------------------- 0.0/12.8 MB 217.9 kB/s eta 0:00:59\n",
|
| 81 |
+
" --------------------------------------- 0.0/12.8 MB 119.1 kB/s eta 0:01:48\n",
|
| 82 |
+
" --------------------------------------- 0.0/12.8 MB 119.1 kB/s eta 0:01:48\n",
|
| 83 |
+
" --------------------------------------- 0.0/12.8 MB 122.9 kB/s eta 0:01:44\n",
|
| 84 |
+
" --------------------------------------- 0.1/12.8 MB 164.1 kB/s eta 0:01:18\n",
|
| 85 |
+
" --------------------------------------- 0.1/12.8 MB 201.8 kB/s eta 0:01:03\n",
|
| 86 |
+
" --------------------------------------- 0.1/12.8 MB 234.3 kB/s eta 0:00:55\n",
|
| 87 |
+
" --------------------------------------- 0.1/12.8 MB 232.7 kB/s eta 0:00:55\n",
|
| 88 |
+
" --------------------------------------- 0.2/12.8 MB 280.8 kB/s eta 0:00:46\n",
|
| 89 |
+
" -------------------------------------- 0.2/12.8 MB 318.9 kB/s eta 0:00:40\n",
|
| 90 |
+
" -------------------------------------- 0.2/12.8 MB 367.4 kB/s eta 0:00:35\n",
|
| 91 |
+
" -------------------------------------- 0.3/12.8 MB 402.4 kB/s eta 0:00:32\n",
|
| 92 |
+
" - ------------------------------------- 0.3/12.8 MB 446.4 kB/s eta 0:00:28\n",
|
| 93 |
+
" - ------------------------------------- 0.4/12.8 MB 511.3 kB/s eta 0:00:25\n",
|
| 94 |
+
" - ------------------------------------- 0.5/12.8 MB 581.3 kB/s eta 0:00:22\n",
|
| 95 |
+
" - ------------------------------------- 0.5/12.8 MB 619.9 kB/s eta 0:00:20\n",
|
| 96 |
+
" - ------------------------------------- 0.6/12.8 MB 655.3 kB/s eta 0:00:19\n",
|
| 97 |
+
" -- ------------------------------------ 0.7/12.8 MB 686.9 kB/s eta 0:00:18\n",
|
| 98 |
+
" -- ------------------------------------ 0.8/12.8 MB 735.7 kB/s eta 0:00:17\n",
|
| 99 |
+
" -- ------------------------------------ 0.8/12.8 MB 762.9 kB/s eta 0:00:16\n",
|
| 100 |
+
" -- ------------------------------------ 0.9/12.8 MB 786.3 kB/s eta 0:00:16\n",
|
| 101 |
+
" -- ------------------------------------ 0.9/12.8 MB 816.7 kB/s eta 0:00:15\n",
|
| 102 |
+
" --- ----------------------------------- 1.1/12.8 MB 893.8 kB/s eta 0:00:14\n",
|
| 103 |
+
" --- ------------------------------------ 1.3/12.8 MB 1.0 MB/s eta 0:00:12\n",
|
| 104 |
+
" ---- ----------------------------------- 1.4/12.8 MB 1.1 MB/s eta 0:00:11\n",
|
| 105 |
+
" ----- ---------------------------------- 1.6/12.8 MB 1.2 MB/s eta 0:00:10\n",
|
| 106 |
+
" ----- ---------------------------------- 1.8/12.8 MB 1.3 MB/s eta 0:00:09\n",
|
| 107 |
+
" ------ --------------------------------- 2.1/12.8 MB 1.5 MB/s eta 0:00:08\n",
|
| 108 |
+
" ------- -------------------------------- 2.5/12.8 MB 1.7 MB/s eta 0:00:07\n",
|
| 109 |
+
" --------- ------------------------------ 3.2/12.8 MB 2.1 MB/s eta 0:00:05\n",
|
| 110 |
+
" ------------ --------------------------- 3.9/12.8 MB 2.4 MB/s eta 0:00:04\n",
|
| 111 |
+
" -------------- ------------------------- 4.5/12.8 MB 2.7 MB/s eta 0:00:04\n",
|
| 112 |
+
" --------------- ------------------------ 4.9/12.8 MB 2.9 MB/s eta 0:00:03\n",
|
| 113 |
+
" ---------------- ----------------------- 5.2/12.8 MB 3.0 MB/s eta 0:00:03\n",
|
| 114 |
+
" ----------------- ---------------------- 5.5/12.8 MB 3.0 MB/s eta 0:00:03\n",
|
| 115 |
+
" ------------------ --------------------- 5.8/12.8 MB 3.1 MB/s eta 0:00:03\n",
|
| 116 |
+
" ------------------- -------------------- 6.1/12.8 MB 3.2 MB/s eta 0:00:03\n",
|
| 117 |
+
" ------------------- -------------------- 6.4/12.8 MB 3.3 MB/s eta 0:00:02\n",
|
| 118 |
+
" -------------------- ------------------- 6.7/12.8 MB 3.3 MB/s eta 0:00:02\n",
|
| 119 |
+
" --------------------- ------------------ 7.0/12.8 MB 3.4 MB/s eta 0:00:02\n",
|
| 120 |
+
" ---------------------- ----------------- 7.3/12.8 MB 3.5 MB/s eta 0:00:02\n",
|
| 121 |
+
" ----------------------- ---------------- 7.6/12.8 MB 3.5 MB/s eta 0:00:02\n",
|
| 122 |
+
" ------------------------ --------------- 7.9/12.8 MB 3.6 MB/s eta 0:00:02\n",
|
| 123 |
+
" ------------------------- -------------- 8.2/12.8 MB 3.7 MB/s eta 0:00:02\n",
|
| 124 |
+
" -------------------------- ------------- 8.5/12.8 MB 3.7 MB/s eta 0:00:02\n",
|
| 125 |
+
" --------------------------- ------------ 8.8/12.8 MB 3.8 MB/s eta 0:00:02\n",
|
| 126 |
+
" ---------------------------- ----------- 9.1/12.8 MB 3.8 MB/s eta 0:00:01\n",
|
| 127 |
+
" ----------------------------- ---------- 9.4/12.8 MB 3.9 MB/s eta 0:00:01\n",
|
| 128 |
+
" ------------------------------ --------- 9.7/12.8 MB 3.9 MB/s eta 0:00:01\n",
|
| 129 |
+
" ------------------------------- -------- 10.0/12.8 MB 3.9 MB/s eta 0:00:01\n",
|
| 130 |
+
" -------------------------------- ------- 10.3/12.8 MB 4.4 MB/s eta 0:00:01\n",
|
| 131 |
+
" --------------------------------- ------ 10.6/12.8 MB 5.4 MB/s eta 0:00:01\n",
|
| 132 |
+
" --------------------------------- ------ 10.9/12.8 MB 5.9 MB/s eta 0:00:01\n",
|
| 133 |
+
" ---------------------------------- ----- 11.2/12.8 MB 6.5 MB/s eta 0:00:01\n",
|
| 134 |
+
" ----------------------------------- ---- 11.4/12.8 MB 6.7 MB/s eta 0:00:01\n",
|
| 135 |
+
" ------------------------------------ --- 11.7/12.8 MB 6.8 MB/s eta 0:00:01\n",
|
| 136 |
+
" ------------------------------------- -- 12.0/12.8 MB 6.9 MB/s eta 0:00:01\n",
|
| 137 |
+
" -------------------------------------- - 12.3/12.8 MB 6.9 MB/s eta 0:00:01\n",
|
| 138 |
+
" --------------------------------------- 12.6/12.8 MB 6.8 MB/s eta 0:00:01\n",
|
| 139 |
+
" --------------------------------------- 12.8/12.8 MB 6.8 MB/s eta 0:00:01\n",
|
| 140 |
+
" ---------------------------------------- 12.8/12.8 MB 6.6 MB/s eta 0:00:00\n",
|
| 141 |
+
"Requirement already satisfied: spacy<3.8.0,>=3.7.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from en-core-web-sm==3.7.1) (3.7.4)\n",
|
| 142 |
+
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.12)\n",
|
| 143 |
+
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.5)\n",
|
| 144 |
+
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.10)\n",
|
| 145 |
+
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.8)\n",
|
| 146 |
+
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.9)\n",
|
| 147 |
+
"Requirement already satisfied: thinc<8.3.0,>=8.2.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.2.3)\n",
|
| 148 |
+
"Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.2)\n",
|
| 149 |
+
"Requirement already satisfied: srsly<3.0.0,>=2.4.3 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.4.8)\n",
|
| 150 |
+
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.10)\n",
|
| 151 |
+
"Requirement already satisfied: weasel<0.4.0,>=0.1.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.3.4)\n",
|
| 152 |
+
"Requirement already satisfied: typer<0.10.0,>=0.3.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.9.4)\n",
|
| 153 |
+
"Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (6.4.0)\n",
|
| 154 |
+
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.66.1)\n",
|
| 155 |
+
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.31.0)\n",
|
| 156 |
+
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.7.1)\n",
|
| 157 |
+
"Requirement already satisfied: jinja2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.1.2)\n",
|
| 158 |
+
"Requirement already satisfied: setuptools in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (68.0.0)\n",
|
| 159 |
+
"Requirement already satisfied: packaging>=20.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (23.2)\n",
|
| 160 |
+
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.4.0)\n",
|
| 161 |
+
"Requirement already satisfied: numpy>=1.15.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.24.4)\n",
|
| 162 |
+
"Requirement already satisfied: language-data>=1.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from langcodes<4.0.0,>=3.2.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.2.0)\n",
|
| 163 |
+
"Requirement already satisfied: annotated-types>=0.4.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.0)\n",
|
| 164 |
+
"Requirement already satisfied: pydantic-core==2.18.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.18.2)\n",
|
| 165 |
+
"Requirement already satisfied: typing-extensions>=4.6.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.8.0)\n",
|
| 166 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.2)\n",
|
| 167 |
+
"Requirement already satisfied: idna<4,>=2.5 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.4)\n",
|
| 168 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.7)\n",
|
| 169 |
+
"Requirement already satisfied: certifi>=2017.4.17 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2023.7.22)\n",
|
| 170 |
+
"Requirement already satisfied: blis<0.8.0,>=0.7.8 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.11)\n",
|
| 171 |
+
"Requirement already satisfied: confection<1.0.0,>=0.0.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.1.4)\n",
|
| 172 |
+
"Requirement already satisfied: colorama in d:\\miniconda\\envs\\llm\\lib\\site-packages (from tqdm<5.0.0,>=4.38.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.4.6)\n",
|
| 173 |
+
"Requirement already satisfied: click<9.0.0,>=7.1.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from typer<0.10.0,>=0.3.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.1.7)\n",
|
| 174 |
+
"Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.16.0)\n",
|
| 175 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from jinja2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.1.3)\n",
|
| 176 |
+
"Requirement already satisfied: marisa-trie>=0.7.7 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from language-data>=1.2->langcodes<4.0.0,>=3.2.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.1)\n",
|
| 177 |
+
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
|
| 178 |
+
"You can now load the package via spacy.load('en_core_web_sm')\n"
|
| 179 |
+
]
|
| 180 |
+
}
|
| 181 |
+
],
|
| 182 |
+
"source": [
|
| 183 |
+
"#use spacy to get rid of the influence of standford-corenlp.jar, which requires java\n",
|
| 184 |
+
"! pip install spacy\n",
|
| 185 |
+
"! python -m spacy download en_core_web_sm "
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "markdown",
|
| 190 |
+
"metadata": {
|
| 191 |
+
"id": "RF9-uATHjSJT"
|
| 192 |
+
},
|
| 193 |
+
"source": [
|
| 194 |
+
"Ready to go!"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "code",
|
| 199 |
+
"execution_count": 2,
|
| 200 |
+
"metadata": {
|
| 201 |
+
"id": "UvK_ACyMeqDD"
|
| 202 |
+
},
|
| 203 |
+
"outputs": [],
|
| 204 |
+
"source": [
|
| 205 |
+
"from cidereval import cider, ciderD\n",
|
| 206 |
+
"import pandas as pd"
|
| 207 |
+
]
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"cell_type": "code",
|
| 211 |
+
"execution_count": 6,
|
| 212 |
+
"metadata": {},
|
| 213 |
+
"outputs": [
|
| 214 |
+
{
|
| 215 |
+
"name": "stdout",
|
| 216 |
+
"output_type": "stream",
|
| 217 |
+
"text": [
|
| 218 |
+
"CIDEr Score: 1.1466548664799776\n"
|
| 219 |
+
]
|
| 220 |
+
}
|
| 221 |
+
],
|
| 222 |
+
"source": [
|
| 223 |
+
"def calculate_cider_score(file_path,pred_column,ref_column):\n",
|
| 224 |
+
" \"\"\"\n",
|
| 225 |
+
" input:\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" file_path: table to be analysed.\n",
|
| 228 |
+
" pred_column: column name that the analysed model generated\n",
|
| 229 |
+
" ref_column: column name of the ground truth\n",
|
| 230 |
+
"\n",
|
| 231 |
+
" output:\n",
|
| 232 |
+
" score:{avg_score:xxx,scores:[array]}\n",
|
| 233 |
+
" \"\"\"\n",
|
| 234 |
+
" # 读取xlsx文件\n",
|
| 235 |
+
" df = pd.read_excel(file_path)\n",
|
| 236 |
+
"\n",
|
| 237 |
+
" # 创建存储ground_truth和generated句子的列表\n",
|
| 238 |
+
" references = []\n",
|
| 239 |
+
" candidates = []\n",
|
| 240 |
+
"\n",
|
| 241 |
+
" # 遍历每一行并将内容添加到对应的列表中\n",
|
| 242 |
+
" for index, row in df.iterrows():\n",
|
| 243 |
+
" # references.append([row['answers']])\n",
|
| 244 |
+
" # candidates.append([row['results']])\n",
|
| 245 |
+
" references.append([row[ref_column]])\n",
|
| 246 |
+
" candidates.append(row[pred_column])\n",
|
| 247 |
+
" # 创建Cider对象\n",
|
| 248 |
+
" # cider_scorer = Cider()\n",
|
| 249 |
+
"\n",
|
| 250 |
+
" # 计算CIDEr分数\n",
|
| 251 |
+
"\n",
|
| 252 |
+
" cider_score = cider(candidates,references,df=\"corpus\")\n",
|
| 253 |
+
"\n",
|
| 254 |
+
" return cider_score\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"# 调用函数并传入xlsx文件路径\n",
|
| 257 |
+
"file_path = 'output_sample.xls'\n",
|
| 258 |
+
"score = calculate_cider_score(file_path,\"results\",\"answers\")\n",
|
| 259 |
+
"print(\"CIDEr Score:\", score['avg_score'])"
|
| 260 |
+
]
|
| 261 |
+
}
|
| 262 |
+
],
|
| 263 |
+
"metadata": {
|
| 264 |
+
"colab": {
|
| 265 |
+
"authorship_tag": "ABX9TyMT4TNjFKKrEMcMc0uZ6Ubr",
|
| 266 |
+
"include_colab_link": true,
|
| 267 |
+
"name": "cider_demo.ipynb",
|
| 268 |
+
"provenance": []
|
| 269 |
+
},
|
| 270 |
+
"kernelspec": {
|
| 271 |
+
"display_name": "asiw9691_conda_env (Conda)",
|
| 272 |
+
"language": "python",
|
| 273 |
+
"name": "sys_asiw9691_conda_env"
|
| 274 |
+
},
|
| 275 |
+
"language_info": {
|
| 276 |
+
"codemirror_mode": {
|
| 277 |
+
"name": "ipython",
|
| 278 |
+
"version": 3
|
| 279 |
+
},
|
| 280 |
+
"file_extension": ".py",
|
| 281 |
+
"mimetype": "text/x-python",
|
| 282 |
+
"name": "python",
|
| 283 |
+
"nbconvert_exporter": "python",
|
| 284 |
+
"pygments_lexer": "ipython3",
|
| 285 |
+
"version": "3.10.4"
|
| 286 |
+
}
|
| 287 |
+
},
|
| 288 |
+
"nbformat": 4,
|
| 289 |
+
"nbformat_minor": 4
|
| 290 |
+
}
|
evaluation/cider_score/cidereval/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'tylin'
|
| 2 |
+
# edited by Michele Cafagna
|
| 3 |
+
from cidereval.cider.cider import Cider
|
| 4 |
+
from cidereval.ciderD.ciderD import CiderD
|
| 5 |
+
from cidereval.scorers import cider, ciderD
|
evaluation/cider_score/cidereval/cider/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'tylin'
|
evaluation/cider_score/cidereval/cider/cider.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Filename: cider.py
|
| 2 |
+
#
|
| 3 |
+
#
|
| 4 |
+
# Description: Describes the class to compute the CIDEr
|
| 5 |
+
# (Consensus-Based Image Description Evaluation) Metric
|
| 6 |
+
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
|
| 7 |
+
#
|
| 8 |
+
# Creation Date: Sun Feb 8 14:16:54 2015
|
| 9 |
+
#
|
| 10 |
+
# Authors: Ramakrishna Vedantam <vrama91@vt.edu> and
|
| 11 |
+
# Tsung-Yi Lin <tl483@cornell.edu>
|
| 12 |
+
|
| 13 |
+
# edited by Michele Cafagna
|
| 14 |
+
|
| 15 |
+
from .cider_scorer import CiderScorer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Cider:
|
| 19 |
+
"""
|
| 20 |
+
Main Class to compute the CIDEr metric
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, n=4, df="corpus"):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the CIDEr scoring function
|
| 26 |
+
: param n (int): n-gram size
|
| 27 |
+
: param df (string): specifies where to get the IDF values from
|
| 28 |
+
takes values 'corpus', 'coco-val'
|
| 29 |
+
: return: None
|
| 30 |
+
"""
|
| 31 |
+
# set cider to sum over 1 to 4-grams
|
| 32 |
+
self._n = n
|
| 33 |
+
self._df = df
|
| 34 |
+
self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
|
| 35 |
+
|
| 36 |
+
def compute_score(self, gts, res):
|
| 37 |
+
"""
|
| 38 |
+
Main function to compute CIDEr score
|
| 39 |
+
: param gts (dict) : {image:tokenized reference sentence}
|
| 40 |
+
: param res (dict) : {image:tokenized candidate sentence}
|
| 41 |
+
: return: cider (float) : computed CIDEr score for the corpus
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# clear all the previous hypos and refs
|
| 45 |
+
self.cider_scorer.clear()
|
| 46 |
+
|
| 47 |
+
for res_id in res:
|
| 48 |
+
|
| 49 |
+
hypo = res_id['caption']
|
| 50 |
+
ref = gts[res_id['image_id']]
|
| 51 |
+
|
| 52 |
+
# Sanity check.
|
| 53 |
+
assert(type(hypo) is list)
|
| 54 |
+
assert(len(hypo) == 1)
|
| 55 |
+
assert(type(ref) is list)
|
| 56 |
+
assert(len(ref) > 0)
|
| 57 |
+
self.cider_scorer += (hypo[0], ref)
|
| 58 |
+
|
| 59 |
+
(score, scores) = self.cider_scorer.compute_score()
|
| 60 |
+
|
| 61 |
+
return score, scores
|
| 62 |
+
|
| 63 |
+
def save_df(self, df_name="corpus"):
|
| 64 |
+
self.cider_scorer.save_df(df_name)
|
| 65 |
+
|
| 66 |
+
def method(self):
|
| 67 |
+
return "CIDEr"
|
evaluation/cider_score/cidereval/cider/cider_scorer.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Tsung-Yi Lin <tl483@cornell.edu>
|
| 3 |
+
# Ramakrishna Vedantam <vrama91@vt.edu>
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import pickle
|
| 8 |
+
import math
|
| 9 |
+
from copy import copy
|
| 10 |
+
|
| 11 |
+
from importlib_resources import files, as_file
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cidereval.data
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def precook(s, n=4, out=False):
|
| 17 |
+
"""
|
| 18 |
+
Takes a string as input and returns an object that can be given to
|
| 19 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
| 20 |
+
can take string arguments as well.
|
| 21 |
+
:param s: string : sentence to be converted into ngrams
|
| 22 |
+
:param n: int : number of ngrams for which representation is calculated
|
| 23 |
+
:return: term frequency vector for occuring ngrams
|
| 24 |
+
"""
|
| 25 |
+
words = s.split()
|
| 26 |
+
counts = defaultdict(int)
|
| 27 |
+
for k in range(1, n + 1):
|
| 28 |
+
for i in range(len(words) - k + 1):
|
| 29 |
+
ngram = tuple(words[i:i + k])
|
| 30 |
+
counts[ngram] += 1
|
| 31 |
+
return counts
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def cook_refs(refs, n=4): # lhuang: oracle will call with "average"
|
| 35 |
+
'''Takes a list of reference sentences for a single segment
|
| 36 |
+
and returns an object that encapsulates everything that BLEU
|
| 37 |
+
needs to know about them.
|
| 38 |
+
:param refs: list of string : reference sentences for some image
|
| 39 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 40 |
+
:return: result (list of dict)
|
| 41 |
+
'''
|
| 42 |
+
return [precook(ref, n) for ref in refs]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def cook_test(test, n=4):
|
| 46 |
+
'''Takes a test sentence and returns an object that
|
| 47 |
+
encapsulates everything that BLEU needs to know about it.
|
| 48 |
+
:param test: list of string : hypothesis sentence for some image
|
| 49 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 50 |
+
:return: result (dict)
|
| 51 |
+
'''
|
| 52 |
+
return precook(test, n, True)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CiderScorer(object):
|
| 56 |
+
"""CIDEr scorer.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def save_df(self, df_name="corpus", path=None):
|
| 60 |
+
"""Save the idf computed in corpus mode
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
df_name (str, optional): [description]. Defaults to "corpus". name of idf file
|
| 64 |
+
(without the file exntension)
|
| 65 |
+
|
| 66 |
+
df_name (str, optional): [description]. Defaults to None. path of the idf if note provided
|
| 67 |
+
it will be used the home directory
|
| 68 |
+
Raises:
|
| 69 |
+
ValueError: [description] if you try to call this method before computing the scores
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
if path:
|
| 73 |
+
path = Path(path)
|
| 74 |
+
|
| 75 |
+
if not path.exists():
|
| 76 |
+
path=Path.home()
|
| 77 |
+
print(f"the path provided is not valid. The df will be saved in {path}")
|
| 78 |
+
else:
|
| 79 |
+
path=Path.home()
|
| 80 |
+
print(f"the path provided is not valid. The df will be saved in {path}")
|
| 81 |
+
|
| 82 |
+
filename = Path(path, df_name + '.p')
|
| 83 |
+
|
| 84 |
+
if len(self.document_frequency) > 0:
|
| 85 |
+
with open(filename, "wb") as fp:
|
| 86 |
+
|
| 87 |
+
df_idf = {
|
| 88 |
+
"ref_len" : np.log(float(len(self.crefs))),
|
| 89 |
+
"df": self.document_frequency
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
pickle.dump(df_idf, fp)
|
| 93 |
+
print(f"saved to {filename}")
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError("document frequency not computed run 'compute_score'")
|
| 96 |
+
|
| 97 |
+
def copy(self):
|
| 98 |
+
''' copy the refs.'''
|
| 99 |
+
new = CiderScorer(n=self.n)
|
| 100 |
+
new.ctest = copy.copy(self.ctest)
|
| 101 |
+
new.crefs = copy.copy(self.crefs)
|
| 102 |
+
return new
|
| 103 |
+
|
| 104 |
+
def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
|
| 105 |
+
''' singular instance '''
|
| 106 |
+
self.n = n
|
| 107 |
+
self.sigma = sigma
|
| 108 |
+
self.crefs = []
|
| 109 |
+
self.ctest = []
|
| 110 |
+
self.ref_len = None
|
| 111 |
+
self.df_mode = df_mode
|
| 112 |
+
|
| 113 |
+
if self.df_mode != "corpus":
|
| 114 |
+
if self.df_mode !="coco-val":
|
| 115 |
+
try:
|
| 116 |
+
with open(self.df_mode, 'rb') as fp:
|
| 117 |
+
df = pickle.load(fp, encoding='iso-8859-1')
|
| 118 |
+
except FileNotFoundError as e:
|
| 119 |
+
print(f"Error retrieveing {self.df_mode}.p df_mode set to 'coco-val'")
|
| 120 |
+
|
| 121 |
+
self.df_mode="coco-val"
|
| 122 |
+
df_path = files(cidereval.data).joinpath(self.df_mode + '.p')
|
| 123 |
+
with as_file(df_path)as res:
|
| 124 |
+
with open(res, 'rb') as fp:
|
| 125 |
+
df = pickle.load(fp, encoding='iso-8859-1')
|
| 126 |
+
else:
|
| 127 |
+
df_path = files(cidereval.data).joinpath(self.df_mode + '.p')
|
| 128 |
+
with as_file(df_path)as res:
|
| 129 |
+
with open(res, 'rb') as fp:
|
| 130 |
+
df = pickle.load(fp, encoding='iso-8859-1')
|
| 131 |
+
|
| 132 |
+
#df_path = os.path.join('data', df_mode + '.p')
|
| 133 |
+
#df = pickle.load(open(os.path.join('data', df_mode + '.p'), 'rb'), encoding='iso-8859-1') # TODO fix path
|
| 134 |
+
self.document_frequency = df['df']
|
| 135 |
+
self.ref_len = df['ref_len']
|
| 136 |
+
self.cook_append(test, refs)
|
| 137 |
+
|
| 138 |
+
def clear(self):
|
| 139 |
+
self.crefs = []
|
| 140 |
+
self.ctest = []
|
| 141 |
+
|
| 142 |
+
def cook_append(self, test, refs):
|
| 143 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
| 144 |
+
|
| 145 |
+
if refs is not None:
|
| 146 |
+
self.crefs.append(cook_refs(refs))
|
| 147 |
+
if test is not None:
|
| 148 |
+
self.ctest.append(cook_test(test)) ## N.B.: -1
|
| 149 |
+
else:
|
| 150 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
| 151 |
+
|
| 152 |
+
def size(self):
|
| 153 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
| 154 |
+
return len(self.crefs)
|
| 155 |
+
|
| 156 |
+
def __iadd__(self, other):
|
| 157 |
+
'''add an instance (e.g., from another sentence).'''
|
| 158 |
+
|
| 159 |
+
if type(other) is tuple:
|
| 160 |
+
# avoid creating new CiderScorer instances
|
| 161 |
+
self.cook_append(other[0], other[1])
|
| 162 |
+
else:
|
| 163 |
+
self.ctest.extend(other.ctest)
|
| 164 |
+
self.crefs.extend(other.crefs)
|
| 165 |
+
|
| 166 |
+
return self
|
| 167 |
+
|
| 168 |
+
def compute_doc_freq(self):
|
| 169 |
+
'''
|
| 170 |
+
Compute term frequency for reference data.
|
| 171 |
+
This will be used to compute idf (inverse document frequency later)
|
| 172 |
+
The term frequency is stored in the object
|
| 173 |
+
:return: None
|
| 174 |
+
'''
|
| 175 |
+
for refs in self.crefs:
|
| 176 |
+
# refs, k ref captions of one image
|
| 177 |
+
for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]):
|
| 178 |
+
self.document_frequency[ngram] += 1
|
| 179 |
+
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
| 180 |
+
|
| 181 |
+
def compute_cider(self):
|
| 182 |
+
def counts2vec(cnts):
|
| 183 |
+
"""
|
| 184 |
+
Function maps counts of ngram to vector of tfidf weights.
|
| 185 |
+
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
|
| 186 |
+
The n-th entry of array denotes length of n-grams.
|
| 187 |
+
:param cnts:
|
| 188 |
+
:return: vec (array of dict), norm (array of float), length (int)
|
| 189 |
+
"""
|
| 190 |
+
vec = [defaultdict(float) for _ in range(self.n)]
|
| 191 |
+
length = 0
|
| 192 |
+
norm = [0.0 for _ in range(self.n)]
|
| 193 |
+
for (ngram, term_freq) in cnts.items():
|
| 194 |
+
# give word count 1 if it doesn't appear in reference corpus
|
| 195 |
+
df = np.log(max(1.0, self.document_frequency[ngram]))
|
| 196 |
+
# ngram index
|
| 197 |
+
n = len(ngram) - 1
|
| 198 |
+
# tf (term_freq) * idf (precomputed idf) for n-grams
|
| 199 |
+
vec[n][ngram] = float(term_freq) * (self.ref_len - df)
|
| 200 |
+
# compute norm for the vector. the norm will be used for
|
| 201 |
+
# computing similarity
|
| 202 |
+
norm[n] += pow(vec[n][ngram], 2)
|
| 203 |
+
|
| 204 |
+
if n == 1:
|
| 205 |
+
length += term_freq
|
| 206 |
+
norm = [np.sqrt(n) for n in norm]
|
| 207 |
+
return vec, norm, length
|
| 208 |
+
|
| 209 |
+
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
|
| 210 |
+
'''
|
| 211 |
+
Compute the cosine similarity of two vectors.
|
| 212 |
+
:param vec_hyp: array of dictionary for vector corresponding to hypothesis
|
| 213 |
+
:param vec_ref: array of dictionary for vector corresponding to reference
|
| 214 |
+
:param norm_hyp: array of float for vector corresponding to hypothesis
|
| 215 |
+
:param norm_ref: array of float for vector corresponding to reference
|
| 216 |
+
:param length_hyp: int containing length of hypothesis
|
| 217 |
+
:param length_ref: int containing length of reference
|
| 218 |
+
:return: array of score for each n-grams cosine similarity
|
| 219 |
+
'''
|
| 220 |
+
delta = float(length_hyp - length_ref)
|
| 221 |
+
# measure consine similarity
|
| 222 |
+
val = np.array([0.0 for _ in range(self.n)])
|
| 223 |
+
for n in range(self.n):
|
| 224 |
+
# ngram
|
| 225 |
+
for (ngram, count) in vec_hyp[n].items():
|
| 226 |
+
val[n] += vec_hyp[n][ngram] * vec_ref[n][ngram]
|
| 227 |
+
|
| 228 |
+
if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
|
| 229 |
+
val[n] /= (norm_hyp[n] * norm_ref[n])
|
| 230 |
+
|
| 231 |
+
assert (not math.isnan(val[n]))
|
| 232 |
+
return val
|
| 233 |
+
|
| 234 |
+
# compute log reference length
|
| 235 |
+
if self.df_mode == "corpus":
|
| 236 |
+
self.ref_len = np.log(float(len(self.crefs)))
|
| 237 |
+
#elif self.df_mode == "coco-val":
|
| 238 |
+
# if coco option selected, use length of coco-val set
|
| 239 |
+
#self.ref_len = np.log(float(40504))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
scores = []
|
| 244 |
+
for test, refs in zip(self.ctest, self.crefs):
|
| 245 |
+
# compute vector for test captions
|
| 246 |
+
vec, norm, length = counts2vec(test)
|
| 247 |
+
# compute vector for ref captions
|
| 248 |
+
score = np.array([0.0 for _ in range(self.n)])
|
| 249 |
+
for ref in refs:
|
| 250 |
+
vec_ref, norm_ref, length_ref = counts2vec(ref)
|
| 251 |
+
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
|
| 252 |
+
# change by vrama91 - mean of ngram scores, instead of sum
|
| 253 |
+
score_avg = np.mean(score)
|
| 254 |
+
# divide by number of references
|
| 255 |
+
score_avg /= len(refs)
|
| 256 |
+
# multiply score by 10
|
| 257 |
+
score_avg *= 10.0
|
| 258 |
+
# append score of an image to the score list
|
| 259 |
+
scores.append(score_avg)
|
| 260 |
+
return scores
|
| 261 |
+
|
| 262 |
+
def compute_score(self, option=None, verbose=0):
|
| 263 |
+
# compute idf
|
| 264 |
+
if self.df_mode == "corpus":
|
| 265 |
+
self.document_frequency = defaultdict(float)
|
| 266 |
+
self.compute_doc_freq()
|
| 267 |
+
# assert to check document frequency
|
| 268 |
+
assert (len(self.ctest) >= max(self.document_frequency.values()))
|
| 269 |
+
# import json for now and write the corresponding files
|
| 270 |
+
# compute cider score
|
| 271 |
+
score = self.compute_cider()
|
| 272 |
+
# debug
|
| 273 |
+
# print score
|
| 274 |
+
return np.mean(np.array(score)), np.array(score)
|
evaluation/cider_score/cidereval/ciderD/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'tylin'
|
evaluation/cider_score/cidereval/ciderD/ciderD.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Filename: ciderD.py
|
| 2 |
+
#
|
| 3 |
+
# Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric
|
| 4 |
+
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
|
| 5 |
+
#
|
| 6 |
+
# Creation Date: Sun Feb 8 14:16:54 2015
|
| 7 |
+
#
|
| 8 |
+
# Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>
|
| 9 |
+
|
| 10 |
+
from .ciderD_scorer import CiderScorer
|
| 11 |
+
import pdb
|
| 12 |
+
|
| 13 |
+
class CiderD:
|
| 14 |
+
"""
|
| 15 |
+
Main Class to compute the CIDEr metric
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, n=4, sigma=6.0, df="corpus"):
|
| 19 |
+
# set cider to sum over 1 to 4-grams
|
| 20 |
+
self._n = n
|
| 21 |
+
# set the standard deviation parameter for gaussian penalty
|
| 22 |
+
self._sigma = sigma
|
| 23 |
+
# set which where to compute document frequencies from
|
| 24 |
+
self._df = df
|
| 25 |
+
self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
|
| 26 |
+
|
| 27 |
+
def compute_score(self, gts, res):
|
| 28 |
+
"""
|
| 29 |
+
Main function to compute CIDEr score
|
| 30 |
+
:param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
|
| 31 |
+
ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
|
| 32 |
+
:return: cider (float) : computed CIDEr score for the corpus
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# clear all the previous hypos and refs
|
| 36 |
+
self.cider_scorer.clear()
|
| 37 |
+
for res_id in res:
|
| 38 |
+
|
| 39 |
+
hypo = res_id['caption']
|
| 40 |
+
ref = gts[res_id['image_id']]
|
| 41 |
+
|
| 42 |
+
# Sanity check.
|
| 43 |
+
assert(type(hypo) is list)
|
| 44 |
+
assert(len(hypo) == 1)
|
| 45 |
+
assert(type(ref) is list)
|
| 46 |
+
assert(len(ref) > 0)
|
| 47 |
+
self.cider_scorer += (hypo[0], ref)
|
| 48 |
+
|
| 49 |
+
(score, scores) = self.cider_scorer.compute_score()
|
| 50 |
+
|
| 51 |
+
return score, scores
|
| 52 |
+
|
| 53 |
+
def save_df(self, df_name="corpus"):
|
| 54 |
+
self.cider_scorer.save_df(df_name)
|
| 55 |
+
|
| 56 |
+
def method(self):
|
| 57 |
+
return "CIDEr-D"
|
evaluation/cider_score/cidereval/ciderD/ciderD_scorer.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Tsung-Yi Lin <tl483@cornell.edu>
|
| 3 |
+
# Ramakrishna Vedantam <vrama91@vt.edu>
|
| 4 |
+
|
| 5 |
+
# edited by Michele Cafagna
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import pickle
|
| 10 |
+
import math
|
| 11 |
+
from copy import copy
|
| 12 |
+
|
| 13 |
+
from importlib_resources import files, as_file
|
| 14 |
+
import numpy as np
|
| 15 |
+
import cidereval.data
|
| 16 |
+
|
| 17 |
+
def precook(s, n=4, out=False):
|
| 18 |
+
"""
|
| 19 |
+
Takes a string as input and returns an object that can be given to
|
| 20 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
| 21 |
+
can take string arguments as well.
|
| 22 |
+
:param s: string : sentence to be converted into ngrams
|
| 23 |
+
:param n: int : number of ngrams for which representation is calculated
|
| 24 |
+
:return: term frequency vector for occuring ngrams
|
| 25 |
+
"""
|
| 26 |
+
words = s.split()
|
| 27 |
+
counts = defaultdict(int)
|
| 28 |
+
for k in range(1,n+1):
|
| 29 |
+
for i in range(len(words)-k+1):
|
| 30 |
+
ngram = tuple(words[i:i+k])
|
| 31 |
+
counts[ngram] += 1
|
| 32 |
+
return counts
|
| 33 |
+
|
| 34 |
+
def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
|
| 35 |
+
'''Takes a list of reference sentences for a single segment
|
| 36 |
+
and returns an object that encapsulates everything that BLEU
|
| 37 |
+
needs to know about them.
|
| 38 |
+
:param refs: list of string : reference sentences for some image
|
| 39 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 40 |
+
:return: result (list of dict)
|
| 41 |
+
'''
|
| 42 |
+
return [precook(ref, n) for ref in refs]
|
| 43 |
+
|
| 44 |
+
def cook_test(test, n=4):
|
| 45 |
+
'''Takes a test sentence and returns an object that
|
| 46 |
+
encapsulates everything that BLEU needs to know about it.
|
| 47 |
+
:param test: list of string : hypothesis sentence for some image
|
| 48 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 49 |
+
:return: result (dict)
|
| 50 |
+
'''
|
| 51 |
+
return precook(test, n, True)
|
| 52 |
+
|
| 53 |
+
class CiderScorer(object):
|
| 54 |
+
"""CIDEr scorer.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def save_df(self, df_name="corpus", path=None):
|
| 58 |
+
"""Save the idf computed in corpus mode
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
df_name (str, optional): [description]. Defaults to "corpus". name of idf file
|
| 62 |
+
(without the file exntension)
|
| 63 |
+
|
| 64 |
+
df_name (str, optional): [description]. Defaults to None. path of the idf if note provided
|
| 65 |
+
it will be used the home directory
|
| 66 |
+
Raises:
|
| 67 |
+
ValueError: [description] if you try to call this method before computing the scores
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
if path:
|
| 71 |
+
path = Path(path)
|
| 72 |
+
|
| 73 |
+
if not path.exists():
|
| 74 |
+
path=Path.home()
|
| 75 |
+
print(f"the path provided is not valid. The df will be saved in {path}")
|
| 76 |
+
else:
|
| 77 |
+
path=Path.home()
|
| 78 |
+
print(f"the path provided is not valid. The df will be saved in {path}")
|
| 79 |
+
|
| 80 |
+
filename = Path(path, df_name + '.p')
|
| 81 |
+
|
| 82 |
+
if len(self.document_frequency) > 0:
|
| 83 |
+
with open(filename, "wb") as fp:
|
| 84 |
+
|
| 85 |
+
df_idf = {
|
| 86 |
+
"ref_len" : np.log(float(len(self.crefs))),
|
| 87 |
+
"df": self.document_frequency
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
pickle.dump(df_idf, fp)
|
| 91 |
+
print(f"saved to {filename}")
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError("document frequency not computed run 'compute_score'")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def copy(self):
|
| 97 |
+
''' copy the refs.'''
|
| 98 |
+
new = CiderScorer(n=self.n)
|
| 99 |
+
new.ctest = copy.copy(self.ctest)
|
| 100 |
+
new.crefs = copy.copy(self.crefs)
|
| 101 |
+
return new
|
| 102 |
+
|
| 103 |
+
def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
|
| 104 |
+
''' singular instance '''
|
| 105 |
+
self.n = n
|
| 106 |
+
self.sigma = sigma
|
| 107 |
+
self.crefs = []
|
| 108 |
+
self.ctest = []
|
| 109 |
+
self.df_mode = df_mode
|
| 110 |
+
self.ref_len = None
|
| 111 |
+
|
| 112 |
+
if self.df_mode != "corpus":
|
| 113 |
+
if self.df_mode !="coco-val":
|
| 114 |
+
try:
|
| 115 |
+
with open(self.df_mode, 'rb') as fp:
|
| 116 |
+
df = pickle.load(fp, encoding='iso-8859-1')
|
| 117 |
+
except FileNotFoundError as e:
|
| 118 |
+
print(f"Error retrieveing {self.df_mode}. df_mode set to 'coco-val'")
|
| 119 |
+
else:
|
| 120 |
+
df_path = files(cidereval.data).joinpath(self.df_mode + '.p')
|
| 121 |
+
with as_file(df_path)as res:
|
| 122 |
+
with open(res, 'rb') as fp:
|
| 123 |
+
df = pickle.load(fp, encoding='iso-8859-1')
|
| 124 |
+
|
| 125 |
+
self.document_frequency = df['df']
|
| 126 |
+
self.ref_len = df['ref_len']
|
| 127 |
+
|
| 128 |
+
self.cook_append(test, refs)
|
| 129 |
+
|
| 130 |
+
def clear(self):
|
| 131 |
+
self.crefs = []
|
| 132 |
+
self.ctest = []
|
| 133 |
+
|
| 134 |
+
def cook_append(self, test, refs):
|
| 135 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
| 136 |
+
|
| 137 |
+
if refs is not None:
|
| 138 |
+
self.crefs.append(cook_refs(refs))
|
| 139 |
+
if test is not None:
|
| 140 |
+
self.ctest.append(cook_test(test)) ## N.B.: -1
|
| 141 |
+
else:
|
| 142 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
| 143 |
+
|
| 144 |
+
def size(self):
|
| 145 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
| 146 |
+
return len(self.crefs)
|
| 147 |
+
|
| 148 |
+
def __iadd__(self, other):
|
| 149 |
+
'''add an instance (e.g., from another sentence).'''
|
| 150 |
+
|
| 151 |
+
if type(other) is tuple:
|
| 152 |
+
## avoid creating new CiderScorer instances
|
| 153 |
+
self.cook_append(other[0], other[1])
|
| 154 |
+
else:
|
| 155 |
+
self.ctest.extend(other.ctest)
|
| 156 |
+
self.crefs.extend(other.crefs)
|
| 157 |
+
|
| 158 |
+
return self
|
| 159 |
+
def compute_doc_freq(self):
|
| 160 |
+
'''
|
| 161 |
+
Compute term frequency for reference data.
|
| 162 |
+
This will be used to compute idf (inverse document frequency later)
|
| 163 |
+
The term frequency is stored in the object
|
| 164 |
+
:return: None
|
| 165 |
+
'''
|
| 166 |
+
for refs in self.crefs:
|
| 167 |
+
# refs, k ref captions of one image
|
| 168 |
+
for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
|
| 169 |
+
self.document_frequency[ngram] += 1
|
| 170 |
+
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
| 171 |
+
|
| 172 |
+
def compute_cider(self):
|
| 173 |
+
def counts2vec(cnts):
|
| 174 |
+
"""
|
| 175 |
+
Function maps counts of ngram to vector of tfidf weights.
|
| 176 |
+
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
|
| 177 |
+
The n-th entry of array denotes length of n-grams.
|
| 178 |
+
:param cnts:
|
| 179 |
+
:return: vec (array of dict), norm (array of float), length (int)
|
| 180 |
+
"""
|
| 181 |
+
vec = [defaultdict(float) for _ in range(self.n)]
|
| 182 |
+
length = 0
|
| 183 |
+
norm = [0.0 for _ in range(self.n)]
|
| 184 |
+
for (ngram,term_freq) in cnts.items():
|
| 185 |
+
# give word count 1 if it doesn't appear in reference corpus
|
| 186 |
+
df = np.log(max(1.0, self.document_frequency[ngram]))
|
| 187 |
+
# ngram index
|
| 188 |
+
n = len(ngram)-1
|
| 189 |
+
# tf (term_freq) * idf (precomputed idf) for n-grams
|
| 190 |
+
vec[n][ngram] = float(term_freq)*(self.ref_len - df)
|
| 191 |
+
# compute norm for the vector. the norm will be used for computing similarity
|
| 192 |
+
norm[n] += pow(vec[n][ngram], 2)
|
| 193 |
+
|
| 194 |
+
if n == 1:
|
| 195 |
+
length += term_freq
|
| 196 |
+
norm = [np.sqrt(n) for n in norm]
|
| 197 |
+
return vec, norm, length
|
| 198 |
+
|
| 199 |
+
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
|
| 200 |
+
'''
|
| 201 |
+
Compute the cosine similarity of two vectors.
|
| 202 |
+
:param vec_hyp: array of dictionary for vector corresponding to hypothesis
|
| 203 |
+
:param vec_ref: array of dictionary for vector corresponding to reference
|
| 204 |
+
:param norm_hyp: array of float for vector corresponding to hypothesis
|
| 205 |
+
:param norm_ref: array of float for vector corresponding to reference
|
| 206 |
+
:param length_hyp: int containing length of hypothesis
|
| 207 |
+
:param length_ref: int containing length of reference
|
| 208 |
+
:return: array of score for each n-grams cosine similarity
|
| 209 |
+
'''
|
| 210 |
+
delta = float(length_hyp - length_ref)
|
| 211 |
+
# measure consine similarity
|
| 212 |
+
val = np.array([0.0 for _ in range(self.n)])
|
| 213 |
+
for n in range(self.n):
|
| 214 |
+
# ngram
|
| 215 |
+
for (ngram,count) in vec_hyp[n].items():
|
| 216 |
+
# vrama91 : added clipping
|
| 217 |
+
val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
|
| 218 |
+
|
| 219 |
+
if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
|
| 220 |
+
val[n] /= (norm_hyp[n]*norm_ref[n])
|
| 221 |
+
|
| 222 |
+
assert(not math.isnan(val[n]))
|
| 223 |
+
# vrama91: added a length based gaussian penalty
|
| 224 |
+
val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
|
| 225 |
+
return val
|
| 226 |
+
|
| 227 |
+
# compute log reference length
|
| 228 |
+
if self.df_mode == "corpus":
|
| 229 |
+
self.ref_len = np.log(float(len(self.crefs)))
|
| 230 |
+
#elif self.df_mode == "coco-val":
|
| 231 |
+
# if coco option selected, use length of coco-val set
|
| 232 |
+
#self.ref_len = np.log(float(40504))
|
| 233 |
+
|
| 234 |
+
scores = []
|
| 235 |
+
for test, refs in zip(self.ctest, self.crefs):
|
| 236 |
+
# compute vector for test captions
|
| 237 |
+
vec, norm, length = counts2vec(test)
|
| 238 |
+
# compute vector for ref captions
|
| 239 |
+
score = np.array([0.0 for _ in range(self.n)])
|
| 240 |
+
for ref in refs:
|
| 241 |
+
vec_ref, norm_ref, length_ref = counts2vec(ref)
|
| 242 |
+
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
|
| 243 |
+
# change by vrama91 - mean of ngram scores, instead of sum
|
| 244 |
+
score_avg = np.mean(score)
|
| 245 |
+
# divide by number of references
|
| 246 |
+
score_avg /= len(refs)
|
| 247 |
+
# multiply score by 10
|
| 248 |
+
score_avg *= 10.0
|
| 249 |
+
# append score of an image to the score list
|
| 250 |
+
scores.append(score_avg)
|
| 251 |
+
return scores
|
| 252 |
+
|
| 253 |
+
def compute_score(self, option=None, verbose=0):
|
| 254 |
+
# compute idf
|
| 255 |
+
if self.df_mode == "corpus":
|
| 256 |
+
self.document_frequency = defaultdict(float)
|
| 257 |
+
self.compute_doc_freq()
|
| 258 |
+
# assert to check document frequency
|
| 259 |
+
assert(len(self.ctest) >= max(self.document_frequency.values()))
|
| 260 |
+
# import json for now and write the corresponding files
|
| 261 |
+
# compute cider score
|
| 262 |
+
score = self.compute_cider()
|
| 263 |
+
# debug
|
| 264 |
+
# print score
|
| 265 |
+
return np.mean(np.array(score)), np.array(score)
|
evaluation/cider_score/cidereval/data/__init__.py
ADDED
|
File without changes
|
evaluation/cider_score/cidereval/data/coco-val.p
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:48e79470ebb58f0251df49dca5c9976a726a9e41b4a1d82be332b6f676b6950a
|
| 3 |
+
size 73520211
|
evaluation/cider_score/cidereval/eval.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'rama'
|
| 2 |
+
from .tokenizer.ptbtokenizer import PTBTokenizer
|
| 3 |
+
from .cider.cider import Cider
|
| 4 |
+
from .ciderD.ciderD import CiderD
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CIDErEvalCap:
|
| 8 |
+
def __init__(self, gts, res, df):
|
| 9 |
+
print('tokenization...')
|
| 10 |
+
tokenizer = PTBTokenizer('gts')
|
| 11 |
+
_gts = tokenizer.tokenize(gts)
|
| 12 |
+
print('tokenized refs')
|
| 13 |
+
tokenizer = PTBTokenizer('res')
|
| 14 |
+
_res = tokenizer.tokenize(res)
|
| 15 |
+
print('tokenized cands')
|
| 16 |
+
|
| 17 |
+
self.gts = _gts
|
| 18 |
+
self.res = _res
|
| 19 |
+
self.df = df
|
| 20 |
+
|
| 21 |
+
def evaluate(self):
|
| 22 |
+
# =================================================
|
| 23 |
+
# Set up scorers
|
| 24 |
+
# =================================================
|
| 25 |
+
|
| 26 |
+
print('setting up scorers...')
|
| 27 |
+
scorers = [
|
| 28 |
+
(Cider(df=self.df), "CIDEr"), (CiderD(df=self.df), "CIDErD")
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
# =================================================
|
| 32 |
+
# Compute scores
|
| 33 |
+
# =================================================
|
| 34 |
+
metric_scores = {}
|
| 35 |
+
for scorer, method in scorers:
|
| 36 |
+
print('computing %s score...' % (scorer.method()))
|
| 37 |
+
score, scores = scorer.compute_score(self.gts, self.res)
|
| 38 |
+
print("Mean %s score: %0.3f" % (method, score))
|
| 39 |
+
metric_scores[method] = list(scores)
|
| 40 |
+
return metric_scores
|
evaluation/cider_score/cidereval/scorers.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cidereval import CiderD, Cider
|
| 2 |
+
from cidereval.tokenizer import PTBTokenizer
|
| 3 |
+
from cidereval.tokenizer import SimpleTokenizer
|
| 4 |
+
|
| 5 |
+
def _preprocess_for_cider(refs, preds):
|
| 6 |
+
r"""
|
| 7 |
+
Convert preds and refs to the cider data format
|
| 8 |
+
|
| 9 |
+
refs: List[List[str]]
|
| 10 |
+
preds : List[str]
|
| 11 |
+
|
| 12 |
+
return gts: Dict[str : List[Dict['caption':str] : str ]],
|
| 13 |
+
res: List[Dict['image_id':str]: 'caption':str]
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
assert len(refs) == len(preds)
|
| 17 |
+
|
| 18 |
+
gts = {}
|
| 19 |
+
res = []
|
| 20 |
+
|
| 21 |
+
for i, (caps, pred) in enumerate(zip(refs, preds)):
|
| 22 |
+
gts[i] = [{ 'caption': cap } for cap in caps ]
|
| 23 |
+
|
| 24 |
+
res.append({ 'image_id': i,
|
| 25 |
+
'caption': pred})
|
| 26 |
+
return gts, res
|
| 27 |
+
|
| 28 |
+
def cider(predictions, references, df="coco-val"):
|
| 29 |
+
r"""
|
| 30 |
+
Compute the cider score for the given predictions and references
|
| 31 |
+
|
| 32 |
+
predictions : List[str], model's predictions
|
| 33 |
+
references: List[List[str]], references
|
| 34 |
+
df: str, either 'coco-val' or 'corpus' (default : 'coco-val'). If 'coco-val' the TF-IDF COCO validation split is \\
|
| 35 |
+
used. If 'corpus' the TF-IDF is computed over the reference set provided.
|
| 36 |
+
|
| 37 |
+
returns {"avg_score": mp.float, "scores": np.array(np.float)}
|
| 38 |
+
"""
|
| 39 |
+
gts, res = _preprocess_for_cider(references, predictions)
|
| 40 |
+
tokenizer_res = SimpleTokenizer('res')
|
| 41 |
+
tokenizer_gts = SimpleTokenizer('gts')
|
| 42 |
+
|
| 43 |
+
_gts = tokenizer_gts.tokenize(gts)
|
| 44 |
+
_res = tokenizer_res.tokenize(res)
|
| 45 |
+
|
| 46 |
+
scorer = Cider(df=df)
|
| 47 |
+
|
| 48 |
+
score, scores = scorer.compute_score(_gts, _res)
|
| 49 |
+
|
| 50 |
+
return {"avg_score": score, "scores": scores}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def ciderD(predictions, references, df="coco-va"):
|
| 54 |
+
r"""
|
| 55 |
+
Compute the ciderD score for the given predictions and references
|
| 56 |
+
|
| 57 |
+
predictions : List[str], model's predictions
|
| 58 |
+
references: List[List[str]], references
|
| 59 |
+
df: str, either 'coco-val' or 'corpus' (default : 'coco-val'). If 'coco-val' the TF-IDF COCO validation split is \\
|
| 60 |
+
used. If 'corpus' the TF-IDF is computed over the reference set provided.
|
| 61 |
+
|
| 62 |
+
returns {"avg_score": mp.float, "scores": np.array(np.float)}
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
gts, res = _preprocess_for_cider(references, predictions)
|
| 66 |
+
tokenizer_res = SimpleTokenizer('res')
|
| 67 |
+
tokenizer_gts = SimpleTokenizer('gts')
|
| 68 |
+
|
| 69 |
+
_gts = tokenizer_gts.tokenize(gts)
|
| 70 |
+
_res = tokenizer_res.tokenize(res)
|
| 71 |
+
|
| 72 |
+
scorer = CiderD(df=df)
|
| 73 |
+
|
| 74 |
+
score, scores = scorer.compute_score(_gts, _res)
|
| 75 |
+
|
| 76 |
+
return { "avg_score": score, "scores": scores}
|
evaluation/cider_score/cidereval/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'hfang'
|
| 2 |
+
# edited by Michele Cafagna
|
| 3 |
+
from cidereval.tokenizer.ptbtokenizer import PTBTokenizer
|
| 4 |
+
from cidereval.tokenizer.simpletokenizer import SimpleTokenizer
|
evaluation/cider_score/cidereval/tokenizer/ptbtokenizer.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# File Name : ptbtokenizer.py
|
| 4 |
+
#
|
| 5 |
+
# Description : Do the PTB Tokenization and remove punctuations.
|
| 6 |
+
#
|
| 7 |
+
# Creation Date : 29-12-2014
|
| 8 |
+
# Last Modified : Thu Mar 19 09:53:35 2015
|
| 9 |
+
# Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
import tempfile
|
| 14 |
+
|
| 15 |
+
# path to the stanford corenlp jar
|
| 16 |
+
STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
|
| 17 |
+
|
| 18 |
+
# punctuations to be removed from the sentences
|
| 19 |
+
PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-",
|
| 20 |
+
".", "?", "!", ",", ":", "-", "--", "...", ";"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PTBTokenizer:
|
| 24 |
+
"""Python wrapper of Stanford PTBTokenizer"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, _source='gts'):
|
| 27 |
+
self.source = _source
|
| 28 |
+
|
| 29 |
+
def tokenize(self, captions_for_image):
|
| 30 |
+
"""Tokenize a sample
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
captions_for_image :
|
| 34 |
+
|
| 35 |
+
IF _source='gts' follows format:
|
| 36 |
+
dict: { str : [
|
| 37 |
+
{ "caption" : str },
|
| 38 |
+
{ "caption" : str },
|
| 39 |
+
...
|
| 40 |
+
],
|
| 41 |
+
str : [ ... ],
|
| 42 |
+
...
|
| 43 |
+
}
|
| 44 |
+
IF _source='res' follows format:
|
| 45 |
+
list: [ {"image_id" : str,
|
| 46 |
+
"caption" : str,
|
| 47 |
+
},
|
| 48 |
+
...
|
| 49 |
+
]
|
| 50 |
+
Returns:
|
| 51 |
+
final_tokenized_captions_for_index:
|
| 52 |
+
list: [ {"image_id" : str,
|
| 53 |
+
"caption" : str,
|
| 54 |
+
},
|
| 55 |
+
...
|
| 56 |
+
]
|
| 57 |
+
"""
|
| 58 |
+
cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR,
|
| 59 |
+
'edu.stanford.nlp.process.PTBTokenizer',
|
| 60 |
+
'-preserveLines', '-lowerCase']
|
| 61 |
+
|
| 62 |
+
# ======================================================
|
| 63 |
+
# prepare data for PTB Tokenizer
|
| 64 |
+
# ======================================================
|
| 65 |
+
|
| 66 |
+
if self.source == 'gts':
|
| 67 |
+
image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
|
| 68 |
+
sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
|
| 69 |
+
final_tokenized_captions_for_image = {}
|
| 70 |
+
|
| 71 |
+
elif self.source == 'res':
|
| 72 |
+
index = [i for i, v in enumerate(captions_for_image)]
|
| 73 |
+
image_id = [v["image_id"] for v in captions_for_image]
|
| 74 |
+
sentences = '\n'.join(v["caption"].replace('\n', ' ') for v in captions_for_image)
|
| 75 |
+
final_tokenized_captions_for_index = []
|
| 76 |
+
|
| 77 |
+
# ======================================================
|
| 78 |
+
# save sentences to temporary file
|
| 79 |
+
# ======================================================
|
| 80 |
+
path_to_jar_dir_name = os.path.dirname(os.path.abspath(__file__))
|
| 81 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dir_name, mode='w')
|
| 82 |
+
tmp_file.write(sentences)
|
| 83 |
+
tmp_file.close()
|
| 84 |
+
|
| 85 |
+
# ======================================================
|
| 86 |
+
# tokenize sentence
|
| 87 |
+
# ======================================================
|
| 88 |
+
cmd.append(os.path.basename(tmp_file.name))
|
| 89 |
+
p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dir_name, stdout=subprocess.PIPE)
|
| 90 |
+
token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0].decode("utf-8")
|
| 91 |
+
lines = token_lines.split('\n')
|
| 92 |
+
# remove temp file
|
| 93 |
+
os.remove(tmp_file.name)
|
| 94 |
+
|
| 95 |
+
# ======================================================
|
| 96 |
+
# create dictionary for tokenized captions
|
| 97 |
+
# ======================================================
|
| 98 |
+
if self.source == 'gts':
|
| 99 |
+
for k, line in zip(image_id, lines):
|
| 100 |
+
if k not in final_tokenized_captions_for_image:
|
| 101 |
+
final_tokenized_captions_for_image[k] = []
|
| 102 |
+
tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') if w not in PUNCTUATIONS])
|
| 103 |
+
final_tokenized_captions_for_image[k].append(tokenized_caption)
|
| 104 |
+
|
| 105 |
+
return final_tokenized_captions_for_image
|
| 106 |
+
|
| 107 |
+
elif self.source == 'res':
|
| 108 |
+
for k, img, line in zip(index, image_id, lines):
|
| 109 |
+
tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') if w not in PUNCTUATIONS])
|
| 110 |
+
final_tokenized_captions_for_index.append({'image_id': img, 'caption': [tokenized_caption]})
|
| 111 |
+
|
| 112 |
+
return final_tokenized_captions_for_index
|
evaluation/cider_score/cidereval/tokenizer/simpletokenizer.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# File Name : simpletokenizer.py
|
| 4 |
+
#
|
| 5 |
+
# Description : Yet another tokenizer.
|
| 6 |
+
#
|
| 7 |
+
# Creation Date : 12-11-2021
|
| 8 |
+
|
| 9 |
+
import spacy
|
| 10 |
+
from spacy.lang.char_classes import ALPHA, ALPHA_LOWER, ALPHA_UPPER
|
| 11 |
+
from spacy.lang.char_classes import CONCAT_QUOTES, LIST_ELLIPSES, LIST_ICONS
|
| 12 |
+
from spacy.util import compile_infix_regex
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# punctuations to be removed from the sentences
|
| 16 |
+
PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-",
|
| 17 |
+
".", "?", "!", ",", ":", "-", "--", "...", ";", " ", ""]
|
| 18 |
+
|
| 19 |
+
infixes = (
|
| 20 |
+
LIST_ELLIPSES
|
| 21 |
+
+ LIST_ICONS
|
| 22 |
+
+ [
|
| 23 |
+
r"(?<=[0-9])[+\-\*^](?=[0-9-])",
|
| 24 |
+
r"(?<=[{al}{q}])\.(?=[{au}{q}])".format(
|
| 25 |
+
al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES
|
| 26 |
+
),
|
| 27 |
+
r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA),
|
| 28 |
+
# ✅ Commented out regex that splits on hyphens between letters:
|
| 29 |
+
# r"(?<=[{a}])(?:{h})(?=[{a}])".format(a=ALPHA, h=HYPHENS),
|
| 30 |
+
r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=ALPHA),
|
| 31 |
+
]
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SimpleTokenizer:
|
| 36 |
+
"""Simple Tokenizer"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, _source='gts'):
|
| 39 |
+
self.source = _source
|
| 40 |
+
|
| 41 |
+
# setting up the tokenizer
|
| 42 |
+
self._nlp = spacy.load("en_core_web_sm")
|
| 43 |
+
infix_re = compile_infix_regex(infixes)
|
| 44 |
+
self._nlp.tokenizer.infix_finditer = infix_re.finditer
|
| 45 |
+
self._tokenizer = self._nlp.tokenizer
|
| 46 |
+
|
| 47 |
+
def tokenize(self, captions_for_image):
|
| 48 |
+
"""Tokenize a sample
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
captions_for_image :
|
| 52 |
+
|
| 53 |
+
IF _source='gts' follows format:
|
| 54 |
+
dict: { str : [
|
| 55 |
+
{ "caption" : str },
|
| 56 |
+
{ "caption" : str },
|
| 57 |
+
...
|
| 58 |
+
],
|
| 59 |
+
str : [ ... ],
|
| 60 |
+
...
|
| 61 |
+
}
|
| 62 |
+
IF _source='res' follows format:
|
| 63 |
+
list: [ {"image_id" : str,
|
| 64 |
+
"caption" : str,
|
| 65 |
+
},
|
| 66 |
+
...
|
| 67 |
+
]
|
| 68 |
+
Returns:
|
| 69 |
+
final_tokenized_captions_for_index:
|
| 70 |
+
list: [ {"image_id" : str,
|
| 71 |
+
"caption" : str,
|
| 72 |
+
},
|
| 73 |
+
...
|
| 74 |
+
]
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
tokenized_captions = None
|
| 78 |
+
|
| 79 |
+
if self.source == 'gts':
|
| 80 |
+
tokenized_captions= {}
|
| 81 |
+
|
| 82 |
+
for k in captions_for_image:
|
| 83 |
+
|
| 84 |
+
if k not in tokenized_captions:
|
| 85 |
+
tokenized_captions[k] = []
|
| 86 |
+
|
| 87 |
+
for item in captions_for_image[k]:
|
| 88 |
+
|
| 89 |
+
tokenized_captions[k].append(
|
| 90 |
+
" ".join([ tok.text.lower().strip() for tok in self._tokenizer(item['caption']) if tok.text.lower().strip() not in PUNCTUATIONS]))
|
| 91 |
+
|
| 92 |
+
elif self.source == 'res':
|
| 93 |
+
|
| 94 |
+
tokenized_captions= []
|
| 95 |
+
|
| 96 |
+
for item in captions_for_image:
|
| 97 |
+
|
| 98 |
+
tokenized_captions.append(
|
| 99 |
+
{ 'image_id' : item['image_id'],
|
| 100 |
+
'caption' : [" ".join([ tok.text.lower().strip() for tok in self._tokenizer(item['caption']) if tok.text.lower().strip() not in PUNCTUATIONS])]
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
else:
|
| 104 |
+
ValueError("source can be either 'gts' or 'res' ")
|
| 105 |
+
|
| 106 |
+
return tokenized_captions
|
evaluation/cider_score/output_sample.xls
ADDED
|
Binary file (190 kB). View file
|
|
|
filter_dataset.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset, DatasetDict
|
| 2 |
+
from PIL import Image, ImageFile, UnidentifiedImageError
|
| 3 |
+
import io
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 7 |
+
|
| 8 |
+
# Assuming your dataset is loaded using load_dataset
|
| 9 |
+
cache_dir = "/bask/projects/p/phwq4930-gbm/Zeyu/PathVLM/.cache"
|
| 10 |
+
dataset_name = "CNX-PathLLM/Pathcap"
|
| 11 |
+
dataset = load_dataset(dataset_name, split="train", cache_dir=cache_dir)
|
| 12 |
+
|
| 13 |
+
print(f"original dataset size: {len(dataset)}")
|
| 14 |
+
|
| 15 |
+
# keep valid indices
|
| 16 |
+
valid_indices = []
|
| 17 |
+
|
| 18 |
+
# go through and check every element
|
| 19 |
+
for idx in tqdm(range(len(dataset))):
|
| 20 |
+
try:
|
| 21 |
+
example = dataset[idx]
|
| 22 |
+
|
| 23 |
+
text = example["txt"]
|
| 24 |
+
if not isinstance(text, str):
|
| 25 |
+
raise ValueError(f"not a string: {text}")
|
| 26 |
+
valid_indices.append(idx)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Cannot recognize file {idx}: {e}")
|
| 29 |
+
|
| 30 |
+
# Select valid samples according to the indices of valid samples.
|
| 31 |
+
filtered_dataset = dataset.select(valid_indices)
|
| 32 |
+
|
| 33 |
+
# Filter out images that cannot be loaded.
|
| 34 |
+
# filtered_dataset = dataset.filter(lambda example: example["is_valid"])
|
| 35 |
+
|
| 36 |
+
# Print the size of the filtered dataset
|
| 37 |
+
print(f"filtered dataset size: {len(filtered_dataset)}")
|
| 38 |
+
|
| 39 |
+
if len(dataset) != len(filtered_dataset):
|
| 40 |
+
# convert to DatasetDict
|
| 41 |
+
filtered_dataset_dict = DatasetDict({"train": filtered_dataset})
|
| 42 |
+
# push to hub
|
| 43 |
+
filtered_dataset_dict.push_to_hub(dataset_name)
|
gigapath/__init__.py
ADDED
|
File without changes
|
gigapath/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
gigapath/__pycache__/pos_embed.cpython-310.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
gigapath/__pycache__/slide_encoder.cpython-310.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
gigapath/__pycache__/slide_encoder_vision.cpython-310.pyc
ADDED
|
Binary file (8.31 kB). View file
|
|
|
gigapath/classification_head.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
from . import slide_encoder
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def reshape_input(imgs, coords, pad_mask=None):
|
| 8 |
+
if len(imgs.shape) == 4:
|
| 9 |
+
imgs = imgs.squeeze(0)
|
| 10 |
+
if len(coords.shape) == 4:
|
| 11 |
+
coords = coords.squeeze(0)
|
| 12 |
+
if pad_mask is not None:
|
| 13 |
+
if len(pad_mask.shape) != 2:
|
| 14 |
+
pad_mask = pad_mask.squeeze(0)
|
| 15 |
+
return imgs, coords, pad_mask
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ClassificationHead(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
The classification head for the slide encoder
|
| 21 |
+
|
| 22 |
+
Arguments:
|
| 23 |
+
----------
|
| 24 |
+
input_dim: int
|
| 25 |
+
The input dimension of the slide encoder
|
| 26 |
+
latent_dim: int
|
| 27 |
+
The latent dimension of the slide encoder
|
| 28 |
+
feat_layer: str
|
| 29 |
+
The layers from which embeddings are fed to the classifier, e.g., 5-11 for taking out the 5th and 11th layers
|
| 30 |
+
n_classes: int
|
| 31 |
+
The number of classes
|
| 32 |
+
model_arch: str
|
| 33 |
+
The architecture of the slide encoder
|
| 34 |
+
pretrained: str
|
| 35 |
+
The path to the pretrained slide encoder
|
| 36 |
+
freeze: bool
|
| 37 |
+
Whether to freeze the pretrained model
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
input_dim,
|
| 43 |
+
latent_dim,
|
| 44 |
+
feat_layer,
|
| 45 |
+
n_classes=2,
|
| 46 |
+
model_arch="gigapath_slide_enc12l768d",
|
| 47 |
+
pretrained="hf_hub:prov-gigapath/prov-gigapath",
|
| 48 |
+
freeze=False,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
super(ClassificationHead, self).__init__()
|
| 52 |
+
|
| 53 |
+
# setup the slide encoder
|
| 54 |
+
self.feat_layer = [eval(x) for x in feat_layer.split("-")]
|
| 55 |
+
self.feat_dim = len(self.feat_layer) * latent_dim
|
| 56 |
+
self.slide_encoder = slide_encoder.create_model(pretrained, model_arch, in_chans=input_dim, **kwargs)
|
| 57 |
+
|
| 58 |
+
# whether to freeze the pretrained model
|
| 59 |
+
if freeze:
|
| 60 |
+
print("Freezing Pretrained GigaPath model")
|
| 61 |
+
for name, param in self.slide_encoder.named_parameters():
|
| 62 |
+
param.requires_grad = False
|
| 63 |
+
print("Done")
|
| 64 |
+
# setup the classifier
|
| 65 |
+
self.classifier = nn.Sequential(*[nn.Linear(self.feat_dim, n_classes)])
|
| 66 |
+
|
| 67 |
+
def forward(self, images: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Arguments:
|
| 70 |
+
----------
|
| 71 |
+
images: torch.Tensor
|
| 72 |
+
The input images with shape [N, L, D]
|
| 73 |
+
coords: torch.Tensor
|
| 74 |
+
The input coordinates with shape [N, L, 2]
|
| 75 |
+
"""
|
| 76 |
+
# inputs: [N, L, D]
|
| 77 |
+
if len(images.shape) == 2:
|
| 78 |
+
images = images.unsqueeze(0)
|
| 79 |
+
assert len(images.shape) == 3
|
| 80 |
+
# forward GigaPath slide encoder
|
| 81 |
+
img_enc = self.slide_encoder.forward(images, coords, all_layer_embed=True)
|
| 82 |
+
img_enc = [img_enc[i] for i in self.feat_layer]
|
| 83 |
+
img_enc = torch.cat(img_enc, dim=-1)
|
| 84 |
+
# classifier
|
| 85 |
+
h = img_enc.reshape([-1, img_enc.size(-1)])
|
| 86 |
+
logits = self.classifier(h)
|
| 87 |
+
return logits
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_model(**kwargs):
|
| 91 |
+
model = ClassificationHead(**kwargs)
|
| 92 |
+
return model
|
gigapath/pipeline.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Pipeline for running with GigaPath
|
| 3 |
+
# --------------------------------------------------------
|
| 4 |
+
import os
|
| 5 |
+
import timm
|
| 6 |
+
import torch
|
| 7 |
+
import shutil
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import gigapath.slide_encoder as slide_encoder
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
from typing import List, Tuple, Union
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
from gigapath.preprocessing.data.create_tiles_dataset import process_slide
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TileEncodingDataset(Dataset):
|
| 22 |
+
"""
|
| 23 |
+
Do encoding for tiles
|
| 24 |
+
|
| 25 |
+
Arguments:
|
| 26 |
+
----------
|
| 27 |
+
image_paths : List[str]
|
| 28 |
+
List of image paths, each image is named with its coordinates
|
| 29 |
+
Example: ['images/256x_256y.png', 'images/256x_512y.png']
|
| 30 |
+
transform : torchvision.transforms.Compose
|
| 31 |
+
Transform to apply to each image
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, image_paths: List[str], transform=None):
|
| 34 |
+
self.transform = transform
|
| 35 |
+
self.image_paths = image_paths
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.image_paths)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
img_path = self.image_paths[idx]
|
| 42 |
+
img_name = os.path.basename(img_path)
|
| 43 |
+
# get x, y coordinates from the image name
|
| 44 |
+
x, y = img_name.split('.png')[0].split('_')
|
| 45 |
+
x, y = int(x.replace('x', '')), int(y.replace('y', ''))
|
| 46 |
+
# load the image
|
| 47 |
+
with open(img_path, "rb") as f:
|
| 48 |
+
img = Image.open(f).convert("RGB")
|
| 49 |
+
if self.transform:
|
| 50 |
+
img = self.transform(img)
|
| 51 |
+
return {'img': torch.from_numpy(np.array(img)),
|
| 52 |
+
'coords': torch.from_numpy(np.array([x, y])).float()}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def tile_one_slide(slide_file:str='', save_dir:str='', level:int=0, tile_size:int=256):
|
| 56 |
+
"""
|
| 57 |
+
This function is used to tile a single slide and save the tiles to a directory.
|
| 58 |
+
-------------------------------------------------------------------------------
|
| 59 |
+
Warnings: pixman 0.38 has a known bug, which produces partial broken images.
|
| 60 |
+
Make sure to use a different version of pixman.
|
| 61 |
+
-------------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
Arguments:
|
| 64 |
+
----------
|
| 65 |
+
slide_file : str
|
| 66 |
+
The path to the slide file.
|
| 67 |
+
save_dir : str
|
| 68 |
+
The directory to save the tiles.
|
| 69 |
+
level : int
|
| 70 |
+
The magnification level to use for tiling. level=0 is the highest magnification level.
|
| 71 |
+
tile_size : int
|
| 72 |
+
The size of the tiles.
|
| 73 |
+
"""
|
| 74 |
+
slide_id = os.path.basename(slide_file)
|
| 75 |
+
# slide_sample = {"image": slide_file, "slide_id": slide_id, "metadata": {'TP53': 1, 'Diagnosis': 'Lung Cancer'}}
|
| 76 |
+
slide_sample = {"image": slide_file, "slide_id": slide_id, "metadata": {}}
|
| 77 |
+
|
| 78 |
+
save_dir = Path(save_dir)
|
| 79 |
+
if save_dir.exists():
|
| 80 |
+
print(f"Warning: Directory {save_dir} already exists. ")
|
| 81 |
+
|
| 82 |
+
print(f"Processing slide {slide_file} at level {level} with tile size {tile_size}. Saving to {save_dir}.")
|
| 83 |
+
|
| 84 |
+
slide_dir = process_slide(
|
| 85 |
+
slide_sample,
|
| 86 |
+
level=level,
|
| 87 |
+
margin=0,
|
| 88 |
+
tile_size=tile_size,
|
| 89 |
+
foreground_threshold=None,
|
| 90 |
+
occupancy_threshold=0.1,
|
| 91 |
+
output_dir=save_dir / "output",
|
| 92 |
+
thumbnail_dir=save_dir / "thumbnails",
|
| 93 |
+
tile_progress=True,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
dataset_csv_path = slide_dir / "dataset.csv"
|
| 97 |
+
dataset_df = pd.read_csv(dataset_csv_path)
|
| 98 |
+
assert len(dataset_df) > 0
|
| 99 |
+
failed_csv_path = slide_dir / "failed_tiles.csv"
|
| 100 |
+
failed_df = pd.read_csv(failed_csv_path)
|
| 101 |
+
assert len(failed_df) == 0
|
| 102 |
+
|
| 103 |
+
print(f"Slide {slide_file} has been tiled. {len(dataset_df)} tiles saved to {slide_dir}.")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def load_tile_encoder_transforms() -> transforms.Compose:
|
| 107 |
+
"""Load the transforms for the tile encoder"""
|
| 108 |
+
transform = transforms.Compose(
|
| 109 |
+
[
|
| 110 |
+
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 111 |
+
transforms.CenterCrop(224),
|
| 112 |
+
transforms.ToTensor(),
|
| 113 |
+
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 114 |
+
])
|
| 115 |
+
return transform
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def load_tile_slide_encoder(local_tile_encoder_path: str='',
|
| 119 |
+
local_slide_encoder_path: str='',
|
| 120 |
+
global_pool=False) -> Tuple[torch.nn.Module, torch.nn.Module]:
|
| 121 |
+
"""Load the GigaPath tile and slide encoder models.
|
| 122 |
+
Note: Older versions of timm have compatibility issues.
|
| 123 |
+
Please ensure that you use a newer version by running the following command: pip install timm>=1.0.3.
|
| 124 |
+
"""
|
| 125 |
+
if local_tile_encoder_path:
|
| 126 |
+
tile_encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=False, checkpoint_path=local_tile_encoder_path)
|
| 127 |
+
else:
|
| 128 |
+
tile_encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)
|
| 129 |
+
print("Tile encoder param #", sum(p.numel() for p in tile_encoder.parameters()))
|
| 130 |
+
|
| 131 |
+
if local_slide_encoder_path:
|
| 132 |
+
slide_encoder_model = slide_encoder.create_model(local_slide_encoder_path, "gigapath_slide_enc12l768d", 1536, global_pool=global_pool)
|
| 133 |
+
else:
|
| 134 |
+
slide_encoder_model = slide_encoder.create_model("hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536, global_pool=global_pool)
|
| 135 |
+
print("Slide encoder param #", sum(p.numel() for p in slide_encoder_model.parameters()))
|
| 136 |
+
|
| 137 |
+
return tile_encoder, slide_encoder_model
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def run_inference_with_tile_encoder(image_paths: List[str], tile_encoder: torch.nn.Module, batch_size: int=128) -> dict:
|
| 142 |
+
"""
|
| 143 |
+
Run inference with the tile encoder
|
| 144 |
+
|
| 145 |
+
Arguments:
|
| 146 |
+
----------
|
| 147 |
+
image_paths : List[str]
|
| 148 |
+
List of image paths, each image is named with its coordinates
|
| 149 |
+
tile_encoder : torch.nn.Module
|
| 150 |
+
Tile encoder model
|
| 151 |
+
"""
|
| 152 |
+
tile_encoder = tile_encoder.cuda()
|
| 153 |
+
# make the tile dataloader
|
| 154 |
+
tile_dl = DataLoader(TileEncodingDataset(image_paths, transform=load_tile_encoder_transforms()), batch_size=batch_size, shuffle=False)
|
| 155 |
+
# run inference
|
| 156 |
+
tile_encoder.eval()
|
| 157 |
+
collated_outputs = {'tile_embeds': [], 'coords': []}
|
| 158 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
| 159 |
+
for batch in tqdm(tile_dl, desc='Running inference with tile encoder'):
|
| 160 |
+
collated_outputs['tile_embeds'].append(tile_encoder(batch['img'].cuda()).detach().cpu())
|
| 161 |
+
collated_outputs['coords'].append(batch['coords'])
|
| 162 |
+
return {k: torch.cat(v) for k, v in collated_outputs.items()}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@torch.no_grad()
|
| 166 |
+
def run_inference_with_slide_encoder(tile_embeds: torch.Tensor, coords: torch.Tensor, slide_encoder_model: torch.nn.Module) -> torch.Tensor:
|
| 167 |
+
"""
|
| 168 |
+
Run inference with the slide encoder
|
| 169 |
+
|
| 170 |
+
Arguments:
|
| 171 |
+
----------
|
| 172 |
+
tile_embeds : torch.Tensor
|
| 173 |
+
Tile embeddings
|
| 174 |
+
coords : torch.Tensor
|
| 175 |
+
Coordinates of the tiles
|
| 176 |
+
slide_encoder_model : torch.nn.Module
|
| 177 |
+
Slide encoder model
|
| 178 |
+
"""
|
| 179 |
+
if len(tile_embeds.shape) == 2:
|
| 180 |
+
tile_embeds = tile_embeds.unsqueeze(0)
|
| 181 |
+
coords = coords.unsqueeze(0)
|
| 182 |
+
|
| 183 |
+
slide_encoder_model = slide_encoder_model.cuda()
|
| 184 |
+
slide_encoder_model.eval()
|
| 185 |
+
# run inference
|
| 186 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
| 187 |
+
slide_embeds = slide_encoder_model(tile_embeds.cuda(), coords.cuda(), all_layer_embed=True)
|
| 188 |
+
outputs = {"layer_{}_embed".format(i): slide_embeds[i].cpu() for i in range(len(slide_embeds))}
|
| 189 |
+
outputs["last_layer_embed"] = slide_embeds[-1].cpu()
|
| 190 |
+
return outputs
|
gigapath/pos_embed.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 10 |
+
# MAE: https://github.com/facebookresearch/mae
|
| 11 |
+
# --------------------------------------------------------
|
| 12 |
+
#
|
| 13 |
+
# Portions Copyright Prov-GigaPath
|
| 14 |
+
# Original File: https://github.com/facebookresearch/mae
|
| 15 |
+
# --------------------------------------------------------
|
| 16 |
+
# Position embedding utils
|
| 17 |
+
# --------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# --------------------------------------------------------
|
| 25 |
+
# 2D sine-cosine position embedding
|
| 26 |
+
# References:
|
| 27 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 28 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 29 |
+
# --------------------------------------------------------
|
| 30 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 31 |
+
"""
|
| 32 |
+
grid_size: int of the grid height and width
|
| 33 |
+
return:
|
| 34 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 35 |
+
"""
|
| 36 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 37 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 38 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 39 |
+
grid = np.stack(grid, axis=0)
|
| 40 |
+
|
| 41 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 42 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 43 |
+
if cls_token:
|
| 44 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 45 |
+
return pos_embed
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 49 |
+
assert embed_dim % 2 == 0
|
| 50 |
+
|
| 51 |
+
# use half of dimensions to encode grid_h
|
| 52 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 53 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 54 |
+
|
| 55 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 56 |
+
return emb
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 60 |
+
"""
|
| 61 |
+
embed_dim: output dimension for each position
|
| 62 |
+
pos: a list of positions to be encoded: size (M,)
|
| 63 |
+
out: (M, D)
|
| 64 |
+
"""
|
| 65 |
+
assert embed_dim % 2 == 0
|
| 66 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 67 |
+
omega /= embed_dim / 2.0
|
| 68 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 69 |
+
|
| 70 |
+
pos = pos.reshape(-1) # (M,)
|
| 71 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 72 |
+
|
| 73 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 74 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 75 |
+
|
| 76 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 77 |
+
return emb
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# --------------------------------------------------------
|
| 81 |
+
# Interpolate position embeddings for high-resolution
|
| 82 |
+
# References:
|
| 83 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 84 |
+
# --------------------------------------------------------
|
| 85 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 86 |
+
if "pos_embed" in checkpoint_model:
|
| 87 |
+
pos_embed_checkpoint = checkpoint_model["pos_embed"]
|
| 88 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 89 |
+
num_patches = model.patch_embed.num_patches
|
| 90 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 91 |
+
# height (== width) for the checkpoint position embedding
|
| 92 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 93 |
+
# height (== width) for the new position embedding
|
| 94 |
+
new_size = int(num_patches**0.5)
|
| 95 |
+
# class_token and dist_token are kept unchanged
|
| 96 |
+
if orig_size != new_size:
|
| 97 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 98 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 99 |
+
# only the position tokens are interpolated
|
| 100 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 101 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 102 |
+
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
|
| 103 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 104 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 105 |
+
checkpoint_model["pos_embed"] = new_pos_embed
|
gigapath/preprocessing/__init__.py
ADDED
|
File without changes
|
gigapath/preprocessing/data/__init__.py
ADDED
|
File without changes
|
gigapath/preprocessing/data/box_utils.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
| 3 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
| 4 |
+
#
|
| 5 |
+
# Original: https://github.com/microsoft/hi-ml/blob/main/hi-ml/src/health_ml/utils/box_utils.py
|
| 6 |
+
# ------------------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional, Sequence, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from scipy import ndimage
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class Box:
|
| 17 |
+
"""Utility class representing rectangular regions in 2D images.
|
| 18 |
+
|
| 19 |
+
:param x: Horizontal coordinate of the top-left corner.
|
| 20 |
+
:param y: Vertical coordinate of the top-left corner.
|
| 21 |
+
:param w: Box width.
|
| 22 |
+
:param h: Box height.
|
| 23 |
+
:raises ValueError: If either `w` or `h` are <= 0.
|
| 24 |
+
"""
|
| 25 |
+
x: int
|
| 26 |
+
y: int
|
| 27 |
+
w: int
|
| 28 |
+
h: int
|
| 29 |
+
|
| 30 |
+
def __post_init__(self) -> None:
|
| 31 |
+
if self.w <= 0:
|
| 32 |
+
raise ValueError(f"Width must be strictly positive, received {self.w}")
|
| 33 |
+
if self.h <= 0:
|
| 34 |
+
raise ValueError(f"Height must be strictly positive, received {self.w}")
|
| 35 |
+
|
| 36 |
+
def __add__(self, shift: Sequence[int]) -> 'Box':
|
| 37 |
+
"""Translates the box's location by a given shift.
|
| 38 |
+
|
| 39 |
+
:param shift: A length-2 sequence containing horizontal and vertical shifts.
|
| 40 |
+
:return: A new box with updated `x = x + shift[0]` and `y = y + shift[1]`.
|
| 41 |
+
:raises ValueError: If `shift` does not have two elements.
|
| 42 |
+
"""
|
| 43 |
+
if len(shift) != 2:
|
| 44 |
+
raise ValueError("Shift must be two-dimensional")
|
| 45 |
+
return Box(x=self.x + shift[0],
|
| 46 |
+
y=self.y + shift[1],
|
| 47 |
+
w=self.w,
|
| 48 |
+
h=self.h)
|
| 49 |
+
|
| 50 |
+
def __mul__(self, factor: float) -> 'Box':
|
| 51 |
+
"""Scales the box by a given factor, e.g. when changing resolution.
|
| 52 |
+
|
| 53 |
+
:param factor: The factor by which to multiply the box's location and dimensions.
|
| 54 |
+
:return: The updated box, with location and dimensions rounded to `int`.
|
| 55 |
+
"""
|
| 56 |
+
return Box(x=int(self.x * factor),
|
| 57 |
+
y=int(self.y * factor),
|
| 58 |
+
w=int(self.w * factor),
|
| 59 |
+
h=int(self.h * factor))
|
| 60 |
+
|
| 61 |
+
def __rmul__(self, factor: float) -> 'Box':
|
| 62 |
+
"""Scales the box by a given factor, e.g. when changing resolution.
|
| 63 |
+
|
| 64 |
+
:param factor: The factor by which to multiply the box's location and dimensions.
|
| 65 |
+
:return: The updated box, with location and dimensions rounded to `int`.
|
| 66 |
+
"""
|
| 67 |
+
return self * factor
|
| 68 |
+
|
| 69 |
+
def __truediv__(self, factor: float) -> 'Box':
|
| 70 |
+
"""Scales the box by a given factor, e.g. when changing resolution.
|
| 71 |
+
|
| 72 |
+
:param factor: The factor by which to divide the box's location and dimensions.
|
| 73 |
+
:return: The updated box, with location and dimensions rounded to `int`.
|
| 74 |
+
"""
|
| 75 |
+
return self * (1. / factor)
|
| 76 |
+
|
| 77 |
+
def add_margin(self, margin: int) -> 'Box':
|
| 78 |
+
"""Adds a symmetric margin on all sides of the box.
|
| 79 |
+
|
| 80 |
+
:param margin: The amount by which to enlarge the box.
|
| 81 |
+
:return: A new box enlarged by `margin` on all sides.
|
| 82 |
+
"""
|
| 83 |
+
return Box(x=self.x - margin,
|
| 84 |
+
y=self.y - margin,
|
| 85 |
+
w=self.w + 2 * margin,
|
| 86 |
+
h=self.h + 2 * margin)
|
| 87 |
+
|
| 88 |
+
def clip(self, other: 'Box') -> Optional['Box']:
|
| 89 |
+
"""Clips a box to the interior of another.
|
| 90 |
+
|
| 91 |
+
This is useful to constrain a region to the interior of an image.
|
| 92 |
+
|
| 93 |
+
:param other: Box representing the new constraints.
|
| 94 |
+
:return: A new constrained box, or `None` if the boxes do not overlap.
|
| 95 |
+
"""
|
| 96 |
+
x0 = max(self.x, other.x)
|
| 97 |
+
y0 = max(self.y, other.y)
|
| 98 |
+
x1 = min(self.x + self.w, other.x + other.w)
|
| 99 |
+
y1 = min(self.y + self.h, other.y + other.h)
|
| 100 |
+
try:
|
| 101 |
+
return Box(x=x0, y=y0, w=x1 - x0, h=y1 - y0)
|
| 102 |
+
except ValueError: # Empty result, boxes don't overlap
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
def to_slices(self) -> Tuple[slice, slice]:
|
| 106 |
+
"""Converts the box to slices for indexing arrays.
|
| 107 |
+
|
| 108 |
+
For example: `my_2d_array[my_box.to_slices()]`.
|
| 109 |
+
|
| 110 |
+
:return: A 2-tuple with vertical and horizontal slices.
|
| 111 |
+
"""
|
| 112 |
+
return (slice(self.y, self.y + self.h),
|
| 113 |
+
slice(self.x, self.x + self.w))
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def from_slices(slices: Sequence[slice]) -> 'Box':
|
| 117 |
+
"""Converts a pair of vertical and horizontal slices into a box.
|
| 118 |
+
|
| 119 |
+
:param slices: A length-2 sequence containing vertical and horizontal `slice` objects.
|
| 120 |
+
:return: A box with corresponding location and dimensions.
|
| 121 |
+
"""
|
| 122 |
+
vert_slice, horz_slice = slices
|
| 123 |
+
return Box(x=horz_slice.start,
|
| 124 |
+
y=vert_slice.start,
|
| 125 |
+
w=horz_slice.stop - horz_slice.start,
|
| 126 |
+
h=vert_slice.stop - vert_slice.start)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_bounding_box(mask: np.ndarray) -> Box:
|
| 130 |
+
"""Extracts a bounding box from a binary 2D array.
|
| 131 |
+
|
| 132 |
+
:param mask: A 2D array with 0 (or `False`) as background and >0 (or `True`) as foreground.
|
| 133 |
+
:return: The smallest box covering all non-zero elements of `mask`.
|
| 134 |
+
:raises TypeError: When the input mask has more than two dimensions.
|
| 135 |
+
:raises RuntimeError: When all elements in the mask are zero.
|
| 136 |
+
"""
|
| 137 |
+
if mask.ndim != 2:
|
| 138 |
+
raise TypeError(f"Expected a 2D array but got an array with shape {mask.shape}")
|
| 139 |
+
|
| 140 |
+
slices = ndimage.find_objects(mask > 0)
|
| 141 |
+
if not slices:
|
| 142 |
+
raise RuntimeError("The input mask is empty")
|
| 143 |
+
assert len(slices) == 1
|
| 144 |
+
|
| 145 |
+
return Box.from_slices(slices[0])
|