qve / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
0ce344e verified
import streamlit as st
import torch
import json
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
# ==========================
# LABEL MAPPING
# ==========================
LABEL_MAPPING = {
0: "51-100",
1: "101-150",
2: "151-200",
3: "201-250",
4: "251-500",
5: "501-1000",
6: "1001-2000",
7: "2001-5000",
8: "5001-10000",
9: "10001-100000",
10: "100001-200000",
11: "200001+"
}
# ==========================
# PAGE CONFIG
# ==========================
st.set_page_config(
page_title="KVE",
page_icon="πŸ“Š",
layout="centered"
)
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
st.logo(
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
link="https://dejan.ai/",
size="large"
)
# ==========================
# LOAD MODEL
# ==========================
@st.cache_resource
def load_model():
"""Load model and tokenizer from Hugging Face"""
model_name = "dejanseo/qve-klook"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
return model, tokenizer
# ==========================
# PREDICTION
# ==========================
def predict(query, model, tokenizer):
inputs = tokenizer(query, padding="max_length", truncation=True, max_length=20, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=-1).squeeze().numpy()
predicted_class = np.argmax(probabilities)
confidence = probabilities[predicted_class]
return predicted_class, confidence, probabilities
# ==========================
# CONFIDENCE INDICATOR
# ==========================
def get_confidence_indicator(confidence):
if confidence > 0.8:
return "🟒", "High Confidence", "#4CAF50"
elif confidence > 0.5:
return "🟑", "Moderate Confidence", "#FFC107"
else:
return "πŸ”΄", "Low Confidence", "#f44336"
# ==========================
# MAIN APP
# ==========================
def main():
try:
with st.spinner("Loading model..."):
model, tokenizer = load_model()
query = st.text_input(
"Enter search query to classify:",
placeholder="Type your search query here...",
key="input_text"
)
col1, col2, col3 = st.columns([1, 1, 2])
with col1:
classify_btn = st.button("πŸ” Classify", type="primary", use_container_width=True)
with col2:
clear_btn = st.button("πŸ—‘οΈ Clear", use_container_width=True)
if clear_btn:
st.rerun()
if classify_btn and query:
with st.spinner("Analyzing..."):
predicted_class, confidence, probabilities = predict(query, model, tokenizer)
predicted_label = LABEL_MAPPING.get(predicted_class, f"Class {predicted_class}")
emoji, conf_text, color = get_confidence_indicator(confidence)
result_html = f"""
<div style="
background: linear-gradient(135deg, {color}22 0%, {color}11 100%);
border-left: 4px solid {color};
padding: 20px;
border-radius: 10px;
margin: 10px 0;
">
<h2 style="margin: 0; color: {color};">
{emoji} {predicted_label}
</h2>
</div>
"""
st.markdown(result_html, unsafe_allow_html=True)
st.markdown("Search Volume Probabilities")
top_10_idx = np.argsort(probabilities)[::-1][:10]
for rank, idx in enumerate(top_10_idx, 1):
label_name = LABEL_MAPPING.get(idx, f"Class {idx}")
prob = probabilities[idx]
bar_width = int(prob * 100)
bar_color = color if idx == predicted_class else "#B0BEC5"
st.markdown(f"""
<div style="margin-bottom: 8px;">
<div style="display: flex; align-items: center;">
<span style="width: 30px;">#{rank}</span>
<span style="flex: 1; font-weight: {'bold' if idx == predicted_class else 'normal'};">
{label_name}
</span>
<span style="width: 60px; text-align: right;">{prob:.1%}</span>
</div>
<div style="background: #E0E0E0; height: 4px; border-radius: 2px; margin-top: 2px;">
<div style="background: {bar_color}; width: {bar_width}%; height: 100%; border-radius: 2px;"></div>
</div>
</div>
""", unsafe_allow_html=True)
elif classify_btn and not query:
st.warning("⚠️ Please enter some text to classify")
except Exception as e:
st.error(f"❌ Error: {str(e)}")
if __name__ == "__main__":
main()