AlienKevin commited on
Commit
11d47e4
·
1 Parent(s): 8f1ccbe

Use mps device in eval.py

Browse files
Files changed (1) hide show
  1. eval.py +4 -3
eval.py CHANGED
@@ -33,7 +33,7 @@ if __name__ == '__main__':
33
  parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
34
  parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
35
  help="img path")
36
- parser.add_argument('--random_state', default=None, type=int, help="用于训练集划分的随机数")
37
 
38
  args = parser.parse_args()
39
  os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES
@@ -51,9 +51,10 @@ if __name__ == '__main__':
51
  vocab = processor.tokenizer.get_vocab()
52
 
53
  vocab_inp = {vocab[key]: key for key in vocab}
 
54
  model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
55
  model.eval()
56
- model.cuda()
57
 
58
  vocab = processor.tokenizer.get_vocab()
59
  vocab_inp = {vocab[key]: key for key in vocab}
@@ -67,7 +68,7 @@ if __name__ == '__main__':
67
  pixel_values = processor([img], return_tensors="pt").pixel_values
68
 
69
  with torch.no_grad():
70
- generated_ids = model.generate(pixel_values[:, :, :].cuda())
71
 
72
  generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
73
  pred_str.append(generated_text)
 
33
  parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
34
  parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
35
  help="img path")
36
+ parser.add_argument('--random_state', default=10086, type=int, help="用于训练集划分的随机数")
37
 
38
  args = parser.parse_args()
39
  os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES
 
51
  vocab = processor.tokenizer.get_vocab()
52
 
53
  vocab_inp = {vocab[key]: key for key in vocab}
54
+ mps_device = torch.device("mps")
55
  model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
56
  model.eval()
57
+ model.to(mps_device)
58
 
59
  vocab = processor.tokenizer.get_vocab()
60
  vocab_inp = {vocab[key]: key for key in vocab}
 
68
  pixel_values = processor([img], return_tensors="pt").pixel_values
69
 
70
  with torch.no_grad():
71
+ generated_ids = model.generate(pixel_values[:, :, :].to(mps_device))
72
 
73
  generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
74
  pred_str.append(generated_text)