JLFaller commited on
Commit
bce55bf
·
verified ·
1 Parent(s): d6b7cfc

Update: support mps

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr2.py +25 -34
modeling_deepseekocr2.py CHANGED
@@ -901,41 +901,32 @@ class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM):
901
 
902
 
903
 
904
- if not eval_mode:
905
- streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
906
- with torch.autocast("cuda", dtype=torch.bfloat16):
907
- with torch.no_grad():
908
- output_ids = self.generate(
909
- input_ids.unsqueeze(0).cuda(),
910
- images=[(images_crop.cuda(), images_ori.cuda())],
911
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
912
- images_spatial_crop = images_spatial_crop,
913
- # do_sample=False,
914
- # num_beams = 1,
915
- temperature=0.0,
916
- eos_token_id=tokenizer.eos_token_id,
917
- streamer=streamer,
918
- max_new_tokens=8192,
919
- no_repeat_ngram_size = 20,
920
- use_cache = True
921
- )
922
-
923
  else:
924
- with torch.autocast("cuda", dtype=torch.bfloat16):
925
- with torch.no_grad():
926
- output_ids = self.generate(
927
- input_ids.unsqueeze(0).cuda(),
928
- images=[(images_crop.cuda(), images_ori.cuda())],
929
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
930
- images_spatial_crop = images_spatial_crop,
931
- # do_sample=False,
932
- # num_beams = 1,
933
- temperature=0.0,
934
- eos_token_id=tokenizer.eos_token_id,
935
- max_new_tokens=8192,
936
- no_repeat_ngram_size = 35,
937
- use_cache = True
938
- )
 
 
 
 
 
939
 
940
 
941
  if '<image>' in conversation[0]['content'] and eval_mode:
 
901
 
902
 
903
 
904
+ # Initialization
905
+ if torch.backends.mps.is_available():
906
+ device = torch.device("mps")
907
+ elif torch.cuda.is_available():
908
+ device = torch.device("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909
  else:
910
+ device = torch.device("cpu")
911
+
912
+ dtype = torch.float16 if device.type == "mps" else torch.float32
913
+ # Execution Block
914
+ streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) if not eval_mode else None
915
+
916
+ with torch.no_grad():
917
+ output_ids = self.generate(
918
+ input_ids.unsqueeze(0).to(device=device, dtype=torch.long),
919
+ images=[(images_crop.to(device=device, dtype=dtype),
920
+ images_ori.to(device=device, dtype=dtype))],
921
+ images_seq_mask=images_seq_mask.unsqueeze(0).to(device=device, dtype=dtype),
922
+ images_spatial_crop=images_spatial_crop,
923
+ temperature=0.0,
924
+ eos_token_id=tokenizer.eos_token_id,
925
+ streamer=streamer,
926
+ max_new_tokens=8192,
927
+ no_repeat_ngram_size=20 if not eval_mode else 35,
928
+ use_cache=True
929
+ )
930
 
931
 
932
  if '<image>' in conversation[0]['content'] and eval_mode: