席亚东 commited on
Commit ·
fc6a062
1
Parent(s): 05e5527
fix the bug in inferen.py
Browse files- inference.py +4 -4
inference.py
CHANGED
|
@@ -87,10 +87,10 @@ class Inference(object):
|
|
| 87 |
|
| 88 |
model_path = args.path
|
| 89 |
checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
|
| 90 |
-
checkpoint["model"].update(model_path.replace("best.pt", "best_part_2.pt"))
|
| 91 |
-
checkpoint["model"].update(model_path.replace("best.pt", "best_part_3.pt"))
|
| 92 |
torch.save(checkpoint, model_path)
|
| 93 |
-
|
| 94 |
state = torch.load(args.path, map_location=torch.device("cpu"))
|
| 95 |
cfg_args = eval(str(state["cfg"]))["model"]
|
| 96 |
del cfg_args["_name"]
|
|
@@ -178,4 +178,4 @@ class Inference(object):
|
|
| 178 |
score = hypo['score'] / math.log(2) # convert to base 2
|
| 179 |
tmp_res.append([detok_hypo_str, score])
|
| 180 |
final_results.append(tmp_res)
|
| 181 |
-
return final_results
|
|
|
|
| 87 |
|
| 88 |
model_path = args.path
|
| 89 |
checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
|
| 90 |
+
checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_2.pt")))
|
| 91 |
+
checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_3.pt")))
|
| 92 |
torch.save(checkpoint, model_path)
|
| 93 |
+
|
| 94 |
state = torch.load(args.path, map_location=torch.device("cpu"))
|
| 95 |
cfg_args = eval(str(state["cfg"]))["model"]
|
| 96 |
del cfg_args["_name"]
|
|
|
|
| 178 |
score = hypo['score'] / math.log(2) # convert to base 2
|
| 179 |
tmp_res.append([detok_hypo_str, score])
|
| 180 |
final_results.append(tmp_res)
|
| 181 |
+
return final_results
|