Update: support mps
Browse files- modeling_deepseekocr2.py +25 -34
modeling_deepseekocr2.py
CHANGED
|
@@ -901,41 +901,32 @@ class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM):
|
|
| 901 |
|
| 902 |
|
| 903 |
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
input_ids.unsqueeze(0).cuda(),
|
| 910 |
-
images=[(images_crop.cuda(), images_ori.cuda())],
|
| 911 |
-
images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
|
| 912 |
-
images_spatial_crop = images_spatial_crop,
|
| 913 |
-
# do_sample=False,
|
| 914 |
-
# num_beams = 1,
|
| 915 |
-
temperature=0.0,
|
| 916 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 917 |
-
streamer=streamer,
|
| 918 |
-
max_new_tokens=8192,
|
| 919 |
-
no_repeat_ngram_size = 20,
|
| 920 |
-
use_cache = True
|
| 921 |
-
)
|
| 922 |
-
|
| 923 |
else:
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
|
| 940 |
|
| 941 |
if '<image>' in conversation[0]['content'] and eval_mode:
|
|
|
|
| 901 |
|
| 902 |
|
| 903 |
|
| 904 |
+
# Initialization
|
| 905 |
+
if torch.backends.mps.is_available():
|
| 906 |
+
device = torch.device("mps")
|
| 907 |
+
elif torch.cuda.is_available():
|
| 908 |
+
device = torch.device("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 909 |
else:
|
| 910 |
+
device = torch.device("cpu")
|
| 911 |
+
|
| 912 |
+
dtype = torch.float16 if device.type == "mps" else torch.float32
|
| 913 |
+
# Execution Block
|
| 914 |
+
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) if not eval_mode else None
|
| 915 |
+
|
| 916 |
+
with torch.no_grad():
|
| 917 |
+
output_ids = self.generate(
|
| 918 |
+
input_ids.unsqueeze(0).to(device=device, dtype=torch.long),
|
| 919 |
+
images=[(images_crop.to(device=device, dtype=dtype),
|
| 920 |
+
images_ori.to(device=device, dtype=dtype))],
|
| 921 |
+
images_seq_mask=images_seq_mask.unsqueeze(0).to(device=device, dtype=dtype),
|
| 922 |
+
images_spatial_crop=images_spatial_crop,
|
| 923 |
+
temperature=0.0,
|
| 924 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 925 |
+
streamer=streamer,
|
| 926 |
+
max_new_tokens=8192,
|
| 927 |
+
no_repeat_ngram_size=20 if not eval_mode else 35,
|
| 928 |
+
use_cache=True
|
| 929 |
+
)
|
| 930 |
|
| 931 |
|
| 932 |
if '<image>' in conversation[0]['content'] and eval_mode:
|