reverse-prompter / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
a169838 verified
Raw
History Blame Contribute Delete
10.2 kB
#!/usr/bin/env python3
"""Streamlit app to reconstruct prompts from AI assistant responses."""
import html as html_lib
import os
import re
import requests
import torch
import pandas as pd
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "dejanseo/reverse-prompter"
SEPARATOR = "\n###\n"
CONTRASTIVE_CONFIGS = [
{"penalty_alpha": a, "top_k": k}
for k in [2, 4, 6, 15]
for a in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
]
@st.cache_resource
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="bfloat16").to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
return model, tokenizer
def analyze_output(model, tokenizer, input_ids, prefix_len):
with torch.no_grad():
outputs = model(input_ids.unsqueeze(0))
logits = outputs.logits[0, prefix_len - 1:-1]
targets = input_ids[prefix_len:]
log_probs = torch.log_softmax(logits, dim=-1)
token_log_probs = log_probs[torch.arange(len(targets)), targets]
perplexity = torch.exp(-token_log_probs.mean()).item()
probs = torch.exp(token_log_probs).tolist()
tokens = [tokenizer.decode([t]) for t in targets.tolist()]
return perplexity, tokens, probs
def render_colored(tokens, probs):
parts = []
for token, prob in zip(tokens, probs):
opacity = 0.15 + 0.85 * prob
clean = html_lib.escape(token).replace("\n", " ").replace("\r", "")
clean = re.sub(r'[^\x20-\x7E\u00A0-\uFFFF]', '', clean)
parts.append(f'<span style="opacity:{opacity:.2f}">{clean}</span>')
return "".join(parts)
def render_result(tokens, probs):
colored = render_colored(tokens, probs)
return f'<div style="margin-bottom:8px">{colored}</div>'
def fetch_url_markdown(url):
login = os.getenv("DATAFORSEO_LOGIN", "")
password = os.getenv("DATAFORSEO_PASSWORD", "")
if not login or not password:
raise ValueError("DATAFORSEO_LOGIN / DATAFORSEO_PASSWORD not set in Spaces secrets")
resp = requests.post(
"https://api.dataforseo.com/v3/on_page/content_parsing/live",
json=[{"url": url, "enable_javascript": True, "markdown_view": True}],
auth=(login, password),
timeout=30,
)
if resp.status_code != 200:
raise ValueError(f"API returned status {resp.status_code}")
data = resp.json()
tasks = data.get("tasks", [])
if not tasks or tasks[0].get("status_code") != 20000:
raise ValueError(f"API task error: {tasks[0].get('status_message') if tasks else 'no tasks'}")
items = (tasks[0].get("result") or [{}])[0].get("items")
if not items:
raise ValueError("No items returned from API")
text = items[0].get("page_as_markdown") or ""
if len(text) < 100:
raise ValueError("Page content too short or empty")
return text
st.set_page_config(layout="wide")
st.logo("https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", size="LARGE", link="https://dejan.ai/")
st.subheader("Reverse Prompting")
st.caption("Paste any AI-generated text and this tool will reverse-engineer the most likely prompts that produced it. It runs the text through a fine-tuned language model across multiple decoding configurations and ranks the reconstructed prompts by model confidence.")
model, tokenizer = load_model()
TEST_FILE = os.path.join(os.path.dirname(__file__), "test.md")
if "quick_test" not in st.session_state:
st.session_state.quick_test = False
if "fetched_text" not in st.session_state:
st.session_state.fetched_text = ""
if "run_after_fetch" not in st.session_state:
st.session_state.run_after_fetch = False
if st.session_state.fetched_text:
st.session_state.text_input = st.session_state.fetched_text
st.session_state.fetched_text = ""
st.session_state.run_after_fetch = True
if st.session_state.quick_test:
with open(TEST_FILE) as f:
st.session_state.text_input = f.read()
tab_paste, tab_url = st.tabs(["Paste", "URL Mode \u1D49\u02E3\u1D56\u1D49\u02B3\u2071\u1D50\u1D49\u207F\u1D57\u1D43\u02E1"])
with tab_paste:
text_input = st.text_area("Paste an AI assistant response", height=200, key="text_input")
col_btn1, col_btn2, col_btn3, _ = st.columns([1, 1, 1, 3])
run = col_btn1.button("Reconstruct Prompts", type="primary", width="stretch")
quick = col_btn2.button("Quick Test", type="secondary", width="stretch")
clear = col_btn3.button("Clear", type="secondary", width="stretch")
if quick and not st.session_state.quick_test:
st.session_state.quick_test = True
st.rerun()
if clear:
st.session_state.text_input = ""
st.rerun()
run = run or st.session_state.quick_test or st.session_state.run_after_fetch
if st.session_state.quick_test:
st.session_state.quick_test = False
if st.session_state.run_after_fetch:
st.session_state.run_after_fetch = False
with tab_url:
st.caption("This experimental feature reveals what prompts would lead to generating a page such as the one you entered. While *not the intended use of the model*, it's certainly an interesting feature for exploring semantic make-up of the page.")
url_col, url_btn_col = st.columns([5, 1])
url_input = url_col.text_input("URL to scrape", label_visibility="collapsed", placeholder="Enter a URL to scrape")
fetch = url_btn_col.button("Fetch", type="primary", width="stretch")
if fetch and url_input.strip():
url = url_input.strip()
if not url.startswith(("http://", "https://")):
url = "https://" + url
if "/" not in url.split("//", 1)[-1]:
url += "/"
with st.spinner("Fetching page..."):
try:
st.session_state.fetched_text = fetch_url_markdown(url)
st.rerun()
except Exception:
st.warning("This page couldn't be scraped. The site may be blocking automated access or the page has insufficient content. Try pasting the content manually in the text area.")
if run and text_input.strip():
prompt = text_input.strip() + SEPARATOR
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
prefix_len = inputs["input_ids"].shape[-1]
tables_container = st.container()
st.subheader("Word Frequency Heatmap")
results = []
seen = set()
progress = st.progress(0)
output_container = st.empty()
for i, config in enumerate(CONTRASTIVE_CONFIGS):
progress.progress((i + 1) / len(CONTRASTIVE_CONFIGS))
try:
outputs = model.generate(**inputs, max_new_tokens=256, trust_remote_code=True, num_return_sequences=1, **config)
gen_ids = outputs[0][prefix_len:]
eos_id = tokenizer.eos_token_id
if eos_id is not None and (gen_ids == eos_id).any():
gen_ids = gen_ids[:(gen_ids == eos_id).nonzero(as_tuple=True)[0][0]]
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
if text and text not in seen:
seen.add(text)
trimmed = torch.cat([outputs[0][:prefix_len], gen_ids])
ppl, tokens, probs = analyze_output(model, tokenizer, trimmed, prefix_len)
results.append((text, ppl, tokens, probs, config))
top = sorted(results, key=lambda x: x[1])[:10]
html = "".join(render_result(t, p) for _, _, t, p, c in top)
output_container.markdown(html, unsafe_allow_html=True)
except Exception:
pass
progress.empty()
results.sort(key=lambda x: x[1])
if results:
top10 = results[:10]
col_prompts, col_phrases = tables_container.columns(2)
with col_prompts:
st.subheader("Reconstructed Prompts")
st.caption("Top 10 most likely prompts ranked by lowest perplexity. Token opacity reflects the model's confidence in each word.")
df_top = pd.DataFrame(
[(text, round(ppl, 2), [round(p * 100) for p in probs]) for text, ppl, _, probs, _ in top10],
columns=["Prompt", "Perplexity", "Confidence"],
)
st.dataframe(df_top, width="content", hide_index=True, column_config={
"Prompt": st.column_config.TextColumn(width=None),
"Perplexity": st.column_config.NumberColumn(width=None),
"Confidence": st.column_config.BarChartColumn(y_min=0, y_max=100),
})
with col_phrases:
st.subheader("Key Phrases")
st.caption("The most most important phrases scored by balance between length and frequency.")
texts = [re.sub(r'[^\w\s]', '', text.lower()).split() for text, _, _, _, _ in top10]
phrase_hits = {}
for idx, words in enumerate(texts):
for length in range(2, len(words) + 1):
for start in range(len(words) - length + 1):
phrase = " ".join(words[start:start + length])
if phrase not in phrase_hits:
phrase_hits[phrase] = set()
phrase_hits[phrase].add(idx)
shared = [(p, len(ids), len(p.split()), len(ids) * len(p.split())) for p, ids in phrase_hits.items() if len(ids) >= 2]
shared.sort(key=lambda x: x[3], reverse=True)
filtered = []
for phrase, count, length, score in shared:
if not any(phrase in longer and score <= lscore for longer, _, _, lscore in filtered):
filtered.append((phrase, count, length, score))
if filtered:
max_score = filtered[0][3]
rows = [(p, round(s / max_score * 100)) for p, _, _, s in filtered[:20]]
df = pd.DataFrame(rows, columns=["Phrase", "Score"])
st.dataframe(df, width="content", hide_index=True, column_config={
"Score": st.column_config.ProgressColumn(format="%d%%", min_value=0, max_value=100),
})