Songyou commited on
Commit
f93cc1d
·
verified ·
1 Parent(s): a971216

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -2
generate.py CHANGED
@@ -63,8 +63,7 @@ class GenerateRunner():
63
  # 加载模型
64
  file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt')
65
  if opt.model_choice == 'transformer':
66
- # self.model = EncoderDecoder.load_from_file(file_name)
67
- self.model = EncoderDecoder.load_from_file(file_name, map_location=torch.device('cpu'))
68
  self.model.to(self.device)
69
  self.model.eval()
70
  elif opt.model_choice == 'seq2seq':
 
63
  # 加载模型
64
  file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt')
65
  if opt.model_choice == 'transformer':
66
+ self.model = EncoderDecoder.load_from_file(file_name)
 
67
  self.model.to(self.device)
68
  self.model.eval()
69
  elif opt.model_choice == 'seq2seq':