席亚东
commited on
Commit
·
16bf127
1
Parent(s):
ef2abea
fix the bug in inference.py
Browse files- inference.py +7 -2
inference.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch
|
|
| 7 |
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
|
| 9 |
from fairseq import checkpoint_utils, options, tasks, utils
|
|
|
|
| 10 |
|
| 11 |
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
|
| 12 |
|
|
@@ -77,7 +78,12 @@ class Inference(object):
|
|
| 77 |
use_cuda = torch.cuda.is_available() and not args.cpu
|
| 78 |
self.use_cuda = use_cuda
|
| 79 |
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
state = torch.load(args.path, map_location=torch.device("cpu"))
|
| 82 |
cfg_args = eval(str(state["cfg"]))["model"]
|
| 83 |
del cfg_args["_name"]
|
|
@@ -97,7 +103,6 @@ class Inference(object):
|
|
| 97 |
"max_batch":eet_batch_size,
|
| 98 |
"full_seq_len":eet_seq_len}
|
| 99 |
print(model_args)
|
| 100 |
-
from eet.fairseq.transformer import EETTransformerDecoder
|
| 101 |
eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path = args.path,
|
| 102 |
dictionary = self.src_dict,args=model_args,
|
| 103 |
config = eet_config,
|
|
|
|
| 7 |
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
|
| 9 |
from fairseq import checkpoint_utils, options, tasks, utils
|
| 10 |
+
from eet.fairseq.transformer import EETTransformerDecoder
|
| 11 |
|
| 12 |
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
|
| 13 |
|
|
|
|
| 78 |
use_cuda = torch.cuda.is_available() and not args.cpu
|
| 79 |
self.use_cuda = use_cuda
|
| 80 |
|
| 81 |
+
model_path = args.path
|
| 82 |
+
checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
|
| 83 |
+
checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_2.pt")))
|
| 84 |
+
checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_3.pt")))
|
| 85 |
+
torch.save(checkpoint, model_path)
|
| 86 |
+
|
| 87 |
state = torch.load(args.path, map_location=torch.device("cpu"))
|
| 88 |
cfg_args = eval(str(state["cfg"]))["model"]
|
| 89 |
del cfg_args["_name"]
|
|
|
|
| 103 |
"max_batch":eet_batch_size,
|
| 104 |
"full_seq_len":eet_seq_len}
|
| 105 |
print(model_args)
|
|
|
|
| 106 |
eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path = args.path,
|
| 107 |
dictionary = self.src_dict,args=model_args,
|
| 108 |
config = eet_config,
|