dinusha11 commited on
Commit
1eee1d7
·
verified ·
1 Parent(s): 365b023

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -1,13 +1,3 @@
1
- # Upgrade pip to avoid dependency issues
2
- pip install --upgrade pip
3
-
4
- # Install PyTorch and related libraries
5
- pip install torch torchvision torchaudio
6
-
7
- # Verify installation
8
- python -c "import torch; print(torch.__version__)"
9
-
10
-
11
  import streamlit as st
12
  import pandas as pd
13
  import torch
@@ -15,19 +5,24 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
15
 
16
  # Load the fine-tuned model
17
  MODEL_NAME = "dinusha11/finetuned-distilbert-news"
 
 
 
 
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
20
 
21
- # Define label mapping
22
  LABEL_MAPPING = {0: "Business", 1: "Opinion", 2: "Sports", 3: "Political_gossip", 4: "World_news"}
23
 
24
- # Function to classify text
25
- def classify_text(text):
26
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
 
27
  with torch.no_grad():
28
  outputs = model(**inputs)
29
- predicted_class = torch.argmax(outputs.logits, dim=1).item()
30
- return LABEL_MAPPING[predicted_class]
31
 
32
  # Streamlit UI
33
  st.title("News Classification App")
@@ -38,18 +33,19 @@ uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
38
 
39
  if uploaded_file:
40
  df = pd.read_csv(uploaded_file)
41
-
42
  if "text" not in df.columns:
43
  st.error("CSV must contain a 'text' column.")
44
  else:
45
- # Preprocess text
46
  df["text"] = df["text"].fillna("").str.strip().str.lower()
47
-
48
- # Apply classification
49
- df["class"] = df["text"].apply(classify_text)
50
-
51
  # Download output CSV
52
  output_csv = df.to_csv(index=False).encode("utf-8")
53
  st.download_button("Download Results", data=output_csv, file_name="output.csv", mime="text/csv")
54
 
55
  st.write("Classification Complete! Download your file above.")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
 
5
 
6
  # Load the fine-tuned model
7
  MODEL_NAME = "dinusha11/finetuned-distilbert-news"
8
+
9
+ # Ensure model runs on CPU if GPU is unavailable
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(device)
14
 
15
+ # Define label mapping (Ensure this matches your model's label order)
16
  LABEL_MAPPING = {0: "Business", 1: "Opinion", 2: "Sports", 3: "Political_gossip", 4: "World_news"}
17
 
18
+ # Function to classify a batch of text
19
+ def classify_texts(texts):
20
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
21
+ inputs = {key: val.to(device) for key, val in inputs.items()} # Move inputs to device
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
+ predicted_classes = torch.argmax(outputs.logits, dim=1).cpu().numpy()
25
+ return [LABEL_MAPPING[pred] for pred in predicted_classes]
26
 
27
  # Streamlit UI
28
  st.title("News Classification App")
 
33
 
34
  if uploaded_file:
35
  df = pd.read_csv(uploaded_file)
36
+
37
  if "text" not in df.columns:
38
  st.error("CSV must contain a 'text' column.")
39
  else:
40
+ # Preprocess text (handle missing values, strip spaces, and convert to lowercase)
41
  df["text"] = df["text"].fillna("").str.strip().str.lower()
42
+
43
+ # Apply batch classification
44
+ df["class"] = classify_texts(df["text"].tolist())
45
+
46
  # Download output CSV
47
  output_csv = df.to_csv(index=False).encode("utf-8")
48
  st.download_button("Download Results", data=output_csv, file_name="output.csv", mime="text/csv")
49
 
50
  st.write("Classification Complete! Download your file above.")
51
+