Spaces:
Runtime error
Runtime error
Commit
·
bbf7aff
1
Parent(s):
fd97dee
inference code
Browse files
app.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
|
|
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
# load the dataset and
|
| 6 |
# use the patent number, abstract and claim columns for UI
|
| 7 |
with st.spinner("Setting up the app..."):
|
|
@@ -16,11 +20,6 @@ with st.spinner("Setting up the app..."):
|
|
| 16 |
val_filing_end_date="2016-01-31",
|
| 17 |
)
|
| 18 |
|
| 19 |
-
# widget for selecting our finetuned langugae model
|
| 20 |
-
language_model_path = "juliaannjose/finetuned_model"
|
| 21 |
-
|
| 22 |
-
# pass the model to transformers pipeline - model selection component.
|
| 23 |
-
classifier_model = pipeline(model=language_model_path)
|
| 24 |
|
| 25 |
# drop down menu with patent numbers
|
| 26 |
_patent_id = st.selectbox(
|
|
@@ -28,19 +27,39 @@ _patent_id = st.selectbox(
|
|
| 28 |
dataset_dict["train"]["patent_number"],
|
| 29 |
)
|
| 30 |
|
|
|
|
| 31 |
# display abstract and claim
|
| 32 |
@st.cache(persist=True)
|
| 33 |
def get_abs_claim(_patent_id):
|
| 34 |
# get abstract and claim corresponding to this patent id
|
| 35 |
_abstract = dataset_dict["train"][["patent_number"] == _patent_id]["abstract"]
|
| 36 |
_claim = dataset_dict["train"][["patent_number"] == _patent_id]["claims"]
|
| 37 |
-
return _abstract,_claim
|
|
|
|
| 38 |
|
| 39 |
-
_abstract,_claim = get_abs_claim(_patent_id)
|
| 40 |
st.write(_abstract)
|
| 41 |
st.write(_claim)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# when submit button clicked, run the model and get result
|
| 44 |
if st.button("Submit"):
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 4 |
from datasets import load_dataset
|
| 5 |
|
| 6 |
+
# finetuned model
|
| 7 |
+
language_model_path = "juliaannjose/finetuned_model"
|
| 8 |
+
|
| 9 |
# load the dataset and
|
| 10 |
# use the patent number, abstract and claim columns for UI
|
| 11 |
with st.spinner("Setting up the app..."):
|
|
|
|
| 20 |
val_filing_end_date="2016-01-31",
|
| 21 |
)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# drop down menu with patent numbers
|
| 25 |
_patent_id = st.selectbox(
|
|
|
|
| 27 |
dataset_dict["train"]["patent_number"],
|
| 28 |
)
|
| 29 |
|
| 30 |
+
|
| 31 |
# display abstract and claim
|
| 32 |
@st.cache(persist=True)
|
| 33 |
def get_abs_claim(_patent_id):
|
| 34 |
# get abstract and claim corresponding to this patent id
|
| 35 |
_abstract = dataset_dict["train"][["patent_number"] == _patent_id]["abstract"]
|
| 36 |
_claim = dataset_dict["train"][["patent_number"] == _patent_id]["claims"]
|
| 37 |
+
return _abstract, _claim
|
| 38 |
+
|
| 39 |
|
| 40 |
+
_abstract, _claim = get_abs_claim(_patent_id)
|
| 41 |
st.write(_abstract)
|
| 42 |
st.write(_claim)
|
| 43 |
|
| 44 |
+
input_text = _abstract + _claim
|
| 45 |
+
|
| 46 |
+
# model and tokenizer initialization
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained(language_model_path)
|
| 48 |
+
inputs = tokenizer(
|
| 49 |
+
input_text,
|
| 50 |
+
truncation=True,
|
| 51 |
+
padding=True,
|
| 52 |
+
return_tensors="pt",
|
| 53 |
+
)
|
| 54 |
+
model = AutoModelForSequenceClassification.from_pretrained(language_model_path)
|
| 55 |
+
|
| 56 |
+
# get predictions
|
| 57 |
+
id2label = {0: "REJECTED", 1: "ACCEPTED"}
|
| 58 |
# when submit button clicked, run the model and get result
|
| 59 |
if st.button("Submit"):
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
logits = model(**inputs).logits
|
| 62 |
+
|
| 63 |
+
predicted_class_id = logits.argmax().item()
|
| 64 |
+
pred_label = id2label[predicted_class_id]
|
| 65 |
+
st.write(pred_label)
|