chris1nexus commited on
Commit
d0b2e68
·
1 Parent(s): 4cbaa71

First commit

Browse files
Files changed (6) hide show
  1. adapter.py +764 -0
  2. app.py +476 -0
  3. config.py +104 -0
  4. requirements.txt +7 -2
  5. src/streamlit_app.py +0 -40
  6. utils.py +106 -0
adapter.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import os
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Dict, List, Optional, Union
8
+
9
+ import requests
10
+
11
+
12
+ import io
13
+ import re
14
+ import random
15
+ from dataclasses import dataclass
16
+ from typing import List, Dict, Callable, Optional, Tuple, Union
17
+
18
+ import streamlit as st
19
+ from PIL import Image, ImageDraw
20
+ import pandas as pd
21
+ from io import BytesIO
22
+
23
+ class BaseAdapterError(RuntimeError):
24
+ pass
25
+
26
+ @dataclass
27
+ class BaseAdapter:
28
+ provider: str
29
+ model: str
30
+ api_key: Optional[str] = None
31
+ timeout: float = 60.0
32
+ base_url: Optional[str] = None
33
+ extra_headers: Dict[str, str] = field(default_factory=dict)
34
+
35
+ OPENAI = 'openai'
36
+ ANTHROPIC = 'anthropic'
37
+ GEMINI = 'gemini'
38
+ MISTRAL = 'mistral'
39
+ GROK = 'grok'
40
+ COHERE = 'cohere'
41
+ TOGETHER = 'together'
42
+ providers = [OPENAI,
43
+ ANTHROPIC,
44
+ GEMINI,
45
+ MISTRAL,
46
+ GROK,
47
+ COHERE,
48
+ TOGETHER ]
49
+ def __post_init__(self) -> None:
50
+ self.provider = self.provider.lower().strip()
51
+ if self.api_key is None:
52
+ env_keys = {
53
+ "openai": "OPENAI_API_KEY",
54
+ "anthropic": "ANTHROPIC_API_KEY",
55
+ "gemini": "GEMINI_API_KEY",
56
+ "mistral": "MISTRAL_API_KEY",
57
+ "grok" : "XAI_API_KEY",
58
+ "cohere": "COHERE_API_KEY",
59
+ "together": "TOGETHER_API_KEY",
60
+ }
61
+ env_var = env_keys.get(self.provider)
62
+ if env_var:
63
+ self.api_key = os.getenv(env_var)
64
+ if not self.api_key and self.provider not in ("gemini",):
65
+ raise BaseAdapterError(f"Missing api_key for {self.provider}. Set via environment.")
66
+
67
+ @staticmethod
68
+ def list_models(provider: str, api_key: Optional[str] = None, base_url: Optional[str] = None, timeout: float = 60.0) -> List[str]:
69
+ p = provider.lower().strip()
70
+ if p == "openai":
71
+ url = (base_url or "https://api.openai.com") + "/v1/models"
72
+ headers = {"Authorization": f"Bearer {api_key or os.getenv('OPENAI_API_KEY')}"}
73
+ r = requests.get(url, headers=headers, timeout=timeout)
74
+ BaseAdapter._raise_for_status_static(r)
75
+ return [m["id"] for m in r.json().get("data", [])]
76
+ if p == "anthropic":
77
+ return [
78
+ "claude-3-5-sonnet-latest",
79
+ "claude-3-5-haiku-latest",
80
+ "claude-3-opus-latest",
81
+ ]
82
+ if p == "gemini":
83
+ key = api_key or os.getenv("GEMINI_API_KEY")
84
+ if not key:
85
+ raise BaseAdapterError("Missing GEMINI_API_KEY.")
86
+ url = (base_url or "https://generativelanguage.googleapis.com") + f"/v1beta/models?key={key}"
87
+ r = requests.get(url, timeout=timeout)
88
+ BaseAdapter._raise_for_status_static(r)
89
+ return [m["name"] for m in r.json().get("models", [])]
90
+ if p == "grok":
91
+ key = api_key or os.getenv("XAI_API_KEY")
92
+ url = (base_url or "https://api.x.ai") + "/v1/models"
93
+ headers = {"Authorization": f"Bearer {key}"}
94
+ r = requests.get(url, headers=headers, timeout=timeout)
95
+ BaseAdapter._raise_for_status_static(r)
96
+ return [m["id"] for m in r.json().get("data", [])]
97
+ if p == "mistral":
98
+ url = (base_url or "https://api.mistral.ai") + "/v1/models"
99
+ headers = {"Authorization": f"Bearer {api_key or os.getenv('MISTRAL_API_KEY')}"}
100
+ r = requests.get(url, headers=headers, timeout=timeout)
101
+ BaseAdapter._raise_for_status_static(r)
102
+ return [m["id"] for m in r.json().get("data", [])]
103
+ if p == "cohere":
104
+ url = (base_url or "https://api.cohere.ai") + "/v1/models"
105
+ headers = {"Authorization": f"Bearer {api_key or os.getenv('COHERE_API_KEY')}"}
106
+ r = requests.get(url, headers=headers, timeout=timeout)
107
+ BaseAdapter._raise_for_status_static(r)
108
+ return [m["name"] for m in r.json().get("models", [])]
109
+ if p == "together":
110
+ url = (base_url or "https://api.together.xyz") + "/v1/models"
111
+ headers = {"Authorization": f"Bearer {api_key or os.getenv('TOGETHER_API_KEY')}"}
112
+ r = requests.get(url, headers=headers, timeout=timeout)
113
+ BaseAdapter._raise_for_status_static(r)
114
+ return [m["id"] for m in r.json().get("data", [])]
115
+ raise BaseAdapterError(f"Unsupported provider: {p}")
116
+
117
+ # ---------- Utilities ---------- #
118
+
119
+ @staticmethod
120
+ def _raise_for_status_static(response: requests.Response) -> None:
121
+ if 200 <= response.status_code < 300:
122
+ return
123
+ try:
124
+ detail = response.json()
125
+ msg = json.dumps(detail)
126
+ except Exception:
127
+ msg = response.text
128
+ raise BaseAdapterError(f"HTTP {response.status_code}: {msg}")
129
+
130
+ def _raise_for_status(self, response: requests.Response) -> None:
131
+ return self._raise_for_status_static(response)
132
+
133
+ @staticmethod
134
+ def _detect_mime(b: bytes) -> str:
135
+ if len(b) >= 8 and b[:8] == b"\x89PNG\r\n\x1a\n":
136
+ return "image/png"
137
+ if len(b) >= 3 and b[:3] == b"\xff\xd8\xff":
138
+ return "image/jpeg"
139
+ if len(b) >= 6 and b[:6] in (b"GIF87a", b"GIF89a"):
140
+ return "image/gif"
141
+ if len(b) >= 12 and b[8:12] == b"WEBP":
142
+ return "image/webp"
143
+ return "application/octet-stream"
144
+
145
+ @staticmethod
146
+ def _normalize_image(image: Union[str, bytes], default_mime: str = "image/png") -> tuple[str, str, str]:
147
+ """Return (data_url, base64_str, mime_type) for the given image input.
148
+ Accepts bytes, base64 string, data URL, or local file path.
149
+ """
150
+ if isinstance(image, bytes):
151
+ b64 = base64.b64encode(image).decode()
152
+ mime = BaseAdapter._detect_mime(image)
153
+ if mime == "application/octet-stream":
154
+ mime = default_mime
155
+ return f"data:{mime};base64,{b64}", b64, mime
156
+ if isinstance(image, str):
157
+ if image.startswith("data:"):
158
+ header, b64 = image.split(",", 1)
159
+ # data:image/png;base64,XXXX
160
+ mime = header.split(";")[0].split(":", 1)[1] or default_mime
161
+ return image, b64, mime
162
+ if os.path.exists(image):
163
+ with open(image, "rb") as f:
164
+ raw = f.read()
165
+ b64 = base64.b64encode(raw).decode()
166
+ mime = BaseAdapter._detect_mime(raw)
167
+ if mime == "application/octet-stream":
168
+ mime = default_mime
169
+ return f"data:{mime};base64,{b64}", b64, mime
170
+ # assume bare base64 string
171
+ b64 = image
172
+ mime = default_mime
173
+ return f"data:{mime};base64,{b64}", b64, mime
174
+ raise BaseAdapterError("Unsupported image type; pass bytes, path, base64 string, or data URL.")
175
+
176
+
177
+
178
+ class OpenaiAdapter(BaseAdapter):
179
+ provider: str
180
+ model: str
181
+ api_key: Optional[str] = None
182
+ timeout: float = 60.0
183
+ base_url: Optional[str] = None
184
+ extra_headers: Dict[str, str] = field(default_factory=dict)
185
+
186
+ def __init__(self, model_name):
187
+ super().__init__(BaseAdapter.OPENAI, model_name)
188
+
189
+ def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str:
190
+ url = (self.base_url or "https://api.openai.com") + "/v1/chat/completions"
191
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
192
+ messages = []
193
+ if system:
194
+ messages.append({"role": "system", "content": system})
195
+ content = [{"type": "text", "text": prompt}]
196
+ data_url = None
197
+ if image is not None:
198
+ if not isinstance(image, list):
199
+ image = [image]
200
+
201
+ for img in image:
202
+ data_url, _b64, _mime = self._normalize_image(img, default_mime="image/png")
203
+ content.append({"type": "image_url", "image_url": {"url": data_url}})
204
+ messages.append({"role": "user", "content": content})
205
+ payload = {"model": self.model, "messages": messages}
206
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
207
+ self._raise_for_status(r)
208
+ data = r.json()
209
+ return data["choices"][0]["message"]["content"].strip()
210
+
211
+
212
+ class AnthropicAdapter(BaseAdapter):
213
+ provider: str
214
+ model: str
215
+ api_key: Optional[str] = None
216
+ timeout: float = 60.0
217
+ base_url: Optional[str] = None
218
+ extra_headers: Dict[str, str] = field(default_factory=dict)
219
+
220
+ def __init__(self, model_name):
221
+ super().__init__(BaseAdapter.ANTHROPIC, model_name)
222
+
223
+ def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str:
224
+ url = (self.base_url or "https://api.anthropic.com") + "/v1/messages"
225
+ headers = {
226
+ "x-api-key": self.api_key or "",
227
+ "anthropic-version": "2023-06-01",
228
+ "content-type": "application/json",
229
+ **self.extra_headers,
230
+ }
231
+ content_items: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
232
+ if image is not None:
233
+ if not isinstance(image, list):
234
+ image = [image]
235
+ for img in image:
236
+ _data_url, b64, mime = self._normalize_image(img, default_mime="image/png")
237
+ content_items.append({
238
+ "type": "image",
239
+ "source": {
240
+ "type": "base64",
241
+ "media_type": mime,
242
+ "data": b64,
243
+ },
244
+ })
245
+ payload: Dict[str, Any] = {
246
+ "model": self.model,
247
+ "max_tokens": kwargs.get("max_tokens", 1024),
248
+ "messages": [{"role": "user", "content": content_items}],
249
+ }
250
+ if system:
251
+ payload["system"] = system
252
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
253
+ self._raise_for_status(r)
254
+ data = r.json()
255
+ parts = data.get("content", [])
256
+ return "".join(p.get("text", "") for p in parts if p.get("type") == "text").strip()
257
+
258
+
259
+
260
+
261
+ class GeminiAdapter(BaseAdapter):
262
+ provider: str
263
+ model: str
264
+ api_key: Optional[str] = None
265
+ timeout: float = 60.0
266
+ base_url: Optional[str] = None
267
+ extra_headers: Dict[str, str] = field(default_factory=dict)
268
+
269
+ def __init__(self, model_name):
270
+ super().__init__(BaseAdapter.GEMINI, model_name)
271
+
272
+ def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str:
273
+ key = self.api_key or os.getenv("GEMINI_API_KEY")
274
+ if not key:
275
+ raise BaseAdapterError("Missing GEMINI_API_KEY.")
276
+ base = self.base_url or "https://generativelanguage.googleapis.com"
277
+ url = f"{base}/v1beta/models/{self.model}:generateContent?key={key}"
278
+ headers = {"Content-Type": "application/json", **self.extra_headers}
279
+ parts = [{"text": prompt}]
280
+ if image is not None:
281
+ if not isinstance(image,list):
282
+ image = [image]
283
+ for img in image:
284
+ _data_url, b64, mime = self._normalize_image(img, default_mime="image/png")
285
+ parts.append({"inline_data": {"mime_type": mime, "data": b64}})
286
+ contents = [{"role": "user", "parts": parts}]
287
+ if system:
288
+ contents.insert(0, {"role": "system", "parts": [{"text": system}]})
289
+ payload: Dict[str, Any] = {"contents": contents}
290
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
291
+ self._raise_for_status(r)
292
+ data = r.json()
293
+ return data["candidates"][0]["content"]["parts"][0]["text"].strip()
294
+
295
+
296
+ class MistralAdapter(BaseAdapter):
297
+ provider: str
298
+ model: str
299
+ api_key: Optional[str] = None
300
+ timeout: float = 60.0
301
+ base_url: Optional[str] = None
302
+ extra_headers: Dict[str, str] = field(default_factory=dict)
303
+
304
+ def __init__(self, model_name):
305
+ super().__init__(BaseAdapter.MISTRAL, model_name)
306
+
307
+ def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str:
308
+ url = (self.base_url or "https://api.mistral.ai") + "/v1/chat/completions"
309
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
310
+ messages = []
311
+ if system:
312
+ messages.append({"role": "system", "content": system})
313
+ content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
314
+ if image is not None:
315
+ if not isinstance(image, list):
316
+ image = [image]
317
+ for img in image:
318
+ data_url, _b64, _mime = self._normalize_image(img, default_mime="image/png")
319
+ content.append({"type": "image_url", "image_url": {"url": data_url}})
320
+ messages.append({"role": "user", "content": content})
321
+ payload = {"model": self.model, "messages": messages}
322
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
323
+ self._raise_for_status(r)
324
+ data = r.json()
325
+ return data["choices"][0]["message"]["content"].strip()
326
+
327
+ class GrokAdapter(BaseAdapter):
328
+ provider: str
329
+ model: str
330
+ api_key: Optional[str] = None
331
+ timeout: float = 60.0
332
+ base_url: Optional[str] = None
333
+ extra_headers: Dict[str, str] = field(default_factory=dict)
334
+
335
+ def __init__(self, model_name):
336
+ super().__init__(BaseAdapter.GROK, model_name)
337
+
338
+ def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str:
339
+ url = (self.base_url or "https://api.x.ai") + "/v1/chat/completions"
340
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
341
+ messages = []
342
+ if system:
343
+ messages.append({"role": "system", "content": system})
344
+ content = [{"type": "text", "text": prompt}]
345
+ data_url = None
346
+ if image is not None:
347
+ if not isinstance(image, list):
348
+ image = [image]
349
+ for img in image:
350
+ data_url, _b64, _mime = self._normalize_image(img, default_mime="image/png")
351
+ content.append({"type": "image_url", "image_url": {"url": data_url}})
352
+ messages.append({"role": "user", "content": content})
353
+ payload = {"model": self.model, "messages": messages}
354
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
355
+ self._raise_for_status(r)
356
+ data = r.json()
357
+ return data["choices"][0]["message"]["content"].strip()
358
+
359
+
360
+
361
+
362
+
363
+ class TogetherAdapter(BaseAdapter):
364
+ provider: str
365
+ model: str
366
+ api_key: Optional[str] = None
367
+ timeout: float = 60.0
368
+ base_url: Optional[str] = None
369
+ extra_headers: Dict[str, str] = field(default_factory=dict)
370
+
371
+ def __init__(self, model_name):
372
+ super().__init__(BaseAdapter.TOGETHER, model_name)
373
+
374
+ def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str:
375
+ url = (self.base_url or "https://api.together.xyz") + "/v1/chat/completions"
376
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
377
+ messages = []
378
+ if system:
379
+ messages.append({"role": "system", "content": system})
380
+ content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
381
+ if image is not None:
382
+ if not isinstance(image, list):
383
+ image = [image]
384
+ for img in image:
385
+ data_url, _b64, _mime = self._normalize_image(img, default_mime="image/png")
386
+ content.append({"type": "image_url", "image_url": {"url": data_url}})
387
+ messages.append({"role": "user", "content": content})
388
+ payload = {"model": self.model, "messages": messages}
389
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
390
+ self._raise_for_status(r)
391
+ data = r.json()
392
+ return data["choices"][0]["message"]["content"].strip()
393
+
394
+
395
+
396
+ class CohereAdapter(BaseAdapter):
397
+ provider: str
398
+ model: str
399
+ api_key: Optional[str] = None
400
+ timeout: float = 60.0
401
+ base_url: Optional[str] = None
402
+ extra_headers: Dict[str, str] = field(default_factory=dict)
403
+
404
+ def __init__(self, model_name):
405
+ super().__init__(BaseAdapter.COHERE, model_name)
406
+
407
+ def generate(self, prompt: str, system: Optional[str] = None, image: Optional[List[Union[str, bytes, Image]] ] = None, **kwargs: Any) -> str:
408
+ url = (self.base_url or "https://api.cohere.ai") + "/v1/chat"
409
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
410
+ payload: Dict[str, Any] = {"model": self.model, "message": prompt}
411
+ if system:
412
+ payload["preamble"] = system
413
+ if image is not None:
414
+ if not isinstane(image, list):
415
+ image = [image]
416
+ for img in image:
417
+ data_url, _b64, mime = self._normalize_image(img, default_mime="image/png")
418
+ # Cohere chat supports attachments; we send a data URL to keep it dependency-light
419
+ payload["attachments"] = [
420
+ {
421
+ "type": "image",
422
+ "image_url": data_url,
423
+ "mime_type": mime,
424
+ }
425
+ ]
426
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
427
+ self._raise_for_status(r)
428
+ data = r.json()
429
+ # Cohere responses can be under 'text' or 'message.content'
430
+ return (data.get("text") or data.get("message", {}).get("content", [{}])[0].get("text", "")).strip()
431
+
432
+
433
+
434
+
435
+
436
+
437
+
438
+
439
+
440
+
441
+
442
+
443
+
444
+
445
+
446
+
447
+
448
+
449
+
450
+
451
+
452
+
453
+
454
+
455
+
456
+
457
+
458
+
459
+
460
+
461
+
462
+
463
+
464
+
465
+
466
+
467
+
468
+
469
+
470
+
471
+
472
+
473
+
474
+
475
+
476
+
477
+
478
+
479
+
480
+
481
+
482
+
483
+
484
+
485
+
486
+
487
+
488
+ '''
489
+
490
+ @dataclass
491
+ class UniLLM:
492
+ provider: str
493
+ model: str
494
+ api_key: Optional[str] = None
495
+ timeout: float = 60.0
496
+ base_url: Optional[str] = None
497
+ extra_headers: Dict[str, str] = field(default_factory=dict)
498
+
499
+ def __post_init__(self) -> None:
500
+ self.provider = self.provider.lower().strip()
501
+ if self.api_key is None:
502
+ env_keys = {
503
+ "openai": "OPENAI_API_KEY",
504
+ "anthropic": "ANTHROPIC_API_KEY",
505
+ "gemini": "GEMINI_API_KEY",
506
+ "mistral": "MISTRAL_API_KEY",
507
+ "cohere": "COHERE_API_KEY",
508
+ "together": "TOGETHER_API_KEY",
509
+ }
510
+ env_var = env_keys.get(self.provider)
511
+ if env_var:
512
+ self.api_key = os.getenv(env_var)
513
+ if not self.api_key and self.provider not in ("gemini",):
514
+ raise UniLLMError(f"Missing api_key for {self.provider}. Set via environment.")
515
+
516
+ # ---------- Public API ---------- #
517
+
518
+ def generate(self, prompt: str, system: Optional[str] = None, image: Optional[Union[str, bytes]] = None, **kwargs: Any) -> str:
519
+ p = self.provider
520
+ if p == "openai":
521
+ return self._openai_chat(prompt, system, image, **kwargs)
522
+ if p == "anthropic":
523
+ return self._anthropic_messages(prompt, system, image, **kwargs)
524
+ if p == "gemini":
525
+ return self._gemini_generate_content(prompt, system, image, **kwargs)
526
+ if p == "mistral":
527
+ return self._mistral_chat(prompt, system, image, **kwargs)
528
+ if p == "cohere":
529
+ return self._cohere_chat(prompt, system, image, **kwargs)
530
+ if p == "together":
531
+ return self._together_chat(prompt, system, image, **kwargs)
532
+ raise UniLLMError(f"Unsupported provider: {p}")
533
+
534
+ @staticmethod
535
+ def list_models(provider: str, api_key: Optional[str] = None, base_url: Optional[str] = None, timeout: float = 60.0) -> List[str]:
536
+ p = provider.lower().strip()
537
+ if p == "openai":
538
+ url = (base_url or "https://api.openai.com") + "/v1/models"
539
+ headers = {"Authorization": f"Bearer {api_key or os.getenv('OPENAI_API_KEY')}"}
540
+ r = requests.get(url, headers=headers, timeout=timeout)
541
+ UniLLM._raise_for_status_static(r)
542
+ return [m["id"] for m in r.json().get("data", [])]
543
+ if p == "anthropic":
544
+ return [
545
+ "claude-3-5-sonnet-latest",
546
+ "claude-3-5-haiku-latest",
547
+ "claude-3-opus-latest",
548
+ ]
549
+ if p == "gemini":
550
+ key = api_key or os.getenv("GEMINI_API_KEY")
551
+ if not key:
552
+ raise UniLLMError("Missing GEMINI_API_KEY.")
553
+ url = (base_url or "https://generativelanguage.googleapis.com") + f"/v1beta/models?key={key}"
554
+ r = requests.get(url, timeout=timeout)
555
+ UniLLM._raise_for_status_static(r)
556
+ return [m["name"] for m in r.json().get("models", [])]
557
+ if p == "mistral":
558
+ url = (base_url or "https://api.mistral.ai") + "/v1/models"
559
+ headers = {"Authorization": f"Bearer {api_key or os.getenv('MISTRAL_API_KEY')}"}
560
+ r = requests.get(url, headers=headers, timeout=timeout)
561
+ UniLLM._raise_for_status_static(r)
562
+ return [m["id"] for m in r.json().get("data", [])]
563
+ if p == "cohere":
564
+ url = (base_url or "https://api.cohere.ai") + "/v1/models"
565
+ headers = {"Authorization": f"Bearer {api_key or os.getenv('COHERE_API_KEY')}"}
566
+ r = requests.get(url, headers=headers, timeout=timeout)
567
+ UniLLM._raise_for_status_static(r)
568
+ return [m["name"] for m in r.json().get("models", [])]
569
+ if p == "together":
570
+ url = (base_url or "https://api.together.xyz") + "/v1/models"
571
+ headers = {"Authorization": f"Bearer {api_key or os.getenv('TOGETHER_API_KEY')}"}
572
+ r = requests.get(url, headers=headers, timeout=timeout)
573
+ UniLLM._raise_for_status_static(r)
574
+ return [m["id"] for m in r.json().get("data", [])]
575
+ raise UniLLMError(f"Unsupported provider: {p}")
576
+
577
+ # ---------- Provider helpers ---------- #
578
+
579
+ def _openai_chat(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str:
580
+ url = (self.base_url or "https://api.openai.com") + "/v1/chat/completions"
581
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
582
+ messages = []
583
+ if system:
584
+ messages.append({"role": "system", "content": system})
585
+ content = [{"type": "text", "text": prompt}]
586
+ data_url = None
587
+ if image is not None:
588
+ data_url, _b64, _mime = self._normalize_image(image, default_mime="image/png")
589
+ content.append({"type": "image_url", "image_url": {"url": data_url}})
590
+ messages.append({"role": "user", "content": content})
591
+ payload = {"model": self.model, "messages": messages}
592
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
593
+ self._raise_for_status(r)
594
+ data = r.json()
595
+ return data["choices"][0]["message"]["content"].strip()
596
+
597
+ def _anthropic_messages(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str:
598
+ url = (self.base_url or "https://api.anthropic.com") + "/v1/messages"
599
+ headers = {
600
+ "x-api-key": self.api_key or "",
601
+ "anthropic-version": "2023-06-01",
602
+ "content-type": "application/json",
603
+ **self.extra_headers,
604
+ }
605
+ content_items: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
606
+ if image is not None:
607
+ _data_url, b64, mime = self._normalize_image(image, default_mime="image/png")
608
+ content_items.append({
609
+ "type": "image",
610
+ "source": {
611
+ "type": "base64",
612
+ "media_type": mime,
613
+ "data": b64,
614
+ },
615
+ })
616
+ payload: Dict[str, Any] = {
617
+ "model": self.model,
618
+ "max_tokens": kwargs.get("max_tokens", 1024),
619
+ "messages": [{"role": "user", "content": content_items}],
620
+ }
621
+ if system:
622
+ payload["system"] = system
623
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
624
+ self._raise_for_status(r)
625
+ data = r.json()
626
+ parts = data.get("content", [])
627
+ return "".join(p.get("text", "") for p in parts if p.get("type") == "text").strip()
628
+
629
+ def _gemini_generate_content(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str:
630
+ key = self.api_key or os.getenv("GEMINI_API_KEY")
631
+ if not key:
632
+ raise UniLLMError("Missing GEMINI_API_KEY.")
633
+ base = self.base_url or "https://generativelanguage.googleapis.com"
634
+ url = f"{base}/v1beta/models/{self.model}:generateContent?key={key}"
635
+ headers = {"Content-Type": "application/json", **self.extra_headers}
636
+ parts = [{"text": prompt}]
637
+ if image is not None:
638
+ _data_url, b64, mime = self._normalize_image(image, default_mime="image/png")
639
+ parts.append({"inline_data": {"mime_type": mime, "data": b64}})
640
+ contents = [{"role": "user", "parts": parts}]
641
+ if system:
642
+ contents.insert(0, {"role": "system", "parts": [{"text": system}]})
643
+ payload: Dict[str, Any] = {"contents": contents}
644
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
645
+ self._raise_for_status(r)
646
+ data = r.json()
647
+ return data["candidates"][0]["content"]["parts"][0]["text"].strip()
648
+
649
+ def _mistral_chat(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str:
650
+ url = (self.base_url or "https://api.mistral.ai") + "/v1/chat/completions"
651
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
652
+ messages = []
653
+ if system:
654
+ messages.append({"role": "system", "content": system})
655
+ content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
656
+ if image is not None:
657
+ data_url, _b64, _mime = self._normalize_image(image, default_mime="image/png")
658
+ content.append({"type": "image_url", "image_url": {"url": data_url}})
659
+ messages.append({"role": "user", "content": content})
660
+ payload = {"model": self.model, "messages": messages}
661
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
662
+ self._raise_for_status(r)
663
+ data = r.json()
664
+ return data["choices"][0]["message"]["content"].strip()
665
+
666
+ def _cohere_chat(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str:
667
+ url = (self.base_url or "https://api.cohere.ai") + "/v1/chat"
668
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
669
+ payload: Dict[str, Any] = {"model": self.model, "message": prompt}
670
+ if system:
671
+ payload["preamble"] = system
672
+ if image is not None:
673
+ data_url, _b64, mime = self._normalize_image(image, default_mime="image/png")
674
+ # Cohere chat supports attachments; we send a data URL to keep it dependency-light
675
+ payload["attachments"] = [
676
+ {
677
+ "type": "image",
678
+ "image_url": data_url,
679
+ "mime_type": mime,
680
+ }
681
+ ]
682
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
683
+ self._raise_for_status(r)
684
+ data = r.json()
685
+ # Cohere responses can be under 'text' or 'message.content'
686
+ return (data.get("text") or data.get("message", {}).get("content", [{}])[0].get("text", "")).strip()
687
+
688
+ def _together_chat(self, prompt: str, system: Optional[str], image: Optional[Union[str, bytes]], **kwargs: Any) -> str:
689
+ url = (self.base_url or "https://api.together.xyz") + "/v1/chat/completions"
690
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", **self.extra_headers}
691
+ messages = []
692
+ if system:
693
+ messages.append({"role": "system", "content": system})
694
+ content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
695
+ if image is not None:
696
+ data_url, _b64, _mime = self._normalize_image(image, default_mime="image/png")
697
+ content.append({"type": "image_url", "image_url": {"url": data_url}})
698
+ messages.append({"role": "user", "content": content})
699
+ payload = {"model": self.model, "messages": messages}
700
+ r = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
701
+ self._raise_for_status(r)
702
+ data = r.json()
703
+ return data["choices"][0]["message"]["content"].strip()
704
+
705
+ # ---------- Utilities ---------- #
706
+
707
+ @staticmethod
708
+ def _raise_for_status_static(response: requests.Response) -> None:
709
+ if 200 <= response.status_code < 300:
710
+ return
711
+ try:
712
+ detail = response.json()
713
+ msg = json.dumps(detail)
714
+ except Exception:
715
+ msg = response.text
716
+ raise UniLLMError(f"HTTP {response.status_code}: {msg}")
717
+
718
+ def _raise_for_status(self, response: requests.Response) -> None:
719
+ return self._raise_for_status_static(response)
720
+
721
+ @staticmethod
722
+ def _detect_mime(b: bytes) -> str:
723
+ if len(b) >= 8 and b[:8] == b"\x89PNG\r\n\x1a\n":
724
+ return "image/png"
725
+ if len(b) >= 3 and b[:3] == b"\xff\xd8\xff":
726
+ return "image/jpeg"
727
+ if len(b) >= 6 and b[:6] in (b"GIF87a", b"GIF89a"):
728
+ return "image/gif"
729
+ if len(b) >= 12 and b[8:12] == b"WEBP":
730
+ return "image/webp"
731
+ return "application/octet-stream"
732
+
733
+ @staticmethod
734
+ def _normalize_image(image: Union[str, bytes], default_mime: str = "image/png") -> tuple[str, str, str]:
735
+ """Return (data_url, base64_str, mime_type) for the given image input.
736
+ Accepts bytes, base64 string, data URL, or local file path.
737
+ """
738
+ if isinstance(image, bytes):
739
+ b64 = base64.b64encode(image).decode()
740
+ mime = UniLLM._detect_mime(image)
741
+ if mime == "application/octet-stream":
742
+ mime = default_mime
743
+ return f"data:{mime};base64,{b64}", b64, mime
744
+ if isinstance(image, str):
745
+ if image.startswith("data:"):
746
+ header, b64 = image.split(",", 1)
747
+ # data:image/png;base64,XXXX
748
+ mime = header.split(";")[0].split(":", 1)[1] or default_mime
749
+ return image, b64, mime
750
+ if os.path.exists(image):
751
+ with open(image, "rb") as f:
752
+ raw = f.read()
753
+ b64 = base64.b64encode(raw).decode()
754
+ mime = BaseAdapter._detect_mime(raw)
755
+ if mime == "application/octet-stream":
756
+ mime = default_mime
757
+ return f"data:{mime};base64,{b64}", b64, mime
758
+ # assume bare base64 string
759
+ b64 = image
760
+ mime = default_mime
761
+ return f"data:{mime};base64,{b64}", b64, mime
762
+ raise BaseAdapter("Unsupported image type; pass bytes, path, base64 string, or data URL.")
763
+
764
+ '''
app.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # reCAPTCHA‑style 3×3 Demo (Streamlit) — Proof of Concept
3
+ # --------------------------------------------------------
4
+ # - Build challenges from a TSV (columns: image [base64], answer)
5
+ # - Same compact, natural‑size 3×3 layout for EVERY challenge
6
+ # - Manual mode: clickable tiles with baked‑in border + ✓ (works inside iframe)
7
+ # - Model modes: same layout (static), then run adapters
8
+
9
+ from __future__ import annotations
10
+ import io
11
+ import re
12
+ import base64
13
+ import random
14
+ from dataclasses import dataclass
15
+ from typing import List, Dict, Callable, Optional, Tuple, Union
16
+
17
+ import streamlit as st
18
+ from PIL import Image, ImageDraw
19
+ import pandas as pd
20
+ from io import BytesIO
21
+
22
+ import base64
23
+
24
+ from config import *
25
+ from utils import *
26
+ from adapter import *
27
+ # -----------------------------
28
+ # Constants & Utilities
29
+ # -----------------------------
30
+
31
+ IM_HEIGHT,IM_WIDTH = 256,256
32
+
33
+
34
+
35
+
36
+ class ManualAdapter(BaseAdapter):
37
+ name = "Manual"
38
+ def __init__(self, manual_selection: List[int]):
39
+ self.manual_selection = manual_selection
40
+ def solve(self, images, category, prompt_type, available_categories):
41
+ return InferenceResult(selected_ids=sorted(self.manual_selection), raw_outputs={})
42
+
43
+
44
+
45
+
46
+ class LLMadapter(BaseAdapter):
47
+
48
+ def __init__(self, provider, model_name, system:Optional[str]=None ):
49
+ assert provider in BaseAdapter.providers
50
+ #model_list = BaseAdapter.list_models(provider)
51
+ #assert model_name in model_list, f'{model_name} not found for provider: {provider}\nAvailable models:\n{model_list}'
52
+ self.adapter = LLMadapter.get_provider_class(provider)(model_name)
53
+ self.system = system
54
+ def generate(self, prompt, image):
55
+ out = self.adapter.generate(prompt=prompt, image=image, system=self.system)
56
+ return out
57
+
58
+
59
+ def get_provider_class(provider):
60
+ p = provider.lower().strip()
61
+ if p == BaseAdapter.OPENAI:
62
+ return OpenaiAdapter
63
+ if p == BaseAdapter.ANTHROPIC:
64
+ return AnthropicAdapter
65
+ if p == BaseAdapter.GEMINI:
66
+ return GeminiAdapter
67
+ if p == BaseAdapter.GROK:
68
+ return GrokAdapter
69
+ if p == BaseAdapter.MISTRAL:
70
+ return MistralAdapter
71
+ if p == BaseAdapter.COHERE:
72
+ return CohereAdapter
73
+ if p == BaseAdapter.TOGETHER:
74
+ return TogetherAdapter
75
+ raise BaseAdapterError(f"Unsupported provider: {p}")
76
+
77
+
78
+
79
+
80
+ # -----------------------------
81
+ # Data loading & challenge sampling
82
+ # -----------------------------
83
+
84
+
85
+
86
+
87
+ def make_challenge(df: pd.DataFrame, target: str | None, pos_fraction: float = 0.45):
88
+ cats = sorted(df["answer_norm"].unique())
89
+ if not cats: raise ValueError("No categories found in TSV 'answer' column")
90
+ if target is None or target == "__RANDOM__":
91
+ target = random.choice(cats)
92
+
93
+ pos = df[df["answer_norm"] == target]
94
+ neg = df[df["answer_norm"] != target]
95
+ if len(pos) == 0:
96
+ sampled = df.sample(min(9, len(df)))
97
+ else:
98
+ n_pos = max(1, min(len(pos), int(round(9 * pos_fraction))))
99
+ n_neg = max(0, 9 - n_pos)
100
+ pos_s = pos.sample(min(n_pos, len(pos)))
101
+ neg_s = neg.sample(min(n_neg, len(neg))) if n_neg > 0 and len(neg) > 0 else df.iloc[0:0]
102
+ sampled = pd.concat([pos_s, neg_s]).sample(frac=1.0)
103
+ if len(sampled) < 9 and len(df) > len(sampled):
104
+ extra = df.drop(sampled.index).sample(min(9 - len(sampled), len(df) - len(sampled)))
105
+ sampled = pd.concat([sampled, extra]).sample(frac=1.0)
106
+
107
+ sampled = sampled.head(9).copy()
108
+ ids = sampled["index"].astype(str).tolist()
109
+ answers = sampled["answer_norm"].tolist()
110
+ images = [decode_base64_image(b) for b in sampled["image"].tolist()]
111
+ return images, answers, target, ids
112
+
113
+
114
+
115
+ # -----------------------------
116
+ # UI helpers — consistent 3×3 layout
117
+ # -----------------------------
118
+ from PIL import ImageDraw
119
+
120
+ def bake_selection(img, selected: bool, color=(37, 99, 235), thickness: int = 8):
121
+ if not selected:
122
+ return img
123
+ im = img.copy()
124
+ d = ImageDraw.Draw(im)
125
+ w, h = im.size
126
+ t = max(2, min(thickness, max(w, h)//32)) # adaptive thickness helps small tiles
127
+ for k in range(t):
128
+ d.rectangle([k, k, w-1-k, h-1-k], outline=color, width=1)
129
+ # Optional: ✓ badge
130
+ r = max(12, min(22, w//12))
131
+ x, y = w - r - 8, 8
132
+ d.ellipse([x, y, x+r, y+r], fill=color)
133
+ d.line([x + r*0.25, y + r*0.55, x + r*0.45, y + r*0.75], fill=(255,255,255), width=max(2, r//6))
134
+ d.line([x + r*0.45, y + r*0.75, x + r*0.80, y + r*0.30], fill=(255,255,255), width=max(2, r//6))
135
+ return im
136
+
137
+ def render_grid_clickable(images, selected_ids: set):
138
+ from st_clickable_images import clickable_images
139
+ data_uris = []
140
+ for i, im in enumerate(images, start=1):
141
+ im = im.resize((IM_HEIGHT,IM_WIDTH))
142
+ vis = bake_selection(im, (i in selected_ids)) # <-- border baked here
143
+ buf = io.BytesIO(); vis.save(buf, format="PNG")
144
+ b64 = base64.b64encode(buf.getvalue()).decode()
145
+ data_uris.append("data:image/png;base64," + b64)
146
+
147
+ clicked = clickable_images(
148
+ data_uris,
149
+ titles=[str(i) for i in range(1, len(data_uris)+1)],
150
+ div_style={
151
+ "display": "grid",
152
+ "gridTemplateColumns": "repeat(3, max-content)",
153
+ "gap": "6px",
154
+ "justifyContent": "start",
155
+ "width": "fit-content",
156
+ },
157
+ img_style={
158
+ "width": "auto",
159
+ "height": "auto",
160
+ "maxWidth": "100%",
161
+ "borderRadius": "8px",
162
+ "boxSizing": "border-box",
163
+ "cursor": "pointer",
164
+ },
165
+ key=f"tile_clicks_{st.session_state.click_nonce}", # <-- important
166
+ )
167
+ return clicked if isinstance(clicked, int) and clicked >= 0 else None
168
+
169
+ def render_grid_static(images: List[Image.Image], selected_ids: set):
170
+ # build rows, 3 tiles per row
171
+ for row in chunk(list(enumerate(images, start=1)), 3):
172
+ cols = st.columns(3, gap="small") # <-- move inside the loop
173
+ for c, (idx, im) in enumerate(row):
174
+ with cols[c]:
175
+ vis = bake_selection(im, (idx in selected_ids))
176
+ # Option A: let Streamlit size it
177
+ #st.image(vis, caption=str(idx))
178
+ # Option B (uniform tiles): uncomment to normalize size
179
+ st.image(vis.resize((IM_WIDTH, IM_HEIGHT)), caption=str(idx))
180
+
181
+ def render_grid_static(images, selected_ids: set):
182
+ thumbs = []
183
+ for i, im in enumerate(images, 1):
184
+ im = im.resize((IM_WIDTH, IM_HEIGHT)) # (width, height)
185
+ vis = bake_selection(im, i in selected_ids)
186
+ buf = io.BytesIO(); vis.save(buf, format="PNG")
187
+ b64 = base64.b64encode(buf.getvalue()).decode()
188
+ thumbs.append(f'<figure><img src="data:image/png;base64,{b64}"><figcaption>{i}</figcaption></figure>')
189
+
190
+ html = f"""
191
+ <div style="
192
+ display:grid;
193
+ grid-template-columns: repeat(3, max-content);
194
+ gap:6px; justify-content:start; width:fit-content;">
195
+ {''.join(thumbs)}
196
+ </div>
197
+ <style>
198
+ figure {{ margin:0; }}
199
+ figcaption {{ text-align:center; font-size:0.8rem; margin-top:0.2rem; }}
200
+ img {{ border-radius:8px; box-sizing:border-box; }}
201
+ </style>
202
+ """
203
+ st.markdown(html, unsafe_allow_html=True)
204
+
205
+ # -----------------------------
206
+ # Streamlit App
207
+ # -----------------------------
208
+
209
+ st.set_page_config(page_title="reCAPTCHA‑style 3×3 — PoC", layout="wide")
210
+
211
+ # Compact layout & natural-size images (Streamlit native widgets)
212
+ st.markdown(
213
+ """
214
+ <style>
215
+ [data-testid="stHorizontalBlock"] { gap: 0.4rem !important; }
216
+ div[data-testid="stImage"] img { width: auto !important; max-width: none !important; height: auto; }
217
+ div[data-testid="stImage"] figure { width: fit-content !important; margin: 0.1rem auto; }
218
+ div[data-testid="stImage"] figcaption { margin-top: 0.2rem !important; }
219
+ </style>
220
+ """,
221
+ unsafe_allow_html=True,
222
+ )
223
+
224
+ st.title("reCAPTCHA‑style 3×3 Demo — Proof of Concept")
225
+ st.caption("Generate a challenge from TSV, then solve manually or with a model adapter.")
226
+
227
+ # Session state
228
+ for key, default in {
229
+ # existing keys...
230
+ "dataset": None,
231
+ "dataset_modified": None, # NEW
232
+ "categories": [],
233
+ "challenge_images_original": [], # NEW
234
+ "challenge_images_modified": [], # NEW
235
+ "challenge_answers": [],
236
+ "challenge_target": None,
237
+ "challenge_ids": [], # NEW
238
+ "tile_selected": set(),
239
+ "click_nonce": 0,
240
+ "last_clicked_processed": -1,
241
+ "auto_selected_ids": set(),
242
+ "image_view": "Original", # NEW: "Original" | "Modified"
243
+ }.items():
244
+ if key not in st.session_state:
245
+ st.session_state[key] = default
246
+
247
+
248
+ # 2) Use a placeholder for the grid
249
+ grid_ph = st.empty()
250
+ # Sidebar
251
+
252
+ # ---- sensible defaults in session ----
253
+ if "provider" not in st.session_state:
254
+ st.session_state.provider = "Manual" # start in Manual mode
255
+ if "model" not in st.session_state:
256
+ st.session_state.model = None
257
+
258
+
259
+ df_base = load_private_tsv("imageaction__recaptcha_dataset.tsv")
260
+ df_mod = load_private_tsv("imageaction__captcha@SPEC-1de6b70ae2f0.tsv")
261
+ st.session_state.dataset = df_base
262
+ st.session_state.dataset_modified = df_mod
263
+ st.session_state.categories = sorted(df_base["answer_norm"].unique())
264
+ # Sidebar
265
+ with st.sidebar:
266
+ st.subheader("Challenge Settings")
267
+
268
+ target_mode = st.selectbox("Target category mode", ["Pick specific", "Random each time"], index=0)
269
+ if target_mode == "Pick specific":
270
+ target_category = st.selectbox(
271
+ "Target category",
272
+ st.session_state.categories if st.session_state.categories else ["(load TSV first)"]
273
+ )
274
+ chosen_target = target_category if st.session_state.categories else None
275
+ else:
276
+ chosen_target = "__RANDOM__"
277
+
278
+ prompt_type_label = st.selectbox("Prompt type", list(PROMPT_TYPES.keys()), index=1)
279
+ prompt_type = PROMPT_TYPES[prompt_type_label]
280
+
281
+ st.markdown("---")
282
+ st.subheader("Solver")
283
+
284
+ # 1) Provider first (include Manual + all providers from your dict)
285
+ provider_options = ["Manual"] + list(MODEL_PROVIDERS.keys())
286
+ try:
287
+ provider_idx = provider_options.index(st.session_state.provider)
288
+ except ValueError:
289
+ provider_idx = 0 # fallback to Manual if prior value is missing
290
+
291
+ st.session_state.provider = st.selectbox("Provider", provider_options, index=provider_idx)
292
+
293
+ # 2) Model (enabled only when provider != Manual)
294
+ if st.session_state.provider == "Manual":
295
+ st.session_state.model = None
296
+ st.selectbox("Model", ["(not required in Manual mode)"], index=0, disabled=True)
297
+ st.caption("Manual mode: click tiles to select. No model needed.")
298
+ else:
299
+ models_for_provider = MODEL_PROVIDERS.get(st.session_state.provider, [])
300
+ # Keep previously selected model if still valid; otherwise default to first/empty
301
+ if not models_for_provider:
302
+ st.session_state.model = None
303
+ st.selectbox("Model", ["(no models available for this provider)"], index=0, disabled=True)
304
+ else:
305
+ if st.session_state.model not in models_for_provider:
306
+ st.session_state.model = models_for_provider[0]
307
+ model_idx = models_for_provider.index(st.session_state.model)
308
+ st.session_state.model = st.selectbox("Model", models_for_provider, index=model_idx)
309
+
310
+
311
+ # Generate new challenge
312
+ colA, colB = st.columns([1,2])
313
+ with colA:
314
+ gen = st.button("🎲 Generate new challenge", use_container_width=True, disabled=(st.session_state.dataset is None))
315
+
316
+ if gen:
317
+ with st.spinner("Sampling images…"):
318
+ images_orig, answers, tgt, ids = make_challenge(st.session_state.dataset, chosen_target)
319
+ st.session_state.challenge_images_original = images_orig
320
+ st.session_state.challenge_answers = answers
321
+ st.session_state.challenge_target = tgt
322
+ st.session_state.challenge_ids = ids
323
+ st.session_state.tile_selected = set()
324
+ st.session_state.last_clicked_processed = -1
325
+ st.session_state.click_nonce = 0
326
+ st.session_state.auto_selected_ids = set()
327
+
328
+ # Build modified images in the SAME ORDER by id (if modified dataset present)
329
+ st.session_state.challenge_images_modified = []
330
+ if st.session_state.dataset_modified is not None:
331
+ mod_map = st.session_state.dataset_modified.set_index("index")["image"].to_dict()
332
+ miss = []
333
+ for _id in ids:
334
+ b64 = mod_map.get(str(_id))
335
+ if b64 is None:
336
+ miss.append(_id)
337
+ # fallback to original tile if missing
338
+ st.session_state.challenge_images_modified.append(
339
+ st.session_state.challenge_images_original[len(st.session_state.challenge_images_modified)]
340
+ )
341
+ else:
342
+ st.session_state.challenge_images_modified.append(decode_base64_image(b64))
343
+ if miss:
344
+ st.warning(f"Modified TSV is missing {len(miss)} ids used in this challenge; those tiles fall back to original.")
345
+ else:
346
+ st.session_state.challenge_images_modified = [] # not available
347
+
348
+ st.success("New challenge ready. Target: " + str(st.session_state.challenge_target))
349
+
350
+ # Main area
351
+ if st.session_state.challenge_images_original:
352
+ st.subheader("3×3 Grid — Target: **" + str(st.session_state.challenge_target) + "** (Indices 1..9)")
353
+
354
+ # Toggle between Original and Modified
355
+ options = ["Original"]
356
+ if st.session_state.challenge_images_modified:
357
+ options.append("Modified")
358
+ st.session_state.image_view = st.radio(
359
+ "Image set", options, horizontal=True, index=0 if st.session_state.image_view not in options else options.index(st.session_state.image_view)
360
+ )
361
+
362
+ images_to_show = (st.session_state.challenge_images_modified
363
+ if st.session_state.image_view == "Modified" and st.session_state.challenge_images_modified
364
+ else st.session_state.challenge_images_original)
365
+
366
+ if st.session_state.provider == "Manual":
367
+ try:
368
+ clicked = render_grid_clickable(images_to_show, st.session_state.tile_selected)
369
+ if clicked is not None:
370
+ tile_id = clicked + 1
371
+ if tile_id in st.session_state.tile_selected:
372
+ st.session_state.tile_selected.remove(tile_id)
373
+ else:
374
+ st.session_state.tile_selected.add(tile_id)
375
+ st.session_state.click_nonce += 1
376
+ st.rerun()
377
+ except Exception:
378
+ st.info("Install optional dependency: pip install st-clickable-images")
379
+ render_grid_static(images_to_show, st.session_state.tile_selected)
380
+ else:
381
+ render_grid_static(images_to_show, st.session_state.auto_selected_ids)
382
+
383
+
384
+
385
+ st.markdown("---")
386
+
387
+ # Build adapter
388
+ if st.session_state.provider == "Manual":
389
+ adapter = ManualAdapter(manual_selection=sorted(st.session_state.tile_selected)) #ADAPTERS[model_choice](manual_selection=sorted(st.session_state.tile_selected))
390
+ else:
391
+ #adapter = MODEL_ADAPTERS[st.session_state.provider](st.session_state.model)
392
+ adapter = LLMadapter(st.session_state.provider, st.session_state.model)
393
+ # Prompts Preview
394
+ st.subheader("Prompts Preview")
395
+ cats_for_prompt = st.session_state.categories if st.session_state.categories else []
396
+ if prompt_type == 1:
397
+ st.code(build_prompt_1(st.session_state.challenge_target))
398
+ elif prompt_type == 2:
399
+ st.code(build_prompt_2(cats_for_prompt))
400
+ else:
401
+ raise Exception()
402
+
403
+
404
+ if st.button("Run Solver", use_container_width=True):
405
+ images_for_inference = (st.session_state.challenge_images_modified
406
+ if st.session_state.image_view == "Modified" and st.session_state.challenge_images_modified
407
+ else st.session_state.challenge_images_original)
408
+
409
+ with st.spinner("Running solver…"):
410
+ if prompt_type == 1:
411
+ prompt = build_prompt_1(st.session_state.challenge_target)
412
+ output_parse_fn = parse_prompt_1
413
+ elif prompt_type == 2:
414
+ prompt = build_prompt_2(cats_for_prompt)
415
+ output_parse_fn = parse_prompt_2
416
+ else:
417
+ raise Exception()
418
+
419
+ preds, raw_preds = [], []
420
+ if st.session_state.provider == 'Manual':
421
+ selected_ids = [i for i in st.session_state.tile_selected]
422
+ raw_preds = [ ans if (i+1) in selected_ids else 'Other' for i,ans in enumerate(st.session_state.challenge_answers) ]
423
+ preds = [ st.session_state.challenge_target == pred for pred in raw_preds ]
424
+ else:
425
+ challenge_images_b64 = [encode_base64_image(img) for img in images_for_inference]
426
+
427
+ for image_b64 in challenge_images_b64:
428
+ result = adapter.generate(prompt=prompt, image=image_b64)
429
+ outcome = output_parse_fn(result)
430
+ raw_preds.append(outcome)
431
+ preds.append(outcome)
432
+
433
+ selected_ids = [i+1 for i, outcome in enumerate(preds) if outcome]
434
+ st.session_state.auto_selected_ids = set(selected_ids) if st.session_state.provider != "Manual" else set()
435
+ st.success("Done.")
436
+ st.subheader("Selected IDs")
437
+ st.write(selected_ids)
438
+
439
+ if st.session_state.provider != "Manual":
440
+ st.subheader("Prediction overlay")
441
+ render_grid_static(images_for_inference, st.session_state.auto_selected_ids)
442
+
443
+ # evaluation uses the *original ground truth labels* (ids don’t change)
444
+ challenge_gt = [ans == st.session_state.challenge_target for ans in st.session_state.challenge_answers]
445
+ challenge_pairs = list(zip(challenge_gt, preds))
446
+ tp = sum(pred == gt for gt, pred in challenge_pairs if gt)
447
+ true_count = sum(gt for gt, _ in challenge_pairs)
448
+ fn = sum(gt != pred for gt, pred in challenge_pairs if gt)
449
+ fp = sum(pred != gt for gt, pred in challenge_pairs if not gt)
450
+ tn = sum(pred == gt for gt, pred in challenge_pairs if not gt)
451
+
452
+ st.subheader(f"Recall: {tp/(tp+fn) if (tp+fn) else 0.0} # Found {tp}/{true_count}")
453
+ if raw_preds:
454
+ st.subheader("Raw Model Outputs")
455
+ for idx, (gt, pred) in enumerate(zip(st.session_state.challenge_answers, raw_preds)):
456
+ st.markdown(f"**Category: {gt} — Expected: {gt == st.session_state.challenge_target}**")
457
+ st.code(f"Prediction: {pred}", language="text")
458
+
459
+
460
+ with st.expander("Debug: ground‑truth categories per tile", expanded=False):
461
+ grid_truth = [str(i) + ": " + lbl for i, lbl in enumerate(st.session_state.challenge_answers, start=1)]
462
+ st.write(", ".join(grid_truth))
463
+ else:
464
+ st.info("Upload a TSV on the left and click 'Generate new challenge' to begin.")
465
+
466
+
467
+ # -----------------------------
468
+ # Integrations Guide (trimmed)
469
+ # -----------------------------
470
+ with st.expander("Integrations Guide: Wiring real models", expanded=False):
471
+ st.markdown(
472
+ """
473
+ Replace the mock call functions with real SDK calls (OpenAI/Anthropic/HF).
474
+ For CLIP zero‑shot, wire a predict_fn that returns (label, score) per image.
475
+ """
476
+ )
config.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from adapter import *
4
+
5
+
6
+
7
+ HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
8
+ PRIVATE_DATASET_REPO = "Chris1/recaptcha_datasets"
9
+
10
+ PROMPT_TYPES = {
11
+ "Binary per tile (yes/no)": 1,
12
+ "Multiclass per tile (class name)": 2,
13
+ }
14
+
15
+ MODEL_PROVIDERS = [
16
+ "Manual",
17
+ BaseAdapter.OPENAI,
18
+ BaseAdapter.ANTHROPIC,
19
+ BaseAdapter.GEMINI,
20
+ BaseAdapter.MISTRAL,
21
+ BaseAdapter.GROK,
22
+ ]
23
+
24
+ MISTRAL_MODELS = ['mistral-medium-latest']
25
+
26
+ GROK_MODELS = [
27
+ 'grok-4-0709',
28
+ 'grok-4-fast-reasoning'
29
+ ]
30
+
31
+ ANTHROPIC_MODELS = [
32
+ 'claude-4-opus-20250514',
33
+ 'claude-opus-4-1-20250805',
34
+ 'claude-sonnet-4-5-20250929',
35
+ 'claude-haiku-4-5-20251001',
36
+ 'claude-4-sonnet-20250514']
37
+
38
+ GEMINI_MODELS = [
39
+ 'gemini-1.0-pro',
40
+ 'gemini-1.5-pro',
41
+ 'gemini-1.5-flash',
42
+ 'gemini-1.5-pro-002',
43
+ 'gemini-2.0-flash',
44
+ 'gemini-2.0-flash-lite',
45
+ 'gemini-2.5-flash',
46
+ 'gemini-2.5-pro'
47
+ ]
48
+
49
+ OPENAI_MODELS = [
50
+ 'gpt-4o-2024-11-20',
51
+ 'gpt-4o-mini-2024-07-18',
52
+ 'gpt-4.5-preview-2025-02-27',
53
+ 'gpt-4.1-2025-04-14',
54
+ 'gpt-5-2025-08-07',
55
+ 'gpt-5-mini-2025-08-07',
56
+ 'gpt-5-nano-2025-08-07'
57
+ ]
58
+
59
+
60
+ MODEL_PROVIDERS = {
61
+ BaseAdapter.OPENAI : OPENAI_MODELS ,
62
+ BaseAdapter.ANTHROPIC : ANTHROPIC_MODELS,
63
+ BaseAdapter.GEMINI : GEMINI_MODELS,
64
+ BaseAdapter.MISTRAL : MISTRAL_MODELS,
65
+ BaseAdapter.GROK : GROK_MODELS,
66
+ #BaseAdapter.COHERE : [],
67
+ #BaseAdapter.TOGETHER : []
68
+ }
69
+
70
+ MODEL_ADAPTERS = {
71
+ BaseAdapter.OPENAI : OPENAI_MODELS ,
72
+ BaseAdapter.ANTHROPIC : ANTHROPIC_MODELS,
73
+ BaseAdapter.GEMINI : GEMINI_MODELS,
74
+ BaseAdapter.MISTRAL : MISTRAL_MODELS,
75
+ BaseAdapter.GROK : GROK_MODELS,
76
+ #BaseAdapter.COHERE : [],
77
+ #BaseAdapter.TOGETHER : []
78
+ }
79
+
80
+
81
+
82
+
83
+ # -----------------------------
84
+ # Prompt Builders & Parsers
85
+ # -----------------------------
86
+
87
+ def build_prompt_3(category: str) -> str:
88
+ return (
89
+ "Select the images which are of the category '" + category + "' "
90
+ "from left to right, top to bottom, indexed 1 to 9. "
91
+ "Return only the valid numbers separated by spaces or commas."
92
+ )
93
+
94
+
95
+ def build_prompt_1(category: str) -> str:
96
+ return "Is the object in the image a '" + category + "'? Answer with yes or no only."
97
+
98
+
99
+ def build_prompt_2(categories: List[str]) -> str:
100
+ cats = ", ".join(categories)
101
+ return (
102
+ "Predict the category of the provided image among the set of categories: "
103
+ + cats + ". Return exactly and only the class name."
104
+ )
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- altair
 
2
  pandas
3
- streamlit
 
 
 
 
 
1
+ streamlit
2
+ st-clickable-images
3
  pandas
4
+ Pillow
5
+ huggingface_hub
6
+ openai
7
+ anthropic
8
+ google-generativeai
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import pandas as pd
4
+ import numpy as np
5
+ import string
6
+ from uuid import uuid4
7
+ import os.path as osp
8
+ import base64
9
+ from PIL import Image
10
+ import sys
11
+
12
+
13
+
14
+
15
+ import os
16
+ import pandas as pd
17
+ from huggingface_hub import hf_hub_download
18
+ import streamlit as st
19
+
20
+ from config import *
21
+
22
+
23
+
24
+ def normalize_label(s: str) -> str:
25
+ return " ".join(s.strip().lower().split())
26
+
27
+ @st.cache_data(show_spinner=False)
28
+ def load_private_tsv(filename: str) -> pd.DataFrame:
29
+ """Download a TSV file from a private HF dataset repo."""
30
+ local_path = hf_hub_download(
31
+ repo_id=PRIVATE_DATASET_REPO,
32
+ repo_type="dataset",
33
+ filename=filename,
34
+ token=HF_TOKEN,
35
+ )
36
+ df = pd.read_csv(local_path, sep="\t")
37
+ df = df[["index","image", "answer"]].dropna()
38
+ df["answer_norm"] = df["answer"].str.strip().str.lower()
39
+ # enforce string ids to avoid type mismatches
40
+ df["index"] = df["index"].astype(str)
41
+ return df
42
+
43
+
44
+ def load_dataset_from_tsv(upload) -> pd.DataFrame:
45
+ df = pd.read_csv(upload, sep="\t")
46
+ required = {"index", "image", "answer"}
47
+ missing = required - set(df.columns)
48
+ if missing:
49
+ raise ValueError(f"TSV must contain {sorted(required)}. Missing: {sorted(missing)}")
50
+
51
+ df = df[["index", "image", "answer"]].dropna()
52
+ df["answer_norm"] = df["answer"].apply(normalize_label)
53
+ # enforce string ids to avoid type mismatches
54
+ df["index"] = df["index"].astype(str)
55
+ return df
56
+
57
+ class ParseError(Exception):
58
+ pass
59
+
60
+
61
+
62
+
63
+
64
+ def parse_prompt1_indices(text: str) -> List[int]:
65
+ nums = re.findall(r"[1-9]", text)
66
+ return sorted(set(int(n) for n in nums))
67
+
68
+
69
+ def parse_prompt_1(text: str) -> bool:
70
+ t = normalize_label(text)
71
+ if t in {"yes", "y"}: return True
72
+ if t in {"no", "n"}: return False
73
+ if t.startswith("yes"): return True
74
+ if t.startswith("no"): return False
75
+ raise ParseError("Unclear yes/no response")
76
+
77
+
78
+ def parse_prompt_2(text: str, target: str) -> bool:
79
+ return text == target #normalize_label(text) == normalize_label(target)
80
+
81
+
82
+
83
+
84
+ def chunk(lst, n):
85
+ for i in range(0, len(lst), n):
86
+ yield lst[i:i+n]
87
+
88
+
89
+
90
+
91
+ def encode_base64_image(image: Image.Image) -> str:
92
+ buf = io.BytesIO()
93
+ image.save(buf, format="PNG") # or "PNG"/"WEBP" as you choose
94
+ img_bytes = buf.getvalue()
95
+ data_b64 = base64.b64encode(img_bytes).decode("ascii")
96
+ return data_b64
97
+
98
+ def decode_base64_image(b64: str) -> Image.Image:
99
+ if "," in b64 and b64.strip().lower().startswith("data:"):
100
+ b64 = b64.split(",", 1)[1]
101
+ data = base64.b64decode(b64)
102
+ return Image.open(io.BytesIO(data)).convert("RGB")
103
+
104
+
105
+
106
+