query-volume-prediction / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
58b6626 verified
import streamlit as st
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
BUCKETS = {
0: "Minimal",
1: "Very Low",
2: "Low",
3: "Medium",
4: "Medium-High",
5: "High",
6: "Very High"
}
MODEL_PATH = "dejanseo/query-volume-prediction"
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()
return model, tokenizer
def predict_batch(queries: list[str], model, tokenizer):
inputs = tokenizer(
queries,
padding="max_length",
truncation=True,
max_length=32,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).numpy()
return probs
st.set_page_config(page_title="Query Volume Predictor", layout="centered")
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
st.logo(
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
link="https://dejan.ai/",
size="large"
)
st.subheader("Query Volume Predictor")
model, tokenizer = load_model()
st.markdown("Enter one query per line:")
queries_text = st.text_area(
"Queries",
placeholder="iphone case\nwireless earbuds\nleft handed scissors",
height=200,
label_visibility="collapsed"
)
if st.button("Predict", type="primary"):
if queries_text.strip():
queries = [q.strip() for q in queries_text.strip().split("\n") if q.strip()]
if queries:
with st.spinner(f"Processing {len(queries)} queries..."):
all_probs = predict_batch(queries, model, tokenizer)
results = []
for i, query in enumerate(queries):
probs = all_probs[i]
pred_class = np.argmax(probs)
results.append({
"query": query,
"classification": BUCKETS[pred_class],
"volume": (pred_class + 1) / 7,
"confidence": float(probs[pred_class]) * 100
})
df = pd.DataFrame(results)
st.dataframe(
df[["query", "classification", "volume", "confidence"]],
column_config={
"query": st.column_config.TextColumn("Query"),
"classification": st.column_config.TextColumn("Classification"),
"volume": st.column_config.ProgressColumn(
"Volume",
format="",
min_value=0,
max_value=1
),
"confidence": st.column_config.ProgressColumn(
"Confidence",
format="%.0f%%",
min_value=0,
max_value=100
)
},
use_container_width=True,
hide_index=True
)
csv_df = df[["query", "classification", "confidence"]].copy()
csv_df.columns = ["query", "classification", "confidence"]
csv_df["confidence"] = csv_df["confidence"].apply(lambda x: f"{x:.0f}%")
csv_text = csv_df.to_csv(index=False)
st.download_button(
label="Download CSV",
data=csv_text,
file_name="predictions.csv",
mime="text/csv"
)