Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # app.py | |
| # reCAPTCHA‑style 3×3 Demo (Streamlit) — Proof of Concept | |
| # -------------------------------------------------------- | |
| # - Build challenges from a TSV (columns: image [base64], answer) | |
| # - Same compact, natural‑size 3×3 layout for EVERY challenge | |
| # - Manual mode: clickable tiles with baked‑in border + ✓ (works inside iframe) | |
| # - Model modes: same layout (static), then run adapters | |
| from __future__ import annotations | |
| import io | |
| import re | |
| import base64 | |
| import random | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Callable, Optional, Tuple, Union | |
| import streamlit as st | |
| from PIL import Image, ImageDraw | |
| import pandas as pd | |
| from io import BytesIO | |
| import base64 | |
| from config import * | |
| from utils import * | |
| from adapter import * | |
| # ----------------------------- | |
| # Constants & Utilities | |
| # ----------------------------- | |
| IM_HEIGHT,IM_WIDTH = 256,256 | |
| class ManualAdapter(BaseAdapter): | |
| name = "Manual" | |
| def __init__(self, manual_selection: List[int]): | |
| self.manual_selection = manual_selection | |
| def solve(self, images, category, prompt_type, available_categories): | |
| return InferenceResult(selected_ids=sorted(self.manual_selection), raw_outputs={}) | |
| class LLMadapter(BaseAdapter): | |
| def __init__(self, provider, model_name, system:Optional[str]=None ): | |
| assert provider in BaseAdapter.providers | |
| #model_list = BaseAdapter.list_models(provider) | |
| #assert model_name in model_list, f'{model_name} not found for provider: {provider}\nAvailable models:\n{model_list}' | |
| self.adapter = LLMadapter.get_provider_class(provider)(model_name) | |
| self.system = system | |
| def generate(self, prompt, image): | |
| out = self.adapter.generate(prompt=prompt, image=image, system=self.system) | |
| return out | |
| def get_provider_class(provider): | |
| p = provider.lower().strip() | |
| if p == BaseAdapter.OPENAI: | |
| return OpenaiAdapter | |
| if p == BaseAdapter.ANTHROPIC: | |
| return AnthropicAdapter | |
| if p == BaseAdapter.GEMINI: | |
| return GeminiAdapter | |
| if p == BaseAdapter.GROK: | |
| return GrokAdapter | |
| if p == BaseAdapter.MISTRAL: | |
| return MistralAdapter | |
| if p == BaseAdapter.COHERE: | |
| return CohereAdapter | |
| if p == BaseAdapter.TOGETHER: | |
| return TogetherAdapter | |
| raise BaseAdapterError(f"Unsupported provider: {p}") | |
| # ----------------------------- | |
| # Data loading & challenge sampling | |
| # ----------------------------- | |
| def make_challenge(df: pd.DataFrame, target: str | None, pos_fraction: float = 0.45): | |
| cats = sorted(df["answer_norm"].unique()) | |
| if not cats: raise ValueError("No categories found in TSV 'answer' column") | |
| if target is None or target == "__RANDOM__": | |
| target = random.choice(cats) | |
| pos = df[df["answer_norm"] == target] | |
| neg = df[df["answer_norm"] != target] | |
| if len(pos) == 0: | |
| sampled = df.sample(min(9, len(df))) | |
| else: | |
| n_pos = max(1, min(len(pos), int(round(9 * pos_fraction)))) | |
| n_neg = max(0, 9 - n_pos) | |
| pos_s = pos.sample(min(n_pos, len(pos))) | |
| neg_s = neg.sample(min(n_neg, len(neg))) if n_neg > 0 and len(neg) > 0 else df.iloc[0:0] | |
| sampled = pd.concat([pos_s, neg_s]).sample(frac=1.0) | |
| if len(sampled) < 9 and len(df) > len(sampled): | |
| extra = df.drop(sampled.index).sample(min(9 - len(sampled), len(df) - len(sampled))) | |
| sampled = pd.concat([sampled, extra]).sample(frac=1.0) | |
| sampled = sampled.head(9).copy() | |
| ids = sampled["index"].astype(str).tolist() | |
| answers = sampled["answer_norm"].tolist() | |
| images = [decode_base64_image(b) for b in sampled["image"].tolist()] | |
| return images, answers, target, ids | |
| # ----------------------------- | |
| # UI helpers — consistent 3×3 layout | |
| # ----------------------------- | |
| from PIL import ImageDraw | |
| def bake_selection(img, selected: bool, color=(37, 99, 235), thickness: int = 8): | |
| if not selected: | |
| return img | |
| im = img.copy() | |
| d = ImageDraw.Draw(im) | |
| w, h = im.size | |
| t = max(2, min(thickness, max(w, h)//32)) # adaptive thickness helps small tiles | |
| for k in range(t): | |
| d.rectangle([k, k, w-1-k, h-1-k], outline=color, width=1) | |
| # Optional: ✓ badge | |
| r = max(12, min(22, w//12)) | |
| x, y = w - r - 8, 8 | |
| d.ellipse([x, y, x+r, y+r], fill=color) | |
| d.line([x + r*0.25, y + r*0.55, x + r*0.45, y + r*0.75], fill=(255,255,255), width=max(2, r//6)) | |
| d.line([x + r*0.45, y + r*0.75, x + r*0.80, y + r*0.30], fill=(255,255,255), width=max(2, r//6)) | |
| return im | |
| def render_grid_clickable(images, selected_ids: set): | |
| from st_clickable_images import clickable_images | |
| data_uris = [] | |
| for i, im in enumerate(images, start=1): | |
| im = im.resize((IM_HEIGHT,IM_WIDTH)) | |
| vis = bake_selection(im, (i in selected_ids)) # <-- border baked here | |
| buf = io.BytesIO(); vis.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| data_uris.append("data:image/png;base64," + b64) | |
| clicked = clickable_images( | |
| data_uris, | |
| titles=[str(i) for i in range(1, len(data_uris)+1)], | |
| div_style={ | |
| "display": "grid", | |
| "gridTemplateColumns": "repeat(3, max-content)", | |
| "gap": "6px", | |
| "justifyContent": "start", | |
| "width": "fit-content", | |
| }, | |
| img_style={ | |
| "width": "auto", | |
| "height": "auto", | |
| "maxWidth": "100%", | |
| "borderRadius": "8px", | |
| "boxSizing": "border-box", | |
| "cursor": "pointer", | |
| }, | |
| key=f"tile_clicks_{st.session_state.click_nonce}", # <-- important | |
| ) | |
| return clicked if isinstance(clicked, int) and clicked >= 0 else None | |
| def render_grid_static(images: List[Image.Image], selected_ids: set): | |
| # build rows, 3 tiles per row | |
| for row in chunk(list(enumerate(images, start=1)), 3): | |
| cols = st.columns(3, gap="small") # <-- move inside the loop | |
| for c, (idx, im) in enumerate(row): | |
| with cols[c]: | |
| vis = bake_selection(im, (idx in selected_ids)) | |
| # Option A: let Streamlit size it | |
| #st.image(vis, caption=str(idx)) | |
| # Option B (uniform tiles): uncomment to normalize size | |
| st.image(vis.resize((IM_WIDTH, IM_HEIGHT)), caption=str(idx)) | |
| def render_grid_static(images, selected_ids: set): | |
| thumbs = [] | |
| for i, im in enumerate(images, 1): | |
| im = im.resize((IM_WIDTH, IM_HEIGHT)) # (width, height) | |
| vis = bake_selection(im, i in selected_ids) | |
| buf = io.BytesIO(); vis.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| thumbs.append(f'<figure><img src="data:image/png;base64,{b64}"><figcaption>{i}</figcaption></figure>') | |
| html = f""" | |
| <div style=" | |
| display:grid; | |
| grid-template-columns: repeat(3, max-content); | |
| gap:6px; justify-content:start; width:fit-content;"> | |
| {''.join(thumbs)} | |
| </div> | |
| <style> | |
| figure {{ margin:0; }} | |
| figcaption {{ text-align:center; font-size:0.8rem; margin-top:0.2rem; }} | |
| img {{ border-radius:8px; box-sizing:border-box; }} | |
| </style> | |
| """ | |
| st.markdown(html, unsafe_allow_html=True) | |
| # ----------------------------- | |
| # Streamlit App | |
| # ----------------------------- | |
| st.set_page_config(page_title="reCAPTCHA‑style 3×3 — PoC", layout="wide") | |
| # Compact layout & natural-size images (Streamlit native widgets) | |
| st.markdown( | |
| """ | |
| <style> | |
| [data-testid="stHorizontalBlock"] { gap: 0.4rem !important; } | |
| div[data-testid="stImage"] img { width: auto !important; max-width: none !important; height: auto; } | |
| div[data-testid="stImage"] figure { width: fit-content !important; margin: 0.1rem auto; } | |
| div[data-testid="stImage"] figcaption { margin-top: 0.2rem !important; } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.title("reCAPTCHA‑style 3×3 Demo — Proof of Concept") | |
| st.caption("Generate a challenge from TSV, then solve manually or with a model adapter.") | |
| st.caption("Click run solver below to see the result for either choice 'Original' or 'Modified'.") | |
| st.caption("IT MAY TAKE ABOUT 10-20 SECONDS TO SOLVE THE CHALLENGE THROUGH API CALLS, VARIES BASED ON LLM CHOICE.") | |
| # Session state | |
| for key, default in { | |
| # existing keys... | |
| "dataset": None, | |
| "dataset_modified": None, # NEW | |
| "categories": [], | |
| "challenge_images_original": [], # NEW | |
| "challenge_images_modified": [], # NEW | |
| "challenge_answers": [], | |
| "challenge_target": None, | |
| "challenge_ids": [], # NEW | |
| "tile_selected": set(), | |
| "click_nonce": 0, | |
| "last_clicked_processed": -1, | |
| "auto_selected_ids": set(), | |
| "image_view": "Original", # current radio selection | |
| "last_image_view": "Original", # previous radio selection | |
| }.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = default | |
| # 2) Use a placeholder for the grid | |
| grid_ph = st.empty() | |
| # Sidebar | |
| # ---- sensible defaults in session ---- | |
| #if "provider" not in st.session_state: | |
| # st.session_state.provider = "Manual" # start in Manual mode | |
| #if "model" not in st.session_state: | |
| # st.session_state.model = None | |
| df_base = load_private_tsv("imageaction__recaptcha_dataset.tsv") | |
| df_mod = load_private_tsv("imageaction__captcha@SPEC-1de6b70ae2f0.tsv") | |
| st.session_state.dataset = df_base | |
| st.session_state.dataset_modified = df_mod | |
| st.session_state.categories = sorted(df_base["answer_norm"].unique()) | |
| # Session defaults | |
| if "provider" not in st.session_state: | |
| st.session_state.provider = BaseAdapter.OPENAI # default provider = OpenAI | |
| if "model" not in st.session_state: | |
| st.session_state.model = "gpt-5-2025-08-07" # default OpenAI model | |
| if "target_category" not in st.session_state: | |
| st.session_state.target_category = "bus" | |
| # Sidebar | |
| with st.sidebar: | |
| st.subheader("Challenge Settings") | |
| target_mode = st.selectbox("Target category mode", ["Pick specific", "Random each time"], index=0) | |
| if target_mode == "Pick specific": | |
| cats = st.session_state.categories if st.session_state.categories else ["(load TSV first)"] | |
| DEFAULT_CAT = "bus" # normalized label | |
| if cats and DEFAULT_CAT in cats: | |
| default_idx = cats.index(DEFAULT_CAT) | |
| else: | |
| default_idx = 0 # fallback | |
| target_category = st.selectbox( | |
| "Target category", | |
| cats, | |
| index=default_idx, | |
| ) | |
| chosen_target = target_category if st.session_state.categories else None | |
| else: | |
| chosen_target = "__RANDOM__" | |
| prompt_type_label = st.selectbox("Prompt type", list(PROMPT_TYPES.keys()), index=1) | |
| prompt_type = PROMPT_TYPES[prompt_type_label] | |
| st.markdown("---") | |
| st.subheader("Solver") | |
| # 1) Provider first (Manual + all providers) | |
| provider_options = ["Manual"] + list(MODEL_PROVIDERS.keys()) | |
| # ensure current provider is valid; otherwise default to OpenAI | |
| if st.session_state.provider not in provider_options: | |
| st.session_state.provider = BaseAdapter.OPENAI | |
| provider_idx = provider_options.index(st.session_state.provider) | |
| st.session_state.provider = st.selectbox( | |
| "Provider", | |
| provider_options, | |
| index=provider_idx, | |
| ) | |
| # 2) Model: only when provider != Manual | |
| if st.session_state.provider == "Manual": | |
| st.session_state.model = None | |
| st.selectbox("Model", ["(not required in Manual mode)"], index=0, disabled=True) | |
| st.caption("Manual mode: click tiles to select. No model needed.") | |
| else: | |
| models_for_provider = MODEL_PROVIDERS.get(st.session_state.provider, []) | |
| # if provider is OpenAI and our default gpt-5 is in the list, prefer that | |
| if st.session_state.provider == BaseAdapter.OPENAI and "gpt-5-2025-08-07" in models_for_provider: | |
| if st.session_state.model not in models_for_provider: | |
| st.session_state.model = "gpt-5-2025-08-07" | |
| else: | |
| # generic fallback for other providers | |
| if st.session_state.model not in models_for_provider and models_for_provider: | |
| st.session_state.model = models_for_provider[0] | |
| if not models_for_provider: | |
| st.session_state.model = None | |
| st.selectbox("Model", ["(no models for this provider)"], index=0, disabled=True) | |
| else: | |
| model_idx = models_for_provider.index(st.session_state.model) | |
| st.session_state.model = st.selectbox( | |
| "Model", | |
| models_for_provider, | |
| index=model_idx, | |
| ) | |
| # Generate new challenge | |
| colA, colB = st.columns([1,2]) | |
| with colA: | |
| gen = st.button("🎲 Generate new challenge", use_container_width=True, disabled=(st.session_state.dataset is None)) | |
| if gen: | |
| with st.spinner("Sampling images…"): | |
| images_orig, answers, tgt, ids = make_challenge(st.session_state.dataset, chosen_target) | |
| st.session_state.challenge_images_original = images_orig | |
| st.session_state.challenge_answers = answers | |
| st.session_state.challenge_target = tgt | |
| st.session_state.challenge_ids = ids | |
| st.session_state.tile_selected = set() | |
| st.session_state.last_clicked_processed = -1 | |
| st.session_state.click_nonce = 0 | |
| st.session_state.auto_selected_ids = set() | |
| # Build modified images in the SAME ORDER by id (if modified dataset present) | |
| st.session_state.challenge_images_modified = [] | |
| if st.session_state.dataset_modified is not None: | |
| mod_map = st.session_state.dataset_modified.set_index("index")["image"].to_dict() | |
| miss = [] | |
| for _id in ids: | |
| b64 = mod_map.get(str(_id)) | |
| if b64 is None: | |
| miss.append(_id) | |
| # fallback to original tile if missing | |
| st.session_state.challenge_images_modified.append( | |
| st.session_state.challenge_images_original[len(st.session_state.challenge_images_modified)] | |
| ) | |
| else: | |
| st.session_state.challenge_images_modified.append(decode_base64_image(b64)) | |
| if miss: | |
| st.warning(f"Modified TSV is missing {len(miss)} ids used in this challenge; those tiles fall back to original.") | |
| else: | |
| st.session_state.challenge_images_modified = [] # not available | |
| st.success("New challenge ready. Target: " + str(st.session_state.challenge_target)) | |
| # Main area | |
| if st.session_state.challenge_images_original: | |
| st.subheader("3×3 Grid — Target: **" + str(st.session_state.challenge_target) + "** (Indices 1..9)") | |
| # Toggle between Original and Modified | |
| options = ["Original"] | |
| if st.session_state.challenge_images_modified: | |
| options.append("Modified") | |
| st.session_state.image_view = st.radio( | |
| "Image set", options, horizontal=True, index=0 if st.session_state.image_view not in options else options.index(st.session_state.image_view) | |
| ) | |
| # If user switches Original ↔ Modified, treat as "new puzzle view" | |
| prev_view = st.session_state.get("last_image_view", "Original") | |
| if st.session_state.image_view != prev_view: | |
| st.session_state.last_image_view = st.session_state.image_view | |
| st.session_state.tile_selected = set() | |
| st.session_state.auto_selected_ids = set() | |
| st.session_state.click_nonce = 0 | |
| images_to_show = (st.session_state.challenge_images_modified | |
| if st.session_state.image_view == "Modified" and st.session_state.challenge_images_modified | |
| else st.session_state.challenge_images_original) | |
| if st.session_state.provider == "Manual": | |
| try: | |
| clicked = render_grid_clickable(images_to_show, st.session_state.tile_selected) | |
| if clicked is not None: | |
| tile_id = clicked + 1 | |
| if tile_id in st.session_state.tile_selected: | |
| st.session_state.tile_selected.remove(tile_id) | |
| else: | |
| st.session_state.tile_selected.add(tile_id) | |
| st.session_state.click_nonce += 1 | |
| st.rerun() | |
| except Exception: | |
| st.info("Install optional dependency: pip install st-clickable-images") | |
| render_grid_static(images_to_show, st.session_state.tile_selected) | |
| else: | |
| render_grid_static(images_to_show, st.session_state.auto_selected_ids) | |
| st.markdown("---") | |
| # Build adapter | |
| if st.session_state.provider == "Manual": | |
| adapter = ManualAdapter(manual_selection=sorted(st.session_state.tile_selected)) #ADAPTERS[model_choice](manual_selection=sorted(st.session_state.tile_selected)) | |
| else: | |
| #adapter = MODEL_ADAPTERS[st.session_state.provider](st.session_state.model) | |
| adapter = LLMadapter(st.session_state.provider, st.session_state.model) | |
| # Prompts Preview | |
| st.subheader("Prompts Preview") | |
| cats_for_prompt = st.session_state.categories if st.session_state.categories else [] | |
| if prompt_type == 1: | |
| st.code(build_prompt_1(st.session_state.challenge_target)) | |
| elif prompt_type == 2: | |
| st.code(build_prompt_2(cats_for_prompt)) | |
| else: | |
| raise Exception() | |
| if st.button("Run Solver", use_container_width=True): | |
| images_for_inference = (st.session_state.challenge_images_modified | |
| if st.session_state.image_view == "Modified" and st.session_state.challenge_images_modified | |
| else st.session_state.challenge_images_original) | |
| with st.spinner("Running solver…"): | |
| if prompt_type == 1: | |
| prompt = build_prompt_1(st.session_state.challenge_target) | |
| output_parse_fn = parse_prompt_1 | |
| elif prompt_type == 2: | |
| prompt = build_prompt_2(cats_for_prompt) | |
| output_parse_fn = parse_prompt_2 | |
| else: | |
| raise Exception() | |
| preds, raw_preds = [], [] | |
| if st.session_state.provider == 'Manual': | |
| selected_ids = [i for i in st.session_state.tile_selected] | |
| raw_preds = [ ans if (i+1) in selected_ids else 'Other' for i,ans in enumerate(st.session_state.challenge_answers) ] | |
| preds = [ st.session_state.challenge_target == pred for pred in raw_preds ] | |
| else: | |
| challenge_images_b64 = [encode_base64_image(img) for img in images_for_inference] | |
| for image_b64 in challenge_images_b64: | |
| result = adapter.generate(prompt=prompt, image=image_b64) | |
| outcome = output_parse_fn(result, st.session_state.challenge_target) | |
| raw_preds.append(result) | |
| preds.append(outcome) | |
| selected_ids = [i+1 for i, outcome in enumerate(preds) if outcome] | |
| st.session_state.auto_selected_ids = set(selected_ids) if st.session_state.provider != "Manual" else set() | |
| st.success("Done.") | |
| st.subheader("Selected IDs") | |
| st.write(selected_ids) | |
| if st.session_state.provider != "Manual": | |
| st.subheader("Prediction overlay") | |
| render_grid_static(images_for_inference, st.session_state.auto_selected_ids) | |
| # evaluation uses the *original ground truth labels* (ids don’t change) | |
| challenge_gt = [ans == st.session_state.challenge_target for ans in st.session_state.challenge_answers] | |
| challenge_pairs = list(zip(challenge_gt, preds)) | |
| tp = sum(pred == gt for gt, pred in challenge_pairs if gt) | |
| true_count = sum(gt for gt, _ in challenge_pairs) | |
| fn = sum(gt != pred for gt, pred in challenge_pairs if gt) | |
| fp = sum(pred != gt for gt, pred in challenge_pairs if not gt) | |
| tn = sum(pred == gt for gt, pred in challenge_pairs if not gt) | |
| st.subheader(f"Recall: {tp/(tp+fn) if (tp+fn) else 0.0} # Found {tp}/{true_count}") | |
| if raw_preds: | |
| st.subheader("Raw Model Outputs") | |
| for idx, (gt, pred) in enumerate(zip(st.session_state.challenge_answers, raw_preds)): | |
| st.markdown(f"**Category: {gt} — Expected: {gt == st.session_state.challenge_target}**") | |
| st.code(f"Prediction: {pred}", language="text") | |
| with st.expander("Debug: ground‑truth categories per tile", expanded=False): | |
| grid_truth = [str(i) + ": " + lbl for i, lbl in enumerate(st.session_state.challenge_answers, start=1)] | |
| st.write(", ".join(grid_truth)) | |
| else: | |
| st.info("Upload a TSV on the left and click 'Generate new challenge' to begin.") | |