|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import MT5ForConditionalGeneration, MT5Tokenizer |
|
|
import time |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Query Fan-Out", |
|
|
page_icon="🔍", |
|
|
layout="centered" |
|
|
) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load the fine-tuned mT5 model and tokenizer""" |
|
|
model_name = "dejanseo/query-fanout" |
|
|
|
|
|
with st.spinner("Loading model... This may take a minute on first run."): |
|
|
tokenizer = MT5Tokenizer.from_pretrained(model_name, use_auth_token=True) |
|
|
model = MT5ForConditionalGeneration.from_pretrained(model_name, use_auth_token=True) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model = model.cuda() |
|
|
|
|
|
model.eval() |
|
|
|
|
|
return tokenizer, model |
|
|
|
|
|
def generate_expansions(url, query, tokenizer, model, num_return_sequences=5): |
|
|
"""Generate query expansions using the model""" |
|
|
|
|
|
|
|
|
input_text = f"For URL: {url} diversify query: {query}" |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
input_text, |
|
|
max_length=32, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=16, |
|
|
num_return_sequences=num_return_sequences, |
|
|
num_beams=num_return_sequences * 2, |
|
|
num_beam_groups=num_return_sequences, |
|
|
diversity_penalty=0.5, |
|
|
temperature=0.8, |
|
|
do_sample=False, |
|
|
early_stopping=True, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
forced_eos_token_id=tokenizer.eos_token_id, |
|
|
max_new_tokens=16, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
tokenizer_vocab_size = len(tokenizer) |
|
|
outputs_clipped = torch.where( |
|
|
outputs < tokenizer_vocab_size, |
|
|
outputs, |
|
|
tokenizer.pad_token_id |
|
|
) |
|
|
|
|
|
|
|
|
expansions = [] |
|
|
for output in outputs_clipped: |
|
|
try: |
|
|
expansion = tokenizer.decode(output, skip_special_tokens=True) |
|
|
if expansion and expansion != query: |
|
|
expansions.append(expansion) |
|
|
except Exception as e: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
seen = set() |
|
|
unique_expansions = [] |
|
|
for exp in expansions: |
|
|
if exp not in seen: |
|
|
seen.add(exp) |
|
|
unique_expansions.append(exp) |
|
|
|
|
|
return unique_expansions |
|
|
|
|
|
|
|
|
st.title("Query Fan-Out") |
|
|
st.markdown("Query fan-out model trained by [DEJAN AI](https://dejan.ai/).") |
|
|
|
|
|
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
|
|
with col1: |
|
|
url = st.text_input( |
|
|
"URL Context", |
|
|
value="dejan.ai", |
|
|
help="The domain or URL that provides context for the query" |
|
|
) |
|
|
|
|
|
with col2: |
|
|
query = st.text_input( |
|
|
"Search Query", |
|
|
value="AI search SEO Agency", |
|
|
help="The search query you want to expand" |
|
|
) |
|
|
|
|
|
|
|
|
with st.expander("Advanced Options"): |
|
|
num_expansions = st.slider( |
|
|
"Number of expansions", |
|
|
min_value=1, |
|
|
max_value=10, |
|
|
value=10, |
|
|
help="How many query variations to generate" |
|
|
) |
|
|
|
|
|
|
|
|
if st.button("GO", type="primary"): |
|
|
if url and query: |
|
|
start_time = time.time() |
|
|
|
|
|
with st.spinner("Generating fan-out queries..."): |
|
|
expansions = generate_expansions( |
|
|
url, |
|
|
query, |
|
|
tokenizer, |
|
|
model, |
|
|
num_return_sequences=num_expansions |
|
|
) |
|
|
|
|
|
generation_time = time.time() - start_time |
|
|
|
|
|
|
|
|
if expansions: |
|
|
st.markdown("### 📝 Generated Query Fan-Outs") |
|
|
|
|
|
|
|
|
st.markdown(f"**Original:** `{query}`") |
|
|
st.markdown("**Fan-Outs:**") |
|
|
|
|
|
|
|
|
for i, expansion in enumerate(expansions, 1): |
|
|
st.markdown(f"{i}. `{expansion}`") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
st.metric("Fan-Outs Generated", len(expansions)) |
|
|
with col2: |
|
|
st.metric("Generation Time", f"{generation_time:.2f}s") |
|
|
else: |
|
|
st.warning("No valid fan-outs generated. Try a different query.") |
|
|
else: |
|
|
st.error("Please enter both URL and query.") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown( |
|
|
""" |
|
|
<div style='text-align: center'> |
|
|
<p><a href='https://dejan.ai/blog/training-a-query-fan-out-model/'>Training Process</a></p> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |