|
|
import streamlit as st |
|
|
import torch |
|
|
import json |
|
|
import os |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LABEL_MAPPING = { |
|
|
0: "51-100", |
|
|
1: "101-150", |
|
|
2: "151-200", |
|
|
3: "201-250", |
|
|
4: "251-500", |
|
|
5: "501-1000", |
|
|
6: "1001-2000", |
|
|
7: "2001-5000", |
|
|
8: "5001-10000", |
|
|
9: "10001-100000", |
|
|
10: "100001-200000", |
|
|
11: "200001+" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="KVE", |
|
|
page_icon="π", |
|
|
layout="centered" |
|
|
) |
|
|
|
|
|
hide_streamlit_style = """ |
|
|
<style> |
|
|
#MainMenu {visibility: hidden;} |
|
|
</style> |
|
|
""" |
|
|
st.markdown(hide_streamlit_style, unsafe_allow_html=True) |
|
|
|
|
|
st.logo( |
|
|
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", |
|
|
link="https://dejan.ai/", |
|
|
size="large" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load model and tokenizer from Hugging Face""" |
|
|
model_name = "dejanseo/qve-klook" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
model.eval() |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(query, model, tokenizer): |
|
|
inputs = tokenizer(query, padding="max_length", truncation=True, max_length=20, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probabilities = torch.softmax(outputs.logits, dim=-1).squeeze().numpy() |
|
|
predicted_class = np.argmax(probabilities) |
|
|
confidence = probabilities[predicted_class] |
|
|
return predicted_class, confidence, probabilities |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_confidence_indicator(confidence): |
|
|
if confidence > 0.8: |
|
|
return "π’", "High Confidence", "#4CAF50" |
|
|
elif confidence > 0.5: |
|
|
return "π‘", "Moderate Confidence", "#FFC107" |
|
|
else: |
|
|
return "π΄", "Low Confidence", "#f44336" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
try: |
|
|
with st.spinner("Loading model..."): |
|
|
model, tokenizer = load_model() |
|
|
|
|
|
query = st.text_input( |
|
|
"Enter search query to classify:", |
|
|
placeholder="Type your search query here...", |
|
|
key="input_text" |
|
|
) |
|
|
|
|
|
col1, col2, col3 = st.columns([1, 1, 2]) |
|
|
with col1: |
|
|
classify_btn = st.button("π Classify", type="primary", use_container_width=True) |
|
|
with col2: |
|
|
clear_btn = st.button("ποΈ Clear", use_container_width=True) |
|
|
|
|
|
if clear_btn: |
|
|
st.rerun() |
|
|
|
|
|
if classify_btn and query: |
|
|
with st.spinner("Analyzing..."): |
|
|
predicted_class, confidence, probabilities = predict(query, model, tokenizer) |
|
|
|
|
|
predicted_label = LABEL_MAPPING.get(predicted_class, f"Class {predicted_class}") |
|
|
emoji, conf_text, color = get_confidence_indicator(confidence) |
|
|
|
|
|
result_html = f""" |
|
|
<div style=" |
|
|
background: linear-gradient(135deg, {color}22 0%, {color}11 100%); |
|
|
border-left: 4px solid {color}; |
|
|
padding: 20px; |
|
|
border-radius: 10px; |
|
|
margin: 10px 0; |
|
|
"> |
|
|
<h2 style="margin: 0; color: {color};"> |
|
|
{emoji} {predicted_label} |
|
|
</h2> |
|
|
</div> |
|
|
""" |
|
|
st.markdown(result_html, unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("Search Volume Probabilities") |
|
|
|
|
|
top_10_idx = np.argsort(probabilities)[::-1][:10] |
|
|
for rank, idx in enumerate(top_10_idx, 1): |
|
|
label_name = LABEL_MAPPING.get(idx, f"Class {idx}") |
|
|
prob = probabilities[idx] |
|
|
bar_width = int(prob * 100) |
|
|
bar_color = color if idx == predicted_class else "#B0BEC5" |
|
|
|
|
|
st.markdown(f""" |
|
|
<div style="margin-bottom: 8px;"> |
|
|
<div style="display: flex; align-items: center;"> |
|
|
<span style="width: 30px;">#{rank}</span> |
|
|
<span style="flex: 1; font-weight: {'bold' if idx == predicted_class else 'normal'};"> |
|
|
{label_name} |
|
|
</span> |
|
|
<span style="width: 60px; text-align: right;">{prob:.1%}</span> |
|
|
</div> |
|
|
<div style="background: #E0E0E0; height: 4px; border-radius: 2px; margin-top: 2px;"> |
|
|
<div style="background: {bar_color}; width: {bar_width}%; height: 100%; border-radius: 2px;"></div> |
|
|
</div> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
elif classify_btn and not query: |
|
|
st.warning("β οΈ Please enter some text to classify") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"β Error: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |