whytimmy commited on
Commit
84aba52
·
verified ·
1 Parent(s): f026eaa

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +9 -5
src/streamlit_app.py CHANGED
@@ -2,6 +2,10 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
  import json
 
 
 
 
5
 
6
  st.set_page_config(
7
  page_title="Arxiv Classifier",
@@ -14,15 +18,15 @@ st.title("Arxiv Classifier")
14
 
15
  @st.cache_resource
16
  def load_model():
17
- model = AutoModelForSequenceClassification.from_pretrained("./arxiv_dir")
18
- tokenizer = AutoTokenizer.from_pretrained("./arxiv_dir")
19
- with open("./arxiv_dir/id2tag.json") as f:
20
  id2tag = {int(k): v for k, v in json.load(f).items()}
21
- with open("./arxiv_dir/tag2name.json") as f:
22
  tag2name = json.load(f)
23
  model.eval()
24
  return model, tokenizer, id2tag, tag2name
25
-
26
  model, tokenizer, id2tag, tag2name = load_model()
27
 
28
  def predict_top95(title, summary=None):
 
2
  import torch
3
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
  import json
5
+ import os
6
+
7
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
8
+ MODEL_DIR = os.path.join(BASE_DIR, "arxiv_dir")
9
 
10
  st.set_page_config(
11
  page_title="Arxiv Classifier",
 
18
 
19
  @st.cache_resource
20
  def load_model():
21
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
23
+ with open(os.path.join(MODEL_DIR, "id2tag.json")) as f:
24
  id2tag = {int(k): v for k, v in json.load(f).items()}
25
+ with open(os.path.join(MODEL_DIR, "tag2name.json")) as f:
26
  tag2name = json.load(f)
27
  model.eval()
28
  return model, tokenizer, id2tag, tag2name
29
+
30
  model, tokenizer, id2tag, tag2name = load_model()
31
 
32
  def predict_top95(title, summary=None):