import streamlit as st import torch from transformers import MT5ForConditionalGeneration, MT5Tokenizer import time # Page config st.set_page_config( page_title="Query Fan-Out", page_icon="🔍", layout="centered" ) # Cache the model and tokenizer @st.cache_resource def load_model(): """Load the fine-tuned mT5 model and tokenizer""" model_name = "dejanseo/query-fanout" with st.spinner("Loading model... This may take a minute on first run."): tokenizer = MT5Tokenizer.from_pretrained(model_name, use_auth_token=True) model = MT5ForConditionalGeneration.from_pretrained(model_name, use_auth_token=True) # Move to GPU if available if torch.cuda.is_available(): model = model.cuda() model.eval() return tokenizer, model def generate_expansions(url, query, tokenizer, model, num_return_sequences=5): """Generate query expansions using the model""" # Format input as per training input_text = f"For URL: {url} diversify query: {query}" # Tokenize inputs = tokenizer( input_text, max_length=32, truncation=True, return_tensors="pt" ) # Move to GPU if available if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} # Generate with diverse beam search with torch.no_grad(): outputs = model.generate( **inputs, max_length=16, num_return_sequences=num_return_sequences, num_beams=num_return_sequences * 2, # More beams for diversity num_beam_groups=num_return_sequences, # Use beam groups for diversity diversity_penalty=0.5, # Encourage diverse outputs temperature=0.8, do_sample=False, # Use beam search for quality early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, forced_eos_token_id=tokenizer.eos_token_id, # Force proper EOS max_new_tokens=16, # Additional constraint ) # Handle mT5's vocabulary size mismatch issue # mT5 has vocab_size=250112 but tokenizer only knows 250100 tokens tokenizer_vocab_size = len(tokenizer) # 250100 for mT5 outputs_clipped = torch.where( outputs < tokenizer_vocab_size, outputs, tokenizer.pad_token_id ) # Decode all sequences expansions = [] for output in outputs_clipped: try: expansion = tokenizer.decode(output, skip_special_tokens=True) if expansion and expansion != query: # Filter out empty or identical expansions expansions.append(expansion) except Exception as e: # Skip any problematic sequences continue # Remove duplicates while preserving order seen = set() unique_expansions = [] for exp in expansions: if exp not in seen: seen.add(exp) unique_expansions.append(exp) return unique_expansions # Main UI st.title("Query Fan-Out") st.markdown("Query fan-out model trained by [DEJAN AI](https://dejan.ai/).") # Load model tokenizer, model = load_model() # Input section col1, col2 = st.columns([1, 2]) with col1: url = st.text_input( "URL Context", value="dejan.ai", help="The domain or URL that provides context for the query" ) with col2: query = st.text_input( "Search Query", value="AI search SEO Agency", help="The search query you want to expand" ) # Advanced options with st.expander("Advanced Options"): num_expansions = st.slider( "Number of expansions", min_value=1, max_value=10, value=10, help="How many query variations to generate" ) # Generate button if st.button("GO", type="primary"): if url and query: start_time = time.time() with st.spinner("Generating fan-out queries..."): expansions = generate_expansions( url, query, tokenizer, model, num_return_sequences=num_expansions ) generation_time = time.time() - start_time # Display results if expansions: st.markdown("### 📝 Generated Query Fan-Outs") # Show original query for reference st.markdown(f"**Original:** `{query}`") st.markdown("**Fan-Outs:**") # Display each expansion for i, expansion in enumerate(expansions, 1): st.markdown(f"{i}. `{expansion}`") # Show stats st.markdown("---") col1, col2 = st.columns(2) with col1: st.metric("Fan-Outs Generated", len(expansions)) with col2: st.metric("Generation Time", f"{generation_time:.2f}s") else: st.warning("No valid fan-outs generated. Try a different query.") else: st.error("Please enter both URL and query.") # Footer st.markdown("---") st.markdown( """
""", unsafe_allow_html=True )