Commit
·
11d47e4
1
Parent(s):
8f1ccbe
Use mps device in eval.py
Browse files
eval.py
CHANGED
|
@@ -33,7 +33,7 @@ if __name__ == '__main__':
|
|
| 33 |
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
|
| 34 |
parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
|
| 35 |
help="img path")
|
| 36 |
-
parser.add_argument('--random_state', default=
|
| 37 |
|
| 38 |
args = parser.parse_args()
|
| 39 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES
|
|
@@ -51,9 +51,10 @@ if __name__ == '__main__':
|
|
| 51 |
vocab = processor.tokenizer.get_vocab()
|
| 52 |
|
| 53 |
vocab_inp = {vocab[key]: key for key in vocab}
|
|
|
|
| 54 |
model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
|
| 55 |
model.eval()
|
| 56 |
-
model.
|
| 57 |
|
| 58 |
vocab = processor.tokenizer.get_vocab()
|
| 59 |
vocab_inp = {vocab[key]: key for key in vocab}
|
|
@@ -67,7 +68,7 @@ if __name__ == '__main__':
|
|
| 67 |
pixel_values = processor([img], return_tensors="pt").pixel_values
|
| 68 |
|
| 69 |
with torch.no_grad():
|
| 70 |
-
generated_ids = model.generate(pixel_values[:, :, :].
|
| 71 |
|
| 72 |
generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
|
| 73 |
pred_str.append(generated_text)
|
|
|
|
| 33 |
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
|
| 34 |
parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
|
| 35 |
help="img path")
|
| 36 |
+
parser.add_argument('--random_state', default=10086, type=int, help="用于训练集划分的随机数")
|
| 37 |
|
| 38 |
args = parser.parse_args()
|
| 39 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES
|
|
|
|
| 51 |
vocab = processor.tokenizer.get_vocab()
|
| 52 |
|
| 53 |
vocab_inp = {vocab[key]: key for key in vocab}
|
| 54 |
+
mps_device = torch.device("mps")
|
| 55 |
model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
|
| 56 |
model.eval()
|
| 57 |
+
model.to(mps_device)
|
| 58 |
|
| 59 |
vocab = processor.tokenizer.get_vocab()
|
| 60 |
vocab_inp = {vocab[key]: key for key in vocab}
|
|
|
|
| 68 |
pixel_values = processor([img], return_tensors="pt").pixel_values
|
| 69 |
|
| 70 |
with torch.no_grad():
|
| 71 |
+
generated_ids = model.generate(pixel_values[:, :, :].to(mps_device))
|
| 72 |
|
| 73 |
generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
|
| 74 |
pred_str.append(generated_text)
|