File size: 19,491 Bytes
d0b2e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e6b325
 
d0b2e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
# app.py
# reCAPTCHA‑style 3×3 Demo (Streamlit) — Proof of Concept
# --------------------------------------------------------
# - Build challenges from a TSV (columns: image [base64], answer)
# - Same compact, natural‑size 3×3 layout for EVERY challenge
# - Manual mode: clickable tiles with baked‑in border + ✓ (works inside iframe)
# - Model modes: same layout (static), then run adapters

from __future__ import annotations
import io
import re
import base64
import random
from dataclasses import dataclass
from typing import List, Dict, Callable, Optional, Tuple, Union

import streamlit as st
from PIL import Image, ImageDraw
import pandas as pd
from io import BytesIO

import base64    

from config import *
from utils import *
from adapter import *
# -----------------------------
# Constants & Utilities
# -----------------------------

IM_HEIGHT,IM_WIDTH = 256,256




class ManualAdapter(BaseAdapter):
    name = "Manual"
    def __init__(self, manual_selection: List[int]):
        self.manual_selection = manual_selection
    def solve(self, images, category, prompt_type, available_categories):
        return InferenceResult(selected_ids=sorted(self.manual_selection), raw_outputs={})




class LLMadapter(BaseAdapter):
    
    def __init__(self, provider, model_name, system:Optional[str]=None ):
        assert provider in BaseAdapter.providers
        #model_list = BaseAdapter.list_models(provider)
        #assert model_name in model_list, f'{model_name} not found for provider: {provider}\nAvailable models:\n{model_list}' 
        self.adapter = LLMadapter.get_provider_class(provider)(model_name)
        self.system = system
    def generate(self,  prompt, image):
        out = self.adapter.generate(prompt=prompt, image=image, system=self.system)
        return out


    def get_provider_class(provider):
        p = provider.lower().strip()
        if p == BaseAdapter.OPENAI:
            return OpenaiAdapter
        if p == BaseAdapter.ANTHROPIC:
            return AnthropicAdapter
        if p == BaseAdapter.GEMINI:
            return GeminiAdapter
        if p == BaseAdapter.GROK:
            return GrokAdapter
        if p == BaseAdapter.MISTRAL:
            return MistralAdapter
        if p == BaseAdapter.COHERE:
            return CohereAdapter
        if p == BaseAdapter.TOGETHER:
            return TogetherAdapter
        raise BaseAdapterError(f"Unsupported provider: {p}")        




# -----------------------------
# Data loading & challenge sampling
# -----------------------------




def make_challenge(df: pd.DataFrame, target: str | None, pos_fraction: float = 0.45):
    cats = sorted(df["answer_norm"].unique())
    if not cats: raise ValueError("No categories found in TSV 'answer' column")
    if target is None or target == "__RANDOM__":
        target = random.choice(cats)

    pos = df[df["answer_norm"] == target]
    neg = df[df["answer_norm"] != target]
    if len(pos) == 0:
        sampled = df.sample(min(9, len(df)))
    else:
        n_pos = max(1, min(len(pos), int(round(9 * pos_fraction))))
        n_neg = max(0, 9 - n_pos)
        pos_s = pos.sample(min(n_pos, len(pos)))
        neg_s = neg.sample(min(n_neg, len(neg))) if n_neg > 0 and len(neg) > 0 else df.iloc[0:0]
        sampled = pd.concat([pos_s, neg_s]).sample(frac=1.0)
        if len(sampled) < 9 and len(df) > len(sampled):
            extra = df.drop(sampled.index).sample(min(9 - len(sampled), len(df) - len(sampled)))
            sampled = pd.concat([sampled, extra]).sample(frac=1.0)

    sampled = sampled.head(9).copy()
    ids = sampled["index"].astype(str).tolist()
    answers = sampled["answer_norm"].tolist()
    images = [decode_base64_image(b) for b in sampled["image"].tolist()]
    return images, answers, target, ids



# -----------------------------
# UI helpers — consistent 3×3 layout
# -----------------------------
from PIL import ImageDraw

def bake_selection(img, selected: bool, color=(37, 99, 235), thickness: int = 8):
    if not selected:
        return img
    im = img.copy()
    d = ImageDraw.Draw(im)
    w, h = im.size
    t = max(2, min(thickness, max(w, h)//32))  # adaptive thickness helps small tiles
    for k in range(t):
        d.rectangle([k, k, w-1-k, h-1-k], outline=color, width=1)
    # Optional: ✓ badge
    r = max(12, min(22, w//12))
    x, y = w - r - 8, 8
    d.ellipse([x, y, x+r, y+r], fill=color)
    d.line([x + r*0.25, y + r*0.55, x + r*0.45, y + r*0.75], fill=(255,255,255), width=max(2, r//6))
    d.line([x + r*0.45, y + r*0.75, x + r*0.80, y + r*0.30], fill=(255,255,255), width=max(2, r//6))
    return im

def render_grid_clickable(images, selected_ids: set):
    from st_clickable_images import clickable_images
    data_uris = []
    for i, im in enumerate(images, start=1):
        im = im.resize((IM_HEIGHT,IM_WIDTH))
        vis = bake_selection(im, (i in selected_ids))      # <-- border baked here
        buf = io.BytesIO(); vis.save(buf, format="PNG")
        b64 = base64.b64encode(buf.getvalue()).decode()
        data_uris.append("data:image/png;base64," + b64)

    clicked = clickable_images(
        data_uris,
        titles=[str(i) for i in range(1, len(data_uris)+1)],
        div_style={
            "display": "grid",
            "gridTemplateColumns": "repeat(3, max-content)",
            "gap": "6px",
            "justifyContent": "start",
            "width": "fit-content",
        },
        img_style={
            "width": "auto",
            "height": "auto",
            "maxWidth": "100%",
            "borderRadius": "8px",
            "boxSizing": "border-box",
            "cursor": "pointer",
        },
        key=f"tile_clicks_{st.session_state.click_nonce}",  # <-- important
    )
    return clicked if isinstance(clicked, int) and clicked >= 0 else None

def render_grid_static(images: List[Image.Image], selected_ids: set):
    # build rows, 3 tiles per row
    for row in chunk(list(enumerate(images, start=1)), 3):
        cols = st.columns(3, gap="small")   # <-- move inside the loop
        for c, (idx, im) in enumerate(row):
            with cols[c]:
                vis = bake_selection(im, (idx in selected_ids))
                # Option A: let Streamlit size it
                #st.image(vis, caption=str(idx))
                # Option B (uniform tiles): uncomment to normalize size
                st.image(vis.resize((IM_WIDTH, IM_HEIGHT)), caption=str(idx))

def render_grid_static(images, selected_ids: set):
    thumbs = []
    for i, im in enumerate(images, 1):
        im = im.resize((IM_WIDTH, IM_HEIGHT))          # (width, height)
        vis = bake_selection(im, i in selected_ids)
        buf = io.BytesIO(); vis.save(buf, format="PNG")
        b64 = base64.b64encode(buf.getvalue()).decode()
        thumbs.append(f'<figure><img src="data:image/png;base64,{b64}"><figcaption>{i}</figcaption></figure>')

    html = f"""
    <div style="
      display:grid;
      grid-template-columns: repeat(3, max-content);
      gap:6px; justify-content:start; width:fit-content;">
      {''.join(thumbs)}
    </div>
    <style>
      figure {{ margin:0; }}
      figcaption {{ text-align:center; font-size:0.8rem; margin-top:0.2rem; }}
      img {{ border-radius:8px; box-sizing:border-box; }}
    </style>
    """
    st.markdown(html, unsafe_allow_html=True)

# -----------------------------
# Streamlit App
# -----------------------------

st.set_page_config(page_title="reCAPTCHA‑style 3×3 — PoC", layout="wide")

# Compact layout & natural-size images (Streamlit native widgets)
st.markdown(
    """
    <style>
    [data-testid="stHorizontalBlock"] { gap: 0.4rem !important; }
    div[data-testid="stImage"] img { width: auto !important; max-width: none !important; height: auto; }
    div[data-testid="stImage"] figure { width: fit-content !important; margin: 0.1rem auto; }
    div[data-testid="stImage"] figcaption { margin-top: 0.2rem !important; }
    </style>
    """,
    unsafe_allow_html=True,
)

st.title("reCAPTCHA‑style 3×3 Demo — Proof of Concept")
st.caption("Generate a challenge from TSV, then solve manually or with a model adapter.")

# Session state
for key, default in {
    # existing keys...
    "dataset": None,
    "dataset_modified": None,          # NEW
    "categories": [],
    "challenge_images_original": [],   # NEW
    "challenge_images_modified": [],   # NEW
    "challenge_answers": [],
    "challenge_target": None,
    "challenge_ids": [],               # NEW
    "tile_selected": set(),
    "click_nonce": 0,
    "last_clicked_processed": -1,
    "auto_selected_ids": set(),
    "image_view": "Original",          # NEW: "Original" | "Modified"
}.items():
    if key not in st.session_state:
        st.session_state[key] = default


# 2) Use a placeholder for the grid
grid_ph = st.empty()
# Sidebar

# ---- sensible defaults in session ----
if "provider" not in st.session_state:
    st.session_state.provider = "Manual"   # start in Manual mode
if "model" not in st.session_state:
    st.session_state.model = None


df_base = load_private_tsv("imageaction__recaptcha_dataset.tsv")
df_mod = load_private_tsv("imageaction__captcha@SPEC-1de6b70ae2f0.tsv")
st.session_state.dataset = df_base
st.session_state.dataset_modified = df_mod
st.session_state.categories = sorted(df_base["answer_norm"].unique())
# Sidebar
with st.sidebar:
    st.subheader("Challenge Settings")

    target_mode = st.selectbox("Target category mode", ["Pick specific", "Random each time"], index=0)
    if target_mode == "Pick specific":
        target_category = st.selectbox(
            "Target category",
            st.session_state.categories if st.session_state.categories else ["(load TSV first)"]
        )
        chosen_target = target_category if st.session_state.categories else None
    else:
        chosen_target = "__RANDOM__"

    prompt_type_label = st.selectbox("Prompt type", list(PROMPT_TYPES.keys()), index=1)
    prompt_type = PROMPT_TYPES[prompt_type_label]

    st.markdown("---")
    st.subheader("Solver")

    # 1) Provider first (include Manual + all providers from your dict)
    provider_options = ["Manual"] + list(MODEL_PROVIDERS.keys())
    try:
        provider_idx = provider_options.index(st.session_state.provider)
    except ValueError:
        provider_idx = 0  # fallback to Manual if prior value is missing

    st.session_state.provider = st.selectbox("Provider", provider_options, index=provider_idx)

    # 2) Model (enabled only when provider != Manual)
    if st.session_state.provider == "Manual":
        st.session_state.model = None
        st.selectbox("Model", ["(not required in Manual mode)"], index=0, disabled=True)
        st.caption("Manual mode: click tiles to select. No model needed.")
    else:
        models_for_provider = MODEL_PROVIDERS.get(st.session_state.provider, [])
        # Keep previously selected model if still valid; otherwise default to first/empty
        if not models_for_provider:
            st.session_state.model = None
            st.selectbox("Model", ["(no models available for this provider)"], index=0, disabled=True)
        else:
            if st.session_state.model not in models_for_provider:
                st.session_state.model = models_for_provider[0]
            model_idx = models_for_provider.index(st.session_state.model)
            st.session_state.model = st.selectbox("Model", models_for_provider, index=model_idx)


# Generate new challenge
colA, colB = st.columns([1,2])
with colA:
    gen = st.button("🎲 Generate new challenge", use_container_width=True, disabled=(st.session_state.dataset is None))

if gen:
    with st.spinner("Sampling images…"):
        images_orig, answers, tgt, ids = make_challenge(st.session_state.dataset, chosen_target)
        st.session_state.challenge_images_original = images_orig
        st.session_state.challenge_answers = answers
        st.session_state.challenge_target = tgt
        st.session_state.challenge_ids = ids
        st.session_state.tile_selected = set()
        st.session_state.last_clicked_processed = -1
        st.session_state.click_nonce = 0
        st.session_state.auto_selected_ids = set()

        # Build modified images in the SAME ORDER by id (if modified dataset present)
        st.session_state.challenge_images_modified = []
        if st.session_state.dataset_modified is not None:
            mod_map = st.session_state.dataset_modified.set_index("index")["image"].to_dict()
            miss = []
            for _id in ids:
                b64 = mod_map.get(str(_id))
                if b64 is None:
                    miss.append(_id)
                    # fallback to original tile if missing
                    st.session_state.challenge_images_modified.append(
                        st.session_state.challenge_images_original[len(st.session_state.challenge_images_modified)]
                    )
                else:
                    st.session_state.challenge_images_modified.append(decode_base64_image(b64))
            if miss:
                st.warning(f"Modified TSV is missing {len(miss)} ids used in this challenge; those tiles fall back to original.")
        else:
            st.session_state.challenge_images_modified = []  # not available

    st.success("New challenge ready. Target: " + str(st.session_state.challenge_target))

# Main area
if st.session_state.challenge_images_original:
    st.subheader("3×3 Grid — Target: **" + str(st.session_state.challenge_target) + "** (Indices 1..9)")

    # Toggle between Original and Modified
    options = ["Original"]
    if st.session_state.challenge_images_modified:
        options.append("Modified")
    st.session_state.image_view = st.radio(
        "Image set", options, horizontal=True, index=0 if st.session_state.image_view not in options else options.index(st.session_state.image_view)
    )

    images_to_show = (st.session_state.challenge_images_modified
                      if st.session_state.image_view == "Modified" and st.session_state.challenge_images_modified
                      else st.session_state.challenge_images_original)

    if st.session_state.provider == "Manual":
        try:
            clicked = render_grid_clickable(images_to_show, st.session_state.tile_selected)
            if clicked is not None:
                tile_id = clicked + 1
                if tile_id in st.session_state.tile_selected:
                    st.session_state.tile_selected.remove(tile_id)
                else:
                    st.session_state.tile_selected.add(tile_id)
                st.session_state.click_nonce += 1
                st.rerun()
        except Exception:
            st.info("Install optional dependency: pip install st-clickable-images")
            render_grid_static(images_to_show, st.session_state.tile_selected)
    else:
        render_grid_static(images_to_show, st.session_state.auto_selected_ids)

 

    st.markdown("---")

    # Build adapter
    if st.session_state.provider == "Manual":
        adapter = ManualAdapter(manual_selection=sorted(st.session_state.tile_selected)) #ADAPTERS[model_choice](manual_selection=sorted(st.session_state.tile_selected))
    else:
        #adapter = MODEL_ADAPTERS[st.session_state.provider](st.session_state.model)
        adapter = LLMadapter(st.session_state.provider, st.session_state.model)
    # Prompts Preview
    st.subheader("Prompts Preview")
    cats_for_prompt = st.session_state.categories if st.session_state.categories else []
    if prompt_type == 1:
        st.code(build_prompt_1(st.session_state.challenge_target))
    elif prompt_type == 2:
        st.code(build_prompt_2(cats_for_prompt))
    else:
        raise Exception()


    if st.button("Run Solver", use_container_width=True):
        images_for_inference = (st.session_state.challenge_images_modified
                                if st.session_state.image_view == "Modified" and st.session_state.challenge_images_modified
                                else st.session_state.challenge_images_original)

        with st.spinner("Running solver…"):
            if prompt_type == 1:
                prompt = build_prompt_1(st.session_state.challenge_target)
                output_parse_fn = parse_prompt_1
            elif prompt_type == 2:
                prompt = build_prompt_2(cats_for_prompt)
                output_parse_fn = parse_prompt_2
            else:
                raise Exception()

            preds, raw_preds = [], []
            if st.session_state.provider == 'Manual':
                selected_ids = [i for i in st.session_state.tile_selected]
                raw_preds = [ ans if (i+1) in selected_ids else 'Other' for i,ans in enumerate(st.session_state.challenge_answers)  ] 
                preds =  [ st.session_state.challenge_target == pred for pred in raw_preds  ]        
            else:
                challenge_images_b64 = [encode_base64_image(img) for img in images_for_inference]

                for image_b64 in challenge_images_b64:
                    result = adapter.generate(prompt=prompt, image=image_b64)
                    outcome = output_parse_fn(result, st.session_state.challenge_target)
                    raw_preds.append(result)
                    preds.append(outcome)

                selected_ids = [i+1 for i, outcome in enumerate(preds) if outcome]
        st.session_state.auto_selected_ids = set(selected_ids) if st.session_state.provider != "Manual" else set()
        st.success("Done.")
        st.subheader("Selected IDs")
        st.write(selected_ids)

        if st.session_state.provider != "Manual":
            st.subheader("Prediction overlay")
            render_grid_static(images_for_inference, st.session_state.auto_selected_ids)

        # evaluation uses the *original ground truth labels* (ids don’t change)
        challenge_gt = [ans == st.session_state.challenge_target for ans in st.session_state.challenge_answers]
        challenge_pairs = list(zip(challenge_gt, preds))
        tp = sum(pred == gt for gt, pred in challenge_pairs if gt)
        true_count = sum(gt for gt, _ in challenge_pairs)
        fn = sum(gt != pred for gt, pred in challenge_pairs if gt)
        fp = sum(pred != gt for gt, pred in challenge_pairs if not gt)
        tn = sum(pred == gt for gt, pred in challenge_pairs if not gt)

        st.subheader(f"Recall: {tp/(tp+fn) if (tp+fn) else 0.0}  # Found {tp}/{true_count}")
        if raw_preds:
            st.subheader("Raw Model Outputs")
            for idx, (gt, pred) in enumerate(zip(st.session_state.challenge_answers, raw_preds)):
                st.markdown(f"**Category: {gt} — Expected: {gt == st.session_state.challenge_target}**")
                st.code(f"Prediction: {pred}", language="text")


    with st.expander("Debug: ground‑truth categories per tile", expanded=False):
        grid_truth = [str(i) + ": " + lbl for i, lbl in enumerate(st.session_state.challenge_answers, start=1)]
        st.write(", ".join(grid_truth))
else:
    st.info("Upload a TSV on the left and click 'Generate new challenge' to begin.")


# -----------------------------
# Integrations Guide (trimmed)
# -----------------------------
with st.expander("Integrations Guide: Wiring real models", expanded=False):
    st.markdown(
        """
        Replace the mock call functions with real SDK calls (OpenAI/Anthropic/HF).
        For CLIP zero‑shot, wire a predict_fn that returns (label, score) per image.
        """
    )