| | import streamlit as st |
| | import torch |
| | import numpy as np |
| | import pandas as pd |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| |
|
| | BUCKETS = { |
| | 0: "Minimal", |
| | 1: "Very Low", |
| | 2: "Low", |
| | 3: "Medium", |
| | 4: "Medium-High", |
| | 5: "High", |
| | 6: "Very High" |
| | } |
| |
|
| | MODEL_PATH = "dejanseo/query-volume-prediction" |
| |
|
| | @st.cache_resource |
| | def load_model(): |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
| | model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
| | model.eval() |
| | return model, tokenizer |
| |
|
| | def predict_batch(queries: list[str], model, tokenizer): |
| | inputs = tokenizer( |
| | queries, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=32, |
| | return_tensors="pt" |
| | ) |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | probs = torch.softmax(outputs.logits, dim=-1).numpy() |
| | |
| | return probs |
| |
|
| | st.set_page_config(page_title="Query Volume Predictor", layout="centered") |
| |
|
| | hide_streamlit_style = """ |
| | <style> |
| | #MainMenu {visibility: hidden;} |
| | </style> |
| | """ |
| | st.markdown(hide_streamlit_style, unsafe_allow_html=True) |
| | st.logo( |
| | image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", |
| | link="https://dejan.ai/", |
| | size="large" |
| | ) |
| |
|
| | st.subheader("Query Volume Predictor") |
| |
|
| | model, tokenizer = load_model() |
| |
|
| | st.markdown("Enter one query per line:") |
| |
|
| | queries_text = st.text_area( |
| | "Queries", |
| | placeholder="iphone case\nwireless earbuds\nleft handed scissors", |
| | height=200, |
| | label_visibility="collapsed" |
| | ) |
| |
|
| | if st.button("Predict", type="primary"): |
| | if queries_text.strip(): |
| | queries = [q.strip() for q in queries_text.strip().split("\n") if q.strip()] |
| | |
| | if queries: |
| | with st.spinner(f"Processing {len(queries)} queries..."): |
| | all_probs = predict_batch(queries, model, tokenizer) |
| | |
| | results = [] |
| | for i, query in enumerate(queries): |
| | probs = all_probs[i] |
| | pred_class = np.argmax(probs) |
| | results.append({ |
| | "query": query, |
| | "classification": BUCKETS[pred_class], |
| | "volume": (pred_class + 1) / 7, |
| | "confidence": float(probs[pred_class]) * 100 |
| | }) |
| | |
| | df = pd.DataFrame(results) |
| | |
| | st.dataframe( |
| | df[["query", "classification", "volume", "confidence"]], |
| | column_config={ |
| | "query": st.column_config.TextColumn("Query"), |
| | "classification": st.column_config.TextColumn("Classification"), |
| | "volume": st.column_config.ProgressColumn( |
| | "Volume", |
| | format="", |
| | min_value=0, |
| | max_value=1 |
| | ), |
| | "confidence": st.column_config.ProgressColumn( |
| | "Confidence", |
| | format="%.0f%%", |
| | min_value=0, |
| | max_value=100 |
| | ) |
| | }, |
| | use_container_width=True, |
| | hide_index=True |
| | ) |
| | |
| | csv_df = df[["query", "classification", "confidence"]].copy() |
| | csv_df.columns = ["query", "classification", "confidence"] |
| | csv_df["confidence"] = csv_df["confidence"].apply(lambda x: f"{x:.0f}%") |
| | csv_text = csv_df.to_csv(index=False) |
| | |
| | st.download_button( |
| | label="Download CSV", |
| | data=csv_text, |
| | file_name="predictions.csv", |
| | mime="text/csv" |
| | ) |