Update modeling_deepseekocr.py
Browse files- modeling_deepseekocr.py +53 -48
modeling_deepseekocr.py
CHANGED
|
@@ -1,28 +1,31 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
| 4 |
from typing import List, Optional, Tuple, Union
|
| 5 |
-
|
| 6 |
-
import
|
| 7 |
from PIL import Image, ImageOps, ImageDraw, ImageFont
|
| 8 |
-
|
|
|
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
from torch.nn import CrossEntropyLoss
|
| 12 |
from torchvision import transforms
|
| 13 |
-
|
| 14 |
-
import
|
| 15 |
-
from .
|
| 16 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from transformers import TextStreamer
|
|
|
|
| 18 |
from .conversation import get_conv_template
|
| 19 |
-
from abc import ABC
|
| 20 |
-
import math
|
| 21 |
-
import re
|
| 22 |
-
from tqdm import tqdm
|
| 23 |
-
import numpy as np
|
| 24 |
-
import time
|
| 25 |
|
|
|
|
| 26 |
|
| 27 |
def load_image(image_path):
|
| 28 |
|
|
@@ -348,6 +351,23 @@ class NoEOSTextStreamer(TextStreamer):
|
|
| 348 |
print(text, flush=True, end="")
|
| 349 |
|
| 350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
class DeepseekOCRConfig(DeepseekV2Config):
|
| 352 |
model_type = "DeepseekOCR"
|
| 353 |
|
|
@@ -366,8 +386,7 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 366 |
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 367 |
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 368 |
|
| 369 |
-
|
| 370 |
-
|
| 371 |
|
| 372 |
def forward(
|
| 373 |
self,
|
|
@@ -387,12 +406,11 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 387 |
|
| 388 |
|
| 389 |
|
| 390 |
-
|
| 391 |
if inputs_embeds is None:
|
| 392 |
# inputs_embeds = self.embed_tokens(input_ids)
|
| 393 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 394 |
|
| 395 |
-
|
| 396 |
|
| 397 |
sam_model = getattr(self, 'sam_model', None)
|
| 398 |
# sam_model = self.sam_model
|
|
@@ -475,10 +493,6 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 475 |
global_features_2 = vision_model(image_ori, global_features_1)
|
| 476 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 477 |
global_features = self.projector(global_features)
|
| 478 |
-
print('=====================')
|
| 479 |
-
print('BASE: ', global_features.shape)
|
| 480 |
-
print('NO PATCHES')
|
| 481 |
-
print('=====================')
|
| 482 |
_, hw, n_dim = global_features.shape
|
| 483 |
h = w = int(hw ** 0.5)
|
| 484 |
|
|
@@ -496,17 +510,17 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 496 |
images_in_this_batch.append(global_local_features)
|
| 497 |
|
| 498 |
|
| 499 |
-
# print(inputs_embeds.shape)
|
| 500 |
-
|
| 501 |
if images_in_this_batch:
|
| 502 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
idx += 1
|
| 508 |
|
| 509 |
-
|
| 510 |
return super(DeepseekOCRModel, self).forward(
|
| 511 |
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
|
| 512 |
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
|
|
@@ -528,8 +542,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 528 |
|
| 529 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 530 |
|
| 531 |
-
# self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 532 |
-
|
| 533 |
# Initialize weights and apply final processing
|
| 534 |
self.post_init()
|
| 535 |
|
|
@@ -578,10 +590,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 578 |
|
| 579 |
)
|
| 580 |
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
# print(transformer_outputs)
|
| 584 |
-
|
| 585 |
hidden_states = outputs[0]
|
| 586 |
logits = self.lm_head(hidden_states)
|
| 587 |
logits = logits.float()
|
|
@@ -622,8 +630,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 622 |
if past_key_values is not None:
|
| 623 |
if isinstance(past_key_values, Cache):
|
| 624 |
cache_length = past_key_values.get_seq_length()
|
| 625 |
-
past_length = past_key_values.
|
| 626 |
-
max_cache_length =
|
| 627 |
else:
|
| 628 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 629 |
max_cache_length = None
|
|
@@ -799,9 +807,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 799 |
|
| 800 |
|
| 801 |
|
| 802 |
-
images_list.append(image_transform(global_view).to(
|
| 803 |
|
| 804 |
-
# global_view_tensor = image_transform(global_view).to(
|
| 805 |
|
| 806 |
width_crop_num, height_crop_num = crop_ratio
|
| 807 |
|
|
@@ -812,7 +820,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 812 |
"""process the local views"""
|
| 813 |
|
| 814 |
for i in range(len(images_crop_raw)):
|
| 815 |
-
images_crop_list.append(image_transform(images_crop_raw[i]).to(
|
| 816 |
|
| 817 |
if image_size == 640:
|
| 818 |
valid_img_tokens += len(images_crop_list) * 100
|
|
@@ -846,7 +854,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 846 |
# else:
|
| 847 |
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 848 |
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 849 |
-
images_list.append(image_transform(global_view).to(
|
| 850 |
|
| 851 |
if base_size == 1024:
|
| 852 |
valid_img_tokens += int(256 * ratio)
|
|
@@ -888,9 +896,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 888 |
|
| 889 |
input_ids = torch.LongTensor(tokenized_str)
|
| 890 |
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
| 895 |
|
| 896 |
|
|
@@ -911,7 +916,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 911 |
|
| 912 |
if not eval_mode:
|
| 913 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 914 |
-
with torch.autocast("cuda", dtype=
|
| 915 |
with torch.no_grad():
|
| 916 |
output_ids = self.generate(
|
| 917 |
input_ids.unsqueeze(0).cuda(),
|
|
@@ -929,7 +934,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 929 |
)
|
| 930 |
|
| 931 |
else:
|
| 932 |
-
with torch.autocast("cuda", dtype=
|
| 933 |
with torch.no_grad():
|
| 934 |
output_ids = self.generate(
|
| 935 |
input_ids.unsqueeze(0).cuda(),
|
|
@@ -1034,4 +1039,4 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 1034 |
plt.savefig(f'{output_path}/geo.jpg')
|
| 1035 |
plt.close()
|
| 1036 |
|
| 1037 |
-
result.save(f"{output_path}/result_with_boxes.jpg")
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import re
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from abc import ABC
|
| 6 |
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from addict import Dict
|
| 9 |
from PIL import Image, ImageOps, ImageDraw, ImageFont
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
| 14 |
from torch.nn import CrossEntropyLoss
|
| 15 |
from torchvision import transforms
|
| 16 |
+
|
| 17 |
+
from transformers.cache_utils import Cache
|
| 18 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 19 |
+
from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
|
| 20 |
+
from transformers import DeepseekV2Config
|
| 21 |
+
from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
|
| 22 |
+
DeepseekV2Attention, DeepseekV2MLP, DeepseekV2MoE, DeepseekV2RMSNorm, DeepseekV2DecoderLayer)
|
| 23 |
+
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
|
| 24 |
from transformers import TextStreamer
|
| 25 |
+
from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector
|
| 26 |
from .conversation import get_conv_template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 29 |
|
| 30 |
def load_image(image_path):
|
| 31 |
|
|
|
|
| 351 |
print(text, flush=True, end="")
|
| 352 |
|
| 353 |
|
| 354 |
+
def decoder_layer_init(self, config: DeepseekV2Config, layer_idx: int):
|
| 355 |
+
nn.Module.__init__(self)
|
| 356 |
+
self.hidden_size = config.hidden_size
|
| 357 |
+
|
| 358 |
+
if config.use_mla:
|
| 359 |
+
self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx)
|
| 360 |
+
else:
|
| 361 |
+
config.head_dim = config.hidden_size // config.num_attention_heads
|
| 362 |
+
self.self_attn = LlamaAttention(config, layer_idx)
|
| 363 |
+
self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config)
|
| 364 |
+
|
| 365 |
+
self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 366 |
+
self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
DeepseekV2DecoderLayer.__init__ = decoder_layer_init
|
| 370 |
+
|
| 371 |
class DeepseekOCRConfig(DeepseekV2Config):
|
| 372 |
model_type = "DeepseekOCR"
|
| 373 |
|
|
|
|
| 386 |
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 387 |
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
|
| 388 |
|
| 389 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
|
|
|
| 390 |
|
| 391 |
def forward(
|
| 392 |
self,
|
|
|
|
| 406 |
|
| 407 |
|
| 408 |
|
|
|
|
| 409 |
if inputs_embeds is None:
|
| 410 |
# inputs_embeds = self.embed_tokens(input_ids)
|
| 411 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 412 |
|
| 413 |
+
inputs_embeds = inputs_embeds.clone()
|
| 414 |
|
| 415 |
sam_model = getattr(self, 'sam_model', None)
|
| 416 |
# sam_model = self.sam_model
|
|
|
|
| 493 |
global_features_2 = vision_model(image_ori, global_features_1)
|
| 494 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 495 |
global_features = self.projector(global_features)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
_, hw, n_dim = global_features.shape
|
| 497 |
h = w = int(hw ** 0.5)
|
| 498 |
|
|
|
|
| 510 |
images_in_this_batch.append(global_local_features)
|
| 511 |
|
| 512 |
|
|
|
|
|
|
|
| 513 |
if images_in_this_batch:
|
| 514 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 515 |
+
images_in_this_batch = images_in_this_batch.to(
|
| 516 |
+
device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
| 517 |
+
)
|
| 518 |
+
mask = images_seq_mask[idx].unsqueeze(-1).to(inputs_embeds.device) # bool [T, 1]
|
| 519 |
+
updated_row = inputs_embeds[idx].masked_scatter(mask, images_in_this_batch)
|
| 520 |
+
inputs_embeds[idx] = updated_row
|
| 521 |
|
| 522 |
idx += 1
|
| 523 |
|
|
|
|
| 524 |
return super(DeepseekOCRModel, self).forward(
|
| 525 |
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
|
| 526 |
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
|
|
|
|
| 542 |
|
| 543 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 544 |
|
|
|
|
|
|
|
| 545 |
# Initialize weights and apply final processing
|
| 546 |
self.post_init()
|
| 547 |
|
|
|
|
| 590 |
|
| 591 |
)
|
| 592 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
hidden_states = outputs[0]
|
| 594 |
logits = self.lm_head(hidden_states)
|
| 595 |
logits = logits.float()
|
|
|
|
| 630 |
if past_key_values is not None:
|
| 631 |
if isinstance(past_key_values, Cache):
|
| 632 |
cache_length = past_key_values.get_seq_length()
|
| 633 |
+
past_length = past_key_values.get_seq_length()
|
| 634 |
+
max_cache_length = None
|
| 635 |
else:
|
| 636 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 637 |
max_cache_length = None
|
|
|
|
| 807 |
|
| 808 |
|
| 809 |
|
| 810 |
+
images_list.append(image_transform(global_view).to(torch_dtype))
|
| 811 |
|
| 812 |
+
# global_view_tensor = image_transform(global_view).to(torch_dtype)
|
| 813 |
|
| 814 |
width_crop_num, height_crop_num = crop_ratio
|
| 815 |
|
|
|
|
| 820 |
"""process the local views"""
|
| 821 |
|
| 822 |
for i in range(len(images_crop_raw)):
|
| 823 |
+
images_crop_list.append(image_transform(images_crop_raw[i]).to(torch_dtype))
|
| 824 |
|
| 825 |
if image_size == 640:
|
| 826 |
valid_img_tokens += len(images_crop_list) * 100
|
|
|
|
| 854 |
# else:
|
| 855 |
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 856 |
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 857 |
+
images_list.append(image_transform(global_view).to(torch_dtype))
|
| 858 |
|
| 859 |
if base_size == 1024:
|
| 860 |
valid_img_tokens += int(256 * ratio)
|
|
|
|
| 896 |
|
| 897 |
input_ids = torch.LongTensor(tokenized_str)
|
| 898 |
|
|
|
|
|
|
|
|
|
|
| 899 |
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
| 900 |
|
| 901 |
|
|
|
|
| 916 |
|
| 917 |
if not eval_mode:
|
| 918 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 919 |
+
with torch.autocast("cuda", dtype=torch_dtype):
|
| 920 |
with torch.no_grad():
|
| 921 |
output_ids = self.generate(
|
| 922 |
input_ids.unsqueeze(0).cuda(),
|
|
|
|
| 934 |
)
|
| 935 |
|
| 936 |
else:
|
| 937 |
+
with torch.autocast("cuda", dtype=torch_dtype):
|
| 938 |
with torch.no_grad():
|
| 939 |
output_ids = self.generate(
|
| 940 |
input_ids.unsqueeze(0).cuda(),
|
|
|
|
| 1039 |
plt.savefig(f'{output_path}/geo.jpg')
|
| 1040 |
plt.close()
|
| 1041 |
|
| 1042 |
+
result.save(f"{output_path}/result_with_boxes.jpg")
|