Spaces:
Running
Running
| # streamlit_app.py | |
| import time | |
| import torch | |
| import streamlit as st | |
| from typing import List, Tuple, Dict, Any | |
| from transformers import MT5ForConditionalGeneration, MT5Tokenizer | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| # ------------------ CONSTANTS ------------------ | |
| MODEL_PATH = "dejanseo/query-fanout" | |
| CACHE_DIR = "/app/cache/huggingface" | |
| MAX_INPUT_LENGTH = 32 | |
| MAX_TARGET_LENGTH = 16 | |
| # --- BATCHING CONFIGURATION --- | |
| TOTAL_DESIRED_CANDIDATES = 200 | |
| GENERATION_BATCH_SIZE = 10 | |
| # ------------------ HARDCODED SETTINGS ------------------ | |
| GENERATION_CONFIG: Dict[str, Any] = { | |
| "temperature": 1.10, "top_p": 0.98, "no_repeat_ngram_size": 2, | |
| "repetition_penalty": 1.10, "seed": 42, "sort_by": "logp/len", | |
| } | |
| # ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------ | |
| def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]: | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR) | |
| model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR) | |
| model.to(device) | |
| model.eval() | |
| return tok, model, device | |
| # ------------------ GENERATION HELPERS ------------------ | |
| def build_inputs(tok: MT5Tokenizer, url: str, query: str, device: torch.device): | |
| txt = f"For URL: {url} diversify query: {query}" | |
| enc = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True) | |
| return {k: v.to(device) for k, v in enc.items()} | |
| def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]: | |
| return tok.batch_decode(seqs, skip_special_tokens=True) | |
| def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]: | |
| if not hasattr(gen, "scores"): return [float("nan")] * gen.sequences.size(0) | |
| scores, seqs = gen.scores, gen.sequences | |
| nseq, eos_id, pad_id = seqs.size(0), tok.eos_token_id or 1, tok.pad_token_id | |
| sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device) | |
| count = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device) | |
| finished = torch.zeros(nseq, dtype=torch.bool, device=scores[0].device) | |
| for t in range(len(scores)): | |
| step_logits, step_tok = scores[t], seqs[:, t + 1] | |
| valid = step_tok.ne(pad_id) & (~finished) | |
| if valid.any(): | |
| step_logprobs = F.log_softmax(step_logits, dim=-1) | |
| gather = step_logprobs.gather(1, step_tok.unsqueeze(1)).squeeze(1) | |
| sum_logp += torch.where(valid, gather, torch.zeros_like(gather)) | |
| count += valid.float() | |
| finished |= step_tok.eq(eos_id) | |
| count = torch.where(count.eq(0), torch.ones_like(count), count) | |
| return [(lp / c).item() for lp, c in zip(sum_logp, count)] | |
| # --- UPDATED sampling_generate function (Deep Analysis) --- | |
| def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None): | |
| kwargs = dict( | |
| max_length=MAX_TARGET_LENGTH, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_return_sequences=top_n, | |
| return_dict_in_generate=True, | |
| output_scores=True | |
| ) | |
| if int(no_repeat_ngram_size) > 0: | |
| kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size) | |
| if float(repetition_penalty) != 1.0: | |
| kwargs["repetition_penalty"] = float(repetition_penalty) | |
| if bad_words_ids: | |
| kwargs["bad_words_ids"] = bad_words_ids | |
| gen = model.generate(**inputs, **kwargs) | |
| return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen) | |
| def normalize_text(s: str) -> str: | |
| return " ".join(s.strip().lower().split()) | |
| # --- Beam-based quick function (from old script) --- | |
| def generate_expansions_beam(url: str, query: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device, num_return_sequences: int = 10) -> List[str]: | |
| input_text = f"For URL: {url} diversify query: {query}" | |
| inputs = tok(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=MAX_TARGET_LENGTH, | |
| num_return_sequences=num_return_sequences, | |
| num_beams=num_return_sequences * 2, | |
| num_beam_groups=num_return_sequences, | |
| diversity_penalty=0.5, | |
| temperature=0.8, | |
| do_sample=False, | |
| early_stopping=True, | |
| pad_token_id=tok.pad_token_id, | |
| eos_token_id=tok.eos_token_id, | |
| forced_eos_token_id=tok.eos_token_id, | |
| max_new_tokens=MAX_TARGET_LENGTH, | |
| ) | |
| expansions: List[str] = [] | |
| for seq in outputs: | |
| s = tok.decode(seq, skip_special_tokens=True) | |
| if s and normalize_text(s) != normalize_text(query): | |
| expansions.append(s) | |
| seen = set() | |
| uniq = [] | |
| for s in expansions: | |
| if s not in seen: | |
| seen.add(s) | |
| uniq.append(s) | |
| return uniq | |
| # ------------------ STREAMLIT APP ------------------ | |
| st.set_page_config( | |
| page_title="Query Fan-Out by DEJAN AI", | |
| page_icon="🔎", | |
| layout="wide" | |
| ) | |
| st.logo( | |
| image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", | |
| link="https://dejan.ai/", | |
| ) | |
| tok, model, device = load_model() | |
| st.title("Query Fanout Generator") | |
| st.markdown("Enter a URL and a query to generate a diverse set of related queries.") | |
| # Inputs | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.") | |
| with col2: | |
| query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.") | |
| # Mode + single Run button | |
| mode_high_effort = st.toggle("High Effort", value=False, help="On = Deep Analysis (stochastic sampling, large batch). Off = Quick Fan-Out (beam-based).") | |
| run_btn = st.button("Generate", type="primary") | |
| if run_btn: | |
| if mode_high_effort: | |
| # ---- Deep Analysis path (sampling, large batches) ---- | |
| cfg = GENERATION_CONFIG | |
| with st.spinner("Generating queries..."): | |
| start_ts = time.time() | |
| inputs = build_inputs(tok, url, query, device) | |
| all_texts, all_scores = [], [] | |
| seen_texts_for_bad_words = set() | |
| num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE | |
| progress_bar = st.progress(0) | |
| for i in range(num_batches): | |
| current_seed = cfg["seed"] + i | |
| torch.manual_seed(current_seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(current_seed) | |
| bad_words_ids = None | |
| if seen_texts_for_bad_words: | |
| bad_words_ids = tok( | |
| list(seen_texts_for_bad_words), | |
| add_special_tokens=False, | |
| padding=True, | |
| truncation=True | |
| )["input_ids"] | |
| batch_texts, batch_scores = sampling_generate( | |
| tok, model, device, inputs, | |
| top_n=GENERATION_BATCH_SIZE, | |
| temperature=float(cfg["temperature"]), | |
| top_p=float(cfg["top_p"]), | |
| no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]), | |
| repetition_penalty=float(cfg["repetition_penalty"]), | |
| bad_words_ids=bad_words_ids | |
| ) | |
| all_texts.extend(batch_texts) | |
| all_scores.extend(batch_scores) | |
| for txt in batch_texts: | |
| if txt: | |
| seen_texts_for_bad_words.add(txt) | |
| progress_bar.progress((i + 1) / num_batches) | |
| # Deduplicate and finalize | |
| final_enriched = [] | |
| final_seen_normalized = set() | |
| for txt, sc in zip(all_texts, all_scores): | |
| norm = normalize_text(txt) | |
| if norm and norm not in final_seen_normalized and norm != query.lower(): | |
| final_seen_normalized.add(norm) | |
| final_enriched.append({"logp/len": sc, "text": txt}) | |
| if cfg["sort_by"] == "logp/len": | |
| final_enriched.sort(key=lambda x: x["logp/len"], reverse=True) | |
| final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES] | |
| if not final_enriched: | |
| st.warning("No queries were generated. Try a different input.") | |
| else: | |
| output_texts = [item['text'] for item in final_enriched] | |
| df = pd.DataFrame(output_texts, columns=["Generated Query"]) | |
| df.index = range(1, len(df) + 1) | |
| st.dataframe(df, use_container_width=True) | |
| else: | |
| # ---- Quick Fan-Out path (beam-based, small and simple) ---- | |
| with st.spinner("Generating quick fan-out..."): | |
| start_time = time.time() | |
| expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10) | |
| if expansions: | |
| df = pd.DataFrame(expansions, columns=["Generated Query"]) | |
| df.index = range(1, len(df) + 1) | |
| st.dataframe(df, use_container_width=True) | |
| else: | |
| st.warning("No valid fan-outs generated. Try a different query.") | |