Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from datasets import load_dataset | |
| # finetuned model | |
| language_model_path = "juliaannjose/finetuned_model" | |
| # load the dataset and | |
| # use the patent number, abstract and claim columns for UI | |
| with st.spinner("Setting up the app..."): | |
| dataset_dict = load_dataset( | |
| "HUPD/hupd", | |
| name="sample", | |
| data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather", | |
| icpr_label=None, | |
| train_filing_start_date="2016-01-01", | |
| train_filing_end_date="2016-01-21", | |
| val_filing_start_date="2016-01-22", | |
| val_filing_end_date="2016-01-31", | |
| ) | |
| # drop down menu with patent numbers | |
| _patent_id = st.selectbox( | |
| "Select the Patent Number", | |
| dataset_dict["train"]["patent_number"], | |
| ) | |
| # display abstract and claim | |
| def get_abs_claim(_patent_id): | |
| # get abstract and claim corresponding to this patent id | |
| _abstract = dataset_dict["train"][["patent_number"] == _patent_id]["abstract"] | |
| _claim = dataset_dict["train"][["patent_number"] == _patent_id]["claims"] | |
| return _abstract, _claim | |
| _abstract, _claim = get_abs_claim(_patent_id) | |
| st.write(_abstract) | |
| st.write(_claim) | |
| input_text = _abstract + _claim | |
| # model and tokenizer initialization | |
| tokenizer = AutoTokenizer.from_pretrained(language_model_path) | |
| inputs = tokenizer( | |
| input_text, | |
| truncation=True, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| model = AutoModelForSequenceClassification.from_pretrained(language_model_path) | |
| # get predictions | |
| id2label = {0: "REJECTED", 1: "ACCEPTED"} | |
| # when submit button clicked, run the model and get result | |
| if st.button("Submit"): | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| predicted_class_id = logits.argmax().item() | |
| pred_label = id2label[predicted_class_id] | |
| st.write(pred_label) | |