qf-large / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
a033094 verified
import streamlit as st
import torch
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
import time
# Page config
st.set_page_config(
page_title="Query Fan-Out",
page_icon="🔍",
layout="centered"
)
# Cache the model and tokenizer
@st.cache_resource
def load_model():
"""Load the fine-tuned mT5 model and tokenizer"""
model_name = "dejanseo/query-fanout"
with st.spinner("Loading model... This may take a minute on first run."):
tokenizer = MT5Tokenizer.from_pretrained(model_name, use_auth_token=True)
model = MT5ForConditionalGeneration.from_pretrained(model_name, use_auth_token=True)
# Move to GPU if available
if torch.cuda.is_available():
model = model.cuda()
model.eval()
return tokenizer, model
def generate_expansions(url, query, tokenizer, model, num_return_sequences=5):
"""Generate query expansions using the model"""
# Format input as per training
input_text = f"For URL: {url} diversify query: {query}"
# Tokenize
inputs = tokenizer(
input_text,
max_length=32,
truncation=True,
return_tensors="pt"
)
# Move to GPU if available
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# Generate with diverse beam search
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=16,
num_return_sequences=num_return_sequences,
num_beams=num_return_sequences * 2, # More beams for diversity
num_beam_groups=num_return_sequences, # Use beam groups for diversity
diversity_penalty=0.5, # Encourage diverse outputs
temperature=0.8,
do_sample=False, # Use beam search for quality
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
forced_eos_token_id=tokenizer.eos_token_id, # Force proper EOS
max_new_tokens=16, # Additional constraint
)
# Handle mT5's vocabulary size mismatch issue
# mT5 has vocab_size=250112 but tokenizer only knows 250100 tokens
tokenizer_vocab_size = len(tokenizer) # 250100 for mT5
outputs_clipped = torch.where(
outputs < tokenizer_vocab_size,
outputs,
tokenizer.pad_token_id
)
# Decode all sequences
expansions = []
for output in outputs_clipped:
try:
expansion = tokenizer.decode(output, skip_special_tokens=True)
if expansion and expansion != query: # Filter out empty or identical expansions
expansions.append(expansion)
except Exception as e:
# Skip any problematic sequences
continue
# Remove duplicates while preserving order
seen = set()
unique_expansions = []
for exp in expansions:
if exp not in seen:
seen.add(exp)
unique_expansions.append(exp)
return unique_expansions
# Main UI
st.title("Query Fan-Out")
st.markdown("Query fan-out model trained by [DEJAN AI](https://dejan.ai/).")
# Load model
tokenizer, model = load_model()
# Input section
col1, col2 = st.columns([1, 2])
with col1:
url = st.text_input(
"URL Context",
value="dejan.ai",
help="The domain or URL that provides context for the query"
)
with col2:
query = st.text_input(
"Search Query",
value="AI search SEO Agency",
help="The search query you want to expand"
)
# Advanced options
with st.expander("Advanced Options"):
num_expansions = st.slider(
"Number of expansions",
min_value=1,
max_value=10,
value=10,
help="How many query variations to generate"
)
# Generate button
if st.button("GO", type="primary"):
if url and query:
start_time = time.time()
with st.spinner("Generating fan-out queries..."):
expansions = generate_expansions(
url,
query,
tokenizer,
model,
num_return_sequences=num_expansions
)
generation_time = time.time() - start_time
# Display results
if expansions:
st.markdown("### 📝 Generated Query Fan-Outs")
# Show original query for reference
st.markdown(f"**Original:** `{query}`")
st.markdown("**Fan-Outs:**")
# Display each expansion
for i, expansion in enumerate(expansions, 1):
st.markdown(f"{i}. `{expansion}`")
# Show stats
st.markdown("---")
col1, col2 = st.columns(2)
with col1:
st.metric("Fan-Outs Generated", len(expansions))
with col2:
st.metric("Generation Time", f"{generation_time:.2f}s")
else:
st.warning("No valid fan-outs generated. Try a different query.")
else:
st.error("Please enter both URL and query.")
# Footer
st.markdown("---")
st.markdown(
"""
<div style='text-align: center'>
<p><a href='https://dejan.ai/blog/training-a-query-fan-out-model/'>Training Process</a></p>
</div>
""",
unsafe_allow_html=True
)