Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Streamlit app to reconstruct prompts from AI assistant responses.""" | |
| import html as html_lib | |
| import os | |
| import re | |
| import requests | |
| import torch | |
| import pandas as pd | |
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_ID = "dejanseo/reverse-prompter" | |
| SEPARATOR = "\n###\n" | |
| CONTRASTIVE_CONFIGS = [ | |
| {"penalty_alpha": a, "top_k": k} | |
| for k in [2, 4, 6, 15] | |
| for a in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] | |
| ] | |
| def load_model(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="bfloat16").to(device).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| return model, tokenizer | |
| def analyze_output(model, tokenizer, input_ids, prefix_len): | |
| with torch.no_grad(): | |
| outputs = model(input_ids.unsqueeze(0)) | |
| logits = outputs.logits[0, prefix_len - 1:-1] | |
| targets = input_ids[prefix_len:] | |
| log_probs = torch.log_softmax(logits, dim=-1) | |
| token_log_probs = log_probs[torch.arange(len(targets)), targets] | |
| perplexity = torch.exp(-token_log_probs.mean()).item() | |
| probs = torch.exp(token_log_probs).tolist() | |
| tokens = [tokenizer.decode([t]) for t in targets.tolist()] | |
| return perplexity, tokens, probs | |
| def render_colored(tokens, probs): | |
| parts = [] | |
| for token, prob in zip(tokens, probs): | |
| opacity = 0.15 + 0.85 * prob | |
| clean = html_lib.escape(token).replace("\n", " ").replace("\r", "") | |
| clean = re.sub(r'[^\x20-\x7E\u00A0-\uFFFF]', '', clean) | |
| parts.append(f'<span style="opacity:{opacity:.2f}">{clean}</span>') | |
| return "".join(parts) | |
| def render_result(tokens, probs): | |
| colored = render_colored(tokens, probs) | |
| return f'<div style="margin-bottom:8px">{colored}</div>' | |
| def fetch_url_markdown(url): | |
| login = os.getenv("DATAFORSEO_LOGIN", "") | |
| password = os.getenv("DATAFORSEO_PASSWORD", "") | |
| if not login or not password: | |
| raise ValueError("DATAFORSEO_LOGIN / DATAFORSEO_PASSWORD not set in Spaces secrets") | |
| resp = requests.post( | |
| "https://api.dataforseo.com/v3/on_page/content_parsing/live", | |
| json=[{"url": url, "enable_javascript": True, "markdown_view": True}], | |
| auth=(login, password), | |
| timeout=30, | |
| ) | |
| if resp.status_code != 200: | |
| raise ValueError(f"API returned status {resp.status_code}") | |
| data = resp.json() | |
| tasks = data.get("tasks", []) | |
| if not tasks or tasks[0].get("status_code") != 20000: | |
| raise ValueError(f"API task error: {tasks[0].get('status_message') if tasks else 'no tasks'}") | |
| items = (tasks[0].get("result") or [{}])[0].get("items") | |
| if not items: | |
| raise ValueError("No items returned from API") | |
| text = items[0].get("page_as_markdown") or "" | |
| if len(text) < 100: | |
| raise ValueError("Page content too short or empty") | |
| return text | |
| st.set_page_config(layout="wide") | |
| st.logo("https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", size="LARGE", link="https://dejan.ai/") | |
| st.subheader("Reverse Prompting") | |
| st.caption("Paste any AI-generated text and this tool will reverse-engineer the most likely prompts that produced it. It runs the text through a fine-tuned language model across multiple decoding configurations and ranks the reconstructed prompts by model confidence.") | |
| model, tokenizer = load_model() | |
| TEST_FILE = os.path.join(os.path.dirname(__file__), "test.md") | |
| if "quick_test" not in st.session_state: | |
| st.session_state.quick_test = False | |
| if "fetched_text" not in st.session_state: | |
| st.session_state.fetched_text = "" | |
| if "run_after_fetch" not in st.session_state: | |
| st.session_state.run_after_fetch = False | |
| if st.session_state.fetched_text: | |
| st.session_state.text_input = st.session_state.fetched_text | |
| st.session_state.fetched_text = "" | |
| st.session_state.run_after_fetch = True | |
| if st.session_state.quick_test: | |
| with open(TEST_FILE) as f: | |
| st.session_state.text_input = f.read() | |
| tab_paste, tab_url = st.tabs(["Paste", "URL Mode \u1D49\u02E3\u1D56\u1D49\u02B3\u2071\u1D50\u1D49\u207F\u1D57\u1D43\u02E1"]) | |
| with tab_paste: | |
| text_input = st.text_area("Paste an AI assistant response", height=200, key="text_input") | |
| col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3]) | |
| run = col_btn1.button("Reconstruct Prompts", type="primary", width="stretch") | |
| quick = col_btn2.button("Quick Test", type="secondary", width="stretch") | |
| clear = col_btn3.button("Clear", type="secondary", width="stretch") | |
| if quick and not st.session_state.quick_test: | |
| st.session_state.quick_test = True | |
| st.rerun() | |
| if clear: | |
| st.session_state.text_input = "" | |
| st.rerun() | |
| run = run or st.session_state.quick_test or st.session_state.run_after_fetch | |
| if st.session_state.quick_test: | |
| st.session_state.quick_test = False | |
| if st.session_state.run_after_fetch: | |
| st.session_state.run_after_fetch = False | |
| with tab_url: | |
| st.caption("This experimental feature reveals what prompts would lead to generating a page such as the one you entered. While *not the intended use of the model*, it's certainly an interesting feature for exploring semantic make-up of the page.") | |
| url_col, url_btn_col = st.columns([5, 1]) | |
| url_input = url_col.text_input("URL to scrape", label_visibility="collapsed", placeholder="Enter a URL to scrape") | |
| fetch = url_btn_col.button("Fetch", type="primary", width="stretch") | |
| if fetch and url_input.strip(): | |
| url = url_input.strip() | |
| if not url.startswith(("http://", "https://")): | |
| url = "https://" + url | |
| if "/" not in url.split("//", 1)[-1]: | |
| url += "/" | |
| with st.spinner("Fetching page..."): | |
| try: | |
| st.session_state.fetched_text = fetch_url_markdown(url) | |
| st.rerun() | |
| except Exception: | |
| st.warning("This page couldn't be scraped. The site may be blocking automated access or the page has insufficient content. Try pasting the content manually in the text area.") | |
| if run and text_input.strip(): | |
| prompt = text_input.strip() + SEPARATOR | |
| inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) | |
| prefix_len = inputs["input_ids"].shape[-1] | |
| tables_container = st.container() | |
| st.subheader("Word Frequency Heatmap") | |
| results = [] | |
| seen = set() | |
| progress = st.progress(0) | |
| output_container = st.empty() | |
| for i, config in enumerate(CONTRASTIVE_CONFIGS): | |
| progress.progress((i + 1) / len(CONTRASTIVE_CONFIGS)) | |
| try: | |
| outputs = model.generate(**inputs, max_new_tokens=256, trust_remote_code=True, num_return_sequences=1, **config) | |
| gen_ids = outputs[0][prefix_len:] | |
| eos_id = tokenizer.eos_token_id | |
| if eos_id is not None and (gen_ids == eos_id).any(): | |
| gen_ids = gen_ids[:(gen_ids == eos_id).nonzero(as_tuple=True)[0][0]] | |
| text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| if text and text not in seen: | |
| seen.add(text) | |
| trimmed = torch.cat([outputs[0][:prefix_len], gen_ids]) | |
| ppl, tokens, probs = analyze_output(model, tokenizer, trimmed, prefix_len) | |
| results.append((text, ppl, tokens, probs, config)) | |
| top = sorted(results, key=lambda x: x[1])[:10] | |
| html = "".join(render_result(t, p) for _, _, t, p, c in top) | |
| output_container.markdown(html, unsafe_allow_html=True) | |
| except Exception: | |
| pass | |
| progress.empty() | |
| results.sort(key=lambda x: x[1]) | |
| if results: | |
| top10 = results[:10] | |
| col_prompts, col_phrases = tables_container.columns(2) | |
| with col_prompts: | |
| st.subheader("Reconstructed Prompts") | |
| st.caption("Top 10 most likely prompts ranked by lowest perplexity. Token opacity reflects the model's confidence in each word.") | |
| df_top = pd.DataFrame( | |
| [(text, round(ppl, 2), [round(p * 100) for p in probs]) for text, ppl, _, probs, _ in top10], | |
| columns=["Prompt", "Perplexity", "Confidence"], | |
| ) | |
| st.dataframe(df_top, width="content", hide_index=True, column_config={ | |
| "Prompt": st.column_config.TextColumn(width=None), | |
| "Perplexity": st.column_config.NumberColumn(width=None), | |
| "Confidence": st.column_config.BarChartColumn(y_min=0, y_max=100), | |
| }) | |
| with col_phrases: | |
| st.subheader("Key Phrases") | |
| st.caption("The most most important phrases scored by balance between length and frequency.") | |
| texts = [re.sub(r'[^\w\s]', '', text.lower()).split() for text, _, _, _, _ in top10] | |
| phrase_hits = {} | |
| for idx, words in enumerate(texts): | |
| for length in range(2, len(words) + 1): | |
| for start in range(len(words) - length + 1): | |
| phrase = " ".join(words[start:start + length]) | |
| if phrase not in phrase_hits: | |
| phrase_hits[phrase] = set() | |
| phrase_hits[phrase].add(idx) | |
| shared = [(p, len(ids), len(p.split()), len(ids) * len(p.split())) for p, ids in phrase_hits.items() if len(ids) >= 2] | |
| shared.sort(key=lambda x: x[3], reverse=True) | |
| filtered = [] | |
| for phrase, count, length, score in shared: | |
| if not any(phrase in longer and score <= lscore for longer, _, _, lscore in filtered): | |
| filtered.append((phrase, count, length, score)) | |
| if filtered: | |
| max_score = filtered[0][3] | |
| rows = [(p, round(s / max_score * 100)) for p, _, _, s in filtered[:20]] | |
| df = pd.DataFrame(rows, columns=["Phrase", "Score"]) | |
| st.dataframe(df, width="content", hide_index=True, column_config={ | |
| "Score": st.column_config.ProgressColumn(format="%d%%", min_value=0, max_value=100), | |
| }) | |