ChocoLord commited on
Commit
cc2e31c
·
1 Parent(s): be48c0f

Delete num max classes

Browse files
Files changed (1) hide show
  1. app.py +9 -27
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import os
2
- import json
3
  import numpy as np
4
  import pandas as pd
5
  import streamlit as st
@@ -7,9 +5,9 @@ import torch
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
  import plotly.express as px
9
 
10
- MODEL_REPO = os.getenv("MODEL_REPO", "ChocoLord/paper-classifier-model")
11
- MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
12
- TOP_P = float(os.getenv("TOP_P", "0.95"))
13
 
14
  st.set_page_config(page_title="Paper classifier", layout="wide")
15
  st.title("Paper classifier")
@@ -19,20 +17,12 @@ def load_artifacts():
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
20
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
21
  model.eval()
 
22
 
23
- id2label = model.config.id2label
24
- if id2label is None or len(id2label) == 0:
25
- raise ValueError("Model config must contain id2label.")
26
-
27
- id2label = {int(k): v for k, v in id2label.items()} if not isinstance(list(id2label.keys())[0], int) else id2label
28
- return tokenizer, model, id2label
29
-
30
- tokenizer, model, id2label = load_artifacts()
31
 
32
  def predict(title: str, summary: str):
33
- title = title or ""
34
- summary = summary or ""
35
- text = f"{title}\n{summary}".strip()
36
 
37
  inputs = tokenizer(
38
  text,
@@ -46,7 +36,8 @@ def predict(title: str, summary: str):
46
  logits = model(**inputs).logits
47
  probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
48
 
49
- labels = [id2label[i] for i in range(len(probs))]
 
50
  df = pd.DataFrame({
51
  "class_name": labels,
52
  "predicted_proba": probs,
@@ -61,8 +52,6 @@ def predict(title: str, summary: str):
61
  title = st.text_input("Title")
62
  summary = st.text_area("Summary", height=250)
63
 
64
- n_value = st.number_input("Max classes to display in text output", min_value=1, max_value=100, value=20, step=1)
65
-
66
  if st.button("Classify", type="primary"):
67
  if not title.strip() and not summary.strip():
68
  st.warning("Enter title and/or summary.")
@@ -70,15 +59,9 @@ if st.button("Classify", type="primary"):
70
  df, selected_df = predict(title, summary)
71
 
72
  st.subheader("Selected classes")
73
- st.write(
74
- f"Top classes whose cumulative predicted probability reaches at least {TOP_P:.2f}. "
75
- f"Selected {len(selected_df)} classes with total probability {selected_df['predicted_proba'].sum():.4f}."
76
- )
77
-
78
- text_df = selected_df.head(int(n_value)).copy()
79
  lines = [
80
  f"{i+1}. {row.class_name} — {row.predicted_proba:.4f}"
81
- for i, row in text_df.iterrows()
82
  ]
83
  st.text("\n".join(lines))
84
 
@@ -87,7 +70,6 @@ if st.button("Classify", type="primary"):
87
  df,
88
  x="class_name",
89
  y="predicted_proba",
90
- hover_data=["cumsum"],
91
  )
92
  fig.update_layout(
93
  xaxis_title="Class",
 
 
 
1
  import numpy as np
2
  import pandas as pd
3
  import streamlit as st
 
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  import plotly.express as px
7
 
8
+ MODEL_REPO = "ChocoLord/paper-classifier-model"
9
+ MAX_LENGTH = 512
10
+ TOP_P = 0.95
11
 
12
  st.set_page_config(page_title="Paper classifier", layout="wide")
13
  st.title("Paper classifier")
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
18
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
19
  model.eval()
20
+ return tokenizer, model
21
 
22
+ tokenizer, model = load_artifacts()
 
 
 
 
 
 
 
23
 
24
  def predict(title: str, summary: str):
25
+ text = f"{title or ''}\n{summary or ''}".strip()
 
 
26
 
27
  inputs = tokenizer(
28
  text,
 
36
  logits = model(**inputs).logits
37
  probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
38
 
39
+ labels = [model.config.id2label[i] for i in range(len(probs))]
40
+
41
  df = pd.DataFrame({
42
  "class_name": labels,
43
  "predicted_proba": probs,
 
52
  title = st.text_input("Title")
53
  summary = st.text_area("Summary", height=250)
54
 
 
 
55
  if st.button("Classify", type="primary"):
56
  if not title.strip() and not summary.strip():
57
  st.warning("Enter title and/or summary.")
 
59
  df, selected_df = predict(title, summary)
60
 
61
  st.subheader("Selected classes")
 
 
 
 
 
 
62
  lines = [
63
  f"{i+1}. {row.class_name} — {row.predicted_proba:.4f}"
64
+ for i, row in selected_df.iterrows()
65
  ]
66
  st.text("\n".join(lines))
67
 
 
70
  df,
71
  x="class_name",
72
  y="predicted_proba",
 
73
  )
74
  fig.update_layout(
75
  xaxis_title="Class",