Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,13 +1,3 @@
|
|
| 1 |
-
# Upgrade pip to avoid dependency issues
|
| 2 |
-
pip install --upgrade pip
|
| 3 |
-
|
| 4 |
-
# Install PyTorch and related libraries
|
| 5 |
-
pip install torch torchvision torchaudio
|
| 6 |
-
|
| 7 |
-
# Verify installation
|
| 8 |
-
python -c "import torch; print(torch.__version__)"
|
| 9 |
-
|
| 10 |
-
|
| 11 |
import streamlit as st
|
| 12 |
import pandas as pd
|
| 13 |
import torch
|
|
@@ -15,19 +5,24 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
| 15 |
|
| 16 |
# Load the fine-tuned model
|
| 17 |
MODEL_NAME = "dinusha11/finetuned-distilbert-news"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 19 |
-
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
|
| 20 |
|
| 21 |
-
# Define label mapping
|
| 22 |
LABEL_MAPPING = {0: "Business", 1: "Opinion", 2: "Sports", 3: "Political_gossip", 4: "World_news"}
|
| 23 |
|
| 24 |
-
# Function to classify text
|
| 25 |
-
def
|
| 26 |
-
inputs = tokenizer(
|
|
|
|
| 27 |
with torch.no_grad():
|
| 28 |
outputs = model(**inputs)
|
| 29 |
-
|
| 30 |
-
return LABEL_MAPPING[
|
| 31 |
|
| 32 |
# Streamlit UI
|
| 33 |
st.title("News Classification App")
|
|
@@ -38,18 +33,19 @@ uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
|
|
| 38 |
|
| 39 |
if uploaded_file:
|
| 40 |
df = pd.read_csv(uploaded_file)
|
| 41 |
-
|
| 42 |
if "text" not in df.columns:
|
| 43 |
st.error("CSV must contain a 'text' column.")
|
| 44 |
else:
|
| 45 |
-
# Preprocess text
|
| 46 |
df["text"] = df["text"].fillna("").str.strip().str.lower()
|
| 47 |
-
|
| 48 |
-
# Apply classification
|
| 49 |
-
df["class"] = df["text"].
|
| 50 |
-
|
| 51 |
# Download output CSV
|
| 52 |
output_csv = df.to_csv(index=False).encode("utf-8")
|
| 53 |
st.download_button("Download Results", data=output_csv, file_name="output.csv", mime="text/csv")
|
| 54 |
|
| 55 |
st.write("Classification Complete! Download your file above.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import torch
|
|
|
|
| 5 |
|
| 6 |
# Load the fine-tuned model
|
| 7 |
MODEL_NAME = "dinusha11/finetuned-distilbert-news"
|
| 8 |
+
|
| 9 |
+
# Ensure model runs on CPU if GPU is unavailable
|
| 10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
+
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 13 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(device)
|
| 14 |
|
| 15 |
+
# Define label mapping (Ensure this matches your model's label order)
|
| 16 |
LABEL_MAPPING = {0: "Business", 1: "Opinion", 2: "Sports", 3: "Political_gossip", 4: "World_news"}
|
| 17 |
|
| 18 |
+
# Function to classify a batch of text
|
| 19 |
+
def classify_texts(texts):
|
| 20 |
+
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| 21 |
+
inputs = {key: val.to(device) for key, val in inputs.items()} # Move inputs to device
|
| 22 |
with torch.no_grad():
|
| 23 |
outputs = model(**inputs)
|
| 24 |
+
predicted_classes = torch.argmax(outputs.logits, dim=1).cpu().numpy()
|
| 25 |
+
return [LABEL_MAPPING[pred] for pred in predicted_classes]
|
| 26 |
|
| 27 |
# Streamlit UI
|
| 28 |
st.title("News Classification App")
|
|
|
|
| 33 |
|
| 34 |
if uploaded_file:
|
| 35 |
df = pd.read_csv(uploaded_file)
|
| 36 |
+
|
| 37 |
if "text" not in df.columns:
|
| 38 |
st.error("CSV must contain a 'text' column.")
|
| 39 |
else:
|
| 40 |
+
# Preprocess text (handle missing values, strip spaces, and convert to lowercase)
|
| 41 |
df["text"] = df["text"].fillna("").str.strip().str.lower()
|
| 42 |
+
|
| 43 |
+
# Apply batch classification
|
| 44 |
+
df["class"] = classify_texts(df["text"].tolist())
|
| 45 |
+
|
| 46 |
# Download output CSV
|
| 47 |
output_csv = df.to_csv(index=False).encode("utf-8")
|
| 48 |
st.download_button("Download Results", data=output_csv, file_name="output.csv", mime="text/csv")
|
| 49 |
|
| 50 |
st.write("Classification Complete! Download your file above.")
|
| 51 |
+
|