#!/usr/bin/env python3 """Streamlit app to reconstruct prompts from AI assistant responses.""" import html as html_lib import os import re import requests import torch import pandas as pd import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID = "dejanseo/reverse-prompter" SEPARATOR = "\n###\n" CONTRASTIVE_CONFIGS = [ {"penalty_alpha": a, "top_k": k} for k in [2, 4, 6, 15] for a in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ] @st.cache_resource def load_model(): device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="bfloat16").to(device).eval() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) return model, tokenizer def analyze_output(model, tokenizer, input_ids, prefix_len): with torch.no_grad(): outputs = model(input_ids.unsqueeze(0)) logits = outputs.logits[0, prefix_len - 1:-1] targets = input_ids[prefix_len:] log_probs = torch.log_softmax(logits, dim=-1) token_log_probs = log_probs[torch.arange(len(targets)), targets] perplexity = torch.exp(-token_log_probs.mean()).item() probs = torch.exp(token_log_probs).tolist() tokens = [tokenizer.decode([t]) for t in targets.tolist()] return perplexity, tokens, probs def render_colored(tokens, probs): parts = [] for token, prob in zip(tokens, probs): opacity = 0.15 + 0.85 * prob clean = html_lib.escape(token).replace("\n", " ").replace("\r", "") clean = re.sub(r'[^\x20-\x7E\u00A0-\uFFFF]', '', clean) parts.append(f'{clean}') return "".join(parts) def render_result(tokens, probs): colored = render_colored(tokens, probs) return f'
{colored}
' def fetch_url_markdown(url): login = os.getenv("DATAFORSEO_LOGIN", "") password = os.getenv("DATAFORSEO_PASSWORD", "") if not login or not password: raise ValueError("DATAFORSEO_LOGIN / DATAFORSEO_PASSWORD not set in Spaces secrets") resp = requests.post( "https://api.dataforseo.com/v3/on_page/content_parsing/live", json=[{"url": url, "enable_javascript": True, "markdown_view": True}], auth=(login, password), timeout=30, ) if resp.status_code != 200: raise ValueError(f"API returned status {resp.status_code}") data = resp.json() tasks = data.get("tasks", []) if not tasks or tasks[0].get("status_code") != 20000: raise ValueError(f"API task error: {tasks[0].get('status_message') if tasks else 'no tasks'}") items = (tasks[0].get("result") or [{}])[0].get("items") if not items: raise ValueError("No items returned from API") text = items[0].get("page_as_markdown") or "" if len(text) < 100: raise ValueError("Page content too short or empty") return text st.set_page_config(layout="wide") st.logo("https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", size="LARGE", link="https://dejan.ai/") st.subheader("Reverse Prompting") st.caption("Paste any AI-generated text and this tool will reverse-engineer the most likely prompts that produced it. It runs the text through a fine-tuned language model across multiple decoding configurations and ranks the reconstructed prompts by model confidence.") model, tokenizer = load_model() TEST_FILE = os.path.join(os.path.dirname(__file__), "test.md") if "quick_test" not in st.session_state: st.session_state.quick_test = False if "fetched_text" not in st.session_state: st.session_state.fetched_text = "" if "run_after_fetch" not in st.session_state: st.session_state.run_after_fetch = False if st.session_state.fetched_text: st.session_state.text_input = st.session_state.fetched_text st.session_state.fetched_text = "" st.session_state.run_after_fetch = True if st.session_state.quick_test: with open(TEST_FILE) as f: st.session_state.text_input = f.read() tab_paste, tab_url = st.tabs(["Paste", "URL Mode \u1D49\u02E3\u1D56\u1D49\u02B3\u2071\u1D50\u1D49\u207F\u1D57\u1D43\u02E1"]) with tab_paste: text_input = st.text_area("Paste an AI assistant response", height=200, key="text_input") col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3]) run = col_btn1.button("Reconstruct Prompts", type="primary", width="stretch") quick = col_btn2.button("Quick Test", type="secondary", width="stretch") clear = col_btn3.button("Clear", type="secondary", width="stretch") if quick and not st.session_state.quick_test: st.session_state.quick_test = True st.rerun() if clear: st.session_state.text_input = "" st.rerun() run = run or st.session_state.quick_test or st.session_state.run_after_fetch if st.session_state.quick_test: st.session_state.quick_test = False if st.session_state.run_after_fetch: st.session_state.run_after_fetch = False with tab_url: st.caption("This experimental feature reveals what prompts would lead to generating a page such as the one you entered. While *not the intended use of the model*, it's certainly an interesting feature for exploring semantic make-up of the page.") url_col, url_btn_col = st.columns([5, 1]) url_input = url_col.text_input("URL to scrape", label_visibility="collapsed", placeholder="Enter a URL to scrape") fetch = url_btn_col.button("Fetch", type="primary", width="stretch") if fetch and url_input.strip(): url = url_input.strip() if not url.startswith(("http://", "https://")): url = "https://" + url if "/" not in url.split("//", 1)[-1]: url += "/" with st.spinner("Fetching page..."): try: st.session_state.fetched_text = fetch_url_markdown(url) st.rerun() except Exception: st.warning("This page couldn't be scraped. The site may be blocking automated access or the page has insufficient content. Try pasting the content manually in the text area.") if run and text_input.strip(): prompt = text_input.strip() + SEPARATOR inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) prefix_len = inputs["input_ids"].shape[-1] tables_container = st.container() st.subheader("Word Frequency Heatmap") results = [] seen = set() progress = st.progress(0) output_container = st.empty() for i, config in enumerate(CONTRASTIVE_CONFIGS): progress.progress((i + 1) / len(CONTRASTIVE_CONFIGS)) try: outputs = model.generate(**inputs, max_new_tokens=256, trust_remote_code=True, num_return_sequences=1, **config) gen_ids = outputs[0][prefix_len:] eos_id = tokenizer.eos_token_id if eos_id is not None and (gen_ids == eos_id).any(): gen_ids = gen_ids[:(gen_ids == eos_id).nonzero(as_tuple=True)[0][0]] text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() if text and text not in seen: seen.add(text) trimmed = torch.cat([outputs[0][:prefix_len], gen_ids]) ppl, tokens, probs = analyze_output(model, tokenizer, trimmed, prefix_len) results.append((text, ppl, tokens, probs, config)) top = sorted(results, key=lambda x: x[1])[:10] html = "".join(render_result(t, p) for _, _, t, p, c in top) output_container.markdown(html, unsafe_allow_html=True) except Exception: pass progress.empty() results.sort(key=lambda x: x[1]) if results: top10 = results[:10] col_prompts, col_phrases = tables_container.columns(2) with col_prompts: st.subheader("Reconstructed Prompts") st.caption("Top 10 most likely prompts ranked by lowest perplexity. Token opacity reflects the model's confidence in each word.") df_top = pd.DataFrame( [(text, round(ppl, 2), [round(p * 100) for p in probs]) for text, ppl, _, probs, _ in top10], columns=["Prompt", "Perplexity", "Confidence"], ) st.dataframe(df_top, width="content", hide_index=True, column_config={ "Prompt": st.column_config.TextColumn(width=None), "Perplexity": st.column_config.NumberColumn(width=None), "Confidence": st.column_config.BarChartColumn(y_min=0, y_max=100), }) with col_phrases: st.subheader("Key Phrases") st.caption("The most most important phrases scored by balance between length and frequency.") texts = [re.sub(r'[^\w\s]', '', text.lower()).split() for text, _, _, _, _ in top10] phrase_hits = {} for idx, words in enumerate(texts): for length in range(2, len(words) + 1): for start in range(len(words) - length + 1): phrase = " ".join(words[start:start + length]) if phrase not in phrase_hits: phrase_hits[phrase] = set() phrase_hits[phrase].add(idx) shared = [(p, len(ids), len(p.split()), len(ids) * len(p.split())) for p, ids in phrase_hits.items() if len(ids) >= 2] shared.sort(key=lambda x: x[3], reverse=True) filtered = [] for phrase, count, length, score in shared: if not any(phrase in longer and score <= lscore for longer, _, _, lscore in filtered): filtered.append((phrase, count, length, score)) if filtered: max_score = filtered[0][3] rows = [(p, round(s / max_score * 100)) for p, _, _, s in filtered[:20]] df = pd.DataFrame(rows, columns=["Phrase", "Score"]) st.dataframe(df, width="content", hide_index=True, column_config={ "Score": st.column_config.ProgressColumn(format="%d%%", min_value=0, max_value=100), })