| import streamlit as st |
| import torch |
| import pandas as pd |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| MODEL_ID = "dejanseo/ecommerce-query-volume-classifier" |
|
|
| LABELS = ["very_low", "low", "medium", "high", "very_high"] |
|
|
| DISPLAY_LABELS = { |
| "very_high": "Very High", |
| "high": "High", |
| "medium": "Medium", |
| "low": "Low", |
| "very_low": "Very Low", |
| } |
|
|
| VOLUME_RANGES = { |
| "very_high": "10,000+", |
| "high": "1,000β9,999", |
| "medium": "100β999", |
| "low": "10β99", |
| "very_low": "<10", |
| } |
|
|
| CLASS_COLORS = { |
| "Very High": "background-color: rgba(40, 167, 69, 0.1); color: #28a745;", |
| "High": "background-color: rgba(113, 194, 133, 0.1); color: #71c285;", |
| "Medium": "background-color: rgba(255, 193, 7, 0.1); color: #c79100;", |
| "Low": "background-color: rgba(253, 126, 20, 0.1); color: #fd7e14;", |
| "Very Low": "background-color: rgba(220, 53, 69, 0.1); color: #dc3545;", |
| } |
|
|
|
|
| @st.cache_resource |
| def load_model(): |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) |
| model.eval() |
| return tokenizer, model |
|
|
|
|
| def predict(queries, tokenizer, model): |
| inputs = tokenizer( |
| queries, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=32, |
| ) |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| probs = torch.softmax(logits, dim=-1) |
| preds = torch.argmax(probs, dim=-1) |
| results = [] |
| for i, query in enumerate(queries): |
| label = LABELS[preds[i].item()] |
| conf = probs[i][preds[i].item()].item() |
| dist = {LABELS[j]: probs[i][j].item() for j in range(len(LABELS))} |
| results.append({"query": query, "label": label, "confidence": conf, "distribution": dist}) |
| return results |
|
|
|
|
| def style_class_cells(val): |
| return CLASS_COLORS.get(val, "") |
|
|
|
|
| st.set_page_config(page_title="Query Volume Classifier", page_icon="π", layout="wide") |
|
|
| st.markdown(""" |
| <style> |
| @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;500;600;700&display=swap'); |
| html, body, [class*="css"], .stMarkdown, .stTextInput, .stTextArea, |
| .stMetric, .stDataFrame, h1, h2, h3, h4, h5, h6, p, span, div, label { |
| font-family: 'Montserrat', sans-serif !important; |
| font-weight: 500; |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| st.logo( |
| "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", |
| link="https://dejan.ai/", |
| size="large", |
| ) |
|
|
| st.subheader("eCommerce Query Volume Classifier") |
| st.caption("Predicts search volume class for product queries.") |
|
|
| tokenizer, model = load_model() |
|
|
| tab_single, tab_batch = st.tabs(["Single Query", "Batch"]) |
|
|
| with tab_single: |
| query = st.text_input("Enter a product search query", placeholder="e.g. wireless mouse, airpods, replacement gasket for instant pot") |
| if st.button("Classify", key="single") and query.strip(): |
| result = predict([query.strip()], tokenizer, model)[0] |
| label = result["label"] |
| display_label = DISPLAY_LABELS[label] |
| volume_range = VOLUME_RANGES[label] |
|
|
| col1, col2, col3 = st.columns(3) |
| col1.metric("Predicted Class", display_label) |
| col2.metric("Estimated Sessions", volume_range) |
| col3.metric("Confidence", f"{result['confidence']:.1%}") |
|
|
| st.markdown("**Class probabilities**") |
| for lbl in ["very_high", "high", "medium", "low", "very_low"]: |
| prob = result["distribution"][lbl] |
| st.progress(prob, text=f"{DISPLAY_LABELS[lbl]} β {prob:.1%}") |
|
|
| with tab_batch: |
| st.markdown("Enter one query per line:") |
| text = st.text_area("Queries", height=200, placeholder="laptop\nairpods\norganic flurb capsules\nreplacement gasket for instant pot duo 8 quart") |
| if st.button("Classify", key="batch") and text.strip(): |
| queries = [q.strip() for q in text.strip().splitlines() if q.strip()] |
| if len(queries) > 500: |
| st.warning("Please enter 500 queries or fewer.") |
| else: |
| results = predict(queries, tokenizer, model) |
| df = pd.DataFrame([ |
| { |
| "Query": r["query"], |
| "Class": DISPLAY_LABELS[r["label"]], |
| "Sessions": VOLUME_RANGES[r["label"]], |
| "Confidence": r["confidence"], |
| } |
| for r in results |
| ]) |
|
|
| styled = df.style.map(style_class_cells, subset=["Class"]) |
|
|
| st.dataframe( |
| styled, |
| column_config={ |
| "Confidence": st.column_config.ProgressColumn( |
| "Confidence", |
| format="%.1f%%", |
| min_value=0, |
| max_value=1, |
| ), |
| }, |
| use_container_width=True, |
| hide_index=True, |
| ) |