MurDanya commited on
Commit
f0de0e1
·
verified ·
1 Parent(s): f21e94b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -1,18 +1,23 @@
1
- # app.py
2
  import streamlit as st
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
4
  import torch
5
  import numpy as np
6
  import json
7
 
8
  @st.cache_resource
9
  def load_model():
10
- model = AutoModelForSequenceClassification.from_pretrained("MurDanya/ml-course-article-classifier-scibert")
11
- tokenizer = AutoTokenizer.from_pretrained("MurDanya/ml-course-article-classifier-scibert")
12
- with open("labels.json") as f:
13
- id2label = json.load(f)
14
- id2label = {int(idx): label for idx, label in id2label.items()}
15
- return tokenizer, model, id2label
 
 
 
 
 
16
 
17
  def get_top95(labels, probs):
18
  sorted_indices = torch.argsort(probs, descending=True)
@@ -20,7 +25,7 @@ def get_top95(labels, probs):
20
  sorted_labels = [labels[i.item()] for i in sorted_indices]
21
 
22
  cumulative = torch.cumsum(sorted_probs, dim=0)
23
- cutoff = torch.where(cumulative >= 0.95)[0]
24
  last_idx = cutoff[0].item() + 1 if len(cutoff) > 0 else len(sorted_probs)
25
 
26
  return list(zip(sorted_labels[:last_idx], sorted_probs[:last_idx].tolist()))
@@ -37,9 +42,9 @@ if st.button("Classify"):
37
  if not title and not abstract:
38
  st.warning("Please enter at least the title.")
39
  else:
40
- tokenizer, model, id2label = load_model()
41
-
42
- text = title + ". " + abstract if abstract else title
43
  inputs = tokenizer(text, return_tensors="pt", truncation=True)
44
  with torch.no_grad():
45
  outputs = model(**inputs)
@@ -47,6 +52,5 @@ if st.button("Classify"):
47
 
48
  top_labels = get_top95(id2label, probs)
49
 
50
- st.subheader("📚 Top topics (95% confidence)")
51
  for label, prob in top_labels:
52
- st.markdown(f"- **{label}**: {prob:.3f}")
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ from huggingface_hub import hf_hub_download
4
  import torch
5
  import numpy as np
6
  import json
7
 
8
  @st.cache_resource
9
  def load_model():
10
+ repo_id = "MurDanya/ml-course-article-classifier-distilbert"
11
+ model = AutoModelForSequenceClassification.from_pretrained(repo_id)
12
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
13
+
14
+ file_path = hf_hub_download(repo_id, "labels.json")
15
+ with open(file_path) as f:
16
+ labels = json.load(f)
17
+ id2label = {int(idx): label for idx, label in labels['id2label'].items()}
18
+ categories = labels['categories']
19
+
20
+ return tokenizer, model, id2label, categories
21
 
22
  def get_top95(labels, probs):
23
  sorted_indices = torch.argsort(probs, descending=True)
 
25
  sorted_labels = [labels[i.item()] for i in sorted_indices]
26
 
27
  cumulative = torch.cumsum(sorted_probs, dim=0)
28
+ cutoff = torch.where(cumulative >= 0.8)[0]
29
  last_idx = cutoff[0].item() + 1 if len(cutoff) > 0 else len(sorted_probs)
30
 
31
  return list(zip(sorted_labels[:last_idx], sorted_probs[:last_idx].tolist()))
 
42
  if not title and not abstract:
43
  st.warning("Please enter at least the title.")
44
  else:
45
+ tokenizer, model, id2label, categories = load_model()
46
+
47
+ text = title + " - " + abstract if abstract else title
48
  inputs = tokenizer(text, return_tensors="pt", truncation=True)
49
  with torch.no_grad():
50
  outputs = model(**inputs)
 
52
 
53
  top_labels = get_top95(id2label, probs)
54
 
 
55
  for label, prob in top_labels:
56
+ print(f"- **{categories[label]} ({label})**: {prob * 100:.1f}%")