| | import streamlit as st |
| | import torch |
| | import torch.nn.functional as F |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import pandas as pd |
| | import os |
| |
|
| | |
| | st.set_page_config( |
| | page_title="QDF Classifier", |
| | page_icon="🔍", |
| | layout="wide", |
| | initial_sidebar_state="collapsed", |
| | menu_items=None |
| | ) |
| |
|
| | MODEL_ID = "dejanseo/QDF-large" |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | MODEL_ID, |
| | token=HF_TOKEN, |
| | low_cpu_mem_usage=True |
| | ).eval() |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
| |
|
| | def classify(prompt: str): |
| | inputs = tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | padding=True, |
| | max_length=512 |
| | ) |
| | with torch.no_grad(): |
| | logits = model(**inputs).logits |
| | probs = torch.softmax(logits, dim=-1).squeeze().cpu() |
| | pred = torch.argmax(probs).item() |
| | confidence = probs[pred].item() |
| | return pred, confidence |
| |
|
| | |
| | st.markdown(""" |
| | <style> |
| | @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap'); |
| | |
| | html, body, div, span, input, label, textarea, button, h1, h2, p, table { |
| | font-family: 'Montserrat', sans-serif !important; |
| | } |
| | |
| | [class^="css-"], [class*=" css-"] { |
| | font-family: 'Montserrat', sans-serif !important; |
| | } |
| | |
| | header {visibility: hidden;} |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| | st.title("QDF Classifier") |
| | st.write("Built by [**Dejan AI**](https://dejan.ai)") |
| | st.write("This classifier determines whether query deserves freshness.") |
| |
|
| | |
| | example_text = """how would a cat describe a dog |
| | how to reset a Nest thermostat |
| | write a poem about time |
| | is there a train strike in London today |
| | summarize the theory of relativity |
| | who won the champions league last year |
| | explain quantum computing to a child |
| | weather in tokyo tomorrow |
| | generate a social media post for Earth Day |
| | what is the latest iPhone model""" |
| |
|
| | user_input = st.text_area( |
| | "Enter one search query per line:", |
| | placeholder=example_text |
| | ) |
| |
|
| | if st.button("Classify"): |
| | raw_input = user_input.strip() |
| | if raw_input: |
| | prompts = [line.strip() for line in raw_input.split("\n") if line.strip()] |
| | else: |
| | prompts = [line.strip() for line in example_text.split("\n")] |
| |
|
| | if not prompts: |
| | st.warning("Please enter at least one prompt.") |
| | else: |
| | info_box = st.info("Processing... results will appear below one by one.") |
| | table_placeholder = st.empty() |
| | results = [] |
| |
|
| | for p in prompts: |
| | with st.spinner(f"Classifying: {p[:50]}..."): |
| | label, conf = classify(p) |
| | results.append({ |
| | "Prompt": p, |
| | "QDF": "Yes" if label == 1 else "No", |
| | "Confidence": round(conf, 4) |
| | }) |
| | df = pd.DataFrame(results) |
| | table_placeholder.data_editor( |
| | df, |
| | column_config={ |
| | "Confidence": st.column_config.ProgressColumn( |
| | label="Confidence", |
| | min_value=0.0, |
| | max_value=1.0, |
| | format="%.4f" |
| | ) |
| | }, |
| | hide_index=True, |
| | ) |
| | info_box.empty() |
| |
|
| | |
| | st.subheader("Working together.") |
| | st.write("[**Schedule a call**](https://dejan.ai/call/) to see how we can help you.") |
| |
|