import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification import pandas as pd import os # Theme configuration - MUST BE FIRST STREAMLIT COMMAND st.set_page_config( page_title="QDG Classifier", page_icon="🔍", layout="wide", initial_sidebar_state="collapsed", menu_items=None ) MODEL_ID = "dejanseo/query-grounding" HF_TOKEN = os.getenv("HF_TOKEN") PREFERRED_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _has_meta_params(m: torch.nn.Module) -> bool: for p in m.parameters(): if getattr(p, "is_meta", False): return True return False def _first_real_param_device(m: torch.nn.Module) -> torch.device: for p in m.parameters(): if not getattr(p, "is_meta", False): return p.device return torch.device("cpu") @st.cache_resource(show_spinner=False) def load_model_and_tokenizer(): # Attempt 1: normal full load (no meta), then move to preferred device model = AutoModelForSequenceClassification.from_pretrained( MODEL_ID, token=HF_TOKEN, low_cpu_mem_usage=False, torch_dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) # If anything is still meta, fallback to device_map loading (do NOT call .to() after that) if _has_meta_params(model): if torch.cuda.is_available(): model = AutoModelForSequenceClassification.from_pretrained( MODEL_ID, token=HF_TOKEN, torch_dtype="auto", device_map="auto", ) else: # CPU fallback retry without dtype hint model = AutoModelForSequenceClassification.from_pretrained( MODEL_ID, token=HF_TOKEN, low_cpu_mem_usage=False, ) # Only call .to() if the model is not dispatched by Accelerate/device_map if not hasattr(model, "hf_device_map"): if _has_meta_params(model): raise RuntimeError( "Model parameters are still on the meta device after loading. " "This is usually a torch/transformers/accelerate version or memory/offload issue." ) model.to(PREFERRED_DEVICE) model.eval() return model, tokenizer model, tokenizer = load_model_and_tokenizer() def classify(prompt: str): exec_device = _first_real_param_device(model) inputs = tokenizer( prompt, return_tensors="pt", truncation=True, padding=True, max_length=512 ) inputs = {k: v.to(exec_device) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1).squeeze().cpu() pred = torch.argmax(probs).item() confidence = probs[pred].item() return pred, confidence # Font and style overrides st.markdown(""" """, unsafe_allow_html=True) # UI st.title("QDG Classifier") st.write("Built by [**Dejan AI**](https://dejan.ai)") st.write("Google's AI models can request grounding with search to deliver more accurate or up-to-date answers. This tool predicts whether a search query, question, or prompt is likely to be grounded.") # Placeholder example prompts example_text = """how would a cat describe a dog how to reset a Nest thermostat write a poem about time is there a train strike in London today summarize the theory of relativity who won the champions league last year explain quantum computing to a child weather in tokyo tomorrow generate a social media post for Earth Day what is the latest iPhone model""" user_input = st.text_area( "Enter one search query, question, or prompt per line:", placeholder=example_text ) if st.button("Classify"): raw_input = user_input.strip() if raw_input: prompts = [line.strip() for line in raw_input.split("\n") if line.strip()] else: prompts = [line.strip() for line in example_text.split("\n")] if not prompts: st.warning("Please enter at least one prompt.") else: info_box = st.info("Processing... results will appear below one by one.") table_placeholder = st.empty() results = [] for p in prompts: with st.spinner(f"Classifying: {p[:50]}..."): label, conf = classify(p) results.append({ "Prompt": p, "Grounding": "Yes" if label == 1 else "No", "Confidence": round(conf, 4) }) df = pd.DataFrame(results) table_placeholder.data_editor( df, column_config={ "Confidence": st.column_config.ProgressColumn( label="Confidence", min_value=0.0, max_value=1.0, format="%.4f" ) }, hide_index=True, ) info_box.empty() # Promo message shown only after results st.subheader("Working together.") st.write("[**Schedule a call**](https://dejan.ai/call/)")