ChocoLord commited on
Commit
ae436d1
·
1 Parent(s): 6a8432c

Add Streamlit app

Browse files
Files changed (4) hide show
  1. Dockerfile +16 -0
  2. README.md +1 -0
  3. app.py +100 -0
  4. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+ ENV PYTHONUNBUFFERED=1
7
+ ENV PIP_NO_CACHE_DIR=1
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --upgrade pip && pip install -r requirements.txt
11
+
12
+ COPY . .
13
+
14
+ EXPOSE 7860
15
+
16
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -4,6 +4,7 @@ emoji: 🏢
4
  colorFrom: red
5
  colorTo: green
6
  sdk: docker
 
7
  pinned: false
8
  short_description: Classifies arxiv paper
9
  ---
 
4
  colorFrom: red
5
  colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  short_description: Classifies arxiv paper
10
  ---
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import pandas as pd
5
+ import streamlit as st
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ import plotly.express as px
9
+
10
+ MODEL_REPO = os.getenv("MODEL_REPO", "ChocoLord/paper-classifier-model")
11
+ MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
12
+ TOP_P = float(os.getenv("TOP_P", "0.95"))
13
+
14
+ st.set_page_config(page_title="Paper classifier", layout="wide")
15
+ st.title("Paper classifier")
16
+
17
+ @st.cache_resource
18
+ def load_artifacts():
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
20
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
21
+ model.eval()
22
+
23
+ id2label = model.config.id2label
24
+ if id2label is None or len(id2label) == 0:
25
+ raise ValueError("Model config must contain id2label.")
26
+
27
+ id2label = {int(k): v for k, v in id2label.items()} if not isinstance(list(id2label.keys())[0], int) else id2label
28
+ return tokenizer, model, id2label
29
+
30
+ tokenizer, model, id2label = load_artifacts()
31
+
32
+ def predict(title: str, summary: str):
33
+ title = title or ""
34
+ summary = summary or ""
35
+ text = f"{title}\n{summary}".strip()
36
+
37
+ inputs = tokenizer(
38
+ text,
39
+ truncation=True,
40
+ padding="max_length",
41
+ max_length=MAX_LENGTH,
42
+ return_tensors="pt",
43
+ )
44
+
45
+ with torch.no_grad():
46
+ logits = model(**inputs).logits
47
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
48
+
49
+ labels = [id2label[i] for i in range(len(probs))]
50
+ df = pd.DataFrame({
51
+ "class_name": labels,
52
+ "predicted_proba": probs,
53
+ }).sort_values("predicted_proba", ascending=False).reset_index(drop=True)
54
+
55
+ df["cumsum"] = df["predicted_proba"].cumsum()
56
+ cutoff_idx = int(np.searchsorted(df["cumsum"].values, TOP_P, side="left"))
57
+ selected_df = df.iloc[:cutoff_idx + 1].copy()
58
+
59
+ return df, selected_df
60
+
61
+ title = st.text_input("Title")
62
+ summary = st.text_area("Summary", height=250)
63
+
64
+ n_value = st.number_input("Max classes to display in text output", min_value=1, max_value=100, value=20, step=1)
65
+
66
+ if st.button("Classify", type="primary"):
67
+ if not title.strip() and not summary.strip():
68
+ st.warning("Enter title and/or summary.")
69
+ else:
70
+ df, selected_df = predict(title, summary)
71
+
72
+ st.subheader("Selected classes")
73
+ st.write(
74
+ f"Top classes whose cumulative predicted probability reaches at least {TOP_P:.2f}. "
75
+ f"Selected {len(selected_df)} classes with total probability {selected_df['predicted_proba'].sum():.4f}."
76
+ )
77
+
78
+ text_df = selected_df.head(int(n_value)).copy()
79
+ lines = [
80
+ f"{i+1}. {row.class_name} — {row.predicted_proba:.4f}"
81
+ for i, row in text_df.iterrows()
82
+ ]
83
+ st.text("\n".join(lines))
84
+
85
+ st.subheader("Probability bar chart")
86
+ fig = px.bar(
87
+ df,
88
+ x="class_name",
89
+ y="predicted_proba",
90
+ hover_data=["cumsum"],
91
+ )
92
+ fig.update_layout(
93
+ xaxis_title="Class",
94
+ yaxis_title="Predicted probability",
95
+ xaxis_tickangle=-45,
96
+ )
97
+ st.plotly_chart(fig, use_container_width=True)
98
+
99
+ with st.expander("Full sorted predictions"):
100
+ st.dataframe(df, use_container_width=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ pandas
5
+ numpy
6
+ plotly
7
+ sentencepiece
8
+ safetensors