| import streamlit as st |
| import torch |
| from transformers import MT5ForConditionalGeneration, MT5Tokenizer |
| import time |
|
|
| |
| st.set_page_config( |
| page_title="Query Fan-Out", |
| page_icon="🔍", |
| layout="centered" |
| ) |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| """Load the fine-tuned mT5 model and tokenizer""" |
| model_name = "dejanseo/query-fanout" |
| |
| with st.spinner("Loading model... This may take a minute on first run."): |
| tokenizer = MT5Tokenizer.from_pretrained(model_name, use_auth_token=True) |
| model = MT5ForConditionalGeneration.from_pretrained(model_name, use_auth_token=True) |
| |
| |
| if torch.cuda.is_available(): |
| model = model.cuda() |
| |
| model.eval() |
| |
| return tokenizer, model |
|
|
| def generate_expansions(url, query, tokenizer, model, num_return_sequences=5): |
| """Generate query expansions using the model""" |
| |
| |
| input_text = f"For URL: {url} diversify query: {query}" |
| |
| |
| inputs = tokenizer( |
| input_text, |
| max_length=32, |
| truncation=True, |
| return_tensors="pt" |
| ) |
| |
| |
| if torch.cuda.is_available(): |
| inputs = {k: v.cuda() for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_length=16, |
| num_return_sequences=num_return_sequences, |
| num_beams=num_return_sequences * 2, |
| num_beam_groups=num_return_sequences, |
| diversity_penalty=0.5, |
| temperature=0.8, |
| do_sample=False, |
| early_stopping=True, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| forced_eos_token_id=tokenizer.eos_token_id, |
| max_new_tokens=16, |
| ) |
| |
| |
| |
| tokenizer_vocab_size = len(tokenizer) |
| outputs_clipped = torch.where( |
| outputs < tokenizer_vocab_size, |
| outputs, |
| tokenizer.pad_token_id |
| ) |
| |
| |
| expansions = [] |
| for output in outputs_clipped: |
| try: |
| expansion = tokenizer.decode(output, skip_special_tokens=True) |
| if expansion and expansion != query: |
| expansions.append(expansion) |
| except Exception as e: |
| |
| continue |
| |
| |
| seen = set() |
| unique_expansions = [] |
| for exp in expansions: |
| if exp not in seen: |
| seen.add(exp) |
| unique_expansions.append(exp) |
| |
| return unique_expansions |
|
|
| |
| st.title("Query Fan-Out") |
| st.markdown("Query fan-out model trained by [DEJAN AI](https://dejan.ai/).") |
|
|
| |
| tokenizer, model = load_model() |
|
|
| |
| col1, col2 = st.columns([1, 2]) |
|
|
| with col1: |
| url = st.text_input( |
| "URL Context", |
| value="dejan.ai", |
| help="The domain or URL that provides context for the query" |
| ) |
|
|
| with col2: |
| query = st.text_input( |
| "Search Query", |
| value="AI search SEO Agency", |
| help="The search query you want to expand" |
| ) |
|
|
| |
| with st.expander("Advanced Options"): |
| num_expansions = st.slider( |
| "Number of expansions", |
| min_value=1, |
| max_value=10, |
| value=10, |
| help="How many query variations to generate" |
| ) |
|
|
| |
| if st.button("GO", type="primary"): |
| if url and query: |
| start_time = time.time() |
| |
| with st.spinner("Generating fan-out queries..."): |
| expansions = generate_expansions( |
| url, |
| query, |
| tokenizer, |
| model, |
| num_return_sequences=num_expansions |
| ) |
| |
| generation_time = time.time() - start_time |
| |
| |
| if expansions: |
| st.markdown("### 📝 Generated Query Fan-Outs") |
| |
| |
| st.markdown(f"**Original:** `{query}`") |
| st.markdown("**Fan-Outs:**") |
| |
| |
| for i, expansion in enumerate(expansions, 1): |
| st.markdown(f"{i}. `{expansion}`") |
| |
| |
| st.markdown("---") |
| col1, col2 = st.columns(2) |
| with col1: |
| st.metric("Fan-Outs Generated", len(expansions)) |
| with col2: |
| st.metric("Generation Time", f"{generation_time:.2f}s") |
| else: |
| st.warning("No valid fan-outs generated. Try a different query.") |
| else: |
| st.error("Please enter both URL and query.") |
|
|
| |
| st.markdown("---") |
| st.markdown( |
| """ |
| <div style='text-align: center'> |
| <p><a href='https://dejan.ai/blog/training-a-query-fan-out-model/'>Training Process</a></p> |
| </div> |
| """, |
| unsafe_allow_html=True |
| ) |