# streamlit_app.py import time import torch import streamlit as st from typing import List, Tuple, Dict, Any from transformers import MT5ForConditionalGeneration, MT5Tokenizer import torch.nn.functional as F import pandas as pd # ------------------ CONSTANTS ------------------ MODEL_PATH = "dejanseo/query-fanout" CACHE_DIR = "/app/cache/huggingface" MAX_INPUT_LENGTH = 32 MAX_TARGET_LENGTH = 16 # --- BATCHING CONFIGURATION --- TOTAL_DESIRED_CANDIDATES = 200 GENERATION_BATCH_SIZE = 10 # ------------------ HARDCODED SETTINGS ------------------ GENERATION_CONFIG: Dict[str, Any] = { "temperature": 1.10, "top_p": 0.98, "no_repeat_ngram_size": 2, "repetition_penalty": 1.10, "seed": 42, "sort_by": "logp/len", } # ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------ @st.cache_resource def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]: device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR) model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR) model.to(device) model.eval() return tok, model, device # ------------------ GENERATION HELPERS ------------------ def build_inputs(tok: MT5Tokenizer, url: str, query: str, device: torch.device): txt = f"For URL: {url} diversify query: {query}" enc = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True) return {k: v.to(device) for k, v in enc.items()} def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]: return tok.batch_decode(seqs, skip_special_tokens=True) def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]: if not hasattr(gen, "scores"): return [float("nan")] * gen.sequences.size(0) scores, seqs = gen.scores, gen.sequences nseq, eos_id, pad_id = seqs.size(0), tok.eos_token_id or 1, tok.pad_token_id sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device) count = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device) finished = torch.zeros(nseq, dtype=torch.bool, device=scores[0].device) for t in range(len(scores)): step_logits, step_tok = scores[t], seqs[:, t + 1] valid = step_tok.ne(pad_id) & (~finished) if valid.any(): step_logprobs = F.log_softmax(step_logits, dim=-1) gather = step_logprobs.gather(1, step_tok.unsqueeze(1)).squeeze(1) sum_logp += torch.where(valid, gather, torch.zeros_like(gather)) count += valid.float() finished |= step_tok.eq(eos_id) count = torch.where(count.eq(0), torch.ones_like(count), count) return [(lp / c).item() for lp, c in zip(sum_logp, count)] # --- UPDATED sampling_generate function (Deep Analysis) --- def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None): kwargs = dict( max_length=MAX_TARGET_LENGTH, do_sample=True, temperature=temperature, top_p=top_p, num_return_sequences=top_n, return_dict_in_generate=True, output_scores=True ) if int(no_repeat_ngram_size) > 0: kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size) if float(repetition_penalty) != 1.0: kwargs["repetition_penalty"] = float(repetition_penalty) if bad_words_ids: kwargs["bad_words_ids"] = bad_words_ids gen = model.generate(**inputs, **kwargs) return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen) def normalize_text(s: str) -> str: return " ".join(s.strip().lower().split()) # --- Beam-based quick function (from old script) --- def generate_expansions_beam(url: str, query: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device, num_return_sequences: int = 10) -> List[str]: input_text = f"For URL: {url} diversify query: {query}" inputs = tok(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_length=MAX_TARGET_LENGTH, num_return_sequences=num_return_sequences, num_beams=num_return_sequences * 2, num_beam_groups=num_return_sequences, diversity_penalty=0.5, temperature=0.8, do_sample=False, early_stopping=True, pad_token_id=tok.pad_token_id, eos_token_id=tok.eos_token_id, forced_eos_token_id=tok.eos_token_id, max_new_tokens=MAX_TARGET_LENGTH, ) expansions: List[str] = [] for seq in outputs: s = tok.decode(seq, skip_special_tokens=True) if s and normalize_text(s) != normalize_text(query): expansions.append(s) seen = set() uniq = [] for s in expansions: if s not in seen: seen.add(s) uniq.append(s) return uniq # ------------------ STREAMLIT APP ------------------ st.set_page_config( page_title="Query Fan-Out by DEJAN AI", page_icon="🔎", layout="wide" ) st.logo( image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", link="https://dejan.ai/", ) tok, model, device = load_model() st.title("Query Fanout Generator") st.markdown("Enter a URL and a query to generate a diverse set of related queries.") # Inputs col1, col2 = st.columns(2) with col1: url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.") with col2: query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.") # Mode + single Run button mode_high_effort = st.toggle("High Effort", value=False, help="On = Deep Analysis (stochastic sampling, large batch). Off = Quick Fan-Out (beam-based).") run_btn = st.button("Generate", type="primary") if run_btn: if mode_high_effort: # ---- Deep Analysis path (sampling, large batches) ---- cfg = GENERATION_CONFIG with st.spinner("Generating queries..."): start_ts = time.time() inputs = build_inputs(tok, url, query, device) all_texts, all_scores = [], [] seen_texts_for_bad_words = set() num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE progress_bar = st.progress(0) for i in range(num_batches): current_seed = cfg["seed"] + i torch.manual_seed(current_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(current_seed) bad_words_ids = None if seen_texts_for_bad_words: bad_words_ids = tok( list(seen_texts_for_bad_words), add_special_tokens=False, padding=True, truncation=True )["input_ids"] batch_texts, batch_scores = sampling_generate( tok, model, device, inputs, top_n=GENERATION_BATCH_SIZE, temperature=float(cfg["temperature"]), top_p=float(cfg["top_p"]), no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]), repetition_penalty=float(cfg["repetition_penalty"]), bad_words_ids=bad_words_ids ) all_texts.extend(batch_texts) all_scores.extend(batch_scores) for txt in batch_texts: if txt: seen_texts_for_bad_words.add(txt) progress_bar.progress((i + 1) / num_batches) # Deduplicate and finalize final_enriched = [] final_seen_normalized = set() for txt, sc in zip(all_texts, all_scores): norm = normalize_text(txt) if norm and norm not in final_seen_normalized and norm != query.lower(): final_seen_normalized.add(norm) final_enriched.append({"logp/len": sc, "text": txt}) if cfg["sort_by"] == "logp/len": final_enriched.sort(key=lambda x: x["logp/len"], reverse=True) final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES] if not final_enriched: st.warning("No queries were generated. Try a different input.") else: output_texts = [item['text'] for item in final_enriched] df = pd.DataFrame(output_texts, columns=["Generated Query"]) df.index = range(1, len(df) + 1) st.dataframe(df, use_container_width=True) else: # ---- Quick Fan-Out path (beam-based, small and simple) ---- with st.spinner("Generating quick fan-out..."): start_time = time.time() expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10) if expansions: df = pd.DataFrame(expansions, columns=["Generated Query"]) df.index = range(1, len(df) + 1) st.dataframe(df, use_container_width=True) else: st.warning("No valid fan-outs generated. Try a different query.")