Ginidu2003 commited on
Commit
6472126
Β·
verified Β·
1 Parent(s): 191b0d0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +13 -25
src/streamlit_app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import torch
4
  from transformers import pipeline
5
  import nltk
6
  from nltk.corpus import stopwords
@@ -8,8 +8,6 @@ from nltk.stem import WordNetLemmatizer
8
  import re
9
  import string
10
 
11
- st.set_page_config(page_title="Daily Mirror News Classifier", page_icon="πŸ“°")
12
-
13
  # ====================== PREPROCESSING ======================
14
  nltk.download('stopwords', quiet=True)
15
  nltk.download('wordnet', quiet=True)
@@ -30,31 +28,18 @@ def preprocess_text(text):
30
  return ' '.join(tokens)
31
 
32
  # ====================== LOAD MODEL ======================
33
- @st.cache_resource(show_spinner=False)
34
  def load_model():
35
- model_name = "Ginidu2003/Distilbert-Base-News-classifier"
36
- hf_token = st.secrets.get("HF_TOKEN") # Reads the secret you added
37
-
38
- try:
39
- pipe = pipeline(
40
- "text-classification",
41
- model=model_name,
42
- token=hf_token, # ← This fixes most 403 errors
43
- device=0 if torch.cuda.is_available() else -1
44
- )
45
- st.success("βœ… Model loaded successfully!")
46
- return pipe
47
- except Exception as e:
48
- st.error("❌ Failed to load model")
49
- st.error(str(e))
50
- return None
51
 
52
  classifier = load_model()
53
 
54
- if classifier is None:
55
- st.stop()
56
-
57
- # ====================== APP ======================
58
  st.title("πŸ“° Daily Mirror News Classifier")
59
  st.subheader("Classify news into Business, Opinion, Political Gossip, Sports, or World News")
60
 
@@ -64,6 +49,7 @@ uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
64
 
65
  if uploaded_file is not None:
66
  df = pd.read_csv(uploaded_file)
 
67
  st.write("### Preview of uploaded data")
68
  st.dataframe(df.head())
69
 
@@ -75,7 +61,7 @@ if uploaded_file is not None:
75
 
76
  predictions = []
77
  for text in df['clean_content']:
78
- if not text.strip():
79
  predictions.append("Unknown")
80
  else:
81
  result = classifier(text)[0]
@@ -85,8 +71,10 @@ if uploaded_file is not None:
85
  df = df.drop(columns=['clean_content'], errors='ignore')
86
 
87
  st.success("βœ… Classification completed!")
 
88
  st.dataframe(df.head())
89
 
 
90
  csv = df.to_csv(index=False).encode('utf-8')
91
  st.download_button(
92
  label="πŸ“₯ Download output.csv",
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import torch # ← This was missing
4
  from transformers import pipeline
5
  import nltk
6
  from nltk.corpus import stopwords
 
8
  import re
9
  import string
10
 
 
 
11
  # ====================== PREPROCESSING ======================
12
  nltk.download('stopwords', quiet=True)
13
  nltk.download('wordnet', quiet=True)
 
28
  return ' '.join(tokens)
29
 
30
  # ====================== LOAD MODEL ======================
31
+ @st.cache_resource
32
  def load_model():
33
+ model_name = "Ginidu2003/Distilbert-Base-News-classifier" # ← Change if your model name is different
34
+ return pipeline(
35
+ "text-classification",
36
+ model=model_name,
37
+ device=0 if torch.cuda.is_available() else -1
38
+ )
 
 
 
 
 
 
 
 
 
 
39
 
40
  classifier = load_model()
41
 
42
+ # ====================== STREAMLIT APP ======================
 
 
 
43
  st.title("πŸ“° Daily Mirror News Classifier")
44
  st.subheader("Classify news into Business, Opinion, Political Gossip, Sports, or World News")
45
 
 
49
 
50
  if uploaded_file is not None:
51
  df = pd.read_csv(uploaded_file)
52
+
53
  st.write("### Preview of uploaded data")
54
  st.dataframe(df.head())
55
 
 
61
 
62
  predictions = []
63
  for text in df['clean_content']:
64
+ if text.strip() == "":
65
  predictions.append("Unknown")
66
  else:
67
  result = classifier(text)[0]
 
71
  df = df.drop(columns=['clean_content'], errors='ignore')
72
 
73
  st.success("βœ… Classification completed!")
74
+ st.write("### Preview of classified data")
75
  st.dataframe(df.head())
76
 
77
+ # Download button
78
  csv = df.to_csv(index=False).encode('utf-8')
79
  st.download_button(
80
  label="πŸ“₯ Download output.csv",