WinstonHu commited on
Commit
b3019c8
·
verified ·
1 Parent(s): b54e8ca

Upload folder stage1_qwen25_token_merging to stage_1/token_merging/stage1_qwen25_token_merging

Browse files
stage_1/token_merging/stage1_qwen25_token_merging/projector/projector.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e04644f58e7f4ccdd63fd5a07319846667c81c5b8071b7709f5b804774c87840
3
+ size 40384848
stage_1/token_merging/stage1_qwen25_token_merging/token_merger/merger.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d3c4692384bad2a12168babda110018e9441148a8b2db397bb84cd8255f631
3
+ size 20784
stage_1/token_merging/stage1_qwen25_token_merging/xtuner_config.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.visualization import Visualizer, WandbVisBackend
7
+ from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR, CosineAnnealingLR
8
+ from torch.optim import AdamW
9
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
10
+ BitsAndBytesConfig, CLIPImageProcessor,
11
+ CLIPVisionModel)
12
+
13
+ from xtuner.dataset import LLaVADataset
14
+ from xtuner.dataset.collate_fns import default_collate_fn
15
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
16
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
17
+ from xtuner.engine.runner import TrainLoop
18
+ from xtuner.model.llava_no_longnet import LLaVAModel
19
+ from xtuner.utils import PROMPT_TEMPLATE
20
+ # from xtuner.model.torchscale.model.create_longnet_for_training import create_longvit_model_fast as create_longnet_vit
21
+ # from xtuner.model.torchscale.model.LongNetVit import gigapath_slide_enc3l1536d
22
+ #######################################################################
23
+ # PART 1 Settings #
24
+ #######################################################################
25
+ # Model
26
+ llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
27
+ # Data
28
+ data_path = '/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/stage1_morph2.json'
29
+ image_path_list = None
30
+
31
+ prompt_template = PROMPT_TEMPLATE.qwen_chat
32
+
33
+ # 长序列:保持 per_image_length == sample_num
34
+ max_length = 15836
35
+ per_image_length = 10240
36
+ sample_type = 'wsi' # 'wsi' or 'image'
37
+
38
+ # Scheduler & Optimizer (epoch-based)
39
+ batch_size = 1
40
+ accumulative_counts = 400 # 5 * 400 = 2000
41
+ dataloader_num_workers = 5
42
+ seed = 2025
43
+ optim_type = AdamW
44
+ lr = 1e-3
45
+ betas = (0.9, 0.999)
46
+ weight_decay = 0.0 # 适度WD抑制漂移
47
+ max_norm = 1 # 更紧的梯度裁剪
48
+
49
+ # 以 epoch 为主
50
+ max_epochs = 2
51
+ warmup_ratio = 0.05 # 预热占比(相对 max_iters)
52
+
53
+ # Save
54
+ save_steps = 5120
55
+ save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
56
+
57
+ # Evaluate the generation performance during the training
58
+ evaluation_freq = 512
59
+ SYSTEM = ''
60
+ evaluation_images = '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/skcm_224x224_b20_t15/h5_files/TCGA-EB-A5UN-06Z-00-DX1.h5'
61
+ evaluation_inputs = ['Are the tumor cells organized in a lobulated pattern within the slide?']
62
+
63
+ #######################################################################
64
+ # PART 2 Model & Tokenizer & Image Processor #
65
+ #######################################################################
66
+ tokenizer = dict(
67
+ type=AutoTokenizer.from_pretrained,
68
+ pretrained_model_name_or_path=llm_name_or_path,
69
+ trust_remote_code=True,
70
+ padding_side='right')
71
+
72
+ bnb = dict(
73
+ type=BitsAndBytesConfig,
74
+ load_in_4bit=True,
75
+ load_in_8bit=False,
76
+ llm_int8_threshold=6.0,
77
+ llm_int8_has_fp16_weight=False,
78
+ bnb_4bit_compute_dtype=torch.bfloat16,
79
+ bnb_4bit_use_double_quant=True,
80
+ bnb_4bit_quant_type="nf4",
81
+ )
82
+
83
+ model = dict(
84
+ type=LLaVAModel,
85
+ freeze_llm=True,
86
+ train_stage='1',
87
+ llm=dict(
88
+ type=AutoModelForCausalLM.from_pretrained,
89
+ pretrained_model_name_or_path=llm_name_or_path,
90
+ trust_remote_code=True,
91
+ torch_dtype=torch.bfloat16,
92
+ attn_implementation='flash_attention_2',
93
+ quantization_config=bnb
94
+ ),
95
+
96
+ max_position_embeddings = None, # original 32000 +
97
+ enable_token_merge = True,
98
+ # 建议:前期用 resampler 更稳(也更省显存);如不需要可改回 False
99
+ use_perceiver_resampler=False,
100
+ )
101
+
102
+ #######################################################################
103
+ # PART 3 Dataset & Dataloader #
104
+ #######################################################################
105
+ llava_dataset = dict(
106
+ type=LLaVADataset,
107
+ data_path=data_path,
108
+ image_folder='',
109
+ image_path_list=image_path_list,
110
+ tokenizer=tokenizer,
111
+ dataset_map_fn=llava_map_fn,
112
+ template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
113
+ max_length=max_length,
114
+ per_image_length=per_image_length,
115
+ pad_image_to_square=False,
116
+ sample_num=per_image_length,
117
+ image_feature_prefix='/mnt/bn/xudong-va/meilong/datasets/Token_Compression',
118
+ image_feature_suffix='.h5',
119
+ identifier='_224x224_b20_t15',
120
+ unwanted_prefix_csv='/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/missing_slides3.csv',
121
+ sample_strategy='linspace', #use linspace
122
+ )
123
+
124
+
125
+ # cying: add: per_image_length=per_image_length,
126
+
127
+ train_dataloader = dict(
128
+ batch_size=batch_size,
129
+ num_workers=dataloader_num_workers,
130
+ pin_memory=True,
131
+ persistent_workers=True,
132
+ prefetch_factor=4,
133
+ dataset=llava_dataset,
134
+ sampler=dict(type=DefaultSampler, shuffle=True),
135
+ collate_fn=dict(type=default_collate_fn)
136
+ )
137
+
138
+
139
+
140
+ #######################################################################
141
+ # PART 4 Scheduler & Optimizer #
142
+ #######################################################################
143
+ # optimizer
144
+ optim_wrapper = dict(
145
+ type=AmpOptimWrapper,
146
+ optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
147
+ paramwise_cfg = dict(
148
+ norm_decay_mult=0.0,
149
+ bias_decay_mult=0.0,
150
+
151
+ ),
152
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
153
+ accumulative_counts=accumulative_counts,
154
+ loss_scale='dynamic',
155
+ dtype='bfloat16',
156
+ )
157
+
158
+ param_scheduler = [
159
+ dict(
160
+ type=LinearLR,
161
+ start_factor=0.01, # 从 1% 的 lr 慢启动
162
+ by_epoch=True,
163
+ begin=0,
164
+ end=warmup_ratio * max_epochs,
165
+ convert_to_iter_based=True # 按 iter 计算
166
+ ),
167
+ dict(
168
+ type=CosineAnnealingLR,
169
+ eta_min=0.0,
170
+ by_epoch=True,
171
+ begin=warmup_ratio * max_epochs,
172
+ end=max_epochs,
173
+ convert_to_iter_based=True
174
+ )
175
+ ]
176
+
177
+
178
+
179
+ # train, val, test setting
180
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
181
+
182
+ #######################################################################
183
+ # PART 5 Runtime #
184
+ #######################################################################
185
+ # Log the dialogue periodically during the training process, optional
186
+ custom_hooks = [
187
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
188
+ dict(
189
+ type=EvaluateChatHook,
190
+ tokenizer=tokenizer,
191
+ every_n_iters=evaluation_freq,
192
+ evaluation_inputs=evaluation_inputs,
193
+ evaluation_images=evaluation_images,
194
+ system=SYSTEM,
195
+ prompt_template=prompt_template),
196
+ dict(type = ThroughputHook)
197
+ ]
198
+
199
+ # configure default hooks
200
+ default_hooks = dict(
201
+ # record the time of every iteration.
202
+ timer=dict(type=IterTimerHook),
203
+ # print log every 10 iterations.
204
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
205
+ # enable the parameter scheduler.
206
+ param_scheduler=dict(type=ParamSchedulerHook),
207
+ # save checkpoint per `save_steps`.
208
+ checkpoint=dict(
209
+ type=CheckpointHook,
210
+ by_epoch=False,
211
+ interval=save_steps,
212
+ max_keep_ckpts=save_total_limit),
213
+ # set sampler seed in distributed evrionment.
214
+ sampler_seed=dict(type=DistSamplerSeedHook),
215
+ )
216
+
217
+ # configure environment
218
+ env_cfg = dict(
219
+ # whether to enable cudnn benchmark
220
+ cudnn_benchmark=False,
221
+ # set multi process parameters
222
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
223
+ # set distributed parameters
224
+ dist_cfg=dict(backend='nccl'),
225
+ )
226
+
227
+
228
+ visualizer = dict(
229
+ type=Visualizer,
230
+ vis_backends=[
231
+ dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_no_longnet_with_language_guide2'))])
232
+ visualizer = None
233
+ # set log level
234
+ log_level = 'INFO'
235
+
236
+ # load from which checkpoint
237
+ load_from = None
238
+
239
+ # whether to resume training from the loaded checkpoint
240
+ resume = False
241
+
242
+ # Defaults to use random seed and disable `deterministic`
243
+ randomness = dict(seed=None, deterministic=False)
244
+
245
+ # set log processor
246
+ log_processor = dict(
247
+ by_epoch=False,
248
+ window_size=1,
249
+ mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
250
+ )