Ginidu2003 commited on
Commit
ab4f49e
Β·
verified Β·
1 Parent(s): 8ef9c08

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +24 -23
src/streamlit_app.py CHANGED
@@ -8,31 +8,38 @@ from nltk.stem import WordNetLemmatizer
8
  import re
9
  import string
10
 
11
- import os
12
- from huggingface_hub import login
13
 
14
- hf_token = os.getenv("HF_TOKEN")
15
- if hf_token:
16
- login(hf_token)
17
 
18
- # ====================== PREPROCESSING (Same as Task 2) ======================
19
 
20
- # ====================== LOAD FINE-TUNED MODEL ======================
21
- @st.cache_resource
22
  def load_model():
23
- model_name = "Ginidu2003/Distilbert-Base-News-classifier" # ← Your exact model name
24
- return pipeline(
25
- "text-classification",
26
- model=model_name,
27
- device=0 if torch.cuda.is_available() else -1
28
- )
 
 
 
 
 
 
 
 
29
 
30
  classifier = load_model()
31
 
32
- # ====================== STREAMLIT APP ======================
33
  st.title("πŸ“° Daily Mirror News Classifier")
34
  st.subheader("Classify news into Business, Opinion, Political Gossip, Sports, or World News")
35
 
 
 
 
36
  st.markdown("**Upload a CSV file** with a column named `content`")
37
 
38
  uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
@@ -46,29 +53,23 @@ if uploaded_file is not None:
46
  if 'content' not in df.columns:
47
  st.error("Your CSV must have a column named 'content'")
48
  else:
49
- with st.spinner("Preprocessing and classifying..."):
50
- # Apply same preprocessing as Task 2
51
  #df['clean_content'] = df['content'].apply(preprocess_text)
52
 
53
- # Classify
54
  predictions = []
55
  for text in df['content']:
56
- if text.strip() == "":
57
  predictions.append("Unknown")
58
  else:
59
  result = classifier(text)[0]
60
  predictions.append(result['label'])
61
 
62
  df['class'] = predictions
63
-
64
- # Drop helper column
65
  #df = df.drop(columns=['clean_content'], errors='ignore')
66
 
67
  st.success("βœ… Classification completed!")
68
- st.write("### Preview of classified data")
69
  st.dataframe(df.head())
70
 
71
- # Download button
72
  csv = df.to_csv(index=False).encode('utf-8')
73
  st.download_button(
74
  label="πŸ“₯ Download output.csv",
 
8
  import re
9
  import string
10
 
11
+ st.set_page_config(page_title="Daily Mirror News Classifier", page_icon="πŸ“°")
 
12
 
13
+ # ====================== PREPROCESSING ======================
 
 
14
 
 
15
 
16
+ # ====================== LOAD MODEL (with better error handling) ======================
17
+ @st.cache_resource(show_spinner=False)
18
  def load_model():
19
+ model_name = "Ginidu2003/Distilbert-Base-News-classifier" # ← Make sure this is exact
20
+ try:
21
+ pipe = pipeline(
22
+ "text-classification",
23
+ model=model_name,
24
+ device=0 if torch.cuda.is_available() else -1
25
+ )
26
+ st.success(f"βœ… Model loaded successfully: {model_name}")
27
+ return pipe
28
+ except Exception as e:
29
+ st.error(f"❌ Failed to load model: {model_name}")
30
+ st.error(f"Error: {str(e)}")
31
+ st.info("Make sure the model is Public and the name is correct.")
32
+ return None
33
 
34
  classifier = load_model()
35
 
36
+ # ====================== APP ======================
37
  st.title("πŸ“° Daily Mirror News Classifier")
38
  st.subheader("Classify news into Business, Opinion, Political Gossip, Sports, or World News")
39
 
40
+ if classifier is None:
41
+ st.stop()
42
+
43
  st.markdown("**Upload a CSV file** with a column named `content`")
44
 
45
  uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
 
53
  if 'content' not in df.columns:
54
  st.error("Your CSV must have a column named 'content'")
55
  else:
56
+ with st.spinner("Classifying news..."):
 
57
  #df['clean_content'] = df['content'].apply(preprocess_text)
58
 
 
59
  predictions = []
60
  for text in df['content']:
61
+ if not text.strip():
62
  predictions.append("Unknown")
63
  else:
64
  result = classifier(text)[0]
65
  predictions.append(result['label'])
66
 
67
  df['class'] = predictions
 
 
68
  #df = df.drop(columns=['clean_content'], errors='ignore')
69
 
70
  st.success("βœ… Classification completed!")
 
71
  st.dataframe(df.head())
72
 
 
73
  csv = df.to_csv(index=False).encode('utf-8')
74
  st.download_button(
75
  label="πŸ“₯ Download output.csv",