Spaces:
Sleeping
Sleeping
Update generate.py
Browse files- generate.py +1 -2
generate.py
CHANGED
|
@@ -63,8 +63,7 @@ class GenerateRunner():
|
|
| 63 |
# 加载模型
|
| 64 |
file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt')
|
| 65 |
if opt.model_choice == 'transformer':
|
| 66 |
-
|
| 67 |
-
self.model = EncoderDecoder.load_from_file(file_name, map_location=torch.device('cpu'))
|
| 68 |
self.model.to(self.device)
|
| 69 |
self.model.eval()
|
| 70 |
elif opt.model_choice == 'seq2seq':
|
|
|
|
| 63 |
# 加载模型
|
| 64 |
file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt')
|
| 65 |
if opt.model_choice == 'transformer':
|
| 66 |
+
self.model = EncoderDecoder.load_from_file(file_name)
|
|
|
|
| 67 |
self.model.to(self.device)
|
| 68 |
self.model.eval()
|
| 69 |
elif opt.model_choice == 'seq2seq':
|