from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM from .configuration_deepseek_v2 import DeepseekV2Config from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from typing import List, Optional, Tuple, Union from transformers.cache_utils import Cache import requests from PIL import Image, ImageOps, ImageDraw, ImageFont from io import BytesIO import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from torchvision import transforms # from torchvision.transforms.functional import InterpolationMode import os from .deepencoderv2 import build_sam_vit_b, build_qwen2_decoder_as_encoder, MlpProjector from addict import Dict from transformers import TextStreamer from .conversation import get_conv_template from abc import ABC import math import re from tqdm import tqdm import numpy as np # import time def load_image(image_path): try: image = Image.open(image_path) corrected_image = ImageOps.exif_transpose(image) return corrected_image except Exception as e: print(f"error: {e}") try: return Image.open(image_path) except: return None def re_match(text): pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' matches = re.findall(pattern, text, re.DOTALL) # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n' # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL) mathes_image = [] mathes_other = [] for a_match in matches: if '<|ref|>image<|/ref|>' in a_match[0]: mathes_image.append(a_match[0]) else: mathes_other.append(a_match[0]) return matches, mathes_image, mathes_other def extract_coordinates_and_label(ref_text, image_width, image_height): try: label_type = ref_text[1] cor_list = eval(ref_text[2]) except Exception as e: print(e) return None return (label_type, cor_list) def draw_bounding_boxes(image, refs, ouput_path): image_width, image_height = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) draw2 = ImageDraw.Draw(overlay) # try: # except IOError: # try: # font = ImageFont.truetype("DejaVuSans.ttf", 20) # except IOError: font = ImageFont.load_default() img_idx = 0 for i, ref in enumerate(refs): try: result = extract_coordinates_and_label(ref, image_width, image_height) if result: label_type, points_list = result color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) color_a = color + (20, ) for points in points_list: x1, y1, x2, y2 = points x1 = int(x1 / 999 * image_width) y1 = int(y1 / 999 * image_height) x2 = int(x2 / 999 * image_width) y2 = int(y2 / 999 * image_height) if label_type == 'image': try: cropped = image.crop((x1, y1, x2, y2)) cropped.save(f"{ouput_path}/images/{img_idx}.jpg") except Exception as e: print(e) pass img_idx += 1 try: if label_type == 'title': draw.rectangle([x1, y1, x2, y2], outline=color, width=4) draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) else: draw.rectangle([x1, y1, x2, y2], outline=color, width=2) draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) text_x = x1 text_y = max(0, y1 - 15) text_bbox = draw.textbbox((0, 0), label_type, font=font) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], fill=(255, 255, 255, 30)) draw.text((text_x, text_y), label_type, font=font, fill=color) except: pass except: continue img_draw.paste(overlay, (0, 0), overlay) return img_draw def process_image_with_refs(image, ref_texts, output_path): result_image = draw_bounding_boxes(image, ref_texts, output_path) return result_image def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') return best_ratio def dynamic_preprocess(image, min_num=2, max_num=6, image_size=768, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) # print(target_ratios) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # print(target_aspect_ratio) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images, target_aspect_ratio def normalize_transform(mean, std): if mean is None and std is None: transform = None elif mean is None and std is not None: mean = [0.] * len(std) transform = transforms.Normalize(mean=mean, std=std) elif mean is not None and std is None: std = [1.] * len(mean) transform = transforms.Normalize(mean=mean, std=std) else: transform = transforms.Normalize(mean=mean, std=std) return transform def format_messages( conversations: List[Dict[str, str]], sft_format: str = "deepseek", system_prompt: str = "", ): """ Applies the SFT template to conversation. Args: conversations (List[Dict]): A List of messages. sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". Returns: sft_prompt (str): The formatted text. """ conv = get_conv_template(sft_format) conv.set_system_message(system_prompt) for message in conversations: conv.append_message(message["role"], message["content"].strip()) sft_prompt = conv.get_prompt().strip() return sft_prompt def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): t = tokenizer.encode(text, add_special_tokens=False) bos_id = 0 eos_id = 1 if bos: t = [bos_id] + t if eos: t = t + [eos_id] return t def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: """ Args: conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : [ { "role": "User", "content": "\nExtract all information from this image and convert them into markdown format.", "images": ["./examples/table_datasets.png"] }, {"role": "Assistant", "content": ""}, ] Returns: pil_images (List[PIL.Image.Image]): the list of PIL images. """ pil_images = [] for message in conversations: if "images" not in message: continue for image_path in message["images"]: # print('----------------') # print(image_path) # print('----------------') # exit() # pil_img = Image.open(image_path) pil_img = load_image(image_path) pil_img = pil_img.convert("RGB") pil_images.append(pil_img) return pil_images class BaseTransform(ABC): def set_rng(self, *args, **kwargs): pass def __call__(self, *args, **kwargs) -> torch.Tensor: pass @property def default_shape(self): raise NotImplementedError class BasicImageTransform(BaseTransform): def __init__( self, mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), normalize: bool = True ): self.mean = mean self.std = std transform_pipelines = [ transforms.ToTensor() ] normalize = normalize_transform(mean, std) if normalize else nn.Identity() if normalize is not None: transform_pipelines.append(normalize) self.transform = transforms.Compose(transform_pipelines) def __call__(self, x): x = self.transform(x) return x class NoEOSTextStreamer(TextStreamer): def on_finalized_text(self, text: str, stream_end: bool = False): eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) text = text.replace(eos_text, "\n") print(text, flush=True, end="") class DeepseekOCR2Config(DeepseekV2Config): model_type = "DeepseekOCR2" class DeepseekOCR2Model(DeepseekV2Model): config_class = DeepseekOCR2Config def __init__(self, config: DeepseekV2Config): super(DeepseekOCR2Model, self).__init__(config) self.sam_model = build_sam_vit_b() self.qwen2_model = build_qwen2_decoder_as_encoder() # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2) n_embed = 1280 self.projector = MlpProjector(Dict(projector_type="linear", input_dim=896, n_embed=n_embed)) embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) # self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.FloatTensor] = None, images_spatial_crop: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if inputs_embeds is None: # inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids) sam_model = getattr(self, 'sam_model', None) # sam_model = self.sam_model qwen2_model = getattr(self, 'qwen2_model', None) if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: idx = 0 # sam_model = torch.jit.script(sam_model) # start_time = time.time() for image, crop_shape in zip(images, images_spatial_crop): images_in_this_batch = [] patches = image[0] image_ori = image[1] with torch.no_grad(): # with torch.inference_mode(): if torch.sum(patches).item() != 0: # P, C, H, W = patches.shape crop_flag = 1 local_features_1 = sam_model(patches) local_features_2 = qwen2_model(local_features_1) # vit_time = time.time() local_features = local_features_2 local_features = self.projector(local_features) global_features_1 = sam_model(image_ori) global_features_2 = qwen2_model(global_features_1) global_features = global_features_2 global_features = self.projector(global_features) print('=====================') print('BASE: ', global_features.shape) print('PATCHES: ', local_features.shape) print('=====================') _, hw, n_dim = global_features.shape # h = w = int(hw ** 0.5) _2, hw2, n_dim2 = local_features.shape # h2 = w2 = int(hw2 ** 0.5) global_features = global_features.view(-1, n_dim) local_features = local_features.view(-1, n_dim2) global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) # end_time = time.time() # print('sam: ', sam_time - start_time) # print('vit: ', vit_time - sam_time) # print('all: ', end_time - start_time) # exit() else: global_features_1 = sam_model(image_ori) global_features_2 = qwen2_model(global_features_1) global_features = global_features_2 global_features = self.projector(global_features) print('=====================') print('BASE: ', global_features.shape) print('NO PATCHES') print('=====================') _, hw, n_dim = global_features.shape # h = w = int(hw ** 0.5) # global_features = global_features.view(h, w, n_dim) # global_features = torch.cat( # [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 # ) global_features = global_features.view(-1, n_dim) global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) images_in_this_batch.append(global_local_features) # print(inputs_embeds.shape) if images_in_this_batch: images_in_this_batch = torch.cat(images_in_this_batch, dim=0) # exit() inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) idx += 1 return super(DeepseekOCR2Model, self).forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM): config_class = DeepseekOCR2Config # supports_gradient_checkpointing = True def __init__(self, config): super(DeepseekV2ForCausalLM, self).__init__(config) self.model = DeepseekOCR2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, images_seq_mask: Optional[torch.FloatTensor] = None, images_spatial_crop: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, images=images, images_seq_mask = images_seq_mask, images_spatial_crop = images_spatial_crop, return_dict=return_dict ) # print(transformer_outputs) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() # logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if self.generation_config.cache_implementation == "static": # # generation with static cache # cache_position = kwargs.get("cache_position", None) # if cache_position is None: # past_length = 0 # else: # past_length = cache_position[-1] + 1 # input_ids = input_ids[:, past_length:] # position_ids = position_ids[:, past_length:] # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), "images_seq_mask": kwargs.get("images_seq_mask", None), "images_spatial_crop": kwargs.get("images_spatial_crop", None), } ) return model_inputs def disable_torch_init(self): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): self.disable_torch_init() os.makedirs(output_path, exist_ok=True) os.makedirs(f'{output_path}/images', exist_ok=True) if prompt and image_file: conversation = [ { "role": "<|User|>", # "content": "\n<|grounding|>Given the layout of the image. ", "content": f'{prompt}', # "content": "君不见黄河之水天上来的下一句是什么?", # "content": "\nFree OCR. ", # "content": "\nParse the figure. ", # "content": "\nExtract the text in the image. ", "images": [f'{image_file}'], }, {"role": "<|Assistant|>", "content": ""}, ] elif prompt: conversation = [ { "role": "<|User|>", # "content": "\n<|grounding|>Given the layout of the image. ", "content": f'{prompt}', # "content": "君不见黄河之水天上来的下一句是什么?", # "content": "\nFree OCR. ", # "content": "\nParse the figure. ", # "content": "\nExtract the text in the image. ", # "images": [f'{image_file}'], }, {"role": "<|Assistant|>", "content": ""}, ] else: assert False, f'prompt is none!' prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') patch_size = 16 downsample_ratio = 4 images = load_pil_images(conversation) valid_img_tokens = 0 ratio = 1 image_draw = images[0].copy() w,h = image_draw.size # print(w, h) ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) images_seq_mask = [] image_token = '' image_token_id = 128815 text_splits = prompt.split(image_token) images_list, images_crop_list, images_seq_mask = [], [], [] tokenized_str = [] images_spatial_crop = [] for text_sep, image in zip(text_splits, images): tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) if crop_mode: if image.size[0] <= 768 and image.size[1] <= 768: crop_ratio = [1, 1] else: if crop_mode: # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) images_crop_raw, crop_ratio = dynamic_preprocess(image) else: # best_width, best_height = self.image_size, self.image_size crop_ratio = [1, 1] """process the global view""" # image = image.resize((base_size, base_size)) global_view = ImageOps.pad(image, (base_size, base_size), color=tuple(int(x * 255) for x in image_transform.mean)) if base_size == 1024: valid_img_tokens += int(256 * ratio) elif base_size == 1280: valid_img_tokens += int(400 * ratio) # elif base_size == 640: # valid_img_tokens += int(100 * ratio) images_list.append(image_transform(global_view).to(torch.bfloat16)) # global_view_tensor = image_transform(global_view).to(torch.bfloat16) width_crop_num, height_crop_num = crop_ratio images_spatial_crop.append([width_crop_num, height_crop_num]) if width_crop_num > 1 or height_crop_num > 1: """process the local views""" for i in range(len(images_crop_raw)): images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) if image_size == 768: valid_img_tokens += len(images_crop_list) * 144 num_queries = math.ceil((image_size // patch_size) / downsample_ratio) num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) """add image tokens""" tokenized_image = ([image_token_id] * num_queries_base) * num_queries_base tokenized_image += [image_token_id] if width_crop_num > 1 or height_crop_num > 1: tokenized_image += ([image_token_id] * (num_queries * width_crop_num)) * ( num_queries * height_crop_num) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) # num_image_tokens.append(len(tokenized_image)) else: # best_width, best_height = self.image_size, self.image_size # print(image.size, (best_width, best_height)) # check the select_best_resolutions func """process the global view""" if image_size <= 768: print('directly resize') image = image.resize((image_size, image_size)) # else: global_view = ImageOps.pad(image, (image_size, image_size), color=tuple(int(x * 255) for x in image_transform.mean)) images_list.append(image_transform(global_view).to(torch.bfloat16)) if base_size == 1024: valid_img_tokens += int(256 * ratio) elif base_size == 1280: valid_img_tokens += int(400 * ratio) elif base_size == 640: valid_img_tokens += int(100 * 1) elif base_size == 512: valid_img_tokens += int(64 * 1) elif base_size == 768: valid_img_tokens += int(144 * 1) width_crop_num, height_crop_num = 1, 1 images_spatial_crop.append([width_crop_num, height_crop_num]) """add image tokens""" num_queries = math.ceil((image_size // patch_size) / downsample_ratio) tokenized_image = ([image_token_id] * num_queries) * num_queries tokenized_image += [image_token_id] # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( # num_queries * height_crop_num) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) # num_image_tokens.append(len(tokenized_image)) """process the last text split""" tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) tokenized_str += tokenized_sep images_seq_mask += [False] * len(tokenized_sep) """add the bos tokens""" bos_id = 0 tokenized_str = [bos_id] + tokenized_str images_seq_mask = [False] + images_seq_mask input_ids = torch.LongTensor(tokenized_str) images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) if len(images_list) == 0: images_ori = torch.zeros((1, 3, image_size, image_size)) images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) images_crop = torch.zeros((1, 3, base_size, base_size)) else: images_ori = torch.stack(images_list, dim=0) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) if images_crop_list: images_crop = torch.stack(images_crop_list, dim=0) else: images_crop = torch.zeros((1, 3, base_size, base_size)) if not eval_mode: streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) with torch.autocast("cuda", dtype=torch.bfloat16): with torch.no_grad(): output_ids = self.generate( input_ids.unsqueeze(0).cuda(), images=[(images_crop.cuda(), images_ori.cuda())], images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), images_spatial_crop = images_spatial_crop, # do_sample=False, # num_beams = 1, temperature=0.0, eos_token_id=tokenizer.eos_token_id, streamer=streamer, max_new_tokens=8192, no_repeat_ngram_size = 20, use_cache = True ) else: with torch.autocast("cuda", dtype=torch.bfloat16): with torch.no_grad(): output_ids = self.generate( input_ids.unsqueeze(0).cuda(), images=[(images_crop.cuda(), images_ori.cuda())], images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), images_spatial_crop = images_spatial_crop, # do_sample=False, # num_beams = 1, temperature=0.0, eos_token_id=tokenizer.eos_token_id, max_new_tokens=8192, no_repeat_ngram_size = 35, use_cache = True ) if '' in conversation[0]['content'] and eval_mode: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) stop_str = '<|end▁of▁sentence|>' if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] # re_match outputs = outputs.strip() return outputs if '' in conversation[0]['content'] and test_compress: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) print('='*50) print('image size: ', (w, h)) print('valid image tokens: ', int(valid_img_tokens)) print('output texts tokens (valid): ', pure_texts_outputs_token_length) print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) print('='*50) if '' in conversation[0]['content'] and save_results: outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) stop_str = '<|end▁of▁sentence|>' print('='*15 + 'save results:' + '='*15) # # # # conv.messages[-1][-1] = outputs if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() matches_ref, matches_images, mathes_other = re_match(outputs) # print(matches_ref) result = process_image_with_refs(image_draw, matches_ref, output_path) for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n') for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') # if 'structural formula' in conversation[0]['content']: # outputs = '' + outputs + '' with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: afile.write(outputs) if 'line_type' in outputs: import matplotlib.pyplot as plt lines = eval(outputs)['Line']['line'] line_type = eval(outputs)['Line']['line_type'] # print(lines) endpoints = eval(outputs)['Line']['line_endpoint'] fig, ax = plt.subplots(figsize=(3,3), dpi=200) ax.set_xlim(-15, 15) ax.set_ylim(-15, 15) for idx, line in enumerate(lines): try: p0 = eval(line.split(' -- ')[0]) p1 = eval(line.split(' -- ')[-1]) if line_type[idx] == '--': ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') else: ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') ax.scatter(p0[0], p0[1], s=5, color = 'k') ax.scatter(p1[0], p1[1], s=5, color = 'k') except: pass for endpoint in endpoints: label = endpoint.split(': ')[0] (x, y) = eval(endpoint.split(': ')[1]) ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', fontsize=5, fontweight='light') plt.savefig(f'{output_path}/geo.jpg') plt.close() result.save(f"{output_path}/result_with_boxes.jpg")