Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,13 +14,40 @@ def prepare_model():
|
|
| 14 |
return (tokenizer, model)
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def process(text):
|
| 18 |
"""
|
| 19 |
Translate incoming text to tokens and classify it
|
| 20 |
"""
|
| 21 |
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
| 22 |
result = pipe(text)[0]
|
| 23 |
-
return result
|
| 24 |
|
| 25 |
|
| 26 |
tokenizer, model = prepare_model()
|
|
@@ -105,4 +132,4 @@ text = "\n".join([title, abstract])
|
|
| 105 |
## Output
|
| 106 |
|
| 107 |
if len(text.strip()) > 0:
|
| 108 |
-
st.markdown(f"
|
|
|
|
| 14 |
return (tokenizer, model)
|
| 15 |
|
| 16 |
|
| 17 |
+
def top_pct(preds, threshold=0.95):
|
| 18 |
+
"""
|
| 19 |
+
Output top predictions and their scores
|
| 20 |
+
"""
|
| 21 |
+
preds = sorted(preds, key=lambda x: -x["score"])
|
| 22 |
+
|
| 23 |
+
cum_score = 0
|
| 24 |
+
for i, item in enumerate(preds):
|
| 25 |
+
cum_score += item["score"]
|
| 26 |
+
if cum_score >= threshold:
|
| 27 |
+
break
|
| 28 |
+
|
| 29 |
+
preds = preds[: (i + 1)]
|
| 30 |
+
|
| 31 |
+
return preds
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def format_predictions(preds) -> str:
|
| 35 |
+
"""
|
| 36 |
+
Prepare predictions and their scores for printing to the user
|
| 37 |
+
"""
|
| 38 |
+
out = ""
|
| 39 |
+
for i, item in enumerate(preds):
|
| 40 |
+
out += f"{i+1}. **{item['label']}** *(score {item['score']:.2f})*\n"
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
def process(text):
|
| 45 |
"""
|
| 46 |
Translate incoming text to tokens and classify it
|
| 47 |
"""
|
| 48 |
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
| 49 |
result = pipe(text)[0]
|
| 50 |
+
return format_predictions(top_pct(result))
|
| 51 |
|
| 52 |
|
| 53 |
tokenizer, model = prepare_model()
|
|
|
|
| 132 |
## Output
|
| 133 |
|
| 134 |
if len(text.strip()) > 0:
|
| 135 |
+
st.markdown(f"{process(text)}", unsafe_allow_html=True)
|