Add scores in output
Browse files- pipeline.py +12 -2
pipeline.py
CHANGED
|
@@ -25,6 +25,8 @@ class ExtSummPipeline(Pipeline):
|
|
| 25 |
strategy_args : any
|
| 26 |
Parameters of the strategy.
|
| 27 |
|
|
|
|
|
|
|
| 28 |
Outputs
|
| 29 |
-------
|
| 30 |
selected_sents : list[str]
|
|
@@ -32,6 +34,8 @@ class ExtSummPipeline(Pipeline):
|
|
| 32 |
|
| 33 |
selected_idxs : list[int]
|
| 34 |
List of the indexes of the selected sentences in the original input
|
|
|
|
|
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
|
|
@@ -44,6 +48,8 @@ class ExtSummPipeline(Pipeline):
|
|
| 44 |
postprocess_kwargs["strategy"] = kwargs["strategy"]
|
| 45 |
if "strategy_args" in kwargs:
|
| 46 |
postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]
|
|
|
|
|
|
|
| 47 |
|
| 48 |
return {}, {}, postprocess_kwargs
|
| 49 |
|
|
@@ -95,7 +101,11 @@ class ExtSummPipeline(Pipeline):
|
|
| 95 |
return { "predictions": out_predictions, "sentences": sentences }
|
| 96 |
|
| 97 |
|
| 98 |
-
def postprocess(self, args, strategy: str="count", strategy_args=3):
|
| 99 |
predictions = args["predictions"]
|
| 100 |
sentences = args["sentences"]
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
strategy_args : any
|
| 26 |
Parameters of the strategy.
|
| 27 |
|
| 28 |
+
out_scores : bool
|
| 29 |
+
If True, the score for each sentence is returned.
|
| 30 |
Outputs
|
| 31 |
-------
|
| 32 |
selected_sents : list[str]
|
|
|
|
| 34 |
|
| 35 |
selected_idxs : list[int]
|
| 36 |
List of the indexes of the selected sentences in the original input
|
| 37 |
+
|
| 38 |
+
sents_scores : Tensor (optional)
|
| 39 |
"""
|
| 40 |
|
| 41 |
|
|
|
|
| 48 |
postprocess_kwargs["strategy"] = kwargs["strategy"]
|
| 49 |
if "strategy_args" in kwargs:
|
| 50 |
postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]
|
| 51 |
+
if "out_scores" in kwargs:
|
| 52 |
+
postprocess_kwargs["out_scores"] = kwargs["out_scores"]
|
| 53 |
|
| 54 |
return {}, {}, postprocess_kwargs
|
| 55 |
|
|
|
|
| 101 |
return { "predictions": out_predictions, "sentences": sentences }
|
| 102 |
|
| 103 |
|
| 104 |
+
def postprocess(self, args, strategy: str="count", strategy_args=3, out_scores=False):
|
| 105 |
predictions = args["predictions"]
|
| 106 |
sentences = args["sentences"]
|
| 107 |
+
out = select(sentences, predictions, strategy, strategy_args)
|
| 108 |
+
|
| 109 |
+
if out_scores: out += (predictions,)
|
| 110 |
+
|
| 111 |
+
return out
|