English
weiheng-1009 commited on
Commit
cbff41a
·
1 Parent(s): 958e3c5

added code for running

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +5 -0
  3. README.md +83 -3
  4. accelerate_configs/.ipynb_checkpoints/deepspeed_zero2-checkpoint.yaml +22 -0
  5. accelerate_configs/.ipynb_checkpoints/deepspeed_zero3-checkpoint.yaml +23 -0
  6. accelerate_configs/deepspeed_zero1.yaml +20 -0
  7. accelerate_configs/deepspeed_zero2.yaml +23 -0
  8. accelerate_configs/deepspeed_zero3.yaml +23 -0
  9. accelerate_configs/multi_gpu.yaml +16 -0
  10. accelerate_configs/single_gpu.yaml +16 -0
  11. data/huggingface_data.py +31 -0
  12. dataset_csv/gtex_slide_url_info.csv +0 -0
  13. dataset_csv/indices_and_slide_ids.csv +0 -0
  14. dataset_csv/indices_and_slide_ids_with_folds.csv +0 -0
  15. dataset_csv/tcga_slide_url_info.csv +0 -0
  16. demo/CONCH_clip.py +11 -0
  17. demo/Trainer_Mixtrial_ds_demo.py +154 -0
  18. demo/Trainer_bert_demo.py +76 -0
  19. demo/UNI_clip.py +33 -0
  20. demo/path_clip.py +28 -0
  21. demo/peft_demo.py +17 -0
  22. demo/trl_demo.py +175 -0
  23. evaluation/cider_score/cider_demo.ipynb +290 -0
  24. evaluation/cider_score/cidereval/__init__.py +5 -0
  25. evaluation/cider_score/cidereval/cider/__init__.py +1 -0
  26. evaluation/cider_score/cidereval/cider/cider.py +67 -0
  27. evaluation/cider_score/cidereval/cider/cider_scorer.py +274 -0
  28. evaluation/cider_score/cidereval/ciderD/__init__.py +1 -0
  29. evaluation/cider_score/cidereval/ciderD/ciderD.py +57 -0
  30. evaluation/cider_score/cidereval/ciderD/ciderD_scorer.py +265 -0
  31. evaluation/cider_score/cidereval/data/__init__.py +0 -0
  32. evaluation/cider_score/cidereval/data/coco-val.p +3 -0
  33. evaluation/cider_score/cidereval/eval.py +40 -0
  34. evaluation/cider_score/cidereval/scorers.py +76 -0
  35. evaluation/cider_score/cidereval/tokenizer/__init__.py +4 -0
  36. evaluation/cider_score/cidereval/tokenizer/ptbtokenizer.py +112 -0
  37. evaluation/cider_score/cidereval/tokenizer/simpletokenizer.py +106 -0
  38. evaluation/cider_score/output_sample.xls +0 -0
  39. filter_dataset.py +43 -0
  40. gigapath/__init__.py +0 -0
  41. gigapath/__pycache__/__init__.cpython-310.pyc +0 -0
  42. gigapath/__pycache__/pos_embed.cpython-310.pyc +0 -0
  43. gigapath/__pycache__/slide_encoder.cpython-310.pyc +0 -0
  44. gigapath/__pycache__/slide_encoder_vision.cpython-310.pyc +0 -0
  45. gigapath/classification_head.py +92 -0
  46. gigapath/pipeline.py +190 -0
  47. gigapath/pos_embed.py +105 -0
  48. gigapath/preprocessing/__init__.py +0 -0
  49. gigapath/preprocessing/data/__init__.py +0 -0
  50. gigapath/preprocessing/data/box_utils.py +145 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.p filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Conch_Llama3ins_Instruct3/
2
+ Conch_Mistral_Instruct3/
3
+ output/
4
+ logs/
5
+ wandb/
README.md CHANGED
@@ -1,3 +1,83 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PathLLM
2
+
3
+ Welcome to ALPaCA. This repository aims to provide a straightforward reproduction of ALPaCA. To run ALPaCA, please first download **Llama3.1-8b-instruct** as the base model.
4
+
5
+ For data from TCGA and GTEx, you can visit the [GDC Data Portal Homepage](https://portal.gdc.cancer.gov/) and [GTEx Portal](https://www.gtexportal.org/) to download and extract features yourself. Alternatively, you can use the features we have already extracted based on CONCH: `CNX-PathLLM/GMM_Embeddings` and `CNX-PathLLM/GTEx-TCGA-Embeddings`. After downloading, please unzip them into the respective folders for `TCGA-Embedding` and `GMM_Embedding`.
6
+
7
+ Please ensure you have access to the pathological image description data:
8
+ `CNX-PathLLM/TCGA-WSI-Description-4onew`, `CNX-PathLLM/TCGA-WSI-Description-4omini`, and `CNX-PathLLM/GTEx-WSI-Description`.
9
+
10
+ Please ensure you also have access to the WSI-QA data:
11
+ `CNX-PathLLM/TCGA-WSI-CloseQA-Balanced`, `CNX-PathLLM/GTEx-WSI-CloseQA-Balanced`, `CNX-PathLLM/TCGA-WSI-OpenQA`, and `CNX-PathLLM/GTEx-WSI-OpenQA`.
12
+
13
+ After completing all the setups mentioned above and setting up the correct Python environment, you can start the training process using the provided shell script, e.g., `run_wsi_stage*.sh`, or follow the instructions in the [Train Step](#train-step-1) section below.
14
+
15
+ Do not forget to adjust the TCGA and GMM embedding paths to reflect your own file locations.
16
+
17
+
18
+ ## Settings
19
+
20
+ ### Different Aggregate Strategies
21
+ You can change aggregate strategies using the `--agg_strategy` flag, such as `qformer`, `abmil`, or `longnet`. You can also reproduce the method described in our paper by setting `--agg_strategy gmm,longnet` in the `.sh` script.
22
+
23
+ ### Configurable Settings
24
+
25
+ (1) `--vision_adaptor False --hierarchical_adaptor True`
26
+
27
+ (2) `--vision_adaptor False --hierarchical_adaptor False`
28
+
29
+ (3) `--vision_adaptor True --hierarchical_adaptor True`
30
+
31
+ ```
32
+ --vision_adaptor False (vision-query-question interaction)
33
+ --vision_adaptor True (vision-query interaction)
34
+
35
+ --hierarchical_adaptor False (same adaptor for all levels)
36
+ --hierarchical_adaptor True (different adaptors for different levels)
37
+ ```
38
+
39
+ ## Train Step 1 ##
40
+ ```
41
+ accelerate launch --config_file=./accelerate_configs/deepspeed_zero2.yaml run_wsi.py --learning_rate 1e-4 --max_steps 10000 --warmup_steps 100\
42
+ --gpu 2 --train_batch_size 4 --eval_batch_size 2 --max_seq_length 512 \
43
+ --agg_strategy gmm,longnet --embed_dim 512 --vision_adaptor False --hierachical_token True --hierachical_adaptor True\
44
+ --n_heads 32,16,8 --llm_requires_grad False --resume_from_checkpoint False \
45
+ --llm_name /data_local/pxb/LLM_models/llama3/llama3.1-8b-instruct \
46
+ --dataset_name_list CNX-PathLLM/TCGA-WSI-Description-4onew,CNX-PathLLM/TCGA-WSI-Description-4omini,CNX-PathLLM/GTEx-WSI-Description \
47
+ --data_cache_dir /data_local/pxb/CNX-PathLLM/.cache \
48
+ --fea_root /path/to/CNX-PathLLM/GTEx-TCGA-Embeddings \
49
+ --gmm_root /path/to/GMM_Embeddings\
50
+ --output_dir path/to/output/of/step2
51
+ ```
52
+
53
+ ## Train Step 2 ##
54
+ ```
55
+ accelerate launch --config_file=./accelerate_configs/deepspeed_zero2.yaml run_wsi.py --max_steps 20000 --warmup_steps 10\
56
+ --gpu 2 --train_batch_size 8 --eval_batch_size 2 --max_seq_length 256 \
57
+ --agg_strategy gmm,longnet --embed_dim 512 --vision_adaptor False --hierachical_token True --hierachical_adaptor True\
58
+ --n_heads 32,16,8 --llm_requires_grad True --resume_from_checkpoint False \
59
+ --llm_name /data_local/pxb/LLM_models/llama3/llama3.1-8b-instruct \
60
+ --dataset_name_list CNX-PathLLM/TCGA-WSI-CloseQA-Balanced,CNX-PathLLM/GTEx-WSI-CloseQA-Balanced,CNX-PathLLM/TCGA-WSI-OpenQA,CNX-PathLLM/GTEx-WSI-OpenQA \
61
+ --data_cache_dir /data_local/pxb/CNX-PathLLM/.cache \
62
+ --fea_root /path/to/CNX-PathLLM/GTEx-TCGA-Embeddings \
63
+ --gmm_root /path/to/GMM_Embeddings\
64
+ --output_dir path/to/output/of/step2\
65
+ --ckpt_path path/to/ckpt.bin/of/step1
66
+ ```
67
+
68
+ ## Train Step 3 ##
69
+ To continue training with the specific detailed BRCA dataset! Make sure you can access the dataset and change above command with the dataset you want.
70
+
71
+ ## Test of Step2 General QA ##
72
+
73
+ ```
74
+ python test_wsi.py --max_seq_length 128 --batch_size 1 --select_data_num -1 --eval_sample_size -1 --n_heads 32,16,8 --llm_name /data_local/pxb/LLM_models/llama3/llama3.1-8b-instruct --vision_adaptor False --hierachical_token True --hierachical_adaptor True \
75
+ --shuffle False --data_cache_dir /data_local/pxb/CNX-PathLLM/.cache\
76
+ --dataset_name_list CNX-PathLLM/TCGA-WSI-CloseQA-Balanced,CNX-PathLLM/GTEx-WSI-CloseQA-Balanced,CNX-PathLLM/TCGA-WSI-OpenQA,CNX-PathLLM/GTEx-WSI-OpenQA\
77
+ --agg_strategy gmm,longnet --embed_dim 512\
78
+ --fea_root /path/to/CNX-PathLLM/GTEx-TCGA-Embeddings \
79
+ --gmm_root /path/to/GMM_Embeddings\
80
+ --ckpt_path path/to/ckpt.bin/of/step2\
81
+ --results_save_path /path/to/the/output.csv\
82
+ --use_peft False
83
+ ```
accelerate_configs/.ipynb_checkpoints/deepspeed_zero2-checkpoint.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: "auto"
6
+ offload_optimizer_device: "cpu"
7
+ offload_param_device: "cpu"
8
+ zero3_init_flag: false
9
+ zero_stage: 2
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'auto'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: 'bf16'
15
+ num_machines: 1
16
+ num_processes: 2
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
accelerate_configs/.ipynb_checkpoints/deepspeed_zero3-checkpoint.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: 8
6
+ offload_optimizer_device: "cpu"
7
+ offload_param_device: "cpu"
8
+ zero3_init_flag: true
9
+ zero3_save_16bit_model: true
10
+ zero_stage: 3
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'auto'
13
+ machine_rank: 0
14
+ main_training_function: main
15
+ mixed_precision: 'bf16'
16
+ num_machines: 1
17
+ num_processes: 2
18
+ rdzv_backend: static
19
+ same_network: true
20
+ tpu_env: []
21
+ tpu_use_cluster: false
22
+ tpu_use_sudo: false
23
+ use_cpu: false
accelerate_configs/deepspeed_zero1.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: "auto"
6
+ zero3_init_flag: false
7
+ zero_stage: 1
8
+ distributed_type: DEEPSPEED
9
+ downcast_bf16: 'no'
10
+ machine_rank: 0
11
+ main_training_function: main
12
+ mixed_precision: 'bf16'
13
+ num_machines: 1
14
+ num_processes: 2
15
+ rdzv_backend: static
16
+ same_network: true
17
+ tpu_env: []
18
+ tpu_use_cluster: false
19
+ tpu_use_sudo: false
20
+ use_cpu: false
accelerate_configs/deepspeed_zero2.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: "auto"
6
+ offload_optimizer_device: "cpu"
7
+ offload_param_device: "cpu"
8
+ zero3_init_flag: false
9
+ zero_stage: 2
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'auto'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: 'bf16'
15
+ num_machines: 1
16
+ num_processes: 2
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
23
+ main_process_port: 29502
accelerate_configs/deepspeed_zero3.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: 8
6
+ offload_optimizer_device: "cpu"
7
+ offload_param_device: "cpu"
8
+ zero3_init_flag: true
9
+ zero3_save_16bit_model: true
10
+ zero_stage: 3
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'auto'
13
+ machine_rank: 0
14
+ main_training_function: main
15
+ mixed_precision: 'bf16'
16
+ num_machines: 1
17
+ num_processes: 2
18
+ rdzv_backend: static
19
+ same_network: true
20
+ tpu_env: []
21
+ tpu_use_cluster: false
22
+ tpu_use_sudo: false
23
+ use_cpu: false
accelerate_configs/multi_gpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: 'bf16'
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
accelerate_configs/single_gpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: "NO"
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: 'bf16'
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
data/huggingface_data.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import datasets
3
+ import re
4
+ import os
5
+ import shutil
6
+ #分割训练集、测试集、验证集
7
+ splits = ["test","train","val"]
8
+ #分数据集处理
9
+ for item in splits:
10
+ os.makedirs(f"our_clean/{item}/", exist_ok=True)
11
+ data = pd.read_csv(f"{item}.csv")
12
+ data["image_path"] = data["image_path"].map(lambda x:x.split("/")[-1])
13
+ #简单清洗文本内容
14
+ f = lambda x: re.sub(' +', ' ', str(x).lower()).replace(" ?", "?").strip()
15
+ #huggingface要求需要包含file_name作为键值
16
+ data.insert(0, "file_name", "")
17
+ data["question"] = data["question"].apply(f)
18
+ data["answer"] = data["answer"].apply(f)
19
+ #实现图文对应
20
+ for i, row in data.iterrows():
21
+ file_name = f"img_{i}.jpg"
22
+ data["file_name"].iloc[i] = file_name
23
+ shutil.copyfile(src=f"author-folder/pvqa/pvqa/images/{item}/{row['image']}.jpg", dst=f"our_clean/{item}/{file_name}")
24
+ ##删除无关行
25
+ _ = data.pop("image")
26
+ data.drop(["pathology","image_path"],axis=1,inplace=True)
27
+ data.to_csv(f"our_clean/{item}/metadata.csv", index=False)
28
+ #创建imagefolder格式的数据集,data_dir为存放数据的文件夹,可以参考https://huggingface.co/docs/datasets/en/image_load
29
+ dataset = datasets.load_dataset("imagefolder", data_dir="our_clean/")
30
+ #发布数据
31
+ dataset.push_to_hub("CNX-PathLLM/PVQAClean")
dataset_csv/gtex_slide_url_info.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset_csv/indices_and_slide_ids.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset_csv/indices_and_slide_ids_with_folds.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset_csv/tcga_slide_url_info.csv ADDED
The diff for this file is too large to render. See raw diff
 
demo/CONCH_clip.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from conch.open_clip_custom import create_model_from_pretrained
2
+ import torch
3
+ from PIL import Image
4
+
5
+ model, preprocess = create_model_from_pretrained('conch_ViT-B-16', "/raid/hpc/hekai/WorkShop/My_project/PathLLM_new/load_weights/conch/pytorch_model.bin")
6
+
7
+ image = Image.open("/bask/homes/a/asiw9691/PathVLM/source/Flamingo/med-flamingo/img/test_path5.jpg")
8
+ image = preprocess(image).unsqueeze(0)
9
+ with torch.inference_mode():
10
+ image_embs = model.encode_image(image, proj_contrast=False, normalize=False)
11
+ print(image_embs.shape)
demo/Trainer_Mixtrial_ds_demo.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # accelerate launch --config_file=/raid/hpc/hekai/WorkShop/My_project/PathLLM_new/accelerate_configs/deepspeed_zero2.yaml Trainer_Mixtrial_demo.py
3
+
4
+ import os
5
+ os.environ["WANDB_MODE"] = "offline"
6
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
7
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "5,6"
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
9
+
10
+ import torch
11
+ from torch import nn
12
+ from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, Trainer, PreTrainedModel
13
+ from accelerate import Accelerator
14
+ from datasets import load_dataset
15
+ from typing import Optional
16
+ from dataclasses import dataclass, field
17
+
18
+
19
+ @dataclass
20
+ class ScriptArguments:
21
+ """
22
+ The name of the Casual LM model we wish to fine with SFTTrainer
23
+ """
24
+
25
+ model_name: Optional[str] = field(default="mistralai/Mistral-7B-Instruct-v0.2", metadata={"help": "the model name, meta-llama/Llama-2-7b-chat-hf "})
26
+ dataset_name: Optional[str] = field(default="stingning/ultrachat", metadata={"help": "the dataset name"})
27
+ dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
28
+ log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
29
+ learning_rate: Optional[float] = field(default=2.0e-5, metadata={"help": "the learning rate"})
30
+ batch_size: Optional[int] = field(default=1, metadata={"help": "the batch size"})
31
+ seq_length: Optional[int] = field(default=1024, metadata={"help": "Input sequence length"})
32
+ gradient_accumulation_steps: Optional[int] = field(default=8, metadata={"help": "the number of gradient accumulation steps"})
33
+
34
+ evaluation_strategy: Optional[str] = field(default="steps", metadata={"help": "epoch, step"})
35
+ eval_steps: Optional[int] = field(default=1000, metadata={"help": "the number of gradient accumulation steps"})
36
+
37
+ load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
38
+ load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "load the model in 4 bits precision"})
39
+ use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
40
+ trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
41
+
42
+ output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
43
+ peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
44
+ peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
45
+ logging_steps: Optional[int] = field(default=5, metadata={"help": "the number of logging steps"})
46
+ token: Optional[bool] = field(default="True", metadata={"help": "Use HF auth token to access the model"})
47
+ num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
48
+ max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
49
+ save_steps: Optional[int] = field(default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"})
50
+ save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
51
+ push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
52
+ hub_model_id: Optional[str] = field(default="mistral-7b-finetuned-ultrachat", metadata={"help": "The name of the model on HF Hub"}
53
+ )
54
+
55
+ parser = HfArgumentParser(ScriptArguments)
56
+ script_args = parser.parse_args_into_dataclasses()[0]
57
+
58
+
59
+
60
+
61
+ class MyCustomModel(nn.Module):
62
+ def __init__(self, script_args, num_labels):
63
+ super(MyCustomModel, self).__init__()
64
+ self.num_labels = num_labels
65
+ self.pretrained_model = AutoModelForCausalLM.from_pretrained(script_args.model_name,
66
+ quantization_config=quantization_config,
67
+ device_map=device_map,
68
+ trust_remote_code=script_args.trust_remote_code,
69
+ torch_dtype=torch_dtype,
70
+ token=script_args.token)
71
+ self.classifier = nn.Linear(self.pretrained_model.config.hidden_size, num_labels)
72
+
73
+ def forward(self, input_ids, attention_mask=None, labels=None):
74
+ outputs = self.pretrained_model(input_ids, attention_mask=attention_mask)
75
+ sequence_output = outputs.last_hidden_state[:,0,:]
76
+ logits = self.classifier(sequence_output)
77
+
78
+ loss = None
79
+ if labels is not None:
80
+ loss_fct = nn.CrossEntropyLoss()
81
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
82
+
83
+ return {"loss": loss, "logits": logits} if loss is not None else logits
84
+
85
+
86
+ dataset = load_dataset("glue", "mrpc")
87
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
88
+ tokenizer.pad_token = tokenizer.eos_token
89
+
90
+ def preprocess_function(examples):
91
+ # Tokenize the inputs (pair of sentences)
92
+ return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding=True, max_length=10)
93
+
94
+
95
+
96
+
97
+ from transformers import DataCollatorWithPadding
98
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
99
+
100
+ small_train_dataset = dataset["train"].shuffle(seed=42).select(range(500)) # 选择前500个样本
101
+ small_train_dataset = small_train_dataset.map(preprocess_function, batched=True)
102
+
103
+
104
+ if script_args.load_in_8bit and script_args.load_in_4bit:
105
+ raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
106
+ elif script_args.load_in_8bit or script_args.load_in_4bit:
107
+ quantization_config = BitsAndBytesConfig(
108
+ load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
109
+ )
110
+ # Copy the model to each device
111
+ device_map = {"": Accelerator().local_process_index}
112
+ torch_dtype = torch.bfloat16
113
+ else:
114
+ device_map = None
115
+ quantization_config = None
116
+ torch_dtype = None
117
+
118
+
119
+ model = MyCustomModel(script_args, num_labels=2)
120
+
121
+
122
+
123
+ training_args = TrainingArguments(
124
+ output_dir=script_args.output_dir,
125
+ per_device_train_batch_size=script_args.batch_size,
126
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
127
+ # gradient_checkpointing=True,
128
+ learning_rate=script_args.learning_rate,
129
+ logging_steps=script_args.logging_steps,
130
+ num_train_epochs=script_args.num_train_epochs,
131
+ max_steps=script_args.max_steps,
132
+ report_to=script_args.log_with,
133
+ save_steps=script_args.save_steps,
134
+ save_total_limit=script_args.save_total_limit,
135
+ bf16=True,
136
+ lr_scheduler_type="cosine",
137
+ warmup_ratio=0.1,
138
+ evaluation_strategy=script_args.evaluation_strategy,
139
+ eval_steps=script_args.eval_steps,
140
+ logging_first_step=True,
141
+ )
142
+
143
+
144
+ trainer = Trainer(
145
+ model=model,
146
+ args=training_args,
147
+ train_dataset=small_train_dataset,
148
+ data_collator=data_collator,
149
+ compute_metrics=None,
150
+ )
151
+
152
+
153
+ trainer.train()
154
+ # model.save_pretrained("./my_custom_model")
demo/Trainer_bert_demo.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import Trainer, TrainingArguments
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from datasets import load_dataset
10
+
11
+ # 自定义模型,继承自nn.Module或者transformers提供的预训练模型类
12
+ class MyCustomModel(nn.Module):
13
+ def __init__(self, num_labels):
14
+ super(MyCustomModel, self).__init__()
15
+ self.num_labels = num_labels
16
+ self.pretrained_model = AutoModel.from_pretrained("bert-base-uncased")
17
+ self.classifier = nn.Linear(self.pretrained_model.config.hidden_size, num_labels)
18
+
19
+ def forward(self, input_ids, attention_mask=None, labels=None):
20
+ outputs = self.pretrained_model(input_ids, attention_mask=attention_mask)
21
+ sequence_output = outputs[1]
22
+ logits = self.classifier(sequence_output)
23
+
24
+ loss = None
25
+ if labels is not None:
26
+ loss_fct = nn.CrossEntropyLoss()
27
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
28
+
29
+ return {"loss": loss, "logits": logits} if loss is not None else logits
30
+
31
+ # 加载数据集并预处理
32
+ dataset = load_dataset("glue", "mrpc")
33
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
34
+
35
+ def preprocess_function(examples):
36
+ # Tokenize the inputs (pair of sentences)
37
+ return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding=True)
38
+
39
+ from transformers import DataCollatorWithPadding
40
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
41
+
42
+ small_train_dataset = dataset["train"].shuffle(seed=42).select(range(500)) # 选择前500个样本
43
+ small_train_dataset = small_train_dataset.map(preprocess_function, batched=True)
44
+
45
+ for i in small_train_dataset:
46
+ print(i)
47
+
48
+
49
+ # 自定义模型实例化
50
+ model = MyCustomModel(num_labels=2).to("cuda")
51
+
52
+ # 定义训练参数
53
+ training_args = TrainingArguments(
54
+ output_dir="./results",
55
+ num_train_epochs=3,
56
+ per_device_train_batch_size=8,
57
+ warmup_steps=500,
58
+ weight_decay=0.01,
59
+ logging_dir='./logs',
60
+ logging_steps=10,
61
+ )
62
+
63
+ # 初始化Trainer
64
+ trainer = Trainer(
65
+ model=model,
66
+ args=training_args,
67
+ train_dataset=small_train_dataset,
68
+ data_collator=data_collator,
69
+ compute_metrics=None, # 如果需要可以添加计算指标的函数
70
+ )
71
+
72
+ # 训练模型
73
+ trainer.train()
74
+
75
+ # 保存模型
76
+ model.save_pretrained("./my_custom_model")
demo/UNI_clip.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision import transforms
4
+ import timm
5
+ from huggingface_hub import login, hf_hub_download
6
+
7
+ # login() # login with your User Access Token, found at https://huggingface.co/settings/tokens
8
+
9
+ local_dir = "/bask/homes/a/asiw9691/PathVLM/UNI"
10
+ # os.makedirs(local_dir, exist_ok=True) # create directory if it does not exist
11
+ # hf_hub_download("MahmoodLab/UNI", filename="pytorch_model.bin", local_dir=local_dir, force_download=True)
12
+ model = timm.create_model("vit_large_patch16_224", img_size=224, patch_size=16, init_values=1e-5, num_classes=0, dynamic_img_size=True)
13
+ model.load_state_dict(torch.load(os.path.join(local_dir, "pytorch_model.bin"), map_location="cpu"), strict=True)
14
+ transform = transforms.Compose(
15
+ [
16
+ # transforms.Resize(224),
17
+ transforms.Resize(256), # 先将最短边调整到256像素
18
+ transforms.CenterCrop(224), # 然后从中心裁剪出224x224像素的图像
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
21
+ transforms.Lambda(lambda x: x.unsqueeze(0))
22
+ ]
23
+ )
24
+ model.eval()
25
+
26
+ from PIL import Image
27
+ image = Image.open("/bask/homes/a/asiw9691/PathVLM/source/Flamingo/med-flamingo/img/test_path5.jpg")
28
+
29
+ image = transform(image) # Image (torch.Tensor) with shape [1, 3, 224, 224] following image resizing and normalization (ImageNet parameters)
30
+ with torch.inference_mode():
31
+ feature_emb = model(image) # Extracted features (torch.Tensor) with shape [1,1024]
32
+
33
+ print(feature_emb.shape)
demo/path_clip.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+
4
+ import open_clip
5
+
6
+ ## load the model
7
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained="/bask/homes/a/asiw9691/PathVLM/PathClip/pathclip-base.pt",
8
+ force_quick_gelu=True)
9
+ tokenizer = open_clip.get_tokenizer('ViT-B-16')
10
+ model = model.cuda()
11
+
12
+ ##load the image and prepare the text prompt
13
+ img_path = '/raid/hpc/hekai/WorkShop/My_project/PathLLM_new/data/test_data/test_path1.jpg'
14
+ label_description_list = ['apple', 'liver', 'cancer',] # specify the label descriptions
15
+ text_label_list = ['An image of {}'.format(i) for i in label_description_list]
16
+ image = Image.open(img_path)
17
+ image = preprocess(image).unsqueeze(0).cuda()
18
+ text = tokenizer(text_label_list).cuda()
19
+
20
+ ## extract the img and text feature and predict the label
21
+ with torch.no_grad(), torch.cuda.amp.autocast():
22
+ image_features = model.encode_image(image)
23
+ text_features = model.encode_text(text)
24
+ image_features /= image_features.norm(dim=-1, keepdim=True)
25
+ text_features /= text_features.norm(dim=-1, keepdim=True)
26
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
27
+ predict_label = torch.argmax(text_probs).item()
28
+ print(predict_label)
demo/peft_demo.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from peft import LoraConfig, TaskType, get_peft_model
6
+
7
+ model_name_or_path = "/raid/hpc/hekai/WorkShop/My_project/LLM_models/llama2/Llama-2-7b-chat-hf"
8
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
9
+
10
+ peft_config = LoraConfig(
11
+ r=8,
12
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
13
+ bias="none",
14
+ task_type=TaskType.CAUSAL_LM,
15
+ )
16
+
17
+ model = get_peft_model(model, peft_config)
demo/trl_demo.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # accelerate launch --config_file=/raid/hpc/hekai/WorkShop/My_project/PathLLM_new/accelerate_configs/deepspeed_zero2.yaml demo/trl_demo.py
4
+
5
+ import os
6
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from accelerate import Accelerator
14
+ from datasets import load_dataset
15
+ from peft import LoraConfig
16
+ from tqdm import tqdm
17
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, AutoTokenizer
18
+ from trl import SFTTrainer
19
+
20
+
21
+
22
+ tqdm.pandas()
23
+
24
+
25
+ # Define and parse arguments.
26
+ @dataclass
27
+ class ScriptArguments:
28
+ """
29
+ The name of the Casual LM model we wish to fine with SFTTrainer
30
+ """
31
+
32
+ model_name: Optional[str] = field(default="mistralai/Mistral-7B-Instruct-v0.2", metadata={"help": "the model name, meta-llama/Llama-2-7b-chat-hf "})
33
+ dataset_name: Optional[str] = field(default="stingning/ultrachat", metadata={"help": "the dataset name"})
34
+ dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
35
+ log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
36
+ learning_rate: Optional[float] = field(default=2.0e-5, metadata={"help": "the learning rate"})
37
+ batch_size: Optional[int] = field(default=1, metadata={"help": "the batch size"})
38
+ seq_length: Optional[int] = field(default=1024, metadata={"help": "Input sequence length"})
39
+ gradient_accumulation_steps: Optional[int] = field(default=8, metadata={"help": "the number of gradient accumulation steps"})
40
+
41
+ evaluation_strategy: Optional[str] = field(default="steps", metadata={"help": "epoch, step"})
42
+ eval_steps: Optional[int] = field(default=2, metadata={"help": "the number of gradient accumulation steps"})
43
+
44
+ load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
45
+ load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
46
+ use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
47
+ trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
48
+
49
+ output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
50
+ peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
51
+ peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
52
+ logging_steps: Optional[int] = field(default=5, metadata={"help": "the number of logging steps"})
53
+ token: Optional[bool] = field(default="True", metadata={"help": "Use HF auth token to access the model"})
54
+ num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
55
+ max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
56
+ save_steps: Optional[int] = field(default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"})
57
+ save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
58
+ push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
59
+ hub_model_id: Optional[str] = field(default="mistral-7b-finetuned-ultrachat", metadata={"help": "The name of the model on HF Hub"})
60
+
61
+
62
+ parser = HfArgumentParser(ScriptArguments)
63
+ script_args = parser.parse_args_into_dataclasses()[0]
64
+
65
+ # Step 1: Load the dataset
66
+ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
67
+ tokenizer.padding_side = 'right'
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+
70
+ dataset = load_dataset(script_args.dataset_name, split="train[:200]")
71
+ dataset = dataset.train_test_split(test_size=0.1)
72
+
73
+ def prepare_dialogue(example):
74
+ text = ""
75
+ for idx, msg in enumerate(example["data"]):
76
+ if idx % 2 == 0:
77
+ text += f"<|user|>\n{msg}{tokenizer.eos_token}\n"
78
+ else:
79
+ text += f"<|assistant|>\n{msg}{tokenizer.eos_token}\n"
80
+ example["text"] = text
81
+ return example
82
+
83
+ dataset = dataset.map(prepare_dialogue, num_proc=4, remove_columns=["id", "data"])
84
+
85
+
86
+ # Step 2: Load the model
87
+ if script_args.load_in_8bit and script_args.load_in_4bit:
88
+ raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
89
+ elif script_args.load_in_8bit or script_args.load_in_4bit:
90
+ quantization_config = BitsAndBytesConfig(
91
+ load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
92
+ )
93
+ # Copy the model to each device
94
+ device_map = {"": Accelerator().local_process_index}
95
+ torch_dtype = torch.bfloat16
96
+ else:
97
+ # device_map = "auto"
98
+ device_map = None
99
+ quantization_config = None
100
+ torch_dtype = None
101
+
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ script_args.model_name,
104
+ quantization_config=quantization_config,
105
+ device_map=device_map,
106
+ trust_remote_code=script_args.trust_remote_code,
107
+ torch_dtype=torch_dtype,
108
+ token=script_args.token,
109
+ )
110
+
111
+
112
+ # Step 4: Define the LoraConfig
113
+ if script_args.use_peft:
114
+ peft_config = LoraConfig(
115
+ r=script_args.peft_lora_r,
116
+ lora_alpha=script_args.peft_lora_alpha,
117
+ bias="none",
118
+ task_type="CAUSAL_LM",
119
+ )
120
+ else:
121
+ peft_config = None
122
+
123
+
124
+
125
+ training_args = TrainingArguments(
126
+ output_dir=script_args.output_dir,
127
+ per_device_train_batch_size=script_args.batch_size,
128
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
129
+ gradient_checkpointing=True,
130
+ learning_rate=script_args.learning_rate,
131
+ logging_steps=script_args.logging_steps,
132
+ num_train_epochs=script_args.num_train_epochs,
133
+ max_steps=script_args.max_steps,
134
+ report_to=script_args.log_with,
135
+ save_steps=script_args.save_steps,
136
+ save_total_limit=script_args.save_total_limit,
137
+ # push_to_hub=script_args.push_to_hub,
138
+ # hub_model_id=script_args.hub_model_id,
139
+ bf16=True,
140
+ lr_scheduler_type="cosine",
141
+ warmup_ratio=0.1,
142
+ evaluation_strategy=script_args.evaluation_strategy,
143
+ eval_steps=script_args.eval_steps,
144
+ logging_first_step=True,
145
+ )
146
+
147
+ def my_compute_metrics(p):
148
+ predictions, labels = p
149
+
150
+ return {
151
+ 'precision': 1,
152
+ 'recall': 1,
153
+ 'f1': 1,
154
+ }
155
+
156
+ trainer = SFTTrainer(
157
+ model=model,
158
+ args=training_args,
159
+ max_seq_length=script_args.seq_length,
160
+ train_dataset=dataset["train"],
161
+ eval_dataset=dataset["test"],
162
+ dataset_text_field=script_args.dataset_text_field,
163
+ peft_config=peft_config,
164
+ packing=False,
165
+ tokenizer=tokenizer,
166
+ compute_metrics=my_compute_metrics
167
+ )
168
+
169
+ trainer.train()
170
+
171
+ # Step 6: Save the model
172
+ trainer.save_model(script_args.output_dir)
173
+
174
+
175
+
evaluation/cider_score/cider_demo.ipynb ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/michelecafagna26/cider/blob/master/cider_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "GWTmEvA9jNbE"
17
+ },
18
+ "source": [
19
+ "# Install"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 9,
25
+ "metadata": {
26
+ "colab": {
27
+ "base_uri": "https://localhost:8080/"
28
+ },
29
+ "id": "AeLUPV23cglP",
30
+ "outputId": "61e4fef5-0481-4c41-ee8a-0e8073f0d3e1"
31
+ },
32
+ "outputs": [
33
+ {
34
+ "name": "stdout",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
38
+ "Requirement already satisfied: spacy in d:\\miniconda\\envs\\llm\\lib\\site-packages (3.7.4)\n",
39
+ "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (3.0.12)\n",
40
+ "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (1.0.5)\n",
41
+ "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (1.0.10)\n",
42
+ "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.0.8)\n",
43
+ "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (3.0.9)\n",
44
+ "Requirement already satisfied: thinc<8.3.0,>=8.2.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (8.2.3)\n",
45
+ "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (1.1.2)\n",
46
+ "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.4.8)\n",
47
+ "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.0.10)\n",
48
+ "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (0.3.4)\n",
49
+ "Requirement already satisfied: typer<0.10.0,>=0.3.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (0.9.4)\n",
50
+ "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (6.4.0)\n",
51
+ "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (4.66.1)\n",
52
+ "Requirement already satisfied: requests<3.0.0,>=2.13.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.31.0)\n",
53
+ "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (2.7.1)\n",
54
+ "Requirement already satisfied: jinja2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (3.1.2)\n",
55
+ "Requirement already satisfied: setuptools in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (68.0.0)\n",
56
+ "Requirement already satisfied: packaging>=20.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (23.2)\n",
57
+ "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (3.4.0)\n",
58
+ "Requirement already satisfied: numpy>=1.15.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy) (1.24.4)\n",
59
+ "Requirement already satisfied: language-data>=1.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from langcodes<4.0.0,>=3.2.0->spacy) (1.2.0)\n",
60
+ "Requirement already satisfied: annotated-types>=0.4.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (0.7.0)\n",
61
+ "Requirement already satisfied: pydantic-core==2.18.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (2.18.2)\n",
62
+ "Requirement already satisfied: typing-extensions>=4.6.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (4.8.0)\n",
63
+ "Requirement already satisfied: charset-normalizer<4,>=2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (3.3.2)\n",
64
+ "Requirement already satisfied: idna<4,>=2.5 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (3.4)\n",
65
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (2.0.7)\n",
66
+ "Requirement already satisfied: certifi>=2017.4.17 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy) (2023.7.22)\n",
67
+ "Requirement already satisfied: blis<0.8.0,>=0.7.8 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from thinc<8.3.0,>=8.2.2->spacy) (0.7.11)\n",
68
+ "Requirement already satisfied: confection<1.0.0,>=0.0.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from thinc<8.3.0,>=8.2.2->spacy) (0.1.4)\n",
69
+ "Requirement already satisfied: colorama in d:\\miniconda\\envs\\llm\\lib\\site-packages (from tqdm<5.0.0,>=4.38.0->spacy) (0.4.6)\n",
70
+ "Requirement already satisfied: click<9.0.0,>=7.1.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from typer<0.10.0,>=0.3.0->spacy) (8.1.7)\n",
71
+ "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from weasel<0.4.0,>=0.1.0->spacy) (0.16.0)\n",
72
+ "Requirement already satisfied: MarkupSafe>=2.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from jinja2->spacy) (2.1.3)\n",
73
+ "Requirement already satisfied: marisa-trie>=0.7.7 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from language-data>=1.2->langcodes<4.0.0,>=3.2.0->spacy) (1.1.1)\n",
74
+ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
75
+ "Collecting en-core-web-sm==3.7.1\n",
76
+ " Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)\n",
77
+ " ---------------------------------------- 0.0/12.8 MB ? eta -:--:--\n",
78
+ " ---------------------------------------- 0.0/12.8 MB ? eta -:--:--\n",
79
+ " --------------------------------------- 0.0/12.8 MB 217.9 kB/s eta 0:00:59\n",
80
+ " --------------------------------------- 0.0/12.8 MB 217.9 kB/s eta 0:00:59\n",
81
+ " --------------------------------------- 0.0/12.8 MB 119.1 kB/s eta 0:01:48\n",
82
+ " --------------------------------------- 0.0/12.8 MB 119.1 kB/s eta 0:01:48\n",
83
+ " --------------------------------------- 0.0/12.8 MB 122.9 kB/s eta 0:01:44\n",
84
+ " --------------------------------------- 0.1/12.8 MB 164.1 kB/s eta 0:01:18\n",
85
+ " --------------------------------------- 0.1/12.8 MB 201.8 kB/s eta 0:01:03\n",
86
+ " --------------------------------------- 0.1/12.8 MB 234.3 kB/s eta 0:00:55\n",
87
+ " --------------------------------------- 0.1/12.8 MB 232.7 kB/s eta 0:00:55\n",
88
+ " --------------------------------------- 0.2/12.8 MB 280.8 kB/s eta 0:00:46\n",
89
+ " -------------------------------------- 0.2/12.8 MB 318.9 kB/s eta 0:00:40\n",
90
+ " -------------------------------------- 0.2/12.8 MB 367.4 kB/s eta 0:00:35\n",
91
+ " -------------------------------------- 0.3/12.8 MB 402.4 kB/s eta 0:00:32\n",
92
+ " - ------------------------------------- 0.3/12.8 MB 446.4 kB/s eta 0:00:28\n",
93
+ " - ------------------------------------- 0.4/12.8 MB 511.3 kB/s eta 0:00:25\n",
94
+ " - ------------------------------------- 0.5/12.8 MB 581.3 kB/s eta 0:00:22\n",
95
+ " - ------------------------------------- 0.5/12.8 MB 619.9 kB/s eta 0:00:20\n",
96
+ " - ------------------------------------- 0.6/12.8 MB 655.3 kB/s eta 0:00:19\n",
97
+ " -- ------------------------------------ 0.7/12.8 MB 686.9 kB/s eta 0:00:18\n",
98
+ " -- ------------------------------------ 0.8/12.8 MB 735.7 kB/s eta 0:00:17\n",
99
+ " -- ------------------------------------ 0.8/12.8 MB 762.9 kB/s eta 0:00:16\n",
100
+ " -- ------------------------------------ 0.9/12.8 MB 786.3 kB/s eta 0:00:16\n",
101
+ " -- ------------------------------------ 0.9/12.8 MB 816.7 kB/s eta 0:00:15\n",
102
+ " --- ----------------------------------- 1.1/12.8 MB 893.8 kB/s eta 0:00:14\n",
103
+ " --- ------------------------------------ 1.3/12.8 MB 1.0 MB/s eta 0:00:12\n",
104
+ " ---- ----------------------------------- 1.4/12.8 MB 1.1 MB/s eta 0:00:11\n",
105
+ " ----- ---------------------------------- 1.6/12.8 MB 1.2 MB/s eta 0:00:10\n",
106
+ " ----- ---------------------------------- 1.8/12.8 MB 1.3 MB/s eta 0:00:09\n",
107
+ " ------ --------------------------------- 2.1/12.8 MB 1.5 MB/s eta 0:00:08\n",
108
+ " ------- -------------------------------- 2.5/12.8 MB 1.7 MB/s eta 0:00:07\n",
109
+ " --------- ------------------------------ 3.2/12.8 MB 2.1 MB/s eta 0:00:05\n",
110
+ " ------------ --------------------------- 3.9/12.8 MB 2.4 MB/s eta 0:00:04\n",
111
+ " -------------- ------------------------- 4.5/12.8 MB 2.7 MB/s eta 0:00:04\n",
112
+ " --------------- ------------------------ 4.9/12.8 MB 2.9 MB/s eta 0:00:03\n",
113
+ " ---------------- ----------------------- 5.2/12.8 MB 3.0 MB/s eta 0:00:03\n",
114
+ " ----------------- ---------------------- 5.5/12.8 MB 3.0 MB/s eta 0:00:03\n",
115
+ " ------------------ --------------------- 5.8/12.8 MB 3.1 MB/s eta 0:00:03\n",
116
+ " ------------------- -------------------- 6.1/12.8 MB 3.2 MB/s eta 0:00:03\n",
117
+ " ------------------- -------------------- 6.4/12.8 MB 3.3 MB/s eta 0:00:02\n",
118
+ " -------------------- ------------------- 6.7/12.8 MB 3.3 MB/s eta 0:00:02\n",
119
+ " --------------------- ------------------ 7.0/12.8 MB 3.4 MB/s eta 0:00:02\n",
120
+ " ---------------------- ----------------- 7.3/12.8 MB 3.5 MB/s eta 0:00:02\n",
121
+ " ----------------------- ---------------- 7.6/12.8 MB 3.5 MB/s eta 0:00:02\n",
122
+ " ------------------------ --------------- 7.9/12.8 MB 3.6 MB/s eta 0:00:02\n",
123
+ " ------------------------- -------------- 8.2/12.8 MB 3.7 MB/s eta 0:00:02\n",
124
+ " -------------------------- ------------- 8.5/12.8 MB 3.7 MB/s eta 0:00:02\n",
125
+ " --------------------------- ------------ 8.8/12.8 MB 3.8 MB/s eta 0:00:02\n",
126
+ " ---------------------------- ----------- 9.1/12.8 MB 3.8 MB/s eta 0:00:01\n",
127
+ " ----------------------------- ---------- 9.4/12.8 MB 3.9 MB/s eta 0:00:01\n",
128
+ " ------------------------------ --------- 9.7/12.8 MB 3.9 MB/s eta 0:00:01\n",
129
+ " ------------------------------- -------- 10.0/12.8 MB 3.9 MB/s eta 0:00:01\n",
130
+ " -------------------------------- ------- 10.3/12.8 MB 4.4 MB/s eta 0:00:01\n",
131
+ " --------------------------------- ------ 10.6/12.8 MB 5.4 MB/s eta 0:00:01\n",
132
+ " --------------------------------- ------ 10.9/12.8 MB 5.9 MB/s eta 0:00:01\n",
133
+ " ---------------------------------- ----- 11.2/12.8 MB 6.5 MB/s eta 0:00:01\n",
134
+ " ----------------------------------- ---- 11.4/12.8 MB 6.7 MB/s eta 0:00:01\n",
135
+ " ------------------------------------ --- 11.7/12.8 MB 6.8 MB/s eta 0:00:01\n",
136
+ " ------------------------------------- -- 12.0/12.8 MB 6.9 MB/s eta 0:00:01\n",
137
+ " -------------------------------------- - 12.3/12.8 MB 6.9 MB/s eta 0:00:01\n",
138
+ " --------------------------------------- 12.6/12.8 MB 6.8 MB/s eta 0:00:01\n",
139
+ " --------------------------------------- 12.8/12.8 MB 6.8 MB/s eta 0:00:01\n",
140
+ " ---------------------------------------- 12.8/12.8 MB 6.6 MB/s eta 0:00:00\n",
141
+ "Requirement already satisfied: spacy<3.8.0,>=3.7.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from en-core-web-sm==3.7.1) (3.7.4)\n",
142
+ "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.12)\n",
143
+ "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.5)\n",
144
+ "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.10)\n",
145
+ "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.8)\n",
146
+ "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.9)\n",
147
+ "Requirement already satisfied: thinc<8.3.0,>=8.2.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.2.3)\n",
148
+ "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.2)\n",
149
+ "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.4.8)\n",
150
+ "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.10)\n",
151
+ "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.3.4)\n",
152
+ "Requirement already satisfied: typer<0.10.0,>=0.3.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.9.4)\n",
153
+ "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (6.4.0)\n",
154
+ "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.66.1)\n",
155
+ "Requirement already satisfied: requests<3.0.0,>=2.13.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.31.0)\n",
156
+ "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.7.1)\n",
157
+ "Requirement already satisfied: jinja2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.1.2)\n",
158
+ "Requirement already satisfied: setuptools in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (68.0.0)\n",
159
+ "Requirement already satisfied: packaging>=20.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (23.2)\n",
160
+ "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.4.0)\n",
161
+ "Requirement already satisfied: numpy>=1.15.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.24.4)\n",
162
+ "Requirement already satisfied: language-data>=1.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from langcodes<4.0.0,>=3.2.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.2.0)\n",
163
+ "Requirement already satisfied: annotated-types>=0.4.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.0)\n",
164
+ "Requirement already satisfied: pydantic-core==2.18.2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.18.2)\n",
165
+ "Requirement already satisfied: typing-extensions>=4.6.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.8.0)\n",
166
+ "Requirement already satisfied: charset-normalizer<4,>=2 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.2)\n",
167
+ "Requirement already satisfied: idna<4,>=2.5 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.4)\n",
168
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.7)\n",
169
+ "Requirement already satisfied: certifi>=2017.4.17 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2023.7.22)\n",
170
+ "Requirement already satisfied: blis<0.8.0,>=0.7.8 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.11)\n",
171
+ "Requirement already satisfied: confection<1.0.0,>=0.0.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.1.4)\n",
172
+ "Requirement already satisfied: colorama in d:\\miniconda\\envs\\llm\\lib\\site-packages (from tqdm<5.0.0,>=4.38.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.4.6)\n",
173
+ "Requirement already satisfied: click<9.0.0,>=7.1.1 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from typer<0.10.0,>=0.3.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.1.7)\n",
174
+ "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.16.0)\n",
175
+ "Requirement already satisfied: MarkupSafe>=2.0 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from jinja2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.1.3)\n",
176
+ "Requirement already satisfied: marisa-trie>=0.7.7 in d:\\miniconda\\envs\\llm\\lib\\site-packages (from language-data>=1.2->langcodes<4.0.0,>=3.2.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.1)\n",
177
+ "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
178
+ "You can now load the package via spacy.load('en_core_web_sm')\n"
179
+ ]
180
+ }
181
+ ],
182
+ "source": [
183
+ "#use spacy to get rid of the influence of standford-corenlp.jar, which requires java\n",
184
+ "! pip install spacy\n",
185
+ "! python -m spacy download en_core_web_sm "
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "markdown",
190
+ "metadata": {
191
+ "id": "RF9-uATHjSJT"
192
+ },
193
+ "source": [
194
+ "Ready to go!"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": 2,
200
+ "metadata": {
201
+ "id": "UvK_ACyMeqDD"
202
+ },
203
+ "outputs": [],
204
+ "source": [
205
+ "from cidereval import cider, ciderD\n",
206
+ "import pandas as pd"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 6,
212
+ "metadata": {},
213
+ "outputs": [
214
+ {
215
+ "name": "stdout",
216
+ "output_type": "stream",
217
+ "text": [
218
+ "CIDEr Score: 1.1466548664799776\n"
219
+ ]
220
+ }
221
+ ],
222
+ "source": [
223
+ "def calculate_cider_score(file_path,pred_column,ref_column):\n",
224
+ " \"\"\"\n",
225
+ " input:\n",
226
+ "\n",
227
+ " file_path: table to be analysed.\n",
228
+ " pred_column: column name that the analysed model generated\n",
229
+ " ref_column: column name of the ground truth\n",
230
+ "\n",
231
+ " output:\n",
232
+ " score:{avg_score:xxx,scores:[array]}\n",
233
+ " \"\"\"\n",
234
+ " # 读取xlsx文件\n",
235
+ " df = pd.read_excel(file_path)\n",
236
+ "\n",
237
+ " # 创建存储ground_truth和generated句子的列表\n",
238
+ " references = []\n",
239
+ " candidates = []\n",
240
+ "\n",
241
+ " # 遍历每一行并将内容添加到对应的列表中\n",
242
+ " for index, row in df.iterrows():\n",
243
+ " # references.append([row['answers']])\n",
244
+ " # candidates.append([row['results']])\n",
245
+ " references.append([row[ref_column]])\n",
246
+ " candidates.append(row[pred_column])\n",
247
+ " # 创建Cider对象\n",
248
+ " # cider_scorer = Cider()\n",
249
+ "\n",
250
+ " # 计算CIDEr分数\n",
251
+ "\n",
252
+ " cider_score = cider(candidates,references,df=\"corpus\")\n",
253
+ "\n",
254
+ " return cider_score\n",
255
+ "\n",
256
+ "# 调用函数并传入xlsx文件路径\n",
257
+ "file_path = 'output_sample.xls'\n",
258
+ "score = calculate_cider_score(file_path,\"results\",\"answers\")\n",
259
+ "print(\"CIDEr Score:\", score['avg_score'])"
260
+ ]
261
+ }
262
+ ],
263
+ "metadata": {
264
+ "colab": {
265
+ "authorship_tag": "ABX9TyMT4TNjFKKrEMcMc0uZ6Ubr",
266
+ "include_colab_link": true,
267
+ "name": "cider_demo.ipynb",
268
+ "provenance": []
269
+ },
270
+ "kernelspec": {
271
+ "display_name": "asiw9691_conda_env (Conda)",
272
+ "language": "python",
273
+ "name": "sys_asiw9691_conda_env"
274
+ },
275
+ "language_info": {
276
+ "codemirror_mode": {
277
+ "name": "ipython",
278
+ "version": 3
279
+ },
280
+ "file_extension": ".py",
281
+ "mimetype": "text/x-python",
282
+ "name": "python",
283
+ "nbconvert_exporter": "python",
284
+ "pygments_lexer": "ipython3",
285
+ "version": "3.10.4"
286
+ }
287
+ },
288
+ "nbformat": 4,
289
+ "nbformat_minor": 4
290
+ }
evaluation/cider_score/cidereval/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __author__ = 'tylin'
2
+ # edited by Michele Cafagna
3
+ from cidereval.cider.cider import Cider
4
+ from cidereval.ciderD.ciderD import CiderD
5
+ from cidereval.scorers import cider, ciderD
evaluation/cider_score/cidereval/cider/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
evaluation/cider_score/cidereval/cider/cider.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Filename: cider.py
2
+ #
3
+ #
4
+ # Description: Describes the class to compute the CIDEr
5
+ # (Consensus-Based Image Description Evaluation) Metric
6
+ # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
7
+ #
8
+ # Creation Date: Sun Feb 8 14:16:54 2015
9
+ #
10
+ # Authors: Ramakrishna Vedantam <vrama91@vt.edu> and
11
+ # Tsung-Yi Lin <tl483@cornell.edu>
12
+
13
+ # edited by Michele Cafagna
14
+
15
+ from .cider_scorer import CiderScorer
16
+
17
+
18
+ class Cider:
19
+ """
20
+ Main Class to compute the CIDEr metric
21
+
22
+ """
23
+ def __init__(self, n=4, df="corpus"):
24
+ """
25
+ Initialize the CIDEr scoring function
26
+ : param n (int): n-gram size
27
+ : param df (string): specifies where to get the IDF values from
28
+ takes values 'corpus', 'coco-val'
29
+ : return: None
30
+ """
31
+ # set cider to sum over 1 to 4-grams
32
+ self._n = n
33
+ self._df = df
34
+ self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
35
+
36
+ def compute_score(self, gts, res):
37
+ """
38
+ Main function to compute CIDEr score
39
+ : param gts (dict) : {image:tokenized reference sentence}
40
+ : param res (dict) : {image:tokenized candidate sentence}
41
+ : return: cider (float) : computed CIDEr score for the corpus
42
+ """
43
+
44
+ # clear all the previous hypos and refs
45
+ self.cider_scorer.clear()
46
+
47
+ for res_id in res:
48
+
49
+ hypo = res_id['caption']
50
+ ref = gts[res_id['image_id']]
51
+
52
+ # Sanity check.
53
+ assert(type(hypo) is list)
54
+ assert(len(hypo) == 1)
55
+ assert(type(ref) is list)
56
+ assert(len(ref) > 0)
57
+ self.cider_scorer += (hypo[0], ref)
58
+
59
+ (score, scores) = self.cider_scorer.compute_score()
60
+
61
+ return score, scores
62
+
63
+ def save_df(self, df_name="corpus"):
64
+ self.cider_scorer.save_df(df_name)
65
+
66
+ def method(self):
67
+ return "CIDEr"
evaluation/cider_score/cidereval/cider/cider_scorer.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Tsung-Yi Lin <tl483@cornell.edu>
3
+ # Ramakrishna Vedantam <vrama91@vt.edu>
4
+
5
+ from pathlib import Path
6
+ from collections import defaultdict
7
+ import pickle
8
+ import math
9
+ from copy import copy
10
+
11
+ from importlib_resources import files, as_file
12
+ import numpy as np
13
+ import cidereval.data
14
+
15
+
16
+ def precook(s, n=4, out=False):
17
+ """
18
+ Takes a string as input and returns an object that can be given to
19
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
20
+ can take string arguments as well.
21
+ :param s: string : sentence to be converted into ngrams
22
+ :param n: int : number of ngrams for which representation is calculated
23
+ :return: term frequency vector for occuring ngrams
24
+ """
25
+ words = s.split()
26
+ counts = defaultdict(int)
27
+ for k in range(1, n + 1):
28
+ for i in range(len(words) - k + 1):
29
+ ngram = tuple(words[i:i + k])
30
+ counts[ngram] += 1
31
+ return counts
32
+
33
+
34
+ def cook_refs(refs, n=4): # lhuang: oracle will call with "average"
35
+ '''Takes a list of reference sentences for a single segment
36
+ and returns an object that encapsulates everything that BLEU
37
+ needs to know about them.
38
+ :param refs: list of string : reference sentences for some image
39
+ :param n: int : number of ngrams for which (ngram) representation is calculated
40
+ :return: result (list of dict)
41
+ '''
42
+ return [precook(ref, n) for ref in refs]
43
+
44
+
45
+ def cook_test(test, n=4):
46
+ '''Takes a test sentence and returns an object that
47
+ encapsulates everything that BLEU needs to know about it.
48
+ :param test: list of string : hypothesis sentence for some image
49
+ :param n: int : number of ngrams for which (ngram) representation is calculated
50
+ :return: result (dict)
51
+ '''
52
+ return precook(test, n, True)
53
+
54
+
55
+ class CiderScorer(object):
56
+ """CIDEr scorer.
57
+ """
58
+
59
+ def save_df(self, df_name="corpus", path=None):
60
+ """Save the idf computed in corpus mode
61
+
62
+ Args:
63
+ df_name (str, optional): [description]. Defaults to "corpus". name of idf file
64
+ (without the file exntension)
65
+
66
+ df_name (str, optional): [description]. Defaults to None. path of the idf if note provided
67
+ it will be used the home directory
68
+ Raises:
69
+ ValueError: [description] if you try to call this method before computing the scores
70
+ """
71
+
72
+ if path:
73
+ path = Path(path)
74
+
75
+ if not path.exists():
76
+ path=Path.home()
77
+ print(f"the path provided is not valid. The df will be saved in {path}")
78
+ else:
79
+ path=Path.home()
80
+ print(f"the path provided is not valid. The df will be saved in {path}")
81
+
82
+ filename = Path(path, df_name + '.p')
83
+
84
+ if len(self.document_frequency) > 0:
85
+ with open(filename, "wb") as fp:
86
+
87
+ df_idf = {
88
+ "ref_len" : np.log(float(len(self.crefs))),
89
+ "df": self.document_frequency
90
+ }
91
+
92
+ pickle.dump(df_idf, fp)
93
+ print(f"saved to {filename}")
94
+ else:
95
+ raise ValueError("document frequency not computed run 'compute_score'")
96
+
97
+ def copy(self):
98
+ ''' copy the refs.'''
99
+ new = CiderScorer(n=self.n)
100
+ new.ctest = copy.copy(self.ctest)
101
+ new.crefs = copy.copy(self.crefs)
102
+ return new
103
+
104
+ def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
105
+ ''' singular instance '''
106
+ self.n = n
107
+ self.sigma = sigma
108
+ self.crefs = []
109
+ self.ctest = []
110
+ self.ref_len = None
111
+ self.df_mode = df_mode
112
+
113
+ if self.df_mode != "corpus":
114
+ if self.df_mode !="coco-val":
115
+ try:
116
+ with open(self.df_mode, 'rb') as fp:
117
+ df = pickle.load(fp, encoding='iso-8859-1')
118
+ except FileNotFoundError as e:
119
+ print(f"Error retrieveing {self.df_mode}.p df_mode set to 'coco-val'")
120
+
121
+ self.df_mode="coco-val"
122
+ df_path = files(cidereval.data).joinpath(self.df_mode + '.p')
123
+ with as_file(df_path)as res:
124
+ with open(res, 'rb') as fp:
125
+ df = pickle.load(fp, encoding='iso-8859-1')
126
+ else:
127
+ df_path = files(cidereval.data).joinpath(self.df_mode + '.p')
128
+ with as_file(df_path)as res:
129
+ with open(res, 'rb') as fp:
130
+ df = pickle.load(fp, encoding='iso-8859-1')
131
+
132
+ #df_path = os.path.join('data', df_mode + '.p')
133
+ #df = pickle.load(open(os.path.join('data', df_mode + '.p'), 'rb'), encoding='iso-8859-1') # TODO fix path
134
+ self.document_frequency = df['df']
135
+ self.ref_len = df['ref_len']
136
+ self.cook_append(test, refs)
137
+
138
+ def clear(self):
139
+ self.crefs = []
140
+ self.ctest = []
141
+
142
+ def cook_append(self, test, refs):
143
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
144
+
145
+ if refs is not None:
146
+ self.crefs.append(cook_refs(refs))
147
+ if test is not None:
148
+ self.ctest.append(cook_test(test)) ## N.B.: -1
149
+ else:
150
+ self.ctest.append(None) # lens of crefs and ctest have to match
151
+
152
+ def size(self):
153
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
154
+ return len(self.crefs)
155
+
156
+ def __iadd__(self, other):
157
+ '''add an instance (e.g., from another sentence).'''
158
+
159
+ if type(other) is tuple:
160
+ # avoid creating new CiderScorer instances
161
+ self.cook_append(other[0], other[1])
162
+ else:
163
+ self.ctest.extend(other.ctest)
164
+ self.crefs.extend(other.crefs)
165
+
166
+ return self
167
+
168
+ def compute_doc_freq(self):
169
+ '''
170
+ Compute term frequency for reference data.
171
+ This will be used to compute idf (inverse document frequency later)
172
+ The term frequency is stored in the object
173
+ :return: None
174
+ '''
175
+ for refs in self.crefs:
176
+ # refs, k ref captions of one image
177
+ for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]):
178
+ self.document_frequency[ngram] += 1
179
+ # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
180
+
181
+ def compute_cider(self):
182
+ def counts2vec(cnts):
183
+ """
184
+ Function maps counts of ngram to vector of tfidf weights.
185
+ The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
186
+ The n-th entry of array denotes length of n-grams.
187
+ :param cnts:
188
+ :return: vec (array of dict), norm (array of float), length (int)
189
+ """
190
+ vec = [defaultdict(float) for _ in range(self.n)]
191
+ length = 0
192
+ norm = [0.0 for _ in range(self.n)]
193
+ for (ngram, term_freq) in cnts.items():
194
+ # give word count 1 if it doesn't appear in reference corpus
195
+ df = np.log(max(1.0, self.document_frequency[ngram]))
196
+ # ngram index
197
+ n = len(ngram) - 1
198
+ # tf (term_freq) * idf (precomputed idf) for n-grams
199
+ vec[n][ngram] = float(term_freq) * (self.ref_len - df)
200
+ # compute norm for the vector. the norm will be used for
201
+ # computing similarity
202
+ norm[n] += pow(vec[n][ngram], 2)
203
+
204
+ if n == 1:
205
+ length += term_freq
206
+ norm = [np.sqrt(n) for n in norm]
207
+ return vec, norm, length
208
+
209
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
210
+ '''
211
+ Compute the cosine similarity of two vectors.
212
+ :param vec_hyp: array of dictionary for vector corresponding to hypothesis
213
+ :param vec_ref: array of dictionary for vector corresponding to reference
214
+ :param norm_hyp: array of float for vector corresponding to hypothesis
215
+ :param norm_ref: array of float for vector corresponding to reference
216
+ :param length_hyp: int containing length of hypothesis
217
+ :param length_ref: int containing length of reference
218
+ :return: array of score for each n-grams cosine similarity
219
+ '''
220
+ delta = float(length_hyp - length_ref)
221
+ # measure consine similarity
222
+ val = np.array([0.0 for _ in range(self.n)])
223
+ for n in range(self.n):
224
+ # ngram
225
+ for (ngram, count) in vec_hyp[n].items():
226
+ val[n] += vec_hyp[n][ngram] * vec_ref[n][ngram]
227
+
228
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
229
+ val[n] /= (norm_hyp[n] * norm_ref[n])
230
+
231
+ assert (not math.isnan(val[n]))
232
+ return val
233
+
234
+ # compute log reference length
235
+ if self.df_mode == "corpus":
236
+ self.ref_len = np.log(float(len(self.crefs)))
237
+ #elif self.df_mode == "coco-val":
238
+ # if coco option selected, use length of coco-val set
239
+ #self.ref_len = np.log(float(40504))
240
+
241
+
242
+
243
+ scores = []
244
+ for test, refs in zip(self.ctest, self.crefs):
245
+ # compute vector for test captions
246
+ vec, norm, length = counts2vec(test)
247
+ # compute vector for ref captions
248
+ score = np.array([0.0 for _ in range(self.n)])
249
+ for ref in refs:
250
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
251
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
252
+ # change by vrama91 - mean of ngram scores, instead of sum
253
+ score_avg = np.mean(score)
254
+ # divide by number of references
255
+ score_avg /= len(refs)
256
+ # multiply score by 10
257
+ score_avg *= 10.0
258
+ # append score of an image to the score list
259
+ scores.append(score_avg)
260
+ return scores
261
+
262
+ def compute_score(self, option=None, verbose=0):
263
+ # compute idf
264
+ if self.df_mode == "corpus":
265
+ self.document_frequency = defaultdict(float)
266
+ self.compute_doc_freq()
267
+ # assert to check document frequency
268
+ assert (len(self.ctest) >= max(self.document_frequency.values()))
269
+ # import json for now and write the corresponding files
270
+ # compute cider score
271
+ score = self.compute_cider()
272
+ # debug
273
+ # print score
274
+ return np.mean(np.array(score)), np.array(score)
evaluation/cider_score/cidereval/ciderD/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
evaluation/cider_score/cidereval/ciderD/ciderD.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Filename: ciderD.py
2
+ #
3
+ # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric
4
+ # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5
+ #
6
+ # Creation Date: Sun Feb 8 14:16:54 2015
7
+ #
8
+ # Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>
9
+
10
+ from .ciderD_scorer import CiderScorer
11
+ import pdb
12
+
13
+ class CiderD:
14
+ """
15
+ Main Class to compute the CIDEr metric
16
+
17
+ """
18
+ def __init__(self, n=4, sigma=6.0, df="corpus"):
19
+ # set cider to sum over 1 to 4-grams
20
+ self._n = n
21
+ # set the standard deviation parameter for gaussian penalty
22
+ self._sigma = sigma
23
+ # set which where to compute document frequencies from
24
+ self._df = df
25
+ self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
26
+
27
+ def compute_score(self, gts, res):
28
+ """
29
+ Main function to compute CIDEr score
30
+ :param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
31
+ ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
32
+ :return: cider (float) : computed CIDEr score for the corpus
33
+ """
34
+
35
+ # clear all the previous hypos and refs
36
+ self.cider_scorer.clear()
37
+ for res_id in res:
38
+
39
+ hypo = res_id['caption']
40
+ ref = gts[res_id['image_id']]
41
+
42
+ # Sanity check.
43
+ assert(type(hypo) is list)
44
+ assert(len(hypo) == 1)
45
+ assert(type(ref) is list)
46
+ assert(len(ref) > 0)
47
+ self.cider_scorer += (hypo[0], ref)
48
+
49
+ (score, scores) = self.cider_scorer.compute_score()
50
+
51
+ return score, scores
52
+
53
+ def save_df(self, df_name="corpus"):
54
+ self.cider_scorer.save_df(df_name)
55
+
56
+ def method(self):
57
+ return "CIDEr-D"
evaluation/cider_score/cidereval/ciderD/ciderD_scorer.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Tsung-Yi Lin <tl483@cornell.edu>
3
+ # Ramakrishna Vedantam <vrama91@vt.edu>
4
+
5
+ # edited by Michele Cafagna
6
+
7
+ from pathlib import Path
8
+ from collections import defaultdict
9
+ import pickle
10
+ import math
11
+ from copy import copy
12
+
13
+ from importlib_resources import files, as_file
14
+ import numpy as np
15
+ import cidereval.data
16
+
17
+ def precook(s, n=4, out=False):
18
+ """
19
+ Takes a string as input and returns an object that can be given to
20
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
21
+ can take string arguments as well.
22
+ :param s: string : sentence to be converted into ngrams
23
+ :param n: int : number of ngrams for which representation is calculated
24
+ :return: term frequency vector for occuring ngrams
25
+ """
26
+ words = s.split()
27
+ counts = defaultdict(int)
28
+ for k in range(1,n+1):
29
+ for i in range(len(words)-k+1):
30
+ ngram = tuple(words[i:i+k])
31
+ counts[ngram] += 1
32
+ return counts
33
+
34
+ def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
35
+ '''Takes a list of reference sentences for a single segment
36
+ and returns an object that encapsulates everything that BLEU
37
+ needs to know about them.
38
+ :param refs: list of string : reference sentences for some image
39
+ :param n: int : number of ngrams for which (ngram) representation is calculated
40
+ :return: result (list of dict)
41
+ '''
42
+ return [precook(ref, n) for ref in refs]
43
+
44
+ def cook_test(test, n=4):
45
+ '''Takes a test sentence and returns an object that
46
+ encapsulates everything that BLEU needs to know about it.
47
+ :param test: list of string : hypothesis sentence for some image
48
+ :param n: int : number of ngrams for which (ngram) representation is calculated
49
+ :return: result (dict)
50
+ '''
51
+ return precook(test, n, True)
52
+
53
+ class CiderScorer(object):
54
+ """CIDEr scorer.
55
+ """
56
+
57
+ def save_df(self, df_name="corpus", path=None):
58
+ """Save the idf computed in corpus mode
59
+
60
+ Args:
61
+ df_name (str, optional): [description]. Defaults to "corpus". name of idf file
62
+ (without the file exntension)
63
+
64
+ df_name (str, optional): [description]. Defaults to None. path of the idf if note provided
65
+ it will be used the home directory
66
+ Raises:
67
+ ValueError: [description] if you try to call this method before computing the scores
68
+ """
69
+
70
+ if path:
71
+ path = Path(path)
72
+
73
+ if not path.exists():
74
+ path=Path.home()
75
+ print(f"the path provided is not valid. The df will be saved in {path}")
76
+ else:
77
+ path=Path.home()
78
+ print(f"the path provided is not valid. The df will be saved in {path}")
79
+
80
+ filename = Path(path, df_name + '.p')
81
+
82
+ if len(self.document_frequency) > 0:
83
+ with open(filename, "wb") as fp:
84
+
85
+ df_idf = {
86
+ "ref_len" : np.log(float(len(self.crefs))),
87
+ "df": self.document_frequency
88
+ }
89
+
90
+ pickle.dump(df_idf, fp)
91
+ print(f"saved to {filename}")
92
+ else:
93
+ raise ValueError("document frequency not computed run 'compute_score'")
94
+
95
+
96
+ def copy(self):
97
+ ''' copy the refs.'''
98
+ new = CiderScorer(n=self.n)
99
+ new.ctest = copy.copy(self.ctest)
100
+ new.crefs = copy.copy(self.crefs)
101
+ return new
102
+
103
+ def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
104
+ ''' singular instance '''
105
+ self.n = n
106
+ self.sigma = sigma
107
+ self.crefs = []
108
+ self.ctest = []
109
+ self.df_mode = df_mode
110
+ self.ref_len = None
111
+
112
+ if self.df_mode != "corpus":
113
+ if self.df_mode !="coco-val":
114
+ try:
115
+ with open(self.df_mode, 'rb') as fp:
116
+ df = pickle.load(fp, encoding='iso-8859-1')
117
+ except FileNotFoundError as e:
118
+ print(f"Error retrieveing {self.df_mode}. df_mode set to 'coco-val'")
119
+ else:
120
+ df_path = files(cidereval.data).joinpath(self.df_mode + '.p')
121
+ with as_file(df_path)as res:
122
+ with open(res, 'rb') as fp:
123
+ df = pickle.load(fp, encoding='iso-8859-1')
124
+
125
+ self.document_frequency = df['df']
126
+ self.ref_len = df['ref_len']
127
+
128
+ self.cook_append(test, refs)
129
+
130
+ def clear(self):
131
+ self.crefs = []
132
+ self.ctest = []
133
+
134
+ def cook_append(self, test, refs):
135
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
136
+
137
+ if refs is not None:
138
+ self.crefs.append(cook_refs(refs))
139
+ if test is not None:
140
+ self.ctest.append(cook_test(test)) ## N.B.: -1
141
+ else:
142
+ self.ctest.append(None) # lens of crefs and ctest have to match
143
+
144
+ def size(self):
145
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
146
+ return len(self.crefs)
147
+
148
+ def __iadd__(self, other):
149
+ '''add an instance (e.g., from another sentence).'''
150
+
151
+ if type(other) is tuple:
152
+ ## avoid creating new CiderScorer instances
153
+ self.cook_append(other[0], other[1])
154
+ else:
155
+ self.ctest.extend(other.ctest)
156
+ self.crefs.extend(other.crefs)
157
+
158
+ return self
159
+ def compute_doc_freq(self):
160
+ '''
161
+ Compute term frequency for reference data.
162
+ This will be used to compute idf (inverse document frequency later)
163
+ The term frequency is stored in the object
164
+ :return: None
165
+ '''
166
+ for refs in self.crefs:
167
+ # refs, k ref captions of one image
168
+ for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
169
+ self.document_frequency[ngram] += 1
170
+ # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
171
+
172
+ def compute_cider(self):
173
+ def counts2vec(cnts):
174
+ """
175
+ Function maps counts of ngram to vector of tfidf weights.
176
+ The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
177
+ The n-th entry of array denotes length of n-grams.
178
+ :param cnts:
179
+ :return: vec (array of dict), norm (array of float), length (int)
180
+ """
181
+ vec = [defaultdict(float) for _ in range(self.n)]
182
+ length = 0
183
+ norm = [0.0 for _ in range(self.n)]
184
+ for (ngram,term_freq) in cnts.items():
185
+ # give word count 1 if it doesn't appear in reference corpus
186
+ df = np.log(max(1.0, self.document_frequency[ngram]))
187
+ # ngram index
188
+ n = len(ngram)-1
189
+ # tf (term_freq) * idf (precomputed idf) for n-grams
190
+ vec[n][ngram] = float(term_freq)*(self.ref_len - df)
191
+ # compute norm for the vector. the norm will be used for computing similarity
192
+ norm[n] += pow(vec[n][ngram], 2)
193
+
194
+ if n == 1:
195
+ length += term_freq
196
+ norm = [np.sqrt(n) for n in norm]
197
+ return vec, norm, length
198
+
199
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
200
+ '''
201
+ Compute the cosine similarity of two vectors.
202
+ :param vec_hyp: array of dictionary for vector corresponding to hypothesis
203
+ :param vec_ref: array of dictionary for vector corresponding to reference
204
+ :param norm_hyp: array of float for vector corresponding to hypothesis
205
+ :param norm_ref: array of float for vector corresponding to reference
206
+ :param length_hyp: int containing length of hypothesis
207
+ :param length_ref: int containing length of reference
208
+ :return: array of score for each n-grams cosine similarity
209
+ '''
210
+ delta = float(length_hyp - length_ref)
211
+ # measure consine similarity
212
+ val = np.array([0.0 for _ in range(self.n)])
213
+ for n in range(self.n):
214
+ # ngram
215
+ for (ngram,count) in vec_hyp[n].items():
216
+ # vrama91 : added clipping
217
+ val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
218
+
219
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
220
+ val[n] /= (norm_hyp[n]*norm_ref[n])
221
+
222
+ assert(not math.isnan(val[n]))
223
+ # vrama91: added a length based gaussian penalty
224
+ val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
225
+ return val
226
+
227
+ # compute log reference length
228
+ if self.df_mode == "corpus":
229
+ self.ref_len = np.log(float(len(self.crefs)))
230
+ #elif self.df_mode == "coco-val":
231
+ # if coco option selected, use length of coco-val set
232
+ #self.ref_len = np.log(float(40504))
233
+
234
+ scores = []
235
+ for test, refs in zip(self.ctest, self.crefs):
236
+ # compute vector for test captions
237
+ vec, norm, length = counts2vec(test)
238
+ # compute vector for ref captions
239
+ score = np.array([0.0 for _ in range(self.n)])
240
+ for ref in refs:
241
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
242
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
243
+ # change by vrama91 - mean of ngram scores, instead of sum
244
+ score_avg = np.mean(score)
245
+ # divide by number of references
246
+ score_avg /= len(refs)
247
+ # multiply score by 10
248
+ score_avg *= 10.0
249
+ # append score of an image to the score list
250
+ scores.append(score_avg)
251
+ return scores
252
+
253
+ def compute_score(self, option=None, verbose=0):
254
+ # compute idf
255
+ if self.df_mode == "corpus":
256
+ self.document_frequency = defaultdict(float)
257
+ self.compute_doc_freq()
258
+ # assert to check document frequency
259
+ assert(len(self.ctest) >= max(self.document_frequency.values()))
260
+ # import json for now and write the corresponding files
261
+ # compute cider score
262
+ score = self.compute_cider()
263
+ # debug
264
+ # print score
265
+ return np.mean(np.array(score)), np.array(score)
evaluation/cider_score/cidereval/data/__init__.py ADDED
File without changes
evaluation/cider_score/cidereval/data/coco-val.p ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48e79470ebb58f0251df49dca5c9976a726a9e41b4a1d82be332b6f676b6950a
3
+ size 73520211
evaluation/cider_score/cidereval/eval.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __author__ = 'rama'
2
+ from .tokenizer.ptbtokenizer import PTBTokenizer
3
+ from .cider.cider import Cider
4
+ from .ciderD.ciderD import CiderD
5
+
6
+
7
+ class CIDErEvalCap:
8
+ def __init__(self, gts, res, df):
9
+ print('tokenization...')
10
+ tokenizer = PTBTokenizer('gts')
11
+ _gts = tokenizer.tokenize(gts)
12
+ print('tokenized refs')
13
+ tokenizer = PTBTokenizer('res')
14
+ _res = tokenizer.tokenize(res)
15
+ print('tokenized cands')
16
+
17
+ self.gts = _gts
18
+ self.res = _res
19
+ self.df = df
20
+
21
+ def evaluate(self):
22
+ # =================================================
23
+ # Set up scorers
24
+ # =================================================
25
+
26
+ print('setting up scorers...')
27
+ scorers = [
28
+ (Cider(df=self.df), "CIDEr"), (CiderD(df=self.df), "CIDErD")
29
+ ]
30
+
31
+ # =================================================
32
+ # Compute scores
33
+ # =================================================
34
+ metric_scores = {}
35
+ for scorer, method in scorers:
36
+ print('computing %s score...' % (scorer.method()))
37
+ score, scores = scorer.compute_score(self.gts, self.res)
38
+ print("Mean %s score: %0.3f" % (method, score))
39
+ metric_scores[method] = list(scores)
40
+ return metric_scores
evaluation/cider_score/cidereval/scorers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cidereval import CiderD, Cider
2
+ from cidereval.tokenizer import PTBTokenizer
3
+ from cidereval.tokenizer import SimpleTokenizer
4
+
5
+ def _preprocess_for_cider(refs, preds):
6
+ r"""
7
+ Convert preds and refs to the cider data format
8
+
9
+ refs: List[List[str]]
10
+ preds : List[str]
11
+
12
+ return gts: Dict[str : List[Dict['caption':str] : str ]],
13
+ res: List[Dict['image_id':str]: 'caption':str]
14
+ """
15
+
16
+ assert len(refs) == len(preds)
17
+
18
+ gts = {}
19
+ res = []
20
+
21
+ for i, (caps, pred) in enumerate(zip(refs, preds)):
22
+ gts[i] = [{ 'caption': cap } for cap in caps ]
23
+
24
+ res.append({ 'image_id': i,
25
+ 'caption': pred})
26
+ return gts, res
27
+
28
+ def cider(predictions, references, df="coco-val"):
29
+ r"""
30
+ Compute the cider score for the given predictions and references
31
+
32
+ predictions : List[str], model's predictions
33
+ references: List[List[str]], references
34
+ df: str, either 'coco-val' or 'corpus' (default : 'coco-val'). If 'coco-val' the TF-IDF COCO validation split is \\
35
+ used. If 'corpus' the TF-IDF is computed over the reference set provided.
36
+
37
+ returns {"avg_score": mp.float, "scores": np.array(np.float)}
38
+ """
39
+ gts, res = _preprocess_for_cider(references, predictions)
40
+ tokenizer_res = SimpleTokenizer('res')
41
+ tokenizer_gts = SimpleTokenizer('gts')
42
+
43
+ _gts = tokenizer_gts.tokenize(gts)
44
+ _res = tokenizer_res.tokenize(res)
45
+
46
+ scorer = Cider(df=df)
47
+
48
+ score, scores = scorer.compute_score(_gts, _res)
49
+
50
+ return {"avg_score": score, "scores": scores}
51
+
52
+
53
+ def ciderD(predictions, references, df="coco-va"):
54
+ r"""
55
+ Compute the ciderD score for the given predictions and references
56
+
57
+ predictions : List[str], model's predictions
58
+ references: List[List[str]], references
59
+ df: str, either 'coco-val' or 'corpus' (default : 'coco-val'). If 'coco-val' the TF-IDF COCO validation split is \\
60
+ used. If 'corpus' the TF-IDF is computed over the reference set provided.
61
+
62
+ returns {"avg_score": mp.float, "scores": np.array(np.float)}
63
+ """
64
+
65
+ gts, res = _preprocess_for_cider(references, predictions)
66
+ tokenizer_res = SimpleTokenizer('res')
67
+ tokenizer_gts = SimpleTokenizer('gts')
68
+
69
+ _gts = tokenizer_gts.tokenize(gts)
70
+ _res = tokenizer_res.tokenize(res)
71
+
72
+ scorer = CiderD(df=df)
73
+
74
+ score, scores = scorer.compute_score(_gts, _res)
75
+
76
+ return { "avg_score": score, "scores": scores}
evaluation/cider_score/cidereval/tokenizer/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __author__ = 'hfang'
2
+ # edited by Michele Cafagna
3
+ from cidereval.tokenizer.ptbtokenizer import PTBTokenizer
4
+ from cidereval.tokenizer.simpletokenizer import SimpleTokenizer
evaluation/cider_score/cidereval/tokenizer/ptbtokenizer.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : ptbtokenizer.py
4
+ #
5
+ # Description : Do the PTB Tokenization and remove punctuations.
6
+ #
7
+ # Creation Date : 29-12-2014
8
+ # Last Modified : Thu Mar 19 09:53:35 2015
9
+ # Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
10
+
11
+ import os
12
+ import subprocess
13
+ import tempfile
14
+
15
+ # path to the stanford corenlp jar
16
+ STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
17
+
18
+ # punctuations to be removed from the sentences
19
+ PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-",
20
+ ".", "?", "!", ",", ":", "-", "--", "...", ";"]
21
+
22
+
23
+ class PTBTokenizer:
24
+ """Python wrapper of Stanford PTBTokenizer"""
25
+
26
+ def __init__(self, _source='gts'):
27
+ self.source = _source
28
+
29
+ def tokenize(self, captions_for_image):
30
+ """Tokenize a sample
31
+
32
+ Args:
33
+ captions_for_image :
34
+
35
+ IF _source='gts' follows format:
36
+ dict: { str : [
37
+ { "caption" : str },
38
+ { "caption" : str },
39
+ ...
40
+ ],
41
+ str : [ ... ],
42
+ ...
43
+ }
44
+ IF _source='res' follows format:
45
+ list: [ {"image_id" : str,
46
+ "caption" : str,
47
+ },
48
+ ...
49
+ ]
50
+ Returns:
51
+ final_tokenized_captions_for_index:
52
+ list: [ {"image_id" : str,
53
+ "caption" : str,
54
+ },
55
+ ...
56
+ ]
57
+ """
58
+ cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR,
59
+ 'edu.stanford.nlp.process.PTBTokenizer',
60
+ '-preserveLines', '-lowerCase']
61
+
62
+ # ======================================================
63
+ # prepare data for PTB Tokenizer
64
+ # ======================================================
65
+
66
+ if self.source == 'gts':
67
+ image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
68
+ sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
69
+ final_tokenized_captions_for_image = {}
70
+
71
+ elif self.source == 'res':
72
+ index = [i for i, v in enumerate(captions_for_image)]
73
+ image_id = [v["image_id"] for v in captions_for_image]
74
+ sentences = '\n'.join(v["caption"].replace('\n', ' ') for v in captions_for_image)
75
+ final_tokenized_captions_for_index = []
76
+
77
+ # ======================================================
78
+ # save sentences to temporary file
79
+ # ======================================================
80
+ path_to_jar_dir_name = os.path.dirname(os.path.abspath(__file__))
81
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dir_name, mode='w')
82
+ tmp_file.write(sentences)
83
+ tmp_file.close()
84
+
85
+ # ======================================================
86
+ # tokenize sentence
87
+ # ======================================================
88
+ cmd.append(os.path.basename(tmp_file.name))
89
+ p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dir_name, stdout=subprocess.PIPE)
90
+ token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0].decode("utf-8")
91
+ lines = token_lines.split('\n')
92
+ # remove temp file
93
+ os.remove(tmp_file.name)
94
+
95
+ # ======================================================
96
+ # create dictionary for tokenized captions
97
+ # ======================================================
98
+ if self.source == 'gts':
99
+ for k, line in zip(image_id, lines):
100
+ if k not in final_tokenized_captions_for_image:
101
+ final_tokenized_captions_for_image[k] = []
102
+ tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') if w not in PUNCTUATIONS])
103
+ final_tokenized_captions_for_image[k].append(tokenized_caption)
104
+
105
+ return final_tokenized_captions_for_image
106
+
107
+ elif self.source == 'res':
108
+ for k, img, line in zip(index, image_id, lines):
109
+ tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') if w not in PUNCTUATIONS])
110
+ final_tokenized_captions_for_index.append({'image_id': img, 'caption': [tokenized_caption]})
111
+
112
+ return final_tokenized_captions_for_index
evaluation/cider_score/cidereval/tokenizer/simpletokenizer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : simpletokenizer.py
4
+ #
5
+ # Description : Yet another tokenizer.
6
+ #
7
+ # Creation Date : 12-11-2021
8
+
9
+ import spacy
10
+ from spacy.lang.char_classes import ALPHA, ALPHA_LOWER, ALPHA_UPPER
11
+ from spacy.lang.char_classes import CONCAT_QUOTES, LIST_ELLIPSES, LIST_ICONS
12
+ from spacy.util import compile_infix_regex
13
+
14
+
15
+ # punctuations to be removed from the sentences
16
+ PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-",
17
+ ".", "?", "!", ",", ":", "-", "--", "...", ";", " ", ""]
18
+
19
+ infixes = (
20
+ LIST_ELLIPSES
21
+ + LIST_ICONS
22
+ + [
23
+ r"(?<=[0-9])[+\-\*^](?=[0-9-])",
24
+ r"(?<=[{al}{q}])\.(?=[{au}{q}])".format(
25
+ al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES
26
+ ),
27
+ r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA),
28
+ # ✅ Commented out regex that splits on hyphens between letters:
29
+ # r"(?<=[{a}])(?:{h})(?=[{a}])".format(a=ALPHA, h=HYPHENS),
30
+ r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=ALPHA),
31
+ ]
32
+ )
33
+
34
+
35
+ class SimpleTokenizer:
36
+ """Simple Tokenizer"""
37
+
38
+ def __init__(self, _source='gts'):
39
+ self.source = _source
40
+
41
+ # setting up the tokenizer
42
+ self._nlp = spacy.load("en_core_web_sm")
43
+ infix_re = compile_infix_regex(infixes)
44
+ self._nlp.tokenizer.infix_finditer = infix_re.finditer
45
+ self._tokenizer = self._nlp.tokenizer
46
+
47
+ def tokenize(self, captions_for_image):
48
+ """Tokenize a sample
49
+
50
+ Args:
51
+ captions_for_image :
52
+
53
+ IF _source='gts' follows format:
54
+ dict: { str : [
55
+ { "caption" : str },
56
+ { "caption" : str },
57
+ ...
58
+ ],
59
+ str : [ ... ],
60
+ ...
61
+ }
62
+ IF _source='res' follows format:
63
+ list: [ {"image_id" : str,
64
+ "caption" : str,
65
+ },
66
+ ...
67
+ ]
68
+ Returns:
69
+ final_tokenized_captions_for_index:
70
+ list: [ {"image_id" : str,
71
+ "caption" : str,
72
+ },
73
+ ...
74
+ ]
75
+ """
76
+
77
+ tokenized_captions = None
78
+
79
+ if self.source == 'gts':
80
+ tokenized_captions= {}
81
+
82
+ for k in captions_for_image:
83
+
84
+ if k not in tokenized_captions:
85
+ tokenized_captions[k] = []
86
+
87
+ for item in captions_for_image[k]:
88
+
89
+ tokenized_captions[k].append(
90
+ " ".join([ tok.text.lower().strip() for tok in self._tokenizer(item['caption']) if tok.text.lower().strip() not in PUNCTUATIONS]))
91
+
92
+ elif self.source == 'res':
93
+
94
+ tokenized_captions= []
95
+
96
+ for item in captions_for_image:
97
+
98
+ tokenized_captions.append(
99
+ { 'image_id' : item['image_id'],
100
+ 'caption' : [" ".join([ tok.text.lower().strip() for tok in self._tokenizer(item['caption']) if tok.text.lower().strip() not in PUNCTUATIONS])]
101
+ })
102
+
103
+ else:
104
+ ValueError("source can be either 'gts' or 'res' ")
105
+
106
+ return tokenized_captions
evaluation/cider_score/output_sample.xls ADDED
Binary file (190 kB). View file
 
filter_dataset.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, DatasetDict
2
+ from PIL import Image, ImageFile, UnidentifiedImageError
3
+ import io
4
+ from tqdm import tqdm
5
+
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ # Assuming your dataset is loaded using load_dataset
9
+ cache_dir = "/bask/projects/p/phwq4930-gbm/Zeyu/PathVLM/.cache"
10
+ dataset_name = "CNX-PathLLM/Pathcap"
11
+ dataset = load_dataset(dataset_name, split="train", cache_dir=cache_dir)
12
+
13
+ print(f"original dataset size: {len(dataset)}")
14
+
15
+ # keep valid indices
16
+ valid_indices = []
17
+
18
+ # go through and check every element
19
+ for idx in tqdm(range(len(dataset))):
20
+ try:
21
+ example = dataset[idx]
22
+
23
+ text = example["txt"]
24
+ if not isinstance(text, str):
25
+ raise ValueError(f"not a string: {text}")
26
+ valid_indices.append(idx)
27
+ except Exception as e:
28
+ print(f"Cannot recognize file {idx}: {e}")
29
+
30
+ # Select valid samples according to the indices of valid samples.
31
+ filtered_dataset = dataset.select(valid_indices)
32
+
33
+ # Filter out images that cannot be loaded.
34
+ # filtered_dataset = dataset.filter(lambda example: example["is_valid"])
35
+
36
+ # Print the size of the filtered dataset
37
+ print(f"filtered dataset size: {len(filtered_dataset)}")
38
+
39
+ if len(dataset) != len(filtered_dataset):
40
+ # convert to DatasetDict
41
+ filtered_dataset_dict = DatasetDict({"train": filtered_dataset})
42
+ # push to hub
43
+ filtered_dataset_dict.push_to_hub(dataset_name)
gigapath/__init__.py ADDED
File without changes
gigapath/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
gigapath/__pycache__/pos_embed.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
gigapath/__pycache__/slide_encoder.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
gigapath/__pycache__/slide_encoder_vision.cpython-310.pyc ADDED
Binary file (8.31 kB). View file
 
gigapath/classification_head.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch import nn
4
+ from . import slide_encoder
5
+
6
+
7
+ def reshape_input(imgs, coords, pad_mask=None):
8
+ if len(imgs.shape) == 4:
9
+ imgs = imgs.squeeze(0)
10
+ if len(coords.shape) == 4:
11
+ coords = coords.squeeze(0)
12
+ if pad_mask is not None:
13
+ if len(pad_mask.shape) != 2:
14
+ pad_mask = pad_mask.squeeze(0)
15
+ return imgs, coords, pad_mask
16
+
17
+
18
+ class ClassificationHead(nn.Module):
19
+ """
20
+ The classification head for the slide encoder
21
+
22
+ Arguments:
23
+ ----------
24
+ input_dim: int
25
+ The input dimension of the slide encoder
26
+ latent_dim: int
27
+ The latent dimension of the slide encoder
28
+ feat_layer: str
29
+ The layers from which embeddings are fed to the classifier, e.g., 5-11 for taking out the 5th and 11th layers
30
+ n_classes: int
31
+ The number of classes
32
+ model_arch: str
33
+ The architecture of the slide encoder
34
+ pretrained: str
35
+ The path to the pretrained slide encoder
36
+ freeze: bool
37
+ Whether to freeze the pretrained model
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ input_dim,
43
+ latent_dim,
44
+ feat_layer,
45
+ n_classes=2,
46
+ model_arch="gigapath_slide_enc12l768d",
47
+ pretrained="hf_hub:prov-gigapath/prov-gigapath",
48
+ freeze=False,
49
+ **kwargs,
50
+ ):
51
+ super(ClassificationHead, self).__init__()
52
+
53
+ # setup the slide encoder
54
+ self.feat_layer = [eval(x) for x in feat_layer.split("-")]
55
+ self.feat_dim = len(self.feat_layer) * latent_dim
56
+ self.slide_encoder = slide_encoder.create_model(pretrained, model_arch, in_chans=input_dim, **kwargs)
57
+
58
+ # whether to freeze the pretrained model
59
+ if freeze:
60
+ print("Freezing Pretrained GigaPath model")
61
+ for name, param in self.slide_encoder.named_parameters():
62
+ param.requires_grad = False
63
+ print("Done")
64
+ # setup the classifier
65
+ self.classifier = nn.Sequential(*[nn.Linear(self.feat_dim, n_classes)])
66
+
67
+ def forward(self, images: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ Arguments:
70
+ ----------
71
+ images: torch.Tensor
72
+ The input images with shape [N, L, D]
73
+ coords: torch.Tensor
74
+ The input coordinates with shape [N, L, 2]
75
+ """
76
+ # inputs: [N, L, D]
77
+ if len(images.shape) == 2:
78
+ images = images.unsqueeze(0)
79
+ assert len(images.shape) == 3
80
+ # forward GigaPath slide encoder
81
+ img_enc = self.slide_encoder.forward(images, coords, all_layer_embed=True)
82
+ img_enc = [img_enc[i] for i in self.feat_layer]
83
+ img_enc = torch.cat(img_enc, dim=-1)
84
+ # classifier
85
+ h = img_enc.reshape([-1, img_enc.size(-1)])
86
+ logits = self.classifier(h)
87
+ return logits
88
+
89
+
90
+ def get_model(**kwargs):
91
+ model = ClassificationHead(**kwargs)
92
+ return model
gigapath/pipeline.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pipeline for running with GigaPath
3
+ # --------------------------------------------------------
4
+ import os
5
+ import timm
6
+ import torch
7
+ import shutil
8
+ import numpy as np
9
+ import pandas as pd
10
+ import gigapath.slide_encoder as slide_encoder
11
+
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ from pathlib import Path
15
+ from torchvision import transforms
16
+ from typing import List, Tuple, Union
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from gigapath.preprocessing.data.create_tiles_dataset import process_slide
19
+
20
+
21
+ class TileEncodingDataset(Dataset):
22
+ """
23
+ Do encoding for tiles
24
+
25
+ Arguments:
26
+ ----------
27
+ image_paths : List[str]
28
+ List of image paths, each image is named with its coordinates
29
+ Example: ['images/256x_256y.png', 'images/256x_512y.png']
30
+ transform : torchvision.transforms.Compose
31
+ Transform to apply to each image
32
+ """
33
+ def __init__(self, image_paths: List[str], transform=None):
34
+ self.transform = transform
35
+ self.image_paths = image_paths
36
+
37
+ def __len__(self):
38
+ return len(self.image_paths)
39
+
40
+ def __getitem__(self, idx):
41
+ img_path = self.image_paths[idx]
42
+ img_name = os.path.basename(img_path)
43
+ # get x, y coordinates from the image name
44
+ x, y = img_name.split('.png')[0].split('_')
45
+ x, y = int(x.replace('x', '')), int(y.replace('y', ''))
46
+ # load the image
47
+ with open(img_path, "rb") as f:
48
+ img = Image.open(f).convert("RGB")
49
+ if self.transform:
50
+ img = self.transform(img)
51
+ return {'img': torch.from_numpy(np.array(img)),
52
+ 'coords': torch.from_numpy(np.array([x, y])).float()}
53
+
54
+
55
+ def tile_one_slide(slide_file:str='', save_dir:str='', level:int=0, tile_size:int=256):
56
+ """
57
+ This function is used to tile a single slide and save the tiles to a directory.
58
+ -------------------------------------------------------------------------------
59
+ Warnings: pixman 0.38 has a known bug, which produces partial broken images.
60
+ Make sure to use a different version of pixman.
61
+ -------------------------------------------------------------------------------
62
+
63
+ Arguments:
64
+ ----------
65
+ slide_file : str
66
+ The path to the slide file.
67
+ save_dir : str
68
+ The directory to save the tiles.
69
+ level : int
70
+ The magnification level to use for tiling. level=0 is the highest magnification level.
71
+ tile_size : int
72
+ The size of the tiles.
73
+ """
74
+ slide_id = os.path.basename(slide_file)
75
+ # slide_sample = {"image": slide_file, "slide_id": slide_id, "metadata": {'TP53': 1, 'Diagnosis': 'Lung Cancer'}}
76
+ slide_sample = {"image": slide_file, "slide_id": slide_id, "metadata": {}}
77
+
78
+ save_dir = Path(save_dir)
79
+ if save_dir.exists():
80
+ print(f"Warning: Directory {save_dir} already exists. ")
81
+
82
+ print(f"Processing slide {slide_file} at level {level} with tile size {tile_size}. Saving to {save_dir}.")
83
+
84
+ slide_dir = process_slide(
85
+ slide_sample,
86
+ level=level,
87
+ margin=0,
88
+ tile_size=tile_size,
89
+ foreground_threshold=None,
90
+ occupancy_threshold=0.1,
91
+ output_dir=save_dir / "output",
92
+ thumbnail_dir=save_dir / "thumbnails",
93
+ tile_progress=True,
94
+ )
95
+
96
+ dataset_csv_path = slide_dir / "dataset.csv"
97
+ dataset_df = pd.read_csv(dataset_csv_path)
98
+ assert len(dataset_df) > 0
99
+ failed_csv_path = slide_dir / "failed_tiles.csv"
100
+ failed_df = pd.read_csv(failed_csv_path)
101
+ assert len(failed_df) == 0
102
+
103
+ print(f"Slide {slide_file} has been tiled. {len(dataset_df)} tiles saved to {slide_dir}.")
104
+
105
+
106
+ def load_tile_encoder_transforms() -> transforms.Compose:
107
+ """Load the transforms for the tile encoder"""
108
+ transform = transforms.Compose(
109
+ [
110
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
111
+ transforms.CenterCrop(224),
112
+ transforms.ToTensor(),
113
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
114
+ ])
115
+ return transform
116
+
117
+
118
+ def load_tile_slide_encoder(local_tile_encoder_path: str='',
119
+ local_slide_encoder_path: str='',
120
+ global_pool=False) -> Tuple[torch.nn.Module, torch.nn.Module]:
121
+ """Load the GigaPath tile and slide encoder models.
122
+ Note: Older versions of timm have compatibility issues.
123
+ Please ensure that you use a newer version by running the following command: pip install timm>=1.0.3.
124
+ """
125
+ if local_tile_encoder_path:
126
+ tile_encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=False, checkpoint_path=local_tile_encoder_path)
127
+ else:
128
+ tile_encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)
129
+ print("Tile encoder param #", sum(p.numel() for p in tile_encoder.parameters()))
130
+
131
+ if local_slide_encoder_path:
132
+ slide_encoder_model = slide_encoder.create_model(local_slide_encoder_path, "gigapath_slide_enc12l768d", 1536, global_pool=global_pool)
133
+ else:
134
+ slide_encoder_model = slide_encoder.create_model("hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536, global_pool=global_pool)
135
+ print("Slide encoder param #", sum(p.numel() for p in slide_encoder_model.parameters()))
136
+
137
+ return tile_encoder, slide_encoder_model
138
+
139
+
140
+ @torch.no_grad()
141
+ def run_inference_with_tile_encoder(image_paths: List[str], tile_encoder: torch.nn.Module, batch_size: int=128) -> dict:
142
+ """
143
+ Run inference with the tile encoder
144
+
145
+ Arguments:
146
+ ----------
147
+ image_paths : List[str]
148
+ List of image paths, each image is named with its coordinates
149
+ tile_encoder : torch.nn.Module
150
+ Tile encoder model
151
+ """
152
+ tile_encoder = tile_encoder.cuda()
153
+ # make the tile dataloader
154
+ tile_dl = DataLoader(TileEncodingDataset(image_paths, transform=load_tile_encoder_transforms()), batch_size=batch_size, shuffle=False)
155
+ # run inference
156
+ tile_encoder.eval()
157
+ collated_outputs = {'tile_embeds': [], 'coords': []}
158
+ with torch.cuda.amp.autocast(dtype=torch.float16):
159
+ for batch in tqdm(tile_dl, desc='Running inference with tile encoder'):
160
+ collated_outputs['tile_embeds'].append(tile_encoder(batch['img'].cuda()).detach().cpu())
161
+ collated_outputs['coords'].append(batch['coords'])
162
+ return {k: torch.cat(v) for k, v in collated_outputs.items()}
163
+
164
+
165
+ @torch.no_grad()
166
+ def run_inference_with_slide_encoder(tile_embeds: torch.Tensor, coords: torch.Tensor, slide_encoder_model: torch.nn.Module) -> torch.Tensor:
167
+ """
168
+ Run inference with the slide encoder
169
+
170
+ Arguments:
171
+ ----------
172
+ tile_embeds : torch.Tensor
173
+ Tile embeddings
174
+ coords : torch.Tensor
175
+ Coordinates of the tiles
176
+ slide_encoder_model : torch.nn.Module
177
+ Slide encoder model
178
+ """
179
+ if len(tile_embeds.shape) == 2:
180
+ tile_embeds = tile_embeds.unsqueeze(0)
181
+ coords = coords.unsqueeze(0)
182
+
183
+ slide_encoder_model = slide_encoder_model.cuda()
184
+ slide_encoder_model.eval()
185
+ # run inference
186
+ with torch.cuda.amp.autocast(dtype=torch.float16):
187
+ slide_embeds = slide_encoder_model(tile_embeds.cuda(), coords.cuda(), all_layer_embed=True)
188
+ outputs = {"layer_{}_embed".format(i): slide_embeds[i].cpu() for i in range(len(slide_embeds))}
189
+ outputs["last_layer_embed"] = slide_embeds[-1].cpu()
190
+ return outputs
gigapath/pos_embed.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # MAE: https://github.com/facebookresearch/mae
11
+ # --------------------------------------------------------
12
+ #
13
+ # Portions Copyright Prov-GigaPath
14
+ # Original File: https://github.com/facebookresearch/mae
15
+ # --------------------------------------------------------
16
+ # Position embedding utils
17
+ # --------------------------------------------------------
18
+
19
+ import numpy as np
20
+
21
+ import torch
22
+
23
+
24
+ # --------------------------------------------------------
25
+ # 2D sine-cosine position embedding
26
+ # References:
27
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
28
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
29
+ # --------------------------------------------------------
30
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
31
+ """
32
+ grid_size: int of the grid height and width
33
+ return:
34
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
35
+ """
36
+ grid_h = np.arange(grid_size, dtype=np.float32)
37
+ grid_w = np.arange(grid_size, dtype=np.float32)
38
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
39
+ grid = np.stack(grid, axis=0)
40
+
41
+ grid = grid.reshape([2, 1, grid_size, grid_size])
42
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
43
+ if cls_token:
44
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
45
+ return pos_embed
46
+
47
+
48
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
49
+ assert embed_dim % 2 == 0
50
+
51
+ # use half of dimensions to encode grid_h
52
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
53
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
54
+
55
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
56
+ return emb
57
+
58
+
59
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
60
+ """
61
+ embed_dim: output dimension for each position
62
+ pos: a list of positions to be encoded: size (M,)
63
+ out: (M, D)
64
+ """
65
+ assert embed_dim % 2 == 0
66
+ omega = np.arange(embed_dim // 2, dtype=float)
67
+ omega /= embed_dim / 2.0
68
+ omega = 1.0 / 10000**omega # (D/2,)
69
+
70
+ pos = pos.reshape(-1) # (M,)
71
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
72
+
73
+ emb_sin = np.sin(out) # (M, D/2)
74
+ emb_cos = np.cos(out) # (M, D/2)
75
+
76
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
77
+ return emb
78
+
79
+
80
+ # --------------------------------------------------------
81
+ # Interpolate position embeddings for high-resolution
82
+ # References:
83
+ # DeiT: https://github.com/facebookresearch/deit
84
+ # --------------------------------------------------------
85
+ def interpolate_pos_embed(model, checkpoint_model):
86
+ if "pos_embed" in checkpoint_model:
87
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
88
+ embedding_size = pos_embed_checkpoint.shape[-1]
89
+ num_patches = model.patch_embed.num_patches
90
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
91
+ # height (== width) for the checkpoint position embedding
92
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
93
+ # height (== width) for the new position embedding
94
+ new_size = int(num_patches**0.5)
95
+ # class_token and dist_token are kept unchanged
96
+ if orig_size != new_size:
97
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
98
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
99
+ # only the position tokens are interpolated
100
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
101
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
102
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
103
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
104
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
105
+ checkpoint_model["pos_embed"] = new_pos_embed
gigapath/preprocessing/__init__.py ADDED
File without changes
gigapath/preprocessing/data/__init__.py ADDED
File without changes
gigapath/preprocessing/data/box_utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ #
5
+ # Original: https://github.com/microsoft/hi-ml/blob/main/hi-ml/src/health_ml/utils/box_utils.py
6
+ # ------------------------------------------------------------------------------------------
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional, Sequence, Tuple
10
+
11
+ import numpy as np
12
+ from scipy import ndimage
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class Box:
17
+ """Utility class representing rectangular regions in 2D images.
18
+
19
+ :param x: Horizontal coordinate of the top-left corner.
20
+ :param y: Vertical coordinate of the top-left corner.
21
+ :param w: Box width.
22
+ :param h: Box height.
23
+ :raises ValueError: If either `w` or `h` are <= 0.
24
+ """
25
+ x: int
26
+ y: int
27
+ w: int
28
+ h: int
29
+
30
+ def __post_init__(self) -> None:
31
+ if self.w <= 0:
32
+ raise ValueError(f"Width must be strictly positive, received {self.w}")
33
+ if self.h <= 0:
34
+ raise ValueError(f"Height must be strictly positive, received {self.w}")
35
+
36
+ def __add__(self, shift: Sequence[int]) -> 'Box':
37
+ """Translates the box's location by a given shift.
38
+
39
+ :param shift: A length-2 sequence containing horizontal and vertical shifts.
40
+ :return: A new box with updated `x = x + shift[0]` and `y = y + shift[1]`.
41
+ :raises ValueError: If `shift` does not have two elements.
42
+ """
43
+ if len(shift) != 2:
44
+ raise ValueError("Shift must be two-dimensional")
45
+ return Box(x=self.x + shift[0],
46
+ y=self.y + shift[1],
47
+ w=self.w,
48
+ h=self.h)
49
+
50
+ def __mul__(self, factor: float) -> 'Box':
51
+ """Scales the box by a given factor, e.g. when changing resolution.
52
+
53
+ :param factor: The factor by which to multiply the box's location and dimensions.
54
+ :return: The updated box, with location and dimensions rounded to `int`.
55
+ """
56
+ return Box(x=int(self.x * factor),
57
+ y=int(self.y * factor),
58
+ w=int(self.w * factor),
59
+ h=int(self.h * factor))
60
+
61
+ def __rmul__(self, factor: float) -> 'Box':
62
+ """Scales the box by a given factor, e.g. when changing resolution.
63
+
64
+ :param factor: The factor by which to multiply the box's location and dimensions.
65
+ :return: The updated box, with location and dimensions rounded to `int`.
66
+ """
67
+ return self * factor
68
+
69
+ def __truediv__(self, factor: float) -> 'Box':
70
+ """Scales the box by a given factor, e.g. when changing resolution.
71
+
72
+ :param factor: The factor by which to divide the box's location and dimensions.
73
+ :return: The updated box, with location and dimensions rounded to `int`.
74
+ """
75
+ return self * (1. / factor)
76
+
77
+ def add_margin(self, margin: int) -> 'Box':
78
+ """Adds a symmetric margin on all sides of the box.
79
+
80
+ :param margin: The amount by which to enlarge the box.
81
+ :return: A new box enlarged by `margin` on all sides.
82
+ """
83
+ return Box(x=self.x - margin,
84
+ y=self.y - margin,
85
+ w=self.w + 2 * margin,
86
+ h=self.h + 2 * margin)
87
+
88
+ def clip(self, other: 'Box') -> Optional['Box']:
89
+ """Clips a box to the interior of another.
90
+
91
+ This is useful to constrain a region to the interior of an image.
92
+
93
+ :param other: Box representing the new constraints.
94
+ :return: A new constrained box, or `None` if the boxes do not overlap.
95
+ """
96
+ x0 = max(self.x, other.x)
97
+ y0 = max(self.y, other.y)
98
+ x1 = min(self.x + self.w, other.x + other.w)
99
+ y1 = min(self.y + self.h, other.y + other.h)
100
+ try:
101
+ return Box(x=x0, y=y0, w=x1 - x0, h=y1 - y0)
102
+ except ValueError: # Empty result, boxes don't overlap
103
+ return None
104
+
105
+ def to_slices(self) -> Tuple[slice, slice]:
106
+ """Converts the box to slices for indexing arrays.
107
+
108
+ For example: `my_2d_array[my_box.to_slices()]`.
109
+
110
+ :return: A 2-tuple with vertical and horizontal slices.
111
+ """
112
+ return (slice(self.y, self.y + self.h),
113
+ slice(self.x, self.x + self.w))
114
+
115
+ @staticmethod
116
+ def from_slices(slices: Sequence[slice]) -> 'Box':
117
+ """Converts a pair of vertical and horizontal slices into a box.
118
+
119
+ :param slices: A length-2 sequence containing vertical and horizontal `slice` objects.
120
+ :return: A box with corresponding location and dimensions.
121
+ """
122
+ vert_slice, horz_slice = slices
123
+ return Box(x=horz_slice.start,
124
+ y=vert_slice.start,
125
+ w=horz_slice.stop - horz_slice.start,
126
+ h=vert_slice.stop - vert_slice.start)
127
+
128
+
129
+ def get_bounding_box(mask: np.ndarray) -> Box:
130
+ """Extracts a bounding box from a binary 2D array.
131
+
132
+ :param mask: A 2D array with 0 (or `False`) as background and >0 (or `True`) as foreground.
133
+ :return: The smallest box covering all non-zero elements of `mask`.
134
+ :raises TypeError: When the input mask has more than two dimensions.
135
+ :raises RuntimeError: When all elements in the mask are zero.
136
+ """
137
+ if mask.ndim != 2:
138
+ raise TypeError(f"Expected a 2D array but got an array with shape {mask.shape}")
139
+
140
+ slices = ndimage.find_objects(mask > 0)
141
+ if not slices:
142
+ raise RuntimeError("The input mask is empty")
143
+ assert len(slices) == 1
144
+
145
+ return Box.from_slices(slices[0])