shimmyshimmer commited on
Commit
ec38206
·
verified ·
1 Parent(s): af03e36

Update modeling_deepseekocr.py

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr.py +53 -48
modeling_deepseekocr.py CHANGED
@@ -1,28 +1,31 @@
1
- from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
2
- from .configuration_deepseek_v2 import DeepseekV2Config
3
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
4
  from typing import List, Optional, Tuple, Union
5
- from transformers.cache_utils import Cache
6
- import requests
7
  from PIL import Image, ImageOps, ImageDraw, ImageFont
8
- from io import BytesIO
 
9
  import torch
10
  import torch.nn as nn
11
  from torch.nn import CrossEntropyLoss
12
  from torchvision import transforms
13
- from torchvision.transforms.functional import InterpolationMode
14
- import os
15
- from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector
16
- from addict import Dict
 
 
 
 
17
  from transformers import TextStreamer
 
18
  from .conversation import get_conv_template
19
- from abc import ABC
20
- import math
21
- import re
22
- from tqdm import tqdm
23
- import numpy as np
24
- import time
25
 
 
26
 
27
  def load_image(image_path):
28
 
@@ -348,6 +351,23 @@ class NoEOSTextStreamer(TextStreamer):
348
  print(text, flush=True, end="")
349
 
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  class DeepseekOCRConfig(DeepseekV2Config):
352
  model_type = "DeepseekOCR"
353
 
@@ -366,8 +386,7 @@ class DeepseekOCRModel(DeepseekV2Model):
366
  self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
367
  self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
368
 
369
-
370
-
371
 
372
  def forward(
373
  self,
@@ -387,12 +406,11 @@ class DeepseekOCRModel(DeepseekV2Model):
387
 
388
 
389
 
390
-
391
  if inputs_embeds is None:
392
  # inputs_embeds = self.embed_tokens(input_ids)
393
  inputs_embeds = self.get_input_embeddings()(input_ids)
394
 
395
-
396
 
397
  sam_model = getattr(self, 'sam_model', None)
398
  # sam_model = self.sam_model
@@ -475,10 +493,6 @@ class DeepseekOCRModel(DeepseekV2Model):
475
  global_features_2 = vision_model(image_ori, global_features_1)
476
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
477
  global_features = self.projector(global_features)
478
- print('=====================')
479
- print('BASE: ', global_features.shape)
480
- print('NO PATCHES')
481
- print('=====================')
482
  _, hw, n_dim = global_features.shape
483
  h = w = int(hw ** 0.5)
484
 
@@ -496,17 +510,17 @@ class DeepseekOCRModel(DeepseekV2Model):
496
  images_in_this_batch.append(global_local_features)
497
 
498
 
499
- # print(inputs_embeds.shape)
500
-
501
  if images_in_this_batch:
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
- # exit()
504
-
505
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
 
 
 
506
 
507
  idx += 1
508
 
509
-
510
  return super(DeepseekOCRModel, self).forward(
511
  input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
512
  inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
@@ -528,8 +542,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
528
 
529
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
530
 
531
- # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
532
-
533
  # Initialize weights and apply final processing
534
  self.post_init()
535
 
@@ -578,10 +590,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
578
 
579
  )
580
 
581
-
582
-
583
- # print(transformer_outputs)
584
-
585
  hidden_states = outputs[0]
586
  logits = self.lm_head(hidden_states)
587
  logits = logits.float()
@@ -622,8 +630,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
622
  if past_key_values is not None:
623
  if isinstance(past_key_values, Cache):
624
  cache_length = past_key_values.get_seq_length()
625
- past_length = past_key_values.seen_tokens
626
- max_cache_length = past_key_values.get_max_length()
627
  else:
628
  cache_length = past_length = past_key_values[0][0].shape[2]
629
  max_cache_length = None
@@ -799,9 +807,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
799
 
800
 
801
 
802
- images_list.append(image_transform(global_view).to(torch.bfloat16))
803
 
804
- # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
805
 
806
  width_crop_num, height_crop_num = crop_ratio
807
 
@@ -812,7 +820,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
812
  """process the local views"""
813
 
814
  for i in range(len(images_crop_raw)):
815
- images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
816
 
817
  if image_size == 640:
818
  valid_img_tokens += len(images_crop_list) * 100
@@ -846,7 +854,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
846
  # else:
847
  global_view = ImageOps.pad(image, (image_size, image_size),
848
  color=tuple(int(x * 255) for x in image_transform.mean))
849
- images_list.append(image_transform(global_view).to(torch.bfloat16))
850
 
851
  if base_size == 1024:
852
  valid_img_tokens += int(256 * ratio)
@@ -888,9 +896,6 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
888
 
889
  input_ids = torch.LongTensor(tokenized_str)
890
 
891
-
892
-
893
-
894
  images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
895
 
896
 
@@ -911,7 +916,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
911
 
912
  if not eval_mode:
913
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
- with torch.autocast("cuda", dtype=torch.bfloat16):
915
  with torch.no_grad():
916
  output_ids = self.generate(
917
  input_ids.unsqueeze(0).cuda(),
@@ -929,7 +934,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
929
  )
930
 
931
  else:
932
- with torch.autocast("cuda", dtype=torch.bfloat16):
933
  with torch.no_grad():
934
  output_ids = self.generate(
935
  input_ids.unsqueeze(0).cuda(),
@@ -1034,4 +1039,4 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
1034
  plt.savefig(f'{output_path}/geo.jpg')
1035
  plt.close()
1036
 
1037
- result.save(f"{output_path}/result_with_boxes.jpg")
 
1
+ import os
2
+ import math
3
+ import re
4
+ from tqdm import tqdm
5
+ from abc import ABC
6
  from typing import List, Optional, Tuple, Union
7
+
8
+ from addict import Dict
9
  from PIL import Image, ImageOps, ImageDraw, ImageFont
10
+ import numpy as np
11
+
12
  import torch
13
  import torch.nn as nn
14
  from torch.nn import CrossEntropyLoss
15
  from torchvision import transforms
16
+
17
+ from transformers.cache_utils import Cache
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers import DeepseekV2Model, DeepseekV2ForCausalLM
20
+ from transformers import DeepseekV2Config
21
+ from transformers.models.deepseek_v2.modeling_deepseek_v2 import (
22
+ DeepseekV2Attention, DeepseekV2MLP, DeepseekV2MoE, DeepseekV2RMSNorm, DeepseekV2DecoderLayer)
23
+ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
24
  from transformers import TextStreamer
25
+ from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector
26
  from .conversation import get_conv_template
 
 
 
 
 
 
27
 
28
+ torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
29
 
30
  def load_image(image_path):
31
 
 
351
  print(text, flush=True, end="")
352
 
353
 
354
+ def decoder_layer_init(self, config: DeepseekV2Config, layer_idx: int):
355
+ nn.Module.__init__(self)
356
+ self.hidden_size = config.hidden_size
357
+
358
+ if config.use_mla:
359
+ self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx)
360
+ else:
361
+ config.head_dim = config.hidden_size // config.num_attention_heads
362
+ self.self_attn = LlamaAttention(config, layer_idx)
363
+ self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config)
364
+
365
+ self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
366
+ self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
367
+
368
+
369
+ DeepseekV2DecoderLayer.__init__ = decoder_layer_init
370
+
371
  class DeepseekOCRConfig(DeepseekV2Config):
372
  model_type = "DeepseekOCR"
373
 
 
386
  self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
387
  self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
388
 
389
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
 
390
 
391
  def forward(
392
  self,
 
406
 
407
 
408
 
 
409
  if inputs_embeds is None:
410
  # inputs_embeds = self.embed_tokens(input_ids)
411
  inputs_embeds = self.get_input_embeddings()(input_ids)
412
 
413
+ inputs_embeds = inputs_embeds.clone()
414
 
415
  sam_model = getattr(self, 'sam_model', None)
416
  # sam_model = self.sam_model
 
493
  global_features_2 = vision_model(image_ori, global_features_1)
494
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
495
  global_features = self.projector(global_features)
 
 
 
 
496
  _, hw, n_dim = global_features.shape
497
  h = w = int(hw ** 0.5)
498
 
 
510
  images_in_this_batch.append(global_local_features)
511
 
512
 
 
 
513
  if images_in_this_batch:
514
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
515
+ images_in_this_batch = images_in_this_batch.to(
516
+ device=inputs_embeds.device, dtype=inputs_embeds.dtype
517
+ )
518
+ mask = images_seq_mask[idx].unsqueeze(-1).to(inputs_embeds.device) # bool [T, 1]
519
+ updated_row = inputs_embeds[idx].masked_scatter(mask, images_in_this_batch)
520
+ inputs_embeds[idx] = updated_row
521
 
522
  idx += 1
523
 
 
524
  return super(DeepseekOCRModel, self).forward(
525
  input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
526
  inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
 
542
 
543
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
544
 
 
 
545
  # Initialize weights and apply final processing
546
  self.post_init()
547
 
 
590
 
591
  )
592
 
 
 
 
 
593
  hidden_states = outputs[0]
594
  logits = self.lm_head(hidden_states)
595
  logits = logits.float()
 
630
  if past_key_values is not None:
631
  if isinstance(past_key_values, Cache):
632
  cache_length = past_key_values.get_seq_length()
633
+ past_length = past_key_values.get_seq_length()
634
+ max_cache_length = None
635
  else:
636
  cache_length = past_length = past_key_values[0][0].shape[2]
637
  max_cache_length = None
 
807
 
808
 
809
 
810
+ images_list.append(image_transform(global_view).to(torch_dtype))
811
 
812
+ # global_view_tensor = image_transform(global_view).to(torch_dtype)
813
 
814
  width_crop_num, height_crop_num = crop_ratio
815
 
 
820
  """process the local views"""
821
 
822
  for i in range(len(images_crop_raw)):
823
+ images_crop_list.append(image_transform(images_crop_raw[i]).to(torch_dtype))
824
 
825
  if image_size == 640:
826
  valid_img_tokens += len(images_crop_list) * 100
 
854
  # else:
855
  global_view = ImageOps.pad(image, (image_size, image_size),
856
  color=tuple(int(x * 255) for x in image_transform.mean))
857
+ images_list.append(image_transform(global_view).to(torch_dtype))
858
 
859
  if base_size == 1024:
860
  valid_img_tokens += int(256 * ratio)
 
896
 
897
  input_ids = torch.LongTensor(tokenized_str)
898
 
 
 
 
899
  images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
900
 
901
 
 
916
 
917
  if not eval_mode:
918
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
919
+ with torch.autocast("cuda", dtype=torch_dtype):
920
  with torch.no_grad():
921
  output_ids = self.generate(
922
  input_ids.unsqueeze(0).cuda(),
 
934
  )
935
 
936
  else:
937
+ with torch.autocast("cuda", dtype=torch_dtype):
938
  with torch.no_grad():
939
  output_ids = self.generate(
940
  input_ids.unsqueeze(0).cuda(),
 
1039
  plt.savefig(f'{output_path}/geo.jpg')
1040
  plt.close()
1041
 
1042
+ result.save(f"{output_path}/result_with_boxes.jpg")