| | import streamlit as st |
| | import torch |
| | import torch.nn.functional as F |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import pandas as pd |
| | import os |
| |
|
| | |
| | st.set_page_config( |
| | page_title="QDG Classifier", |
| | page_icon="🔍", |
| | layout="wide", |
| | initial_sidebar_state="collapsed", |
| | menu_items=None |
| | ) |
| |
|
| | MODEL_ID = "dejanseo/query-grounding" |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
| | PREFERRED_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | def _has_meta_params(m: torch.nn.Module) -> bool: |
| | for p in m.parameters(): |
| | if getattr(p, "is_meta", False): |
| | return True |
| | return False |
| |
|
| |
|
| | def _first_real_param_device(m: torch.nn.Module) -> torch.device: |
| | for p in m.parameters(): |
| | if not getattr(p, "is_meta", False): |
| | return p.device |
| | return torch.device("cpu") |
| |
|
| |
|
| | @st.cache_resource(show_spinner=False) |
| | def load_model_and_tokenizer(): |
| | |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | MODEL_ID, |
| | token=HF_TOKEN, |
| | low_cpu_mem_usage=False, |
| | torch_dtype="auto", |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
| |
|
| | |
| | if _has_meta_params(model): |
| | if torch.cuda.is_available(): |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | MODEL_ID, |
| | token=HF_TOKEN, |
| | torch_dtype="auto", |
| | device_map="auto", |
| | ) |
| | else: |
| | |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | MODEL_ID, |
| | token=HF_TOKEN, |
| | low_cpu_mem_usage=False, |
| | ) |
| |
|
| | |
| | if not hasattr(model, "hf_device_map"): |
| | if _has_meta_params(model): |
| | raise RuntimeError( |
| | "Model parameters are still on the meta device after loading. " |
| | "This is usually a torch/transformers/accelerate version or memory/offload issue." |
| | ) |
| | model.to(PREFERRED_DEVICE) |
| |
|
| | model.eval() |
| | return model, tokenizer |
| |
|
| |
|
| | model, tokenizer = load_model_and_tokenizer() |
| |
|
| |
|
| | def classify(prompt: str): |
| | exec_device = _first_real_param_device(model) |
| | inputs = tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | padding=True, |
| | max_length=512 |
| | ) |
| | inputs = {k: v.to(exec_device) for k, v in inputs.items()} |
| |
|
| | with torch.no_grad(): |
| | logits = model(**inputs).logits |
| | probs = torch.softmax(logits, dim=-1).squeeze().cpu() |
| | pred = torch.argmax(probs).item() |
| | confidence = probs[pred].item() |
| | return pred, confidence |
| |
|
| |
|
| | |
| | st.markdown(""" |
| | <style> |
| | @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap'); |
| | |
| | html, body, div, span, input, label, textarea, button, h1, h2, p, table { |
| | font-family: 'Montserrat', sans-serif !important; |
| | } |
| | |
| | [class^="css-"], [class*=" css-"] { |
| | font-family: 'Montserrat', sans-serif !important; |
| | } |
| | |
| | header {visibility: hidden;} |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| | st.title("QDG Classifier") |
| | st.write("Built by [**Dejan AI**](https://dejan.ai)") |
| | st.write("Google's AI models can request grounding with search to deliver more accurate or up-to-date answers. This tool predicts whether a search query, question, or prompt is likely to be grounded.") |
| |
|
| | |
| | example_text = """how would a cat describe a dog |
| | how to reset a Nest thermostat |
| | write a poem about time |
| | is there a train strike in London today |
| | summarize the theory of relativity |
| | who won the champions league last year |
| | explain quantum computing to a child |
| | weather in tokyo tomorrow |
| | generate a social media post for Earth Day |
| | what is the latest iPhone model""" |
| |
|
| | user_input = st.text_area( |
| | "Enter one search query, question, or prompt per line:", |
| | placeholder=example_text |
| | ) |
| |
|
| | if st.button("Classify"): |
| | raw_input = user_input.strip() |
| | if raw_input: |
| | prompts = [line.strip() for line in raw_input.split("\n") if line.strip()] |
| | else: |
| | prompts = [line.strip() for line in example_text.split("\n")] |
| |
|
| | if not prompts: |
| | st.warning("Please enter at least one prompt.") |
| | else: |
| | info_box = st.info("Processing... results will appear below one by one.") |
| | table_placeholder = st.empty() |
| | results = [] |
| |
|
| | for p in prompts: |
| | with st.spinner(f"Classifying: {p[:50]}..."): |
| | label, conf = classify(p) |
| | results.append({ |
| | "Prompt": p, |
| | "Grounding": "Yes" if label == 1 else "No", |
| | "Confidence": round(conf, 4) |
| | }) |
| | df = pd.DataFrame(results) |
| | table_placeholder.data_editor( |
| | df, |
| | column_config={ |
| | "Confidence": st.column_config.ProgressColumn( |
| | label="Confidence", |
| | min_value=0.0, |
| | max_value=1.0, |
| | format="%.4f" |
| | ) |
| | }, |
| | hide_index=True, |
| | ) |
| | info_box.empty() |
| |
|
| | |
| | st.subheader("Working together.") |
| | st.write("[**Schedule a call**](https://dejan.ai/call/)") |
| |
|