gtang666 commited on
Commit
3246934
·
verified ·
1 Parent(s): d427b8a

Upload InternVL/modeling_internvl_chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. InternVL/modeling_internvl_chat.py +1185 -0
InternVL/modeling_internvl_chat.py ADDED
@@ -0,0 +1,1185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from config.configu import *
6
+ from models.model import *
7
+ from models.similarity import *
8
+ from sklearn.cluster import KMeans
9
+ from utils.utils import *
10
+ import warnings
11
+ from typing import Any, List, Optional, Tuple, Union
12
+ import torch
13
+ import random
14
+ import torch.utils.checkpoint
15
+ import transformers
16
+ from torch import nn
17
+ from torch.nn import CrossEntropyLoss
18
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
19
+ LlamaTokenizer)
20
+ from transformers.modeling_outputs import CausalLMOutputWithPast
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers.utils import ModelOutput, logging
23
+
24
+ from .configuration_internvl_chat import InternVLChatConfig
25
+ from .conversation import get_conv_template
26
+ from .modeling_intern_vit import InternVisionModel
27
+ from .modeling_internlm2 import InternLM2ForCausalLM
28
+
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ def coord_transform(box,return_4=True):
34
+ if return_4:
35
+ return [box[0][0],box[0][1],box[1][0],box[1][1]]
36
+ else:
37
+ return [[box[0],box[1]],[box[2],box[3]]]
38
+ def insert_zeros(input_ids, attention_mask, num_zeros=5):
39
+
40
+ device = input_ids.device # 获取原始设备
41
+ input_ids = input_ids.cpu().clone() # 将张量移到 CPU 并克隆
42
+ attention_mask = attention_mask.cpu().clone() # 将张量移到 CPU 并克隆
43
+
44
+ for _ in range(num_zeros):
45
+ # 随机选择插入位置
46
+ insert_pos = random.randint(0, input_ids.size(1))
47
+
48
+ # 在 input_ids 中插入 0
49
+ input_ids = torch.cat((input_ids[:, :insert_pos], torch.tensor([[0]]), input_ids[:, insert_pos:]), dim=1)
50
+
51
+ # 在 attention_mask 中插入 1
52
+ attention_mask = torch.cat((attention_mask[:, :insert_pos], torch.tensor([[1]]), attention_mask[:, insert_pos:]), dim=1)
53
+
54
+ # 将张量移回原始设备
55
+ input_ids = input_ids.to(device)
56
+ attention_mask = attention_mask.to(device)
57
+
58
+ return input_ids, attention_mask
59
+
60
+
61
+ def add_Gaussian_noise(input_embeds, rate=1e-1):
62
+
63
+ device = input_embeds.device
64
+ input_embeds = input_embeds.cpu().clone()
65
+
66
+ mean = input_embeds.mean()
67
+ std = input_embeds.std()
68
+ noise = torch.randn(input_embeds.size()) * std + mean
69
+ noisy_input_embeds = input_embeds + rate * noise
70
+
71
+ noisy_input_embeds = noisy_input_embeds.to(device)
72
+ noisy_input_embeds = noisy_input_embeds.to(torch.bfloat16)
73
+
74
+ return noisy_input_embeds
75
+
76
+
77
+ def version_cmp(v1, v2, op='eq'):
78
+ import operator
79
+
80
+ from packaging import version
81
+ op_func = getattr(operator, op)
82
+ return op_func(version.parse(v1), version.parse(v2))
83
+
84
+ def most_frequent_rgb(image_array):
85
+ """找一张图片中最frequent的rgb,用于填充mask"""
86
+ # Flatten the image array to a 2D array where each row is an RGB tuple
87
+ pixels = image_array.reshape(-1, image_array.shape[-1])
88
+
89
+ # Use np.unique with return_counts to find unique rows and their counts
90
+ unique_pixels, counts = np.unique(pixels, axis=0, return_counts=True)
91
+
92
+ # Find the index of the most frequent pixel
93
+ most_frequent_index = np.argmax(counts)
94
+
95
+ # Get the most frequent pixel and its count
96
+ most_frequent_pixel = unique_pixels[most_frequent_index]
97
+ frequency = counts[most_frequent_index]
98
+ return most_frequent_pixel, frequency
99
+
100
+ def most_frequent_rgb_fast(image_array):
101
+ """快速查找图片中最频繁的RGB值,不返回频率"""
102
+ # 将RGB每个通道的值映射为一个唯一的整数,形如 R * 256^2 + G * 256 + B
103
+ flattened = image_array.reshape(-1, 3)
104
+ rgb_ints = flattened[:, 0] * 256**2 + flattened[:, 1] * 256 + flattened[:, 2]
105
+
106
+ # 使用np.bincount统计每个唯一RGB组合出现的次数
107
+ counts = np.bincount(rgb_ints)
108
+
109
+ # 找到出现次数最多的那个整数
110
+ most_frequent_index = np.argmax(counts)
111
+
112
+ # 将整数转换回RGB值
113
+ r = (most_frequent_index // 256**2) % 256
114
+ g = (most_frequent_index // 256) % 256
115
+ b = most_frequent_index % 256
116
+
117
+ return (r, g, b)
118
+
119
+
120
+
121
+ def mask_area(image_array,coords,color):
122
+ """对一张图片在框定的一系列box进行mask"""
123
+ # Define the bounding box (x1, y1, x2, y2)
124
+ #color=average_rgb(modified_image)
125
+ for coord in coords:
126
+ x1, y1, x2, y2 = coord
127
+ image_array[y1:y2, x1:x2] =color # 255 for white in an RGB image
128
+
129
+ return image_array
130
+
131
+
132
+ class InternVLChatModel(PreTrainedModel):
133
+ config_class = InternVLChatConfig
134
+ main_input_name = 'pixel_values'
135
+ _supports_flash_attn_2 = True
136
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer']
137
+
138
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
139
+ super().__init__(config)
140
+
141
+ assert version_cmp(transformers.__version__, '4.36.2', 'ge')
142
+ image_size = config.force_image_size or config.vision_config.image_size
143
+ patch_size = config.vision_config.patch_size
144
+ self.patch_size = patch_size
145
+ self.select_layer = config.select_layer
146
+ self.template = config.template
147
+ ##TODO change the number of img tokens
148
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
149
+ #self.num_image_token = 3
150
+ self.downsample_ratio = config.downsample_ratio
151
+ self.ps_version = config.ps_version
152
+
153
+
154
+
155
+ self.mu_sigma=torch.load(NORM_PARAMS_PATH)['weight']
156
+ self.mu=self.mu_sigma[:,0].reshape((-1,1))
157
+ self.sigma=self.mu_sigma[:,1].reshape((-1,1)) #[vocab_size, 1]
158
+ self.normed_emb,self.mu_sigma=self.load_normed_tok_embeddings(load_checkboard=True)
159
+ self.resampler=load_perceiver_resampler_2(PERCEIVER_CHECKPOINT,num_layers=4)
160
+
161
+ self.sorter=load_orderformer(ORDERFORMER_CHECKPOINT)
162
+
163
+
164
+ logger.info(f'num_image_token: {self.num_image_token}')
165
+ logger.info(f'ps_version: {self.ps_version}')
166
+ # print('vision_model', vision_model)
167
+ # print('language_model', language_model)
168
+ # print('config.llm_config.architectures[0]', config.llm_config.architectures[0])
169
+ if vision_model is not None:
170
+ self.vision_model = vision_model
171
+ else:
172
+ self.vision_model = InternVisionModel(config.vision_config)
173
+ if language_model is not None:
174
+ self.language_model = language_model
175
+ else:
176
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
177
+ self.language_model = LlamaForCausalLM(config.llm_config)
178
+ elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
179
+ self.language_model = InternLM2ForCausalLM(config.llm_config)
180
+ else:
181
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
182
+
183
+
184
+ vit_hidden_size = config.vision_config.hidden_size
185
+ llm_hidden_size = config.llm_config.hidden_size
186
+
187
+ self.mlp1 = nn.Sequential(
188
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
189
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
190
+ nn.GELU(),
191
+ nn.Linear(llm_hidden_size, llm_hidden_size)
192
+ )
193
+
194
+ self.img_context_token_id = None
195
+ self.conv_template = get_conv_template(self.template)
196
+ self.system_message = self.conv_template.system_message
197
+ def load_normed_tok_embeddings(self,vocab_size=92553, llm_hidden_size=4096,load_checkboard=False):
198
+ tok_embeddings = nn.Embedding(vocab_size, llm_hidden_size, padding_idx=2).to_empty(device=torch.device('cuda')).to(torch.bfloat16)
199
+ tok_embeddings.load_state_dict(torch.load(NORM_TOK_EMBEDDING_PATH, weights_only=True, map_location="cpu"))
200
+ if load_checkboard:
201
+ checkboard_norm=torch.load(NORM_PARAMS_PATH) # (voc_size, 2) mu sigma pred * sigma + mu (逐行)
202
+
203
+ return tok_embeddings,checkboard_norm['weight']
204
+ return tok_embeddings
205
+
206
+ def forward(
207
+ self,
208
+ pixel_values: torch.FloatTensor,
209
+ input_ids: torch.LongTensor = None,
210
+ attention_mask: Optional[torch.Tensor] = None,
211
+ position_ids: Optional[torch.LongTensor] = None,
212
+ image_flags: Optional[torch.LongTensor] = None,
213
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
214
+ labels: Optional[torch.LongTensor] = None,
215
+ use_cache: Optional[bool] = None,
216
+ output_attentions: Optional[bool] = None,
217
+ output_hidden_states: Optional[bool] = None,
218
+ return_dict: Optional[bool] = None,
219
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
220
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
221
+
222
+ image_flags = image_flags.squeeze(-1)
223
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
224
+
225
+ vit_embeds = self.extract_feature(pixel_values)
226
+ vit_embeds = vit_embeds[image_flags == 1]
227
+ vit_batch_size = pixel_values.shape[0]
228
+
229
+ B, N, C = input_embeds.shape
230
+ input_embeds = input_embeds.reshape(B * N, C)
231
+
232
+ if torch.distributed.get_rank() == 0:
233
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
234
+
235
+ input_ids = input_ids.reshape(B * N)
236
+ selected = (input_ids == self.img_context_token_id)
237
+ try:
238
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
239
+ except Exception as e:
240
+ vit_embeds = vit_embeds.reshape(-1, C)
241
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
242
+ f'vit_embeds.shape={vit_embeds.shape}')
243
+ n_token = selected.sum()
244
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
245
+
246
+ input_embeds = input_embeds.reshape(B, N, C)
247
+
248
+ outputs = self.language_model(
249
+ inputs_embeds=input_embeds,
250
+ attention_mask=attention_mask,
251
+ position_ids=position_ids,
252
+ past_key_values=past_key_values,
253
+ use_cache=use_cache,
254
+ output_attentions=output_attentions,
255
+ output_hidden_states=output_hidden_states,
256
+ return_dict=return_dict,
257
+ )
258
+ logits = outputs.logits
259
+
260
+ loss = None
261
+ if labels is not None:
262
+ # Shift so that tokens < n predict n
263
+ shift_logits = logits[..., :-1, :].contiguous()
264
+ shift_labels = labels[..., 1:].contiguous()
265
+ # Flatten the tokens
266
+ loss_fct = CrossEntropyLoss()
267
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
268
+ shift_labels = shift_labels.view(-1)
269
+ # Enable model parallelism
270
+ shift_labels = shift_labels.to(shift_logits.device)
271
+ loss = loss_fct(shift_logits, shift_labels)
272
+
273
+ if not return_dict:
274
+ output = (logits,) + outputs[1:]
275
+ return (loss,) + output if loss is not None else output
276
+
277
+ return CausalLMOutputWithPast(
278
+ loss=loss,
279
+ logits=logits,
280
+ past_key_values=outputs.past_key_values,
281
+ hidden_states=outputs.hidden_states,
282
+ attentions=outputs.attentions,
283
+ )
284
+
285
+ def pixel_shuffle(self, x, scale_factor=0.5):
286
+ n, w, h, c = x.size()
287
+ # N, W, H, C --> N, W, H * scale, C // scale
288
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
289
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
290
+ x = x.permute(0, 2, 1, 3).contiguous()
291
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
292
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
293
+ int(c / (scale_factor * scale_factor)))
294
+ if self.ps_version == 'v1':
295
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
296
+ 'which results in a transposed image.')
297
+ else:
298
+ x = x.permute(0, 2, 1, 3).contiguous()
299
+ return x
300
+
301
+ def extract_feature(self, pixel_values):
302
+ if self.select_layer == -1:
303
+ vit_embeds = self.vision_model(
304
+ pixel_values=pixel_values,
305
+ output_hidden_states=False,
306
+ return_dict=True).last_hidden_state
307
+ else:
308
+
309
+ vit_embeds = self.vision_model(
310
+ pixel_values=pixel_values,
311
+ output_hidden_states=True,
312
+ return_dict=True).hidden_states[self.select_layer]
313
+ vit_embeds = vit_embeds[:, 1:, :]
314
+
315
+ h = w = int(vit_embeds.shape[1] ** 0.5)
316
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
317
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
318
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
319
+
320
+ vit_embeds = self.mlp1(vit_embeds)
321
+ return vit_embeds
322
+
323
+ @torch.no_grad()
324
+ def calli_align(self,img_path,detect_model, drop_zero = False, use_hard_vector_quant=False,save_path=None,verbose=False):
325
+ def dynamic_read(img_path,mode='c'):
326
+ # 如果是字符串类型(文件路径),用 cv2 读取
327
+ if isinstance(img_path, str):
328
+ img = cv2.imread(img_path)
329
+
330
+ if img is None:
331
+ try:
332
+ img = Image.open(img_path).convert("RGB")
333
+ img = np.array(img)
334
+ except:
335
+ raise ValueError(f"Image at path {img_path} could not be loaded.")
336
+ # 如果是 PIL.Image.Image 类型,将其转为 cv2 格式
337
+ elif isinstance(img_path, Image.Image):
338
+ img = np.array(img_path) # PIL 转 numpy 数组
339
+ # 因为 OpenCV 是 BGR,需要将 RGB 转为 BGR
340
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
341
+
342
+ else:
343
+ raise TypeError(f"Unsupported image type: {type(img_path)}")
344
+ if mode=='i':
345
+ img=Image.fromarray(img).convert("RGB")
346
+ return img
347
+ import time
348
+ def iterative_only_boxes(model,jpg_path):
349
+
350
+ image = dynamic_read(jpg_path)
351
+
352
+ image_array = np.array(image)
353
+
354
+ h, w, channels = image.shape
355
+ boxes=[]
356
+
357
+
358
+ color=most_frequent_rgb_fast(image_array)
359
+ while True:
360
+ res=model(image_array,verbose=False)[0]
361
+
362
+ to_be_masked=[]
363
+ for box in res.boxes:
364
+ xyxy = box.xyxy.squeeze().tolist()
365
+ x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
366
+ to_be_masked.append([x1,y1,x2,y2])
367
+ boxes.extend(to_be_masked)
368
+ if len(to_be_masked)>250:
369
+ image_array=mask_area(image_array,to_be_masked,color)
370
+ else:
371
+ break
372
+
373
+ boxes=[[[max(item[0],0),max(item[1],0)],[min(item[2],w),min(item[3],h)]]for item in boxes]
374
+
375
+
376
+ i=0
377
+ length=len(boxes)
378
+ while i<length:
379
+ j=0
380
+ main_box=boxes[i]
381
+ while j<length:
382
+ if i==j:
383
+ j+=1
384
+ continue
385
+ iou=calculate_iou(coord_transform(main_box),coord_transform(boxes[j]))
386
+ if iou>0.8:
387
+ rm = boxes[j]
388
+ boxes.remove(rm)
389
+ if j<i:
390
+ i-=1
391
+ length-=1
392
+ j-=1
393
+ j+=1
394
+ i+=1
395
+
396
+ return boxes
397
+ def char2col_with_kmeans(jpg_path,boxes, verbose=False):
398
+ ## modified
399
+ def kmeans_boxes(bounding_boxes):
400
+ areas = [ (box[1][0] - box[0][0])*(box[1][1] - box[0][1]) for box in bounding_boxes]
401
+
402
+
403
+ # 转换为 numpy 数组
404
+ areas = np.array(areas).reshape(-1, 1)
405
+
406
+ # 使用 KMeans 进行聚类,将面积分为两组
407
+ kmeans = KMeans(n_clusters=2, random_state=0).fit(areas)
408
+
409
+ # 获取每个 bounding box 的标签
410
+ labels = kmeans.labels_
411
+
412
+ # 根据标签将 bounding boxes 分成两个组
413
+ group_0 = []
414
+ group_1 = []
415
+
416
+ for i, label in enumerate(labels):
417
+ if label == 0:
418
+ group_0.append(bounding_boxes[i])
419
+ else:
420
+ group_1.append(bounding_boxes[i])
421
+
422
+ group_0 = sorted(group_0, key = lambda x: (x[1][0]-x[0][0]), reverse=True)
423
+ group_1 = sorted(group_1, key = lambda x: (x[1][0]-x[0][0]), reverse=True)
424
+
425
+ if (group_1[0][1][0] - group_1[0][0][0]) > (group_0[0][1][0] - group_0[0][0][0]):# and len(group_1) > 0.8*len(group_0): # 1 为正文,0为落款
426
+ g1_hs = np.array([x[1][1]-x[0][1] for x in group_1]).mean()
427
+ thr1 = 1*( group_1[-1][1][0] - group_1[-1][0][0])
428
+ thr2 = 0.8*g1_hs
429
+ #luokuan_mean_area = np.array([(ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1]) for ele in group_0]).mean()
430
+ new_0 = []
431
+ for ele in group_0:
432
+ if (ele[1][0] - ele[0][0]) >= thr1 or (ele[1][1] - ele[0][1]) >= thr2 or (areas.min()/(ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1]) <= 1/5 and areas.mean() / ((ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1])) <= 1.3):
433
+ group_1.append(ele)
434
+ else:
435
+ new_0.append(ele)
436
+
437
+ grouped_luokuan = merge_boxes(new_0.copy())
438
+
439
+ final_ = []
440
+ for ele in new_0:
441
+ if ele in grouped_luokuan:
442
+
443
+ group_1.append(ele)
444
+ else:
445
+ final_.append(ele)
446
+ group_0 = final_
447
+
448
+ elif (group_0[0][1][0] - group_0[0][0][0]) > (group_1[0][1][0] - group_1[0][0][0]):# and len(group_0) > 0.8*len(group_1):
449
+ g0_hs = np.array([x[1][1]-x[0][1] for x in group_0]).mean()
450
+ thr1 = 1*( group_0[-1][1][0] - group_0[-1][0][0])
451
+ thr2 = 0.8*g0_hs
452
+ #luokuan_mean_area = np.array([(ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1]) for ele in group_1]).mean()
453
+ new_1 = []
454
+ for ele in group_1:
455
+ if (ele[1][0] - ele[0][0]) >= thr1 or (ele[1][1] - ele[0][1]) >= thr2 or (areas.min()/(ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1]) <= 1/5 and areas.mean() / ((ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1])) <=1.3):
456
+
457
+ group_0.append(ele)
458
+ else:
459
+ new_1.append(ele)
460
+
461
+ grouped_luokuan = merge_boxes(new_1.copy())
462
+
463
+ final_ = []
464
+ for ele in new_1:
465
+ if ele in grouped_luokuan:
466
+ group_0.append(ele)
467
+ else:
468
+ final_.append(ele)
469
+ group_1 = final_
470
+
471
+ return group_0,group_1
472
+
473
+ def toint(lst):
474
+ if len(lst)==2:
475
+ return [[int(lst[0][0]),int(lst[0][1])],[int(lst[1][0]),int(lst[1][1])]]
476
+ else:
477
+ return [int(lst[0]),int(lst[1]),int(lst[2]),int(lst[3])]
478
+ img = dynamic_read(jpg_path)
479
+ h, w, channels = img.shape
480
+
481
+ normalized_boxes=[[[item[0][0]/w,item[0][1]/h],[item[1][0]/w,item[1][1]/h]] for item in boxes]
482
+ S=np.array([(item[0][0]-item[1][0])*(item[0][1]-item[1][1]) for item in normalized_boxes])
483
+ # print(np.max(S)-np.min(S),h,w)
484
+ # print(boxes)
485
+ # print(normalized_boxes)
486
+
487
+ coef_var=np.std(S)/np.mean(S)
488
+ boxes2class=None
489
+ col2class=None
490
+
491
+ if coef_var>0.66 and S.min()/S.mean() <= 1/8:
492
+
493
+ boxes1,boxes2=kmeans_boxes(normalized_boxes)
494
+
495
+
496
+ boxes1=[[[item[0][0]*w,item[0][1]*h],[item[1][0]*w,item[1][1]*h]] for item in boxes1]
497
+ boxes2=[[[item[0][0]*w,item[0][1]*h],[item[1][0]*w,item[1][1]*h]] for item in boxes2]
498
+ columns1=merge_boxes(boxes1.copy())
499
+ columns2=merge_boxes(boxes2.copy())
500
+
501
+ columns=columns1+columns2
502
+ boxes2class={1:[toint(item) for item in boxes1],2:[toint(item) for item in boxes2]}
503
+ col2class={1:[toint(item) for item in columns1],2:[toint(item) for item in columns2]}
504
+ #[[481.3252033886607, 1185.3073037637248], [748.9909909909909, 1616.216216216216]]
505
+
506
+ else:
507
+ columns=merge_boxes(boxes.copy())
508
+
509
+
510
+ results={"imageHeight":h,"imageWidth":w,"shapes":[{"points":toint(col)} for col in columns],
511
+ "boxes2class":boxes2class,"col2class":col2class}
512
+
513
+
514
+ #print("saving results...")
515
+
516
+ # if verbose:
517
+ # frame = dynamic_read(jpg_path)
518
+ # name=jpg_path.split("/")[-1]
519
+ # os.makedirs("./detect_boxes_char2col/result_merge", exist_ok=True)
520
+ # for i,box in enumerate(results['shapes']):
521
+
522
+ # xyxy = box['points']
523
+ # x1, y1, x2, y2 = int(xyxy[0][0]), int(xyxy[0][1]), int(xyxy[1][0]), int(xyxy[1][1])
524
+ # colo = (255,0,0)
525
+ # cv2.rectangle(frame, (x1, y1), (x2, y2), thickness=2,color=colo,lineType=cv2.LINE_AA)
526
+ # # put labels
527
+
528
+ # if boxes2class is not None:
529
+ # if xyxy in col2class[1]:
530
+ # cv2.putText(frame, str(1), ((x1+x2)//2, (y1+y2)//2), cv2.FONT_HERSHEY_SIMPLEX, 1.5, colo, thickness=2, lineType=cv2.LINE_AA)
531
+ # elif xyxy in col2class[2]:
532
+ # cv2.putText(frame, str(2), ((x1+x2)//2, (y1+y2)//2), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 30, 235), thickness=2, lineType=cv2.LINE_AA)
533
+ # #cv2.putText(frame, str(i+1), ((x1+x2)//2, (y1+y2)//2), cv2.FONT_HERSHEY_SIMPLEX, 1.5, colo, thickness=2, lineType=cv2.LINE_AA)
534
+ # cv2.imwrite("./detect_boxes_char2col/result_merge"+name,frame)
535
+ return results
536
+
537
+ def sort_boxes(jpg,detector,model,thres=0.8):
538
+
539
+ boxes=iterative_only_boxes(detector,jpg)
540
+
541
+ data=char2col_with_kmeans(jpg,boxes,verbose=False)
542
+
543
+ res=model.predict(data,jpg)
544
+ final_results=[]
545
+ for idx,col in res.items():
546
+ lst=[]
547
+ for item in boxes:
548
+ ratio=calculate_iou(col,[item[0][0],item[0][1],item[1][0],item[1][1]],mini=True)
549
+
550
+ if ratio>=thres:
551
+ lst.append([item[0][0],item[0][1],item[1][0],item[1][1]])
552
+ lst=sorted(lst, key=lambda item: (item[1]+item[3])/2)
553
+ final_results.extend(lst)
554
+ #print(len(boxes),len(res),len(final_results))
555
+ return final_results
556
+ if img_path is None:
557
+ return None,None
558
+
559
+ st=time.time()
560
+ boxes=sort_boxes(img_path,detect_model,self.sorter)
561
+ ed=time.time()
562
+ if verbose:
563
+ print(f"YOLO+Orderformer {ed-st:.2f}s")
564
+ if save_path!=None:
565
+ frame = dynamic_read(img_path)
566
+ name=img_path.split("/")[-1]
567
+ for i,box in enumerate(boxes):
568
+
569
+ xyxy = box
570
+ x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
571
+ colo = (255,0,0)
572
+ cv2.rectangle(frame, (x1, y1), (x2, y2), thickness=2,color=colo,lineType=cv2.LINE_AA)
573
+ # put labels
574
+ cv2.putText(frame, str(i+1), ((x1+x2)//2, (y1+y2)//2), cv2.FONT_HERSHEY_SIMPLEX, 1.5, colo, thickness=2, lineType=cv2.LINE_AA)
575
+ print(save_path+"oredered_result_"+name)
576
+ cv2.imwrite(save_path+"oredered_result_"+name,frame)
577
+
578
+ st=time.time()
579
+ pixel_values=[]
580
+ img=np.array(dynamic_read(img_path,mode='i').convert("RGB"))
581
+
582
+ for xyxy in boxes:
583
+ x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
584
+ sub_img=Image.fromarray(img[y1:y2,x1:x2])
585
+ pixel_values.append(load_image_2(sub_img).to(torch.bfloat16).cuda())
586
+ ed1=time.time()
587
+ results=torch.cat(pixel_values)
588
+
589
+ image_embeddings=self.extract_feature(results)
590
+ ed2=time.time()
591
+ output=self.resampler(image_embeddings)
592
+ ed3=time.time()
593
+
594
+ #TODO 可以indices转换回去
595
+
596
+ outs=vq_cos_sim(self.normed_emb,output, use_hard_vector_quant) #(B, 3) #如果use_vq的话现在改成dynamic: 对于max cos_sim小于等于thresh的,使用向量量化进行替换
597
+
598
+ ed4=time.time()
599
+ if verbose:
600
+ print(f"Get pixel values {ed1-st:.2f}s")
601
+ print(f"extract feat {ed2-ed1:.2f}s")
602
+ print(f"Resampler forward {ed3-ed2:.2f}")
603
+ print(f"vq cos sim {ed4-ed3:.2f}s")
604
+ if use_hard_vector_quant:
605
+ indices, cos_sim_values = outs
606
+ #### DEFINE THRESH!!!
607
+ thresh = 0.5
608
+ else:
609
+ indices = outs
610
+
611
+ if use_hard_vector_quant:
612
+ print("Dynamic vector quantization...")
613
+
614
+ below_mask = (cos_sim_values <= thresh).to(torch.bfloat16).unsqueeze(-1)
615
+
616
+ output = output * (1-below_mask) + self.normed_emb.weight[indices] * below_mask
617
+
618
+
619
+ flattened_output = output.view(-1, output.shape[-1])
620
+ flattened_indices = indices.view(-1)
621
+
622
+ if drop_zero:
623
+ filtered_indices=flattened_indices[flattened_indices!=0]
624
+ filtered_output=flattened_output[flattened_indices!=0]
625
+
626
+
627
+ sigma_flat = self.sigma[filtered_indices] # 形状 (183 * 3, 1)
628
+ mu_flat = self.mu[filtered_indices]
629
+
630
+ sigma_flat = sigma_flat.expand(-1, filtered_output.shape[-1])
631
+ mu_flat = mu_flat.expand(-1, filtered_output.shape[-1])
632
+ back_to_origin_flat = filtered_output * sigma_flat + mu_flat
633
+
634
+ else:
635
+ sigma_flat = self.sigma[flattened_indices]
636
+ mu_flat = self.mu[flattened_indices]
637
+ sigma_flat = sigma_flat.expand(-1, flattened_output.shape[-1])
638
+ mu_flat = mu_flat.expand(-1, flattened_output.shape[-1])
639
+ back_to_origin_flat = flattened_output * sigma_flat + mu_flat
640
+
641
+
642
+ return back_to_origin_flat, indices
643
+
644
+ def find_coordinates(self,text):
645
+ import re
646
+
647
+ numbers = re.findall(r'\d+', text)
648
+
649
+ numbers = [int(num) for num in numbers] # 如果需要浮点数,可以用 float()
650
+ return numbers
651
+ def chat_ocr(self, tokenizer, detect_model,img_path, questions, generation_config, num_patches_list=None,
652
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
653
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', ALIGNED_TOKEN="[UNUSED_TOKEN_140]",verbose=False, image_counts=None,batch=False,
654
+ use_p=True, drop_zero=False, hard_vq=False, repetition_penalty=1.5,region_wise=False):
655
+
656
+
657
+
658
+
659
+ pixel_values = None
660
+ if img_path is not None:
661
+ try:
662
+ if region_wise:
663
+ img=np.array(Image.open(img_path).convert("RGB"))
664
+ coord=self.find_coordinates(questions)
665
+ x1,x2,y1,y2=coord
666
+ sub_img=Image.fromarray(img[y1:y2,x1:x2])
667
+
668
+ questions="输出图片中所有文字:"
669
+ pixel_values=load_image(sub_img).to(torch.bfloat16).to(torch.device("cuda"))
670
+ else:
671
+ pixel_values=load_image(img_path).to(torch.bfloat16).to(torch.device("cuda"))
672
+ except:
673
+ raise FileNotFoundError
674
+ if use_p:
675
+ import time
676
+ st=time.time()
677
+ if region_wise:
678
+ try:
679
+ out_tokens, indices =self.calli_align(sub_img,detect_model, drop_zero = drop_zero, use_hard_vector_quant=hard_vq,verbose=verbose)
680
+ except:
681
+ return "检测失败"
682
+ else:
683
+
684
+ out_tokens, indices =self.calli_align(img_path,detect_model, drop_zero = drop_zero, use_hard_vector_quant=hard_vq,verbose=verbose) #,tokenizer=tokenizer)
685
+ if verbose:
686
+ print(f"Calli Align: {time.time()-st:.2f}s")
687
+ # 删掉多余0
688
+ # indices 备用,因为我们也想未来看仅使用calliAlign效果
689
+ if pixel_values is None:
690
+ question=questions
691
+
692
+ if pixel_values is not None and '<image>' not in questions:
693
+ question = '<image>\n' + questions
694
+ #question = questions
695
+ elif history is None and pixel_values is None:
696
+ question=questions
697
+ elif '<image>' in questions:
698
+ question=questions
699
+
700
+ if history is None and use_p and '[UNUSED_TOKEN_140]' not in question:
701
+ question =question+'[UNUSED_TOKEN_140]'*out_tokens.shape[0]
702
+ if num_patches_list is None:
703
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
704
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
705
+
706
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
707
+ self.img_context_token_id = img_context_token_id
708
+
709
+ template = get_conv_template(self.template)
710
+ template.system_message = self.system_message
711
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
712
+
713
+ history = [] if history is None else history
714
+ for (old_question, old_answer) in history:
715
+ template.append_message(template.roles[0], old_question)
716
+ template.append_message(template.roles[1], old_answer)
717
+ template.append_message(template.roles[0], question)
718
+ template.append_message(template.roles[1], None)
719
+ query = template.get_prompt()
720
+
721
+
722
+
723
+ for num_patches in num_patches_list:
724
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
725
+
726
+ query = query.replace('<image>', image_tokens, 1)
727
+
728
+ model_inputs = tokenizer(query, return_tensors='pt')
729
+
730
+ input_ids = model_inputs['input_ids'].cuda()
731
+
732
+ attention_mask = model_inputs['attention_mask'].cuda()
733
+
734
+ generation_config['eos_token_id'] = eos_token_id
735
+
736
+
737
+ if use_p:
738
+ generation_output = self.generate_ocr(
739
+ pixel_values=pixel_values,
740
+ input_ids=input_ids,
741
+ attention_mask=attention_mask,
742
+ reference_embeds=out_tokens,
743
+ repetition_penalty=repetition_penalty,
744
+ **generation_config
745
+ )
746
+ else:
747
+ generation_output = self.generate_ocr(
748
+ pixel_values=pixel_values,
749
+ input_ids=input_ids,
750
+ attention_mask=attention_mask,
751
+ repetition_penalty=repetition_penalty,
752
+ **generation_config
753
+ )
754
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
755
+ response = response.split(template.sep)[0].strip()
756
+ history.append((question, response))
757
+ if return_history:
758
+ return response, history
759
+ else:
760
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
761
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
762
+
763
+
764
+ return response
765
+
766
+
767
+ def dynamic_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
768
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
769
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None,batch=False,use_p=True):
770
+ if use_p:
771
+ self.num_image_token=3
772
+ if batch:
773
+ assert isinstance(questions,list) and len(questions)>0 and isinstance(questions[0],str)
774
+ if history is not None or return_history:
775
+ print('Now multi-turn chat is not supported in batch_chat.')
776
+ raise NotImplementedError
777
+
778
+ if image_counts is not None:
779
+ num_patches_list = image_counts
780
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
781
+
782
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
783
+ self.img_context_token_id = img_context_token_id
784
+
785
+ if verbose and pixel_values is not None:
786
+ image_bs = pixel_values.shape[0]
787
+ print(f'dynamic ViT batch size: {image_bs}')
788
+
789
+ queries = []
790
+ for idx, num_patches in enumerate(num_patches_list):
791
+ question = questions[idx]
792
+ if pixel_values is not None and '<image>' not in question:
793
+ question = '<image>\n' + question
794
+ template = get_conv_template(self.template)
795
+ template.append_message(template.roles[0], question)
796
+ template.append_message(template.roles[1], None)
797
+ query = template.get_prompt()
798
+
799
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
800
+ query = query.replace('<image>', image_tokens, 1)
801
+ queries.append(query)
802
+
803
+ # print(query)
804
+ tokenizer.padding_side = 'left'
805
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
806
+ input_ids = model_inputs['input_ids'].cuda()
807
+ attention_mask = model_inputs['attention_mask'].cuda()
808
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
809
+ generation_config['eos_token_id'] = eos_token_id
810
+ if use_p:
811
+ generation_output = self.generate(
812
+ pixel_values=pixel_values,
813
+ input_ids=input_ids,
814
+ attention_mask=attention_mask,
815
+ **generation_config
816
+ )
817
+ else:
818
+
819
+ generation_output = self.generate_origin(
820
+ pixel_values=pixel_values,
821
+ input_ids=input_ids,
822
+ attention_mask=attention_mask,
823
+ **generation_config
824
+ )
825
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
826
+ responses = [response.split(template.sep)[0].strip() for response in responses]
827
+ return responses
828
+ else:
829
+ assert isinstance(questions,str)
830
+ if num_patches_list is None:
831
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
832
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
833
+
834
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
835
+ self.img_context_token_id = img_context_token_id
836
+
837
+ template = get_conv_template(self.template)
838
+ template.system_message = self.system_message
839
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
840
+
841
+ history = [] if history is None else history
842
+ for (old_question, old_answer) in history:
843
+ template.append_message(template.roles[0], old_question)
844
+ template.append_message(template.roles[1], old_answer)
845
+ template.append_message(template.roles[0], questions)
846
+ template.append_message(template.roles[1], None)
847
+ query = template.get_prompt()
848
+
849
+
850
+ if verbose and pixel_values is not None:
851
+ image_bs = pixel_values.shape[0]
852
+ print(f'dynamic ViT batch size: {image_bs}')
853
+
854
+
855
+ # print('num_image_token', self.num_image_token)
856
+ # print('num_patches_list', num_patches_list)
857
+
858
+
859
+ query=f"""<|im_start|>system你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。<|im_end|>\n<|im_start|>user{questions}"""
860
+ query = query+'<image>'
861
+ for num_patches in num_patches_list:
862
+ #image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
863
+ image_tokens = IMG_CONTEXT_TOKEN * self.num_image_token
864
+ #print('tokens_num', len(image_tokens))
865
+ query = query.replace('<image>', image_tokens, 1)
866
+
867
+ query+="<|im_end|>\n<|im_start|>assistant"
868
+ # print(query)
869
+ model_inputs = tokenizer(query, return_tensors='pt')
870
+
871
+
872
+ input_ids = model_inputs['input_ids'].cuda()
873
+ attention_mask = model_inputs['attention_mask'].cuda()
874
+
875
+
876
+ generation_config['eos_token_id'] = eos_token_id
877
+ if use_p:
878
+
879
+ generation_output = self.generate(
880
+ pixel_values=pixel_values,
881
+ input_ids=input_ids,
882
+ attention_mask=attention_mask,
883
+ **generation_config
884
+ )
885
+ else:
886
+ generation_output = self.generate_origin(
887
+ pixel_values=pixel_values,
888
+ input_ids=input_ids,
889
+ attention_mask=attention_mask,
890
+ **generation_config
891
+ )
892
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
893
+ response = response.split(template.sep)[0].strip()
894
+ history.append((questions, response))
895
+ if return_history:
896
+ return response, history
897
+ else:
898
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
899
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
900
+ if verbose:
901
+ print(query_to_print, response)
902
+
903
+ return response
904
+
905
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
906
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
907
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
908
+
909
+ if history is not None or return_history:
910
+ print('Now multi-turn chat is not supported in batch_chat.')
911
+ raise NotImplementedError
912
+
913
+ if image_counts is not None:
914
+ num_patches_list = image_counts
915
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
916
+
917
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
918
+ self.img_context_token_id = img_context_token_id
919
+
920
+ if verbose and pixel_values is not None:
921
+ image_bs = pixel_values.shape[0]
922
+ print(f'dynamic ViT batch size: {image_bs}')
923
+
924
+ queries = []
925
+ for idx, num_patches in enumerate(num_patches_list):
926
+ question = questions[idx]
927
+ if pixel_values is not None and '<image>' not in question:
928
+ question = '<image>\n' + question
929
+ template = get_conv_template(self.template)
930
+ template.append_message(template.roles[0], question)
931
+ template.append_message(template.roles[1], None)
932
+ query = template.get_prompt()
933
+
934
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
935
+ query = query.replace('<image>', image_tokens, 1)
936
+ queries.append(query)
937
+
938
+ # print(query)
939
+ tokenizer.padding_side = 'left'
940
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
941
+ input_ids = model_inputs['input_ids'].cuda()
942
+ attention_mask = model_inputs['attention_mask'].cuda()
943
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
944
+ generation_config['eos_token_id'] = eos_token_id
945
+ generation_output = self.generate_origin(
946
+ pixel_values=pixel_values,
947
+ input_ids=input_ids,
948
+ attention_mask=attention_mask,
949
+ **generation_config
950
+ )
951
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
952
+ responses = [response.split(template.sep)[0].strip() for response in responses]
953
+ return responses
954
+
955
+
956
+ #When call internvl,this func is called
957
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
958
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
959
+ verbose=False):
960
+ #self.num_image_token=3
961
+ # original_question = question
962
+ if history is None and pixel_values is not None and '<image>' not in question:
963
+ question = '<image>\n' + question
964
+
965
+ if num_patches_list is None:
966
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
967
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
968
+
969
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
970
+ self.img_context_token_id = img_context_token_id
971
+
972
+ template = get_conv_template(self.template)
973
+ template.system_message = self.system_message
974
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
975
+
976
+ history = [] if history is None else history
977
+ for (old_question, old_answer) in history:
978
+ template.append_message(template.roles[0], old_question)
979
+ template.append_message(template.roles[1], old_answer)
980
+ template.append_message(template.roles[0], question)
981
+ template.append_message(template.roles[1], None)
982
+ query = template.get_prompt()
983
+
984
+
985
+ if verbose and pixel_values is not None:
986
+ image_bs = pixel_values.shape[0]
987
+
988
+
989
+
990
+
991
+ for num_patches in num_patches_list:
992
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
993
+ query = query.replace('<image>', image_tokens, 1)
994
+ print(num_patches,self.num_image_token)
995
+ print(pixel_values.shape[0])
996
+
997
+ model_inputs = tokenizer(query, return_tensors='pt')
998
+
999
+ input_ids = model_inputs['input_ids'].cuda()
1000
+ attention_mask = model_inputs['attention_mask'].cuda()
1001
+
1002
+ generation_config['eos_token_id'] = eos_token_id
1003
+ generation_output = self.generate_origin(
1004
+ pixel_values=pixel_values,
1005
+ input_ids=input_ids,
1006
+ attention_mask=attention_mask,
1007
+ **generation_config
1008
+ )
1009
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
1010
+ response = response.split(template.sep)[0].strip()
1011
+ history.append((question, response))
1012
+ if return_history:
1013
+ return response, history
1014
+ else:
1015
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
1016
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
1017
+ if verbose:
1018
+ print(query_to_print, response)
1019
+
1020
+ return response
1021
+
1022
+ @torch.no_grad()
1023
+ def generate_origin(
1024
+ self,
1025
+ pixel_values: Optional[torch.FloatTensor] = None,
1026
+ input_ids: Optional[torch.FloatTensor] = None,
1027
+ attention_mask: Optional[torch.LongTensor] = None,
1028
+ visual_features: Optional[torch.FloatTensor] = None,
1029
+ generation_config: Optional[GenerationConfig] = None,
1030
+ output_hidden_states: Optional[bool] = None,
1031
+ return_dict: Optional[bool] = None,
1032
+ **generate_kwargs,
1033
+ ) -> torch.LongTensor:
1034
+
1035
+ assert self.img_context_token_id is not None
1036
+ if pixel_values is not None:
1037
+ if visual_features is not None:
1038
+ vit_embeds = visual_features
1039
+ else:
1040
+ vit_embeds = self.extract_feature(pixel_values)
1041
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
1042
+
1043
+
1044
+ B, N, C = input_embeds.shape
1045
+ input_embeds = input_embeds.reshape(B * N, C)
1046
+
1047
+ input_ids = input_ids.reshape(B * N)
1048
+ selected = (input_ids == self.img_context_token_id)
1049
+ assert selected.sum() != 0
1050
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
1051
+ print("ID: ",self.img_context_token_id)
1052
+ input_embeds = input_embeds.reshape(B, N, C)
1053
+ else:
1054
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
1055
+
1056
+
1057
+ outputs = self.language_model.generate(
1058
+ inputs_embeds=input_embeds,
1059
+ attention_mask=attention_mask,
1060
+ generation_config=generation_config,
1061
+ output_hidden_states=output_hidden_states,
1062
+ return_dict=return_dict,
1063
+ use_cache=True,
1064
+ **generate_kwargs,
1065
+ )
1066
+
1067
+ return outputs
1068
+ @torch.no_grad()
1069
+ def generate_ocr(
1070
+ self,
1071
+ pixel_values: Optional[torch.FloatTensor] = None,
1072
+ input_ids: Optional[torch.FloatTensor] = None,
1073
+ attention_mask: Optional[torch.LongTensor] = None,
1074
+ visual_features: Optional[torch.FloatTensor] = None,
1075
+ generation_config: Optional[GenerationConfig] = None,
1076
+ reference_embeds=None,
1077
+ output_hidden_states: Optional[bool] = None,
1078
+ return_dict: Optional[bool] = None,
1079
+ repetition_penalty=1.5,
1080
+ **generate_kwargs,
1081
+ ) -> torch.LongTensor:
1082
+
1083
+ assert self.img_context_token_id is not None
1084
+ if pixel_values is not None:
1085
+ if visual_features is not None:
1086
+ vit_embeds = visual_features
1087
+ else:
1088
+ vit_embeds = self.extract_feature(pixel_values)
1089
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
1090
+
1091
+
1092
+ B, N, C = input_embeds.shape
1093
+ input_embeds = input_embeds.reshape(B * N, C)
1094
+
1095
+ input_ids = input_ids.reshape(B * N)
1096
+ selected = (input_ids == self.img_context_token_id)
1097
+ assert selected.sum() != 0
1098
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
1099
+
1100
+
1101
+ if reference_embeds is not None:
1102
+ selected = (input_ids == 92537)
1103
+ assert selected.sum() != 0
1104
+ input_embeds[selected] =reference_embeds.reshape(-1, C).to(input_embeds.device)
1105
+
1106
+
1107
+ input_embeds = input_embeds.reshape(B, N, C)
1108
+ else:
1109
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
1110
+
1111
+
1112
+
1113
+ outputs = self.language_model.generate(
1114
+ inputs_embeds=input_embeds,
1115
+ attention_mask=attention_mask,
1116
+ generation_config=generation_config,
1117
+ output_hidden_states=output_hidden_states,
1118
+ return_dict=return_dict,
1119
+ use_cache=True,
1120
+ repetition_penalty=repetition_penalty,
1121
+ **generate_kwargs,
1122
+ )
1123
+
1124
+ return outputs
1125
+ @torch.no_grad()
1126
+ def generate(
1127
+ self,
1128
+ pixel_values: Optional[torch.FloatTensor] = None,
1129
+ input_ids: Optional[torch.FloatTensor] = None,
1130
+ attention_mask: Optional[torch.LongTensor] = None,
1131
+ visual_features: Optional[torch.FloatTensor] = None,
1132
+ generation_config: Optional[GenerationConfig] = None,
1133
+ output_hidden_states: Optional[bool] = None,
1134
+ return_dict: Optional[bool] = None,
1135
+ **generate_kwargs,
1136
+ ) -> torch.LongTensor:
1137
+
1138
+ assert self.img_context_token_id is not None
1139
+ if pixel_values is not None:
1140
+ if visual_features is not None:
1141
+ vit_embeds = visual_features
1142
+ else:
1143
+
1144
+ vit_embeds = self.extract_feature(pixel_values)
1145
+
1146
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
1147
+
1148
+ vit_embeds = self.resampler(vit_embeds)
1149
+
1150
+
1151
+ mu=self.mu_sigma[:,0].reshape((-1,1))
1152
+ sigma=self.mu_sigma[:,1].reshape((-1,1))
1153
+
1154
+ indices=vq_cos_sim(self.normed_emb,vit_embeds).reshape((-1,))
1155
+
1156
+
1157
+ vit_embeds=vit_embeds.reshape((-1,vit_embeds.shape[-1]))*sigma[indices][:]+mu[indices][:]
1158
+
1159
+ B, N, C = input_embeds.shape
1160
+ input_embeds = input_embeds.reshape(B * N, C)
1161
+
1162
+ input_ids = input_ids.reshape(B * N)
1163
+ selected = (input_ids == self.img_context_token_id)
1164
+
1165
+ assert selected.sum() != 0
1166
+
1167
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
1168
+
1169
+
1170
+
1171
+ input_embeds = input_embeds.reshape(B, N, C)
1172
+ else:
1173
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
1174
+
1175
+ outputs = self.language_model.generate(
1176
+ inputs_embeds=input_embeds,
1177
+ attention_mask=attention_mask,
1178
+ generation_config=generation_config,
1179
+ output_hidden_states=output_hidden_states,
1180
+ return_dict=return_dict,
1181
+ use_cache=True,
1182
+ **generate_kwargs,
1183
+ )
1184
+
1185
+ return outputs