Gosse Minnema
commited on
Commit
·
adbf6a0
1
Parent(s):
34cc06e
Allow user to set confidence threshold for frames + roles, default to 0.95
Browse files
sociolome/lome_webserver.py
CHANGED
|
@@ -47,7 +47,7 @@ def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, A
|
|
| 47 |
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
|
| 48 |
return labels
|
| 49 |
|
| 50 |
-
def make_prediction(sentence, spacy_model, predictor):
|
| 51 |
spacy_doc = spacy_model(sentence)
|
| 52 |
tokens = [t.text for t in spacy_doc]
|
| 53 |
tgt_spans, fr_labels, fr_probas = predictor.force_decode(tokens)
|
|
@@ -59,7 +59,7 @@ def make_prediction(sentence, spacy_model, predictor):
|
|
| 59 |
continue
|
| 60 |
if frm.upper() == frm:
|
| 61 |
continue
|
| 62 |
-
if fr_proba.max() !=
|
| 63 |
continue
|
| 64 |
|
| 65 |
arg_spans, arg_labels, label_probas = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
|
|
@@ -70,7 +70,7 @@ def make_prediction(sentence, spacy_model, predictor):
|
|
| 70 |
"roles": [
|
| 71 |
{"boundary": bnd, "label": label}
|
| 72 |
for bnd, label, probas in zip(arg_spans, arg_labels, label_probas)
|
| 73 |
-
if label != "Target" and max(probas) ==
|
| 74 |
]
|
| 75 |
}
|
| 76 |
|
|
@@ -96,9 +96,10 @@ app = Flask(__name__)
|
|
| 96 |
@app.route("/analyze")
|
| 97 |
def analyze():
|
| 98 |
text = request.args.get("text")
|
|
|
|
| 99 |
analyses = []
|
| 100 |
for sentence in text.split("\n"):
|
| 101 |
-
analyses.append(make_prediction(sentence, nlp, predictor))
|
| 102 |
|
| 103 |
return jsonify({
|
| 104 |
"result": "OK",
|
|
|
|
| 47 |
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
|
| 48 |
return labels
|
| 49 |
|
| 50 |
+
def make_prediction(sentence, spacy_model, predictor, confidence_threshold):
|
| 51 |
spacy_doc = spacy_model(sentence)
|
| 52 |
tokens = [t.text for t in spacy_doc]
|
| 53 |
tgt_spans, fr_labels, fr_probas = predictor.force_decode(tokens)
|
|
|
|
| 59 |
continue
|
| 60 |
if frm.upper() == frm:
|
| 61 |
continue
|
| 62 |
+
if fr_proba.max() != confidence_threshold:
|
| 63 |
continue
|
| 64 |
|
| 65 |
arg_spans, arg_labels, label_probas = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
|
|
|
|
| 70 |
"roles": [
|
| 71 |
{"boundary": bnd, "label": label}
|
| 72 |
for bnd, label, probas in zip(arg_spans, arg_labels, label_probas)
|
| 73 |
+
if label != "Target" and max(probas) == confidence_threshold
|
| 74 |
]
|
| 75 |
}
|
| 76 |
|
|
|
|
| 96 |
@app.route("/analyze")
|
| 97 |
def analyze():
|
| 98 |
text = request.args.get("text")
|
| 99 |
+
confidence_threshold = float(request.args.get("confidence_threshold", 0.95))
|
| 100 |
analyses = []
|
| 101 |
for sentence in text.split("\n"):
|
| 102 |
+
analyses.append(make_prediction(sentence, nlp, predictor, confidence_threshold))
|
| 103 |
|
| 104 |
return jsonify({
|
| 105 |
"result": "OK",
|