news_classification_UI / src /streamlit_app.py
Akilashamnaka12's picture
Update src/streamlit_app.py
345b72a verified
import re
import pandas as pd
import streamlit as st
from transformers import pipeline
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
import torch
st.set_page_config(
page_title="News Intelligence Studio",
page_icon="📰",
layout="wide",
initial_sidebar_state="collapsed",
)
MODEL_NAME = "Akilashamnaka12/news-classifier-model"
QA_MODEL = "distilbert-base-cased-distilled-squad"
MAX_CONTEXT_ROWS = 8
def inject_styles() -> None:
st.markdown(
"""
<style>
:root {
--bg: #f5f4ef;
--paper: #fbfaf6;
--ink: #1a1a18;
--muted: #6d6a63;
--line: rgba(26,26,24,0.12);
--soft: #ece8db;
--accent: #121212;
--gradient-a: rgba(127, 177, 183, 0.65);
--gradient-b: rgba(30, 44, 58, 0.35);
--gradient-c: rgba(220, 191, 151, 0.35);
}
.stApp {
background: var(--bg);
color: var(--ink);
}
.block-container {
padding-top: 1.2rem;
padding-bottom: 4rem;
max-width: 1240px;
}
header[data-testid="stHeader"] {
background: transparent;
}
[data-testid="stSidebar"] {
display: none;
}
div[data-testid="stFileUploaderDropzone"] {
background: rgba(255,255,255,0.55);
border: 1px dashed rgba(26,26,24,0.18);
border-radius: 22px;
min-height: 220px;
}
.nav {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.4rem 0 1rem 0;
border-bottom: 1px solid var(--line);
margin-bottom: 1rem;
font-size: 0.92rem;
color: var(--muted);
}
.nav-links {
display: flex;
gap: 1.4rem;
flex-wrap: wrap;
}
.brand {
font-weight: 700;
letter-spacing: 0.02em;
color: var(--ink);
}
.hero {
position: relative;
overflow: hidden;
border-radius: 30px;
min-height: 460px;
padding: 3.5rem 3rem;
margin: 1rem 0 2rem 0;
background:
radial-gradient(circle at 20% 20%, rgba(255,255,255,0.62), transparent 32%),
radial-gradient(circle at 75% 28%, rgba(11, 27, 42, 0.35), transparent 18%),
radial-gradient(circle at 82% 82%, rgba(232, 196, 154, 0.5), transparent 20%),
linear-gradient(135deg, var(--gradient-a), var(--gradient-b) 48%, var(--gradient-c));
box-shadow: 0 10px 30px rgba(0,0,0,0.05);
}
.eyebrow {
text-transform: uppercase;
font-size: 0.78rem;
letter-spacing: 0.18em;
color: rgba(255,255,255,0.82);
margin-bottom: 1rem;
}
.hero h1,
.section-title {
font-family: Georgia, 'Times New Roman', serif;
font-weight: 400;
letter-spacing: -0.02em;
}
.hero h1 {
font-size: clamp(3rem, 7vw, 5.4rem);
line-height: 0.92;
color: #fffdf8;
max-width: 700px;
margin: 0;
}
.hero p {
max-width: 520px;
color: rgba(255,255,255,0.84);
font-size: 1.03rem;
line-height: 1.7;
margin-top: 1rem;
}
.hero-chip-row {
display: flex;
gap: 0.7rem;
flex-wrap: wrap;
margin-top: 1.6rem;
}
.chip {
border: 1px solid rgba(255,255,255,0.24);
background: rgba(255,255,255,0.12);
color: white;
padding: 0.62rem 0.9rem;
border-radius: 999px;
font-size: 0.84rem;
backdrop-filter: blur(6px);
}
.panel,
.soft-panel,
.metric-card,
.story-card {
border-radius: 26px;
overflow: hidden;
}
.panel {
background: rgba(255,255,255,0.5);
border: 1px solid rgba(26,26,24,0.08);
padding: 1.25rem;
}
.soft-panel {
background: #e7e1cf;
border: 1px solid rgba(26,26,24,0.06);
padding: 1.5rem;
}
.section-head {
display: flex;
justify-content: space-between;
gap: 1rem;
align-items: end;
margin: 2.5rem 0 1.2rem 0;
}
.section-title {
font-size: clamp(1.8rem, 3vw, 3rem);
line-height: 1;
margin: 0;
}
.section-copy {
max-width: 520px;
color: var(--muted);
font-size: 0.96rem;
line-height: 1.7;
}
.metric-grid {
display: grid;
grid-template-columns: repeat(4, minmax(0,1fr));
gap: 1rem;
margin-top: 1rem;
}
.metric-card {
background: #f7f5ee;
border: 1px solid rgba(26,26,24,0.08);
padding: 1rem 1.1rem;
min-height: 130px;
}
.metric-label {
color: var(--muted);
font-size: 0.84rem;
margin-bottom: 1rem;
}
.metric-value {
font-size: 2rem;
line-height: 1;
margin-bottom: 0.35rem;
font-weight: 600;
}
.story-card {
position: relative;
min-height: 170px;
padding: 1.2rem;
color: #fffaf2;
background:
linear-gradient(180deg, rgba(0,0,0,0.05), rgba(0,0,0,0.55)),
linear-gradient(135deg, rgba(48,93,112,0.8), rgba(24,24,24,0.75), rgba(176,103,77,0.65));
border: 1px solid rgba(255,255,255,0.08);
}
.story-card h4 {
margin: 0;
font-size: 1.2rem;
line-height: 1.2;
font-family: Georgia, 'Times New Roman', serif;
font-weight: 400;
}
.story-card p {
font-size: 0.9rem;
color: rgba(255,255,255,0.82);
line-height: 1.6;
margin-top: 0.8rem;
}
.cta {
text-align: center;
padding: 3rem 2rem;
margin-top: 2rem;
border-radius: 28px;
background: #dfe7d7;
border: 1px solid rgba(26,26,24,0.08);
}
.cta h2 {
font-family: Georgia, 'Times New Roman', serif;
font-size: clamp(2rem, 4vw, 3.2rem);
font-weight: 400;
margin: 0 0 0.7rem 0;
}
.foot {
border-top: 1px solid var(--line);
margin-top: 2.5rem;
padding-top: 1.2rem;
color: var(--muted);
font-size: 0.88rem;
}
label, .stRadio label, .stCaption, [data-testid="stCaptionContainer"] {
color: #4f4b45 !important;
opacity: 1 !important;
}
[data-testid="stMarkdownContainer"] p {
color: #4f4b45;
}
div[role="radiogroup"] label {
color: #2b2925 !important;
font-weight: 500;
}
@media (max-width: 900px) {
.hero {
min-height: 360px;
padding: 2rem 1.4rem;
}
.metric-grid {
grid-template-columns: repeat(2, minmax(0,1fr));
}
}
@media (max-width: 640px) {
.metric-grid {
grid-template-columns: 1fr;
}
.section-head {
flex-direction: column;
align-items: start;
}
}
</style>
""",
unsafe_allow_html=True,
)
@st.cache_resource(show_spinner=False)
def load_pipelines():
classifier = pipeline(
"text-classification",
model=MODEL_NAME,
tokenizer=MODEL_NAME,
truncation=True,
)
# Load QA manually instead of using pipeline("question-answering")
tokenizer = AutoTokenizer.from_pretrained(QA_MODEL)
model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL)
def qa_fn(question, context):
inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
start = outputs.start_logits.argmax()
end = outputs.end_logits.argmax() + 1
answer = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end])
)
score = float(outputs.start_logits.softmax(dim=-1).max())
return {"answer": answer, "score": score}
return classifier, qa_fn
def preprocess_text(text: str) -> str:
text = str(text)
text = re.sub(r"http\S+|www\.\S+", " ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def get_text_column(df: pd.DataFrame) -> str:
lowered = {c.lower(): c for c in df.columns}
if "content" in lowered:
return lowered["content"]
raise ValueError("CSV must contain a 'content' column.")
def predict_classes(df: pd.DataFrame, text_col: str, classifier):
texts = df[text_col].fillna("").astype(str).apply(preprocess_text).tolist()
outputs = classifier(texts, batch_size=16)
labels = [o.get("label", "Unknown") for o in outputs]
scores = [round(float(o.get("score", 0.0)), 4) for o in outputs]
return texts, labels, scores
def dataframe_to_csv_bytes(df: pd.DataFrame) -> bytes:
return df.to_csv(index=False).encode("utf-8")
inject_styles()
with st.spinner("Loading models..."):
classifier, qa_pipeline = load_pipelines()
st.markdown(
"""
<div class="nav">
<div class="brand">News Intelligence Studio</div>
<div class="nav-links">
<span>Classification</span>
<span>Question Answering</span>
<span>Insights</span>
<span>Local Streamlit</span>
</div>
</div>
""",
unsafe_allow_html=True,
)
st.markdown(
"""
<section class="hero">
<div class="eyebrow">Powered by Hugging Face</div>
<h1>Intelligence that reads your news operations</h1>
<p>
Upload a CSV, classify every news excerpt with your fine-tuned model,
explore the predicted distribution, and ask grounded questions from the
article content in one polished Streamlit workspace.
</p>
<div class="hero-chip-row">
<div class="chip">Model: Akilashamnaka12/news-classifier-model</div>
<div class="chip">CSV in → output.csv out</div>
<div class="chip">Local-first Streamlit experience</div>
</div>
</section>
""",
unsafe_allow_html=True,
)
left, right = st.columns([1.15, 0.85], gap="large")
uploaded_file = None
question = ""
context_mode = "Use first few records"
answer_box = right.empty()
with left:
st.markdown('<div class="panel">', unsafe_allow_html=True)
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
st.caption("Expected column: content")
st.markdown("</div>", unsafe_allow_html=True)
result_df = None
filtered_df = None
selected_class = "All"
text_col = None
if uploaded_file is not None:
try:
raw_df = pd.read_csv(uploaded_file)
text_col = get_text_column(raw_df)
texts, labels, scores = predict_classes(raw_df.copy(), text_col, classifier)
result_df = raw_df.copy()
result_df[text_col] = texts
result_df["class"] = labels
result_df["confidence"] = scores
classes = sorted(result_df["class"].dropna().unique().tolist())
selected_class = left.selectbox("Filter predictions", ["All"] + classes, index=0)
filtered_df = (
result_df
if selected_class == "All"
else result_df[result_df["class"] == selected_class]
)
except Exception as exc:
st.error(f"Could not process the file: {exc}")
with right:
st.markdown('<div class="panel">', unsafe_allow_html=True)
st.subheader("Ask questions from the uploaded news")
question = st.text_input("Type your question")
st.caption("Ask things like: What happened in sports? What caused flooding in Colombo?")
context_mode = st.radio(
"Context source",
["Use first few records", "Use selected class only"],
horizontal=True,
)
if uploaded_file is not None and result_df is not None and question:
try:
qa_source_df = result_df.copy()
if context_mode == "Use selected class only" and selected_class not in (None, "All"):
qa_source_df = qa_source_df[qa_source_df["class"] == selected_class]
candidate_rows = qa_source_df[text_col].fillna("").astype(str).head(MAX_CONTEXT_ROWS).tolist()
candidate_rows = [row for row in candidate_rows if row.strip()]
if candidate_rows:
best_answer = None
best_score = -1.0
best_context = ""
for row_text in candidate_rows:
result = qa_pipeline(
question=question,
context=row_text
)
score = float(result.get("score", 0.0))
if score > best_score:
best_score = score
best_answer = result.get("answer", "No answer found.")
best_context = row_text
st.markdown("---")
st.markdown("### Answer")
st.success(best_answer)
st.caption(f"Confidence: {best_score:.4f}")
with st.expander("Show context used"):
st.write(best_context)
else:
st.warning("No usable context found.")
except Exception as e:
st.error(f"Error generating answer: {e}")
st.markdown("</div>", unsafe_allow_html=True)
if result_df is not None:
st.markdown(
"""
<div class="section-head">
<div>
<div class="section-title">Continuously test and explore output</div>
</div>
<div class="section-copy">
Once a file is uploaded, the app predicts a class for each row,
adds a confidence score, and prepares an exportable output.csv.
</div>
</div>
""",
unsafe_allow_html=True,
)
top_class = result_df["class"].mode().iat[0] if not result_df.empty else "N/A"
avg_conf = f"{result_df['confidence'].mean():.2%}" if not result_df.empty else "0%"
st.markdown(
f"""
<div class="metric-grid">
<div class="metric-card">
<div class="metric-label">Uploaded records</div>
<div class="metric-value">{len(result_df)}</div>
<div>Rows processed from your CSV</div>
</div>
<div class="metric-card">
<div class="metric-label">Detected classes</div>
<div class="metric-value">{result_df['class'].nunique()}</div>
<div>Unique labels predicted by the model</div>
</div>
<div class="metric-card">
<div class="metric-label">Top predicted class</div>
<div class="metric-value">{top_class}</div>
<div>Most frequent label in the batch</div>
</div>
<div class="metric-card">
<div class="metric-label">Average confidence</div>
<div class="metric-value">{avg_conf}</div>
<div>Mean prediction confidence score</div>
</div>
</div>
""",
unsafe_allow_html=True,
)
col_a, col_b = st.columns([1.05, 0.95], gap="large")
with col_a:
st.markdown('<div class="soft-panel">', unsafe_allow_html=True)
st.subheader("Predicted class distribution")
st.bar_chart(result_df["class"].value_counts())
st.markdown("</div>", unsafe_allow_html=True)
with col_b:
st.markdown('<div class="soft-panel">', unsafe_allow_html=True)
st.subheader("Download ready")
st.write(
"Your exported file includes the original columns, the predicted class, and the confidence score."
)
st.download_button(
label="Download output.csv",
data=dataframe_to_csv_bytes(result_df),
file_name="output.csv",
mime="text/csv",
use_container_width=True,
)
st.markdown("</div>", unsafe_allow_html=True)
st.markdown(
"""
<div class="section-head">
<div>
<div class="section-title">Built for the real world</div>
</div>
<div class="section-copy">
Below are presentation-friendly feature cards. They help your app
feel more like a polished product during the live demo.
</div>
</div>
""",
unsafe_allow_html=True,
)
story_cols = st.columns(4, gap="small")
stories = [
(
"Scalable batch classification",
"Upload larger CSV files and label each record in a single flow.",
),
(
"Grounded question answering",
"Ask focused questions using article content as context.",
),
(
"Confidence-aware review",
"Inspect how certain the model is before exporting the final sheet.",
),
(
"Presentation-ready interface",
"A clean editorial design that feels stronger than a default dashboard.",
),
]
for col, (title, copy) in zip(story_cols, stories):
with col:
st.markdown(
f'<div class="story-card"><h4>{title}</h4><p>{copy}</p></div>',
unsafe_allow_html=True,
)
st.markdown(
"""
<div class="section-head">
<div>
<div class="section-title">Records</div>
</div>
<div class="section-copy">
Review the classified rows before downloading the final output.
</div>
</div>
""",
unsafe_allow_html=True,
)
st.dataframe(filtered_df, use_container_width=True, height=360)
else:
st.markdown(
"""
<div class="cta">
<h2>Intelligence that runs your news workflow</h2>
<p>Upload a CSV to activate classification, analytics, downloadable results, and grounded Q&A.</p>
</div>
""",
unsafe_allow_html=True,
)
st.markdown(
"""
<div class="foot">
Local run command: <code>python -m streamlit run app.py</code><br>
Make sure your CSV contains a <code>content</code> column, and keep the preprocessing function aligned with Section 01.
</div>
""",
unsafe_allow_html=True,
)