davidjurgens commited on
Commit
aceb1b2
·
verified ·
1 Parent(s): e6c39d1

Deploy: Potato — Codebook Annotation

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. Dockerfile +41 -0
  3. README.md +42 -5
  4. annotation_output/.gitkeep +0 -0
  5. config.yaml +27 -0
  6. data/codebook-example.json +7 -0
  7. entrypoint.sh +30 -0
  8. layouts/task_layout.html +90 -0
  9. potato/__init__.py +45 -0
  10. potato/__main__.py +3 -0
  11. potato/active_learning_manager.py +1623 -0
  12. potato/adjudication.py +1224 -0
  13. potato/adjudication_export.py +162 -0
  14. potato/admin.py +0 -0
  15. potato/agent_proxy/__init__.py +44 -0
  16. potato/agent_proxy/base.py +138 -0
  17. potato/agent_proxy/coding_proxy.py +466 -0
  18. potato/agent_proxy/echo_proxy.py +55 -0
  19. potato/agent_proxy/http_proxy.py +108 -0
  20. potato/agent_proxy/openai_proxy.py +105 -0
  21. potato/agent_proxy/sandbox.py +76 -0
  22. potato/agent_proxy/session.py +119 -0
  23. potato/agent_runner.py +1008 -0
  24. potato/agent_runner_manager.py +226 -0
  25. potato/agreement.py +278 -0
  26. potato/ai/__init__.py +1 -0
  27. potato/ai/ai_cache.py +1473 -0
  28. potato/ai/ai_endpoint.py +688 -0
  29. potato/ai/ai_help_wrapper.py +203 -0
  30. potato/ai/ai_prompt.py +94 -0
  31. potato/ai/anthropic_endpoint.py +118 -0
  32. potato/ai/anthropic_vision_endpoint.py +405 -0
  33. potato/ai/gemini_endpoint.py +58 -0
  34. potato/ai/huggingface_endpoint.py +62 -0
  35. potato/ai/icl_labeler.py +1110 -0
  36. potato/ai/icl_prompt_builder.py +315 -0
  37. potato/ai/judge.py +265 -0
  38. potato/ai/llm_active_learning.py +733 -0
  39. potato/ai/ollama_endpoint.py +160 -0
  40. potato/ai/ollama_vision_endpoint.py +313 -0
  41. potato/ai/openai_endpoint.py +94 -0
  42. potato/ai/openai_vision_endpoint.py +324 -0
  43. potato/ai/openrouter_endpoint.py +93 -0
  44. potato/ai/prompt/image_annotation.json +44 -0
  45. potato/ai/prompt/likert.json +20 -0
  46. potato/ai/prompt/models_module.py +258 -0
  47. potato/ai/prompt/multiselect.json +20 -0
  48. potato/ai/prompt/number.json +18 -0
  49. potato/ai/prompt/option_highlight.json +8 -0
  50. potato/ai/prompt/radio.json +18 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ potato/static/vendor/font-awesome-6.7.2/webfonts/fa-brands-400.ttf filter=lfs diff=lfs merge=lfs -text
37
+ potato/static/vendor/font-awesome-6.7.2/webfonts/fa-brands-400.woff2 filter=lfs diff=lfs merge=lfs -text
38
+ potato/static/vendor/font-awesome-6.7.2/webfonts/fa-solid-900.ttf filter=lfs diff=lfs merge=lfs -text
39
+ potato/static/vendor/font-awesome-6.7.2/webfonts/fa-solid-900.woff2 filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Create non-root user (HF Spaces requires UID 1000)
4
+ RUN useradd -m -u 1000 potato
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && \
8
+ apt-get install -y --no-install-recommends git && \
9
+ rm -rf /var/lib/apt/lists/*
10
+
11
+ # Set working directory
12
+ WORKDIR /app
13
+
14
+ # Copy requirements first for layer caching
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt gunicorn
17
+
18
+ # Copy the application source
19
+ COPY . .
20
+
21
+ # Install potato
22
+ RUN pip install --no-cache-dir -e .
23
+
24
+ # Copy entrypoint
25
+ COPY entrypoint.sh /entrypoint.sh
26
+ RUN chmod +x /entrypoint.sh
27
+
28
+ # Create directories for output
29
+ RUN mkdir -p /app/annotation_output && \
30
+ chown -R potato:potato /app
31
+
32
+ # Switch to non-root user
33
+ USER potato
34
+
35
+ # HuggingFace Spaces expects port 7860
36
+ EXPOSE 7860
37
+
38
+ ENV POTATO_CONFIG=config.yaml
39
+ ENV PORT=7860
40
+
41
+ ENTRYPOINT ["/entrypoint.sh"]
README.md CHANGED
@@ -1,10 +1,47 @@
1
  ---
2
- title: Codebook
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: docker
 
7
  pinned: false
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Potato — Codebook Annotation
3
+ emoji: 🥔
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
+ license: apache-2.0
10
+ tags:
11
+ - annotation
12
+ - potato
13
+ - advanced
14
+ - qda
15
+ - codebook
16
  ---
17
 
18
+ # Potato Codebook Annotation
19
+
20
+ Shared evolving codebook across annotators.
21
+
22
+ A live demo of **[Potato](https://www.potatoannotator.com)** — the free, self-hosted annotation platform for NLP,
23
+ agentic, and GenAI research, configured entirely through YAML.
24
+ Visit **[www.potatoannotator.com](https://www.potatoannotator.com)** for docs, the schema gallery, and more demos.
25
+
26
+ ## Try it out
27
+
28
+ 1. Enter any username to log in (no password required).
29
+ 2. Read the item shown in the main panel.
30
+ 3. Annotate using the schemes on the right.
31
+ 4. Click **Next** to continue.
32
+
33
+ > **Run your own copy:** click the **⋮ → Duplicate this Space** button (top-right) to launch
34
+ > this exact demo in your own account on free hardware — change the data and config to make it yours.
35
+
36
+ > Annotations in this demo are ephemeral. To collect and keep data, deploy your own
37
+ > Space — see the [deployment guide](https://github.com/davidjurgens/potato/blob/master/deployment/huggingface-spaces/deploy.md).
38
+
39
+ ## About Potato
40
+
41
+ Potato supports 20+ annotation types — text, spans, images, audio, video, documents,
42
+ and agent traces — with AI-assisted labeling, quality control, and adjudication.
43
+
44
+ 🥔 **Website: [www.potatoannotator.com](https://www.potatoannotator.com)**  · 
45
+ [Documentation](https://www.potatoannotator.com)  · 
46
+ [GitHub](https://github.com/davidjurgens/potato)  · 
47
+ [All demos](https://github.com/davidjurgens/potato/blob/master/docs/data-export/potato_on_huggingface.md)
annotation_output/.gitkeep ADDED
File without changes
config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 8000
2
+ annotation_task_name: Codebook Example
3
+ task_dir: .
4
+ output_annotation_dir: annotation_output/codebook-example/
5
+ output_annotation_format: tsv
6
+ data_files:
7
+ - data/codebook-example.json
8
+ item_properties:
9
+ id_key: id
10
+ text_key: text
11
+ user_config:
12
+ allow_all_users: true
13
+ users: []
14
+ alert_time_each_instance: 10000000
15
+ codebook_mode: open
16
+ annotation_schemes:
17
+ - annotation_type: multiselect
18
+ name: themes
19
+ description: Which qualitative themes appear in this excerpt?
20
+ codebook: true
21
+ labels:
22
+ - access barriers
23
+ - cost concerns
24
+ - provider trust
25
+ - wait times
26
+ site_dir: default
27
+ require_no_password: true
data/codebook-example.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [
2
+ {"id": "1", "text": "I couldn't get an appointment for three weeks, and even then the clinic was far from a bus line."},
3
+ {"id": "2", "text": "The doctor explained everything clearly and I felt listened to for the first time."},
4
+ {"id": "3", "text": "My copay went up again this year, so I've been skipping the follow-up visits."},
5
+ {"id": "4", "text": "Front desk staff were dismissive, and I waited over an hour past my scheduled time."},
6
+ {"id": "5", "text": "Telehealth made it much easier, but I still worry whether they have my full history."}
7
+ ]
entrypoint.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # Configuration
5
+ CONFIG_FILE="${POTATO_CONFIG:-config.yaml}"
6
+ PORT="${PORT:-7860}"
7
+ WORKERS="${GUNICORN_WORKERS:-2}"
8
+ THREADS="${GUNICORN_THREADS:-4}"
9
+ TIMEOUT="${GUNICORN_TIMEOUT:-120}"
10
+
11
+ echo "Starting Potato Demo Space..."
12
+ echo " Config: ${CONFIG_FILE}"
13
+ echo " Port: ${PORT}"
14
+ echo " Workers: ${WORKERS}"
15
+
16
+ # Validate config exists
17
+ if [ ! -f "${CONFIG_FILE}" ]; then
18
+ echo "ERROR: Config file not found: ${CONFIG_FILE}"
19
+ exit 1
20
+ fi
21
+
22
+ # Start with gunicorn using the factory pattern
23
+ exec gunicorn \
24
+ --bind "0.0.0.0:${PORT}" \
25
+ --workers "${WORKERS}" \
26
+ --threads "${THREADS}" \
27
+ --timeout "${TIMEOUT}" \
28
+ --access-logfile - \
29
+ --error-logfile - \
30
+ "potato.flask_server:create_app('${CONFIG_FILE}')"
layouts/task_layout.html ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- CONFIG_HASH: 80d495277c377949581be1771e4d8c3e_c4da35694b3e1122725c151f4a96e755 -->
2
+ <!-- Generated annotation layout file -->
3
+ <!-- This file was automatically generated based on the annotation schemes in your config -->
4
+ <!-- You can customize this file to modify the layout of your annotation interface -->
5
+ <!-- Changes to this file will be preserved across server restarts -->
6
+
7
+ <div class="annotation_schema">
8
+
9
+ <form id="themes" class="annotation-form multiselect shadcn-multiselect-container" action="javascript:void(0)" data-annotation-id="0" data-annotation-type="multiselect" data-schema-name="themes" data-grid-columns="1">
10
+
11
+ <fieldset schema="themes">
12
+ <legend class="shadcn-multiselect-title">Which qualitative themes appear in this excerpt?</legend>
13
+ <div class="shadcn-multiselect-grid" style="grid-template-columns: repeat(1, 1fr);">
14
+ <div class="shadcn-multiselect-item">
15
+ <input class="themes shadcn-multiselect-checkbox annotation-input"
16
+ type="checkbox"
17
+ id="themes_access barriers_checkbox"
18
+ name="themes:::access barriers"
19
+ value="access barriers"
20
+ label_name="access barriers"
21
+ schema="themes"
22
+ onclick="whetherNone(this);registerAnnotation(this)"
23
+ validation="">
24
+ <label for="themes_access barriers_checkbox" schema="themes" class="shadcn-multiselect-label">
25
+ Access Barriers
26
+ </label>
27
+ </div>
28
+
29
+ <div class="shadcn-multiselect-item">
30
+ <input class="themes shadcn-multiselect-checkbox annotation-input"
31
+ type="checkbox"
32
+ id="themes_cost concerns_checkbox"
33
+ name="themes:::cost concerns"
34
+ value="cost concerns"
35
+ label_name="cost concerns"
36
+ schema="themes"
37
+ onclick="whetherNone(this);registerAnnotation(this)"
38
+ validation="">
39
+ <label for="themes_cost concerns_checkbox" schema="themes" class="shadcn-multiselect-label">
40
+ Cost Concerns
41
+ </label>
42
+ </div>
43
+
44
+ <div class="shadcn-multiselect-item">
45
+ <input class="themes shadcn-multiselect-checkbox annotation-input"
46
+ type="checkbox"
47
+ id="themes_provider trust_checkbox"
48
+ name="themes:::provider trust"
49
+ value="provider trust"
50
+ label_name="provider trust"
51
+ schema="themes"
52
+ onclick="whetherNone(this);registerAnnotation(this)"
53
+ validation="">
54
+ <label for="themes_provider trust_checkbox" schema="themes" class="shadcn-multiselect-label">
55
+ Provider Trust
56
+ </label>
57
+ </div>
58
+
59
+ <div class="shadcn-multiselect-item">
60
+ <input class="themes shadcn-multiselect-checkbox annotation-input"
61
+ type="checkbox"
62
+ id="themes_wait times_checkbox"
63
+ name="themes:::wait times"
64
+ value="wait times"
65
+ label_name="wait times"
66
+ schema="themes"
67
+ onclick="whetherNone(this);registerAnnotation(this)"
68
+ validation="">
69
+ <label for="themes_wait times_checkbox" schema="themes" class="shadcn-multiselect-label">
70
+ Wait Times
71
+ </label>
72
+ </div>
73
+
74
+ <div class="shadcn-multiselect-item">
75
+ <input class="themes shadcn-multiselect-checkbox annotation-input"
76
+ type="checkbox"
77
+ id="themes_InspectCode_checkbox"
78
+ name="themes:::InspectCode"
79
+ value="InspectCode"
80
+ label_name="InspectCode"
81
+ schema="themes"
82
+ onclick="whetherNone(this);registerAnnotation(this)"
83
+ validation="">
84
+ <label for="themes_InspectCode_checkbox" schema="themes" class="shadcn-multiselect-label">
85
+ InspectCode
86
+ </label>
87
+ </div>
88
+ </div></fieldset></form>
89
+
90
+ </div>
potato/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Potato Annotation Platform
3
+
4
+ A flexible, web-based platform for text annotation tasks.
5
+
6
+ This package provides a comprehensive annotation system with the following features:
7
+ - Multi-phase annotation workflows (consent, instructions, training, annotation, post-study)
8
+ - Support for various annotation types (labels, spans, text, likert scales, best-worst scaling)
9
+ - User authentication and session management
10
+ - Active learning capabilities
11
+ - Admin dashboard for monitoring progress
12
+ - Configurable assignment strategies
13
+ - Multi-language and multi-task support
14
+
15
+ Main Components:
16
+ - flask_server: Core Flask application and server logic
17
+ - routes: HTTP route handlers and request processing
18
+ - user_state_management: User progress tracking and state persistence
19
+ - item_state_management: Data item management and assignment
20
+ - authentificaton: User authentication backends
21
+ - admin: Admin dashboard functionality
22
+ - activelearning: Active learning algorithms and model training
23
+
24
+ Usage:
25
+ from potato.flask_server import create_app
26
+ app = create_app()
27
+ app.run()
28
+ """
29
+
30
+ from .flask_server import create_app
31
+
32
+ __version__ = "2.6.0"
33
+ __author__ = "Potato Annotation Platform Team"
34
+ __description__ = "A flexible, web-based platform for text annotation tasks"
35
+
36
+
37
+ def __getattr__(name):
38
+ """Lazy imports for optional heavy dependencies."""
39
+ if name == "load_as_dataset":
40
+ from .datasets_integration import load_as_dataset
41
+ return load_as_dataset
42
+ if name == "load_annotations":
43
+ from .datasets_integration import load_annotations
44
+ return load_annotations
45
+ raise AttributeError(f"module 'potato' has no attribute {name!r}")
potato/__main__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from potato.flask_server import main
2
+
3
+ main()
potato/active_learning_manager.py ADDED
@@ -0,0 +1,1623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced Active Learning Manager with Database Persistence
3
+
4
+ This module provides a comprehensive active learning system with optional
5
+ database persistence, model saving, LLM integration, and multiple query
6
+ strategies including uncertainty sampling, diversity sampling, BADGE, BALD,
7
+ and hybrid combinations.
8
+
9
+ References:
10
+ [1] Ash et al. (2020) "Deep Batch Active Learning by Diverse, Uncertain
11
+ Gradient Lower Bounds" (BADGE). ICLR 2020.
12
+ [2] Houlsby et al. (2011) "Bayesian Active Learning for Classification
13
+ and Preference Learning" (BALD).
14
+ [3] Bayer et al. (2024) "ActiveLLM: Large Language Model-Based Active
15
+ Learning for Textual Few-Shot Scenarios". TACL.
16
+ [4] Yuan et al. (2024) "Hide and Seek in Noise Labels: Noise-Robust
17
+ Collaborative Active Learning" (NoiseAL). ACL 2024.
18
+ [5] Mavromatis et al. (2024) "CoverICL: Selective Annotation for
19
+ In-Context Learning via Active Graph Coverage". EMNLP 2024.
20
+ """
21
+
22
+ import threading
23
+ import logging
24
+ import time
25
+ import os
26
+ import pickle
27
+ import json
28
+ from typing import Dict, List, Optional, Tuple, Any, Union
29
+ from collections import defaultdict, Counter
30
+ import dataclasses
31
+ from dataclasses import dataclass, field, asdict
32
+ from enum import Enum
33
+ import random
34
+ import queue
35
+ from datetime import datetime
36
+ from abc import ABC, abstractmethod
37
+
38
+ from sklearn.pipeline import Pipeline
39
+ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
40
+ from sklearn.linear_model import LogisticRegression
41
+ from sklearn.ensemble import RandomForestClassifier
42
+ from sklearn.svm import SVC
43
+ from sklearn.metrics import accuracy_score, classification_report
44
+ import numpy as np
45
+
46
+ from potato.item_state_management import ItemStateManager, get_item_state_manager
47
+ from potato.user_state_management import get_user_state_manager
48
+
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ class ResolutionStrategy(Enum):
54
+ """Strategies for resolving multiple annotations per instance."""
55
+ MAJORITY_VOTE = "majority_vote"
56
+ RANDOM = "random"
57
+ CONSENSUS = "consensus"
58
+ WEIGHTED_AVERAGE = "weighted_average"
59
+
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # SentenceTransformerVectorizer
63
+ # ---------------------------------------------------------------------------
64
+
65
+ class SentenceTransformerVectorizer:
66
+ """sklearn-compatible wrapper for sentence-transformers.
67
+
68
+ Uses dense embeddings from pre-trained transformer models instead of
69
+ bag-of-words features. Produces 384-dim vectors (for default model)
70
+ that capture semantic meaning, enabling better classification with
71
+ fewer training examples.
72
+
73
+ The ``sentence-transformers`` package is an **optional** dependency and
74
+ is only imported when this vectorizer is actually used.
75
+ """
76
+
77
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
78
+ self.model_name = model_name
79
+ self._model = None
80
+
81
+ def fit(self, X, y=None):
82
+ from sentence_transformers import SentenceTransformer
83
+ self._model = SentenceTransformer(self.model_name)
84
+ return self
85
+
86
+ def transform(self, X):
87
+ if self._model is None:
88
+ raise RuntimeError("SentenceTransformerVectorizer has not been fitted yet")
89
+ return self._model.encode(list(X), show_progress_bar=False)
90
+
91
+ def fit_transform(self, X, y=None):
92
+ self.fit(X, y)
93
+ return self.transform(X)
94
+
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # Query Strategies
98
+ # ---------------------------------------------------------------------------
99
+
100
+ class QueryStrategy(ABC):
101
+ """Base class for active learning query strategies."""
102
+
103
+ @abstractmethod
104
+ def rank(self, texts: List[str], model, vectorizer,
105
+ annotated_texts: Optional[List[str]] = None) -> List[Tuple[int, float]]:
106
+ """Return list of (index, score) sorted by selection priority (highest first)."""
107
+
108
+
109
+ class UncertaintySampling(QueryStrategy):
110
+ """Select instances where classifier is least confident.
111
+
112
+ Selects x* = argmax_x (1 - max_y P(y|x)), i.e., instances where the
113
+ model's best guess has lowest confidence.
114
+ """
115
+
116
+ def rank(self, texts, model, vectorizer, annotated_texts=None):
117
+ try:
118
+ features = vectorizer.transform(texts)
119
+ probas = model.predict_proba(features)
120
+ # Score = 1 - max_prob (higher = more uncertain = higher priority)
121
+ scores = 1.0 - np.max(probas, axis=1)
122
+ ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
123
+ return ranked
124
+ except Exception as e:
125
+ logger.warning(f"UncertaintySampling failed: {e}")
126
+ return [(i, 0.5) for i in range(len(texts))]
127
+
128
+
129
+ class DiversitySampling(QueryStrategy):
130
+ """Select instances that maximize feature-space coverage.
131
+
132
+ Uses cosine distance from already-annotated instances in the vectorized
133
+ feature space. Ensures the training set covers the full data distribution
134
+ rather than over-sampling one region.
135
+ """
136
+
137
+ def rank(self, texts, model, vectorizer, annotated_texts=None):
138
+ from sklearn.metrics.pairwise import cosine_distances
139
+
140
+ try:
141
+ features = vectorizer.transform(texts)
142
+ if hasattr(features, 'toarray'):
143
+ features = features.toarray()
144
+
145
+ if annotated_texts:
146
+ annotated_features = vectorizer.transform(annotated_texts)
147
+ if hasattr(annotated_features, 'toarray'):
148
+ annotated_features = annotated_features.toarray()
149
+ # Score = min cosine distance to any annotated instance
150
+ distances = cosine_distances(features, annotated_features)
151
+ scores = np.min(distances, axis=1)
152
+ else:
153
+ # No annotated texts yet: use distance from centroid
154
+ centroid = np.mean(features, axis=0, keepdims=True)
155
+ scores = cosine_distances(features, centroid).ravel()
156
+
157
+ ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
158
+ return ranked
159
+ except Exception as e:
160
+ logger.warning(f"DiversitySampling failed: {e}")
161
+ return [(i, 0.5) for i in range(len(texts))]
162
+
163
+
164
+ class BadgeStrategy(QueryStrategy):
165
+ """BADGE approximation: uncertainty-weighted diversity.
166
+
167
+ Inspired by Ash et al. (2020) [Ref 1]. Full BADGE uses gradient embeddings
168
+ from neural networks. Our approximation:
169
+ 1. Weight feature vectors by (1 - max_prob) as uncertainty proxy
170
+ 2. Run k-means++ initialization on weighted vectors to select
171
+ diverse-uncertain instances.
172
+ """
173
+
174
+ def rank(self, texts, model, vectorizer, annotated_texts=None):
175
+ try:
176
+ features = vectorizer.transform(texts)
177
+ if hasattr(features, 'toarray'):
178
+ features = features.toarray()
179
+
180
+ probas = model.predict_proba(features)
181
+ uncertainty = 1.0 - np.max(probas, axis=1)
182
+
183
+ # Weight features by uncertainty
184
+ weighted = features * uncertainty[:, np.newaxis]
185
+
186
+ # Use k-means++ initialization to select diverse-uncertain points
187
+ from sklearn.cluster import kmeans_plusplus
188
+ n_clusters = min(len(texts), max(1, len(texts) // 2))
189
+ _, indices = kmeans_plusplus(weighted, n_clusters=n_clusters,
190
+ random_state=42)
191
+
192
+ # Build score: selected centroids get highest scores
193
+ scores = np.zeros(len(texts))
194
+ for rank_pos, idx in enumerate(indices):
195
+ scores[idx] = len(indices) - rank_pos # highest for first-selected
196
+
197
+ # For non-selected, use uncertainty as tiebreaker
198
+ for i in range(len(texts)):
199
+ if scores[i] == 0:
200
+ scores[i] = uncertainty[i] * 0.01
201
+
202
+ ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
203
+ return ranked
204
+ except Exception as e:
205
+ logger.warning(f"BadgeStrategy failed, falling back to uncertainty: {e}")
206
+ return UncertaintySampling().rank(texts, model, vectorizer, annotated_texts)
207
+
208
+
209
+ class BaldStrategy(QueryStrategy):
210
+ """BALD: Bayesian Active Learning by Disagreement.
211
+
212
+ Based on Houlsby et al. (2011) [Ref 2]. Trains an ensemble of classifiers
213
+ with different random seeds/bootstrap samples. Selects instances with
214
+ highest mutual information: H[y|x] - E_theta[H[y|x,theta]], i.e.,
215
+ where the ensemble disagrees most.
216
+ """
217
+
218
+ def __init__(self, n_estimators: int = 5, bootstrap_fraction: float = 0.8):
219
+ self.n_estimators = n_estimators
220
+ self.bootstrap_fraction = bootstrap_fraction
221
+
222
+ def rank(self, texts, model, vectorizer, annotated_texts=None):
223
+ try:
224
+ features = vectorizer.transform(texts)
225
+ if hasattr(features, 'toarray'):
226
+ features = features.toarray()
227
+
228
+ probas = model.predict_proba(features)
229
+ # Average entropy
230
+ avg_proba = probas
231
+ entropy_avg = -np.sum(avg_proba * np.log(avg_proba + 1e-10), axis=1)
232
+
233
+ # For a single model, we approximate BALD by using dropout-like noise
234
+ # or by comparing with uniform. Since we store the ensemble models
235
+ # on the manager, we just use the single model's entropy here and
236
+ # the ensemble version is handled in ActiveLearningManager._train_bald_ensemble
237
+ scores = entropy_avg
238
+ ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
239
+ return ranked
240
+ except Exception as e:
241
+ logger.warning(f"BaldStrategy failed: {e}")
242
+ return [(i, 0.5) for i in range(len(texts))]
243
+
244
+ def rank_with_ensemble(self, texts, ensemble_models, vectorizer):
245
+ """Rank using actual ensemble disagreement (mutual information)."""
246
+ try:
247
+ features = vectorizer.transform(texts)
248
+ if hasattr(features, 'toarray'):
249
+ features = features.toarray()
250
+
251
+ all_probas = []
252
+ for m in ensemble_models:
253
+ all_probas.append(m.predict_proba(features))
254
+
255
+ all_probas = np.array(all_probas) # (n_estimators, n_samples, n_classes)
256
+
257
+ # Mean prediction across ensemble
258
+ mean_proba = np.mean(all_probas, axis=0) # (n_samples, n_classes)
259
+
260
+ # H[y|x] - entropy of mean prediction
261
+ entropy_mean = -np.sum(mean_proba * np.log(mean_proba + 1e-10), axis=1)
262
+
263
+ # E_theta[H[y|x,theta]] - mean of individual entropies
264
+ individual_entropies = -np.sum(all_probas * np.log(all_probas + 1e-10), axis=2)
265
+ mean_entropy = np.mean(individual_entropies, axis=0)
266
+
267
+ # Mutual information = H[y|x] - E[H[y|x,theta]]
268
+ mutual_info = entropy_mean - mean_entropy
269
+
270
+ ranked = sorted(enumerate(mutual_info), key=lambda x: x[1], reverse=True)
271
+ return ranked
272
+ except Exception as e:
273
+ logger.warning(f"BaldStrategy ensemble ranking failed: {e}")
274
+ return [(i, 0.5) for i in range(len(texts))]
275
+
276
+
277
+ class HybridStrategy(QueryStrategy):
278
+ """Weighted combination of uncertainty and diversity scores.
279
+
280
+ Combines strategies with configurable weights. Default: 0.7 uncertainty +
281
+ 0.3 diversity.
282
+ """
283
+
284
+ def __init__(self, weights: Optional[Dict[str, float]] = None):
285
+ self.weights = weights or {"uncertainty": 0.7, "diversity": 0.3}
286
+
287
+ def rank(self, texts, model, vectorizer, annotated_texts=None):
288
+ try:
289
+ strategies = {}
290
+ if self.weights.get("uncertainty", 0) > 0:
291
+ strategies["uncertainty"] = UncertaintySampling()
292
+ if self.weights.get("diversity", 0) > 0:
293
+ strategies["diversity"] = DiversitySampling()
294
+
295
+ # Collect raw scores from each strategy
296
+ all_scores = {}
297
+ for name, strategy in strategies.items():
298
+ rankings = strategy.rank(texts, model, vectorizer, annotated_texts)
299
+ score_map = {idx: score for idx, score in rankings}
300
+ all_scores[name] = score_map
301
+
302
+ # Normalize each strategy's scores to [0, 1]
303
+ for name in all_scores:
304
+ vals = list(all_scores[name].values())
305
+ min_val, max_val = min(vals), max(vals)
306
+ rng = max_val - min_val if max_val > min_val else 1.0
307
+ all_scores[name] = {
308
+ idx: (s - min_val) / rng for idx, s in all_scores[name].items()
309
+ }
310
+
311
+ # Weighted combination
312
+ combined = {}
313
+ for i in range(len(texts)):
314
+ combined[i] = sum(
315
+ self.weights.get(name, 0) * all_scores.get(name, {}).get(i, 0)
316
+ for name in self.weights
317
+ )
318
+
319
+ ranked = sorted(combined.items(), key=lambda x: x[1], reverse=True)
320
+ return ranked
321
+ except Exception as e:
322
+ logger.warning(f"HybridStrategy failed: {e}")
323
+ return UncertaintySampling().rank(texts, model, vectorizer, annotated_texts)
324
+
325
+
326
+ # Strategy registry
327
+ STRATEGY_REGISTRY = {
328
+ "uncertainty": UncertaintySampling,
329
+ "diversity": DiversitySampling,
330
+ "badge": BadgeStrategy,
331
+ "bald": BaldStrategy,
332
+ "hybrid": HybridStrategy,
333
+ }
334
+
335
+
336
+ def create_query_strategy(config: 'ActiveLearningConfig') -> QueryStrategy:
337
+ """Create a query strategy from config."""
338
+ strategy_name = config.query_strategy
339
+ if strategy_name == "hybrid":
340
+ return HybridStrategy(weights=config.hybrid_weights)
341
+ elif strategy_name == "bald":
342
+ params = config.bald_params
343
+ return BaldStrategy(
344
+ n_estimators=params.get("n_estimators", 5),
345
+ bootstrap_fraction=params.get("bootstrap_fraction", 0.8),
346
+ )
347
+ elif strategy_name in STRATEGY_REGISTRY:
348
+ return STRATEGY_REGISTRY[strategy_name]()
349
+ else:
350
+ logger.warning(f"Unknown strategy '{strategy_name}', falling back to uncertainty")
351
+ return UncertaintySampling()
352
+
353
+
354
+ # ---------------------------------------------------------------------------
355
+ # ICLClassifier wrapper (Phase 5A)
356
+ # ---------------------------------------------------------------------------
357
+
358
+ class ICLClassifier:
359
+ """Wraps ICLLabeler as an sklearn-compatible classifier for ensemble use.
360
+
361
+ Enables combining LLM-based ICL predictions with traditional classifier
362
+ predictions in a hybrid ensemble for active learning scoring.
363
+ """
364
+
365
+ def __init__(self, icl_labeler, schema_name: str, label_names: List[str]):
366
+ self.icl_labeler = icl_labeler
367
+ self.schema_name = schema_name
368
+ self.label_names = label_names
369
+ self.classes_ = np.array(label_names)
370
+
371
+ def predict_proba(self, texts: List[str]) -> np.ndarray:
372
+ """Get label probabilities from LLM via ICL."""
373
+ n_classes = len(self.label_names)
374
+ probas = np.full((len(texts), n_classes), 1.0 / n_classes)
375
+
376
+ for i, text in enumerate(texts):
377
+ try:
378
+ prediction = self.icl_labeler.label_instance(
379
+ instance_id=f"_al_query_{i}",
380
+ schema_name=self.schema_name,
381
+ instance_text=text,
382
+ )
383
+ if prediction and prediction.predicted_label in self.label_names:
384
+ idx = self.label_names.index(prediction.predicted_label)
385
+ conf = prediction.confidence_score
386
+ # Distribute: conf to predicted label, (1-conf)/(n-1) to others
387
+ remaining = (1.0 - conf) / max(1, n_classes - 1)
388
+ probas[i] = remaining
389
+ probas[i, idx] = conf
390
+ except Exception:
391
+ pass # Keep uniform distribution
392
+
393
+ return probas
394
+
395
+
396
+ # ---------------------------------------------------------------------------
397
+ # Configuration
398
+ # ---------------------------------------------------------------------------
399
+
400
+ @dataclass
401
+ class ActiveLearningConfig:
402
+ """Enhanced configuration for active learning."""
403
+ enabled: bool = False
404
+ classifier_name: str = "sklearn.linear_model.LogisticRegression"
405
+ classifier_kwargs: Dict[str, Any] = None
406
+ vectorizer_name: str = "sklearn.feature_extraction.text.TfidfVectorizer"
407
+ vectorizer_kwargs: Dict[str, Any] = None
408
+ min_annotations_per_instance: int = 1
409
+ min_instances_for_training: int = 10
410
+ max_instances_to_reorder: Optional[int] = None
411
+ resolution_strategy: ResolutionStrategy = ResolutionStrategy.MAJORITY_VOTE
412
+ random_sample_percent: float = 0.2
413
+ update_frequency: int = 5
414
+ schema_names: List[str] = None
415
+
416
+ # Classifier/vectorizer passthrough params (Phase 1C)
417
+ classifier_params: Dict[str, Any] = field(default_factory=dict)
418
+ vectorizer_params: Dict[str, Any] = field(default_factory=dict)
419
+
420
+ # Probability calibration (Phase 1D)
421
+ calibrate_probabilities: bool = True
422
+
423
+ # Query strategy (Phase 2)
424
+ query_strategy: str = "uncertainty"
425
+ hybrid_weights: Dict[str, float] = field(
426
+ default_factory=lambda: {"uncertainty": 0.7, "diversity": 0.3}
427
+ )
428
+ bald_params: Dict[str, Any] = field(
429
+ default_factory=lambda: {"n_estimators": 5, "bootstrap_fraction": 0.8}
430
+ )
431
+
432
+ # Cold-start (Phase 3)
433
+ cold_start_strategy: str = "random"
434
+ cold_start_batch_size: int = 20
435
+
436
+ # ICL ensemble (Phase 5)
437
+ use_icl_ensemble: bool = False
438
+ icl_ensemble_params: Dict[str, Any] = field(default_factory=lambda: {
439
+ "initial_icl_weight": 0.7,
440
+ "final_icl_weight": 0.2,
441
+ "transition_instances": 100,
442
+ })
443
+
444
+ # Annotation routing (Phase 5D)
445
+ annotation_routing: bool = False
446
+ routing_thresholds: Dict[str, float] = field(default_factory=lambda: {
447
+ "auto_label_min_confidence": 0.9,
448
+ "show_suggestion_below": 0.5,
449
+ })
450
+ verification_sample_rate: float = 0.2
451
+
452
+ # Database persistence
453
+ database_enabled: bool = False
454
+ database_config: Dict[str, Any] = None
455
+
456
+ # Model persistence
457
+ model_persistence_enabled: bool = False
458
+ model_save_directory: Optional[str] = None
459
+ model_retention_count: int = 2
460
+
461
+ # LLM integration
462
+ llm_enabled: bool = False
463
+ llm_config: Dict[str, Any] = None
464
+
465
+ def __post_init__(self):
466
+ if self.classifier_kwargs is None:
467
+ self.classifier_kwargs = {}
468
+ if self.vectorizer_kwargs is None:
469
+ self.vectorizer_kwargs = {}
470
+ if self.schema_names is None:
471
+ self.schema_names = []
472
+ if self.database_config is None:
473
+ self.database_config = {}
474
+ if self.llm_config is None:
475
+ self.llm_config = {}
476
+ # Merge classifier_params into classifier_kwargs
477
+ if self.classifier_params:
478
+ self.classifier_kwargs.update(self.classifier_params)
479
+ # Merge vectorizer_params into vectorizer_kwargs
480
+ if self.vectorizer_params:
481
+ self.vectorizer_kwargs.update(self.vectorizer_params)
482
+
483
+
484
+ @dataclass
485
+ class TrainingMetrics:
486
+ """Metrics for a training run."""
487
+ schema_name: str
488
+ training_time: float
489
+ accuracy: float
490
+ instance_count: int
491
+ timestamp: datetime
492
+ model_file_path: Optional[str] = None
493
+ confidence_distribution: Dict[str, float] = None
494
+ error_message: Optional[str] = None
495
+
496
+
497
+ class ModelPersistence:
498
+ """Handles model saving and loading with metadata."""
499
+
500
+ def __init__(self, save_directory: str, retention_count: int = 2):
501
+ self.save_directory = save_directory
502
+ self.retention_count = retention_count
503
+ self.logger = logging.getLogger(__name__)
504
+
505
+ # Ensure directory exists
506
+ os.makedirs(save_directory, exist_ok=True)
507
+
508
+ def save_model(self, model: Pipeline, schema_name: str, instance_count: int) -> str:
509
+ """Save a trained model with metadata."""
510
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
511
+ filename = f"{schema_name}_{instance_count}_{timestamp}.pkl"
512
+ filepath = os.path.join(self.save_directory, filename)
513
+
514
+ try:
515
+ # Save the complete model (including vectorizer)
516
+ with open(filepath, 'wb') as f:
517
+ pickle.dump(model, f)
518
+
519
+ self.logger.info(f"Saved model to {filepath}")
520
+
521
+ # Clean up old models
522
+ self._cleanup_old_models(schema_name)
523
+
524
+ return filepath
525
+ except Exception as e:
526
+ self.logger.error(f"Failed to save model: {e}")
527
+ raise
528
+
529
+ def load_model(self, filepath: str) -> Optional[Pipeline]:
530
+ """Load a saved model."""
531
+ try:
532
+ with open(filepath, 'rb') as f:
533
+ model = pickle.load(f)
534
+
535
+ # TODO: Add schema validation here in the future
536
+ # This is a placeholder for future schema validation enhancement
537
+
538
+ self.logger.info(f"Loaded model from {filepath}")
539
+ return model
540
+ except Exception as e:
541
+ self.logger.error(f"Failed to load model from {filepath}: {e}")
542
+ return None
543
+
544
+ def _cleanup_old_models(self, schema_name: str):
545
+ """Clean up old models based on retention policy."""
546
+ try:
547
+ # Find all model files for this schema
548
+ model_files = []
549
+
550
+ for filename in os.listdir(self.save_directory):
551
+ if filename.startswith(f"{schema_name}_") and filename.endswith(".pkl"):
552
+ filepath = os.path.join(self.save_directory, filename)
553
+ model_files.append((filepath, os.path.getmtime(filepath)))
554
+
555
+ # Sort by modification time (newest first)
556
+ model_files.sort(key=lambda x: x[1], reverse=True)
557
+
558
+ # Remove old models beyond retention count
559
+ for filepath, _ in model_files[self.retention_count:]:
560
+ try:
561
+ os.remove(filepath)
562
+ self.logger.info(f"Removed old model: {filepath}")
563
+ except Exception as e:
564
+ self.logger.warning(f"Failed to remove old model {filepath}: {e}")
565
+
566
+ except Exception as e:
567
+ self.logger.error(f"Error during model cleanup: {e}")
568
+
569
+
570
+ class DatabaseStateManager:
571
+ """Manages database persistence for active learning state."""
572
+
573
+ def __init__(self, config: Dict[str, Any]):
574
+ self.config = config
575
+ self.logger = logging.getLogger(__name__)
576
+ self.connection = None
577
+ self._initialize_database()
578
+
579
+ def _initialize_database(self):
580
+ """Initialize database connection and create tables."""
581
+ try:
582
+ # Use the same database system as main Potato application
583
+ if self.config.get('type') == 'mysql':
584
+ self._init_mysql_connection()
585
+ else:
586
+ self._init_file_based_connection()
587
+
588
+ self._create_tables()
589
+ self.logger.info("Active learning database initialized successfully")
590
+ except Exception as e:
591
+ self.logger.error(f"Failed to initialize database: {e}")
592
+ raise
593
+
594
+ def _init_mysql_connection(self):
595
+ """Initialize MySQL connection."""
596
+ # TODO: Implement MySQL connection
597
+ pass
598
+
599
+ def _init_file_based_connection(self):
600
+ """Initialize file-based database connection."""
601
+ # TODO: Implement file-based database
602
+ pass
603
+
604
+ def _create_tables(self):
605
+ """Create database tables for active learning."""
606
+ # TODO: Implement table creation
607
+ pass
608
+
609
+ def save_training_metrics(self, metrics: TrainingMetrics):
610
+ """Save training metrics to database."""
611
+ # TODO: Implement metrics saving
612
+ pass
613
+
614
+ def get_training_history(self, schema_name: Optional[str] = None) -> List[TrainingMetrics]:
615
+ """Get training history from database."""
616
+ # TODO: Implement history retrieval
617
+ return []
618
+
619
+ def save_schema_cycling_state(self, current_schema: str, schema_order: List[str]):
620
+ """Save current schema cycling state."""
621
+ # TODO: Implement state saving
622
+ pass
623
+
624
+ def get_schema_cycling_state(self) -> Tuple[str, List[str]]:
625
+ """Get current schema cycling state."""
626
+ # TODO: Implement state retrieval
627
+ return "", []
628
+
629
+
630
+ class SchemaCycler:
631
+ """Manages cycling through multiple annotation schemes."""
632
+
633
+ def __init__(self, schema_names: List[str], database_manager: Optional[DatabaseStateManager] = None):
634
+ self.schema_names = self._validate_schemas(schema_names)
635
+ self.database_manager = database_manager
636
+ self.current_index = 0
637
+ self.logger = logging.getLogger(__name__)
638
+ self._lock = threading.Lock()
639
+
640
+ # Load state from database if available
641
+ if self.database_manager:
642
+ self._load_state()
643
+
644
+ def _validate_schemas(self, schema_names: List[str]) -> List[str]:
645
+ """Validate and filter schema names."""
646
+ valid_schemas = []
647
+
648
+ for schema in schema_names:
649
+ # Exclude text and span annotation schemes
650
+ if schema in ['text', 'span']:
651
+ raise ValueError(f"Text and span annotation schemes are not supported for active learning: {schema}")
652
+ valid_schemas.append(schema)
653
+
654
+ return valid_schemas
655
+
656
+ def _load_state(self):
657
+ """Load cycling state from database."""
658
+ try:
659
+ current_schema, schema_order = self.database_manager.get_schema_cycling_state()
660
+ with self._lock:
661
+ if current_schema in self.schema_names:
662
+ self.current_index = self.schema_names.index(current_schema)
663
+ except Exception as e:
664
+ self.logger.warning(f"Failed to load schema cycling state: {e}")
665
+
666
+ def get_current_schema(self) -> Optional[str]:
667
+ """Get the current schema for training."""
668
+ if not self.schema_names:
669
+ return None
670
+ with self._lock:
671
+ return self.schema_names[self.current_index]
672
+
673
+ def advance_schema(self):
674
+ """Advance to the next schema in the cycle."""
675
+ if not self.schema_names:
676
+ return
677
+
678
+ with self._lock:
679
+ self.current_index = (self.current_index + 1) % len(self.schema_names)
680
+ current_schema = self.schema_names[self.current_index]
681
+
682
+ # Save state to database if available
683
+ if self.database_manager:
684
+ try:
685
+ self.database_manager.save_schema_cycling_state(
686
+ current_schema,
687
+ self.schema_names
688
+ )
689
+ except Exception as e:
690
+ self.logger.warning(f"Failed to save schema cycling state: {e}")
691
+
692
+ def get_schema_order(self) -> List[str]:
693
+ """Get the current schema cycling order."""
694
+ return self.schema_names.copy()
695
+
696
+
697
+ class ActiveLearningManager:
698
+ """
699
+ Manages active learning operations including classifier training and instance reordering.
700
+
701
+ This class provides thread-safe operations for:
702
+ - Training classifiers on annotated data
703
+ - Predicting confidence scores for unlabeled instances
704
+ - Reordering instances based on configurable query strategies
705
+ - Cold-start LLM-based instance selection
706
+ - ICL/classifier ensemble for improved ranking
707
+ - Noise-aware annotation routing
708
+ - Managing training state and progress
709
+ - Database persistence and model saving
710
+ """
711
+
712
+ def __init__(self, config: ActiveLearningConfig):
713
+ self.config = config
714
+ self.logger = logging.getLogger(__name__)
715
+
716
+ # Thread safety
717
+ self._lock = threading.RLock()
718
+ self._training_queue = queue.Queue()
719
+ self._training_thread = None
720
+ self._stop_training = threading.Event()
721
+
722
+ # State tracking
723
+ self._last_training_time = 0
724
+ self._training_count = 0
725
+ self._models = {} # schema_name -> trained_model
726
+ self._vectorizers = {} # schema_name -> fitted vectorizer
727
+ self._bald_ensembles = {} # schema_name -> list of classifiers
728
+ self._last_annotation_count = 0
729
+ self._training_metrics = [] # List of TrainingMetrics
730
+ self._annotated_texts = {} # schema_name -> list of annotated texts
731
+
732
+ # Query strategy
733
+ self._query_strategy = create_query_strategy(config)
734
+
735
+ # Database and persistence
736
+ self.database_manager = None
737
+ self.model_persistence = None
738
+ self.schema_cycler = None
739
+
740
+ # Initialize components
741
+ self._initialize_components()
742
+
743
+ # Start training thread if enabled
744
+ if self.config.enabled:
745
+ self._start_training_thread()
746
+
747
+ def _initialize_components(self):
748
+ """Initialize database, model persistence, and schema cycler."""
749
+ # Initialize database manager if enabled
750
+ if self.config.database_enabled:
751
+ try:
752
+ self.database_manager = DatabaseStateManager(self.config.database_config)
753
+ except Exception as e:
754
+ self.logger.error(f"Failed to initialize database manager: {e}")
755
+ # Continue without database persistence
756
+
757
+ # Initialize model persistence if enabled
758
+ if self.config.model_persistence_enabled and self.config.model_save_directory:
759
+ try:
760
+ self.model_persistence = ModelPersistence(
761
+ self.config.model_save_directory,
762
+ self.config.model_retention_count
763
+ )
764
+ except Exception as e:
765
+ self.logger.error(f"Failed to initialize model persistence: {e}")
766
+ # Continue without model persistence
767
+
768
+ # Initialize schema cycler
769
+ try:
770
+ self.schema_cycler = SchemaCycler(self.config.schema_names, self.database_manager)
771
+ except Exception as e:
772
+ self.logger.error(f"Failed to initialize schema cycler: {e}")
773
+ raise # Schema cycler is critical
774
+
775
+ def _start_training_thread(self):
776
+ """Start the background training thread."""
777
+ if self._training_thread is None or not self._training_thread.is_alive():
778
+ self._training_thread = threading.Thread(target=self._training_worker, daemon=True)
779
+ self._training_thread.start()
780
+ self.logger.info("Active learning training thread started")
781
+
782
+ def _training_worker(self):
783
+ """Background worker for training classifiers."""
784
+ while not self._stop_training.is_set():
785
+ try:
786
+ # Wait for training request
787
+ training_request = self._training_queue.get(timeout=1.0)
788
+ if training_request is None: # Shutdown signal
789
+ break
790
+
791
+ self._perform_training()
792
+ self._training_queue.task_done()
793
+
794
+ except queue.Empty:
795
+ continue
796
+ except Exception as e:
797
+ self.logger.error(f"Error in training worker: {e}")
798
+
799
+ def _perform_training(self):
800
+ """Perform the actual classifier training."""
801
+ with self._lock:
802
+ try:
803
+ self.logger.info("Starting active learning classifier training")
804
+ start_time = time.time()
805
+
806
+ # Get current schema for training
807
+ current_schema = self.schema_cycler.get_current_schema()
808
+ if not current_schema:
809
+ self.logger.warning("No schema available for training")
810
+ return
811
+
812
+ # Get current annotation state
813
+ item_manager = get_item_state_manager()
814
+ user_manager = get_user_state_manager()
815
+
816
+ # Collect training data
817
+ training_data = self._collect_training_data(item_manager, user_manager, current_schema)
818
+
819
+ if not training_data:
820
+ self.logger.warning(f"No training data available for schema {current_schema}")
821
+ # If in cold-start phase, try LLM-based reordering
822
+ if self.config.cold_start_strategy == "llm" and self.config.llm_enabled:
823
+ self._cold_start_reorder(item_manager)
824
+ return
825
+
826
+ # Train classifier
827
+ model, metrics = self._train_classifier(training_data, current_schema)
828
+
829
+ if model:
830
+ self._models[current_schema] = model
831
+ self._annotated_texts[current_schema] = training_data["texts"]
832
+
833
+ # Save model if persistence is enabled
834
+ if self.model_persistence:
835
+ try:
836
+ model_path = self.model_persistence.save_model(
837
+ model, current_schema, len(training_data["texts"])
838
+ )
839
+ metrics.model_file_path = model_path
840
+ except Exception as e:
841
+ self.logger.error(f"Failed to save model: {e}")
842
+
843
+ # Save metrics to database if available
844
+ if self.database_manager:
845
+ try:
846
+ self.database_manager.save_training_metrics(metrics)
847
+ except Exception as e:
848
+ self.logger.error(f"Failed to save metrics: {e}")
849
+
850
+ # Reorder instances
851
+ self._reorder_instances(item_manager, current_schema)
852
+
853
+ # Advance to next schema
854
+ self.schema_cycler.advance_schema()
855
+
856
+ self._training_count += 1
857
+ self._last_training_time = time.time()
858
+
859
+ training_duration = time.time() - start_time
860
+ self.logger.info(f"Active learning training completed for schema {current_schema} "
861
+ f"(run #{self._training_count}, duration: {training_duration:.2f}s)")
862
+ else:
863
+ self.logger.warning(f"Failed to train model for schema {current_schema}")
864
+ # Try cold-start if not enough data
865
+ if (self.config.cold_start_strategy == "llm"
866
+ and self.config.llm_enabled
867
+ and len(training_data.get("texts", [])) < self.config.min_instances_for_training):
868
+ self._cold_start_reorder(item_manager)
869
+
870
+ except Exception as e:
871
+ self.logger.error(f"Error during training: {e}")
872
+ # Continue without failing the entire system
873
+
874
+ def _collect_training_data(self, item_manager: ItemStateManager, user_manager, schema_name: str) -> Dict:
875
+ """Collect training data for a specific schema."""
876
+ training_data = {"texts": [], "labels": [], "instance_ids": []}
877
+
878
+ # Get all user states
879
+ user_states = user_manager.get_all_users()
880
+ self.logger.debug(f"Found {len(user_states)} user states")
881
+
882
+ # Collect annotations per instance
883
+ instance_annotations = defaultdict(list)
884
+
885
+ for user_state in user_states:
886
+ user_annotations = user_state.get_all_annotations()
887
+ self.logger.debug(f"User {user_state.user_id} has {len(user_annotations)} annotations")
888
+ for instance_id, annotations in user_annotations.items():
889
+ # Check if the schema exists in the labels section
890
+ if 'labels' in annotations:
891
+ labels_dict = annotations['labels']
892
+ # Handle Label objects as keys
893
+ for label_obj, value in labels_dict.items():
894
+ if hasattr(label_obj, 'get_schema') and label_obj.get_schema() == schema_name:
895
+ instance_annotations[instance_id].append({
896
+ "label": label_obj.get_name(),
897
+ "value": value,
898
+ "user": user_state.user_id
899
+ })
900
+
901
+ self.logger.debug(f"Collected annotations for {len(instance_annotations)} instances")
902
+
903
+ # Filter instances with sufficient annotations
904
+ for instance_id, annotations in instance_annotations.items():
905
+ if len(annotations) >= self.config.min_annotations_per_instance:
906
+ # Resolve multiple annotations
907
+ resolved_label = self._resolve_annotations(annotations)
908
+ if resolved_label:
909
+ item = item_manager.get_item(instance_id)
910
+ if item:
911
+ text = item.get_text()
912
+ training_data["texts"].append(text)
913
+ training_data["labels"].append(resolved_label)
914
+ training_data["instance_ids"].append(instance_id)
915
+
916
+ self.logger.debug(f"Training data collected: {len(training_data['texts'])} texts, {len(training_data['labels'])} labels")
917
+ return training_data
918
+
919
+ def _resolve_annotations(self, annotations: List[Dict]) -> Optional[str]:
920
+ """Resolve multiple annotations using the configured strategy."""
921
+ if not annotations:
922
+ return None
923
+
924
+ if self.config.resolution_strategy == ResolutionStrategy.MAJORITY_VOTE:
925
+ return self._majority_vote(annotations)
926
+ elif self.config.resolution_strategy == ResolutionStrategy.RANDOM:
927
+ return self._random_selection(annotations)
928
+ elif self.config.resolution_strategy == ResolutionStrategy.CONSENSUS:
929
+ return self._consensus_resolution(annotations)
930
+ else:
931
+ return self._majority_vote(annotations) # Default fallback
932
+
933
+ def _majority_vote(self, annotations: List[Dict]) -> str:
934
+ """Resolve annotations using majority vote with random tie-breaking."""
935
+ label_counts = Counter(ann["label"] for ann in annotations)
936
+ max_count = max(label_counts.values())
937
+ # Find all labels with the maximum count (handles ties)
938
+ tied_labels = [label for label, count in label_counts.items() if count == max_count]
939
+ # Break ties randomly
940
+ return random.choice(tied_labels)
941
+
942
+ def _random_selection(self, annotations: List[Dict]) -> str:
943
+ """Resolve annotations by random selection."""
944
+ return random.choice(annotations)["label"]
945
+
946
+ def _consensus_resolution(self, annotations: List[Dict]) -> Optional[str]:
947
+ """Resolve annotations by consensus (all must agree)."""
948
+ labels = [ann["label"] for ann in annotations]
949
+ if len(set(labels)) == 1:
950
+ return labels[0]
951
+ return None
952
+
953
+ def _train_classifier(self, training_data: Dict, schema_name: str) -> Tuple[Optional[Pipeline], TrainingMetrics]:
954
+ """Train a classifier for a specific schema."""
955
+ start_time = time.time()
956
+
957
+ if len(training_data["texts"]) < self.config.min_instances_for_training:
958
+ error_msg = f"Insufficient training data for schema {schema_name}: {len(training_data['texts'])} < {self.config.min_instances_for_training}"
959
+ self.logger.warning(error_msg)
960
+ return None, TrainingMetrics(
961
+ schema_name=schema_name,
962
+ training_time=time.time() - start_time,
963
+ accuracy=0.0,
964
+ instance_count=len(training_data["texts"]),
965
+ timestamp=datetime.now(),
966
+ error_message=error_msg
967
+ )
968
+
969
+ # Check for sufficient label diversity
970
+ unique_labels = set(training_data["labels"])
971
+ if len(unique_labels) < 2:
972
+ error_msg = f"Insufficient label diversity for schema {schema_name}: {len(unique_labels)} unique labels"
973
+ self.logger.warning(error_msg)
974
+ return None, TrainingMetrics(
975
+ schema_name=schema_name,
976
+ training_time=time.time() - start_time,
977
+ accuracy=0.0,
978
+ instance_count=len(training_data["texts"]),
979
+ timestamp=datetime.now(),
980
+ error_message=error_msg
981
+ )
982
+
983
+ try:
984
+ # Create and train classifier
985
+ classifier = self._create_classifier()
986
+ vectorizer = self._create_vectorizer()
987
+
988
+ pipeline = Pipeline([
989
+ ("vectorizer", vectorizer),
990
+ ("classifier", classifier)
991
+ ])
992
+
993
+ pipeline.fit(training_data["texts"], training_data["labels"])
994
+
995
+ # Apply probability calibration if enabled
996
+ if self.config.calibrate_probabilities and hasattr(classifier, 'predict_proba'):
997
+ num_samples = len(training_data["texts"])
998
+ if num_samples >= 5:
999
+ try:
1000
+ from sklearn.calibration import CalibratedClassifierCV
1001
+ cv_folds = min(3, num_samples // 2)
1002
+ if cv_folds >= 2:
1003
+ calibrated = CalibratedClassifierCV(
1004
+ pipeline, cv=cv_folds, method='isotonic'
1005
+ )
1006
+ calibrated.fit(training_data["texts"], training_data["labels"])
1007
+ pipeline = calibrated
1008
+ self.logger.debug(f"Applied probability calibration with {cv_folds}-fold CV")
1009
+ except Exception as e:
1010
+ self.logger.warning(f"Calibration failed, using uncalibrated model: {e}")
1011
+
1012
+ # Store vectorizer separately for strategy use
1013
+ self._vectorizers[schema_name] = pipeline.named_steps.get("vectorizer", vectorizer) if hasattr(pipeline, 'named_steps') else vectorizer
1014
+
1015
+ # Train BALD ensemble if needed
1016
+ if self.config.query_strategy == "bald":
1017
+ self._train_bald_ensemble(training_data, schema_name)
1018
+
1019
+ # Calculate accuracy
1020
+ predictions = pipeline.predict(training_data["texts"])
1021
+ accuracy = accuracy_score(training_data["labels"], predictions)
1022
+
1023
+ # Calculate confidence distribution
1024
+ confidence_distribution = self._calculate_confidence_distribution(pipeline, training_data["texts"])
1025
+
1026
+ training_time = time.time() - start_time
1027
+
1028
+ metrics = TrainingMetrics(
1029
+ schema_name=schema_name,
1030
+ training_time=training_time,
1031
+ accuracy=accuracy,
1032
+ instance_count=len(training_data["texts"]),
1033
+ timestamp=datetime.now(),
1034
+ confidence_distribution=confidence_distribution
1035
+ )
1036
+
1037
+ self.logger.info(f"Trained classifier for schema {schema_name} with {len(training_data['texts'])} instances, "
1038
+ f"accuracy: {accuracy:.3f}, time: {training_time:.2f}s")
1039
+
1040
+ return pipeline, metrics
1041
+
1042
+ except Exception as e:
1043
+ error_msg = f"Error training classifier for schema {schema_name}: {e}"
1044
+ self.logger.error(error_msg)
1045
+ return None, TrainingMetrics(
1046
+ schema_name=schema_name,
1047
+ training_time=time.time() - start_time,
1048
+ accuracy=0.0,
1049
+ instance_count=len(training_data["texts"]),
1050
+ timestamp=datetime.now(),
1051
+ error_message=error_msg
1052
+ )
1053
+
1054
+ def _train_bald_ensemble(self, training_data: Dict, schema_name: str):
1055
+ """Train an ensemble of classifiers for BALD strategy."""
1056
+ params = self.config.bald_params
1057
+ n_estimators = params.get("n_estimators", 5)
1058
+ bootstrap_fraction = params.get("bootstrap_fraction", 0.8)
1059
+
1060
+ texts = training_data["texts"]
1061
+ labels = training_data["labels"]
1062
+ n_samples = len(texts)
1063
+ bootstrap_size = max(2, int(n_samples * bootstrap_fraction))
1064
+
1065
+ ensemble = []
1066
+ for i in range(n_estimators):
1067
+ indices = np.random.choice(n_samples, size=bootstrap_size, replace=True)
1068
+ boot_texts = [texts[j] for j in indices]
1069
+ boot_labels = [labels[j] for j in indices]
1070
+
1071
+ # Need at least 2 classes
1072
+ if len(set(boot_labels)) < 2:
1073
+ continue
1074
+
1075
+ clf = self._create_classifier()
1076
+ vec = self._create_vectorizer()
1077
+ pipe = Pipeline([("vectorizer", vec), ("classifier", clf)])
1078
+ pipe.fit(boot_texts, boot_labels)
1079
+ ensemble.append(pipe)
1080
+
1081
+ if ensemble:
1082
+ self._bald_ensembles[schema_name] = ensemble
1083
+ self.logger.info(f"Trained BALD ensemble with {len(ensemble)} models for {schema_name}")
1084
+
1085
+ def _calculate_confidence_distribution(self, pipeline, texts: List[str]) -> Dict[str, float]:
1086
+ """Calculate confidence score distribution."""
1087
+ try:
1088
+ probas = pipeline.predict_proba(texts)
1089
+ max_confidences = np.max(probas, axis=1)
1090
+
1091
+ # Create histogram bins
1092
+ bins = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
1093
+ hist, _ = np.histogram(max_confidences, bins=bins)
1094
+
1095
+ # Convert to percentages
1096
+ total = len(max_confidences)
1097
+ distribution = {}
1098
+ for i, count in enumerate(hist):
1099
+ bin_label = f"{bins[i]:.1f}-{bins[i+1]:.1f}"
1100
+ distribution[bin_label] = (count / total) * 100 if total > 0 else 0
1101
+
1102
+ return distribution
1103
+ except Exception as e:
1104
+ self.logger.warning(f"Failed to calculate confidence distribution: {e}")
1105
+ return {}
1106
+
1107
+ def _create_classifier(self):
1108
+ """Create classifier instance based on configuration."""
1109
+ kwargs = dict(self.config.classifier_kwargs)
1110
+
1111
+ if self.config.classifier_name == "sklearn.linear_model.LogisticRegression":
1112
+ return LogisticRegression(**kwargs)
1113
+ elif self.config.classifier_name == "sklearn.ensemble.RandomForestClassifier":
1114
+ return RandomForestClassifier(**kwargs)
1115
+ elif self.config.classifier_name == "sklearn.svm.SVC":
1116
+ kwargs.setdefault("probability", True)
1117
+ return SVC(**kwargs)
1118
+ else:
1119
+ # Try to import dynamically
1120
+ try:
1121
+ module_name, class_name = self.config.classifier_name.rsplit('.', 1)
1122
+ module = __import__(module_name, fromlist=[class_name])
1123
+ classifier_class = getattr(module, class_name)
1124
+ return classifier_class(**kwargs)
1125
+ except Exception as e:
1126
+ self.logger.error(f"Failed to create classifier {self.config.classifier_name}: {e}")
1127
+ return LogisticRegression() # Fallback
1128
+
1129
+ def _create_vectorizer(self):
1130
+ """Create vectorizer instance based on configuration."""
1131
+ kwargs = dict(self.config.vectorizer_kwargs)
1132
+
1133
+ if self.config.vectorizer_name == "sklearn.feature_extraction.text.CountVectorizer":
1134
+ return CountVectorizer(**kwargs)
1135
+ elif self.config.vectorizer_name == "sklearn.feature_extraction.text.TfidfVectorizer":
1136
+ return TfidfVectorizer(**kwargs)
1137
+ elif self.config.vectorizer_name == "sentence-transformers":
1138
+ model_name = kwargs.pop("model_name", "all-MiniLM-L6-v2")
1139
+ return SentenceTransformerVectorizer(model_name=model_name)
1140
+ else:
1141
+ # Try to import dynamically
1142
+ try:
1143
+ module_name, class_name = self.config.vectorizer_name.rsplit('.', 1)
1144
+ module = __import__(module_name, fromlist=[class_name])
1145
+ vectorizer_class = getattr(module, class_name)
1146
+ return vectorizer_class(**kwargs)
1147
+ except Exception as e:
1148
+ self.logger.error(f"Failed to create vectorizer {self.config.vectorizer_name}: {e}")
1149
+ return TfidfVectorizer() # Fallback
1150
+
1151
+ def _reorder_instances(self, item_manager: ItemStateManager, schema_name: str):
1152
+ """Reorder instances based on the configured query strategy."""
1153
+ if schema_name not in self._models:
1154
+ self.logger.warning(f"No trained model available for schema {schema_name}")
1155
+ return
1156
+
1157
+ # Get unlabeled instances
1158
+ unlabeled_instances = []
1159
+ unlabeled_texts = []
1160
+ for instance_id in item_manager.get_instance_ids():
1161
+ if not item_manager.get_annotators_for_item(instance_id):
1162
+ item = item_manager.get_item(instance_id)
1163
+ if item:
1164
+ unlabeled_instances.append(instance_id)
1165
+ unlabeled_texts.append(item.get_text())
1166
+
1167
+ if not unlabeled_texts:
1168
+ self.logger.info("No unlabeled instances to reorder")
1169
+ return
1170
+
1171
+ # Limit number of instances to process
1172
+ if self.config.max_instances_to_reorder:
1173
+ limit = self.config.max_instances_to_reorder
1174
+ unlabeled_instances = unlabeled_instances[:limit]
1175
+ unlabeled_texts = unlabeled_texts[:limit]
1176
+
1177
+ model = self._models[schema_name]
1178
+ annotated = self._annotated_texts.get(schema_name, [])
1179
+
1180
+ # Get rankings from strategy
1181
+ if (self.config.query_strategy == "bald"
1182
+ and schema_name in self._bald_ensembles
1183
+ and isinstance(self._query_strategy, BaldStrategy)):
1184
+ vectorizer = self._vectorizers.get(schema_name)
1185
+ if vectorizer:
1186
+ rankings = self._query_strategy.rank_with_ensemble(
1187
+ unlabeled_texts, self._bald_ensembles[schema_name], vectorizer
1188
+ )
1189
+ else:
1190
+ rankings = self._query_strategy.rank(unlabeled_texts, model, model, annotated)
1191
+ else:
1192
+ # Extract vectorizer and classifier from pipeline for strategy use
1193
+ vectorizer = self._vectorizers.get(schema_name)
1194
+ classifier = model
1195
+ if vectorizer:
1196
+ rankings = self._query_strategy.rank(
1197
+ unlabeled_texts, classifier, vectorizer, annotated
1198
+ )
1199
+ else:
1200
+ # Fallback: use confidence scores directly
1201
+ instance_scores = self._calculate_confidence_scores(
1202
+ unlabeled_instances, item_manager, schema_name
1203
+ )
1204
+ sorted_instances = sorted(instance_scores, key=lambda x: x[1])
1205
+ self._apply_reordering(sorted_instances, item_manager)
1206
+ return
1207
+
1208
+ # ICL ensemble blending (Phase 5B)
1209
+ if self.config.use_icl_ensemble:
1210
+ rankings = self._blend_icl_scores(
1211
+ rankings, unlabeled_texts, schema_name
1212
+ )
1213
+
1214
+ # Map rankings back to instance IDs
1215
+ sorted_instances = [
1216
+ (unlabeled_instances[idx], score) for idx, score in rankings
1217
+ if idx < len(unlabeled_instances)
1218
+ ]
1219
+
1220
+ # Apply reordering with random sampling
1221
+ self._apply_reordering(sorted_instances, item_manager)
1222
+
1223
+ def _blend_icl_scores(self, rankings: List[Tuple[int, float]],
1224
+ texts: List[str], schema_name: str) -> List[Tuple[int, float]]:
1225
+ """Blend query strategy scores with ICL predictions."""
1226
+ try:
1227
+ from potato.ai.icl_labeler import get_icl_labeler
1228
+ icl_labeler = get_icl_labeler()
1229
+ if icl_labeler is None or not icl_labeler.has_enough_examples(schema_name):
1230
+ return rankings
1231
+
1232
+ # Determine interpolation weight based on annotation count
1233
+ params = self.config.icl_ensemble_params
1234
+ initial_w = params.get("initial_icl_weight", 0.7)
1235
+ final_w = params.get("final_icl_weight", 0.2)
1236
+ transition = params.get("transition_instances", 100)
1237
+
1238
+ annotated_count = len(self._annotated_texts.get(schema_name, []))
1239
+ progress = min(1.0, annotated_count / max(1, transition))
1240
+ icl_weight = initial_w + (final_w - initial_w) * progress
1241
+ strategy_weight = 1.0 - icl_weight
1242
+
1243
+ # Get ICL confidence for each text
1244
+ icl_scores = {}
1245
+ for idx, text in enumerate(texts):
1246
+ try:
1247
+ pred = icl_labeler.label_instance(
1248
+ instance_id=f"_al_blend_{idx}",
1249
+ schema_name=schema_name,
1250
+ instance_text=text,
1251
+ )
1252
+ if pred:
1253
+ # Lower confidence = higher priority (more uncertain)
1254
+ icl_scores[idx] = 1.0 - pred.confidence_score
1255
+ else:
1256
+ icl_scores[idx] = 0.5
1257
+ except Exception:
1258
+ icl_scores[idx] = 0.5
1259
+
1260
+ # Normalize strategy scores
1261
+ strategy_map = {idx: score for idx, score in rankings}
1262
+ s_vals = list(strategy_map.values())
1263
+ s_min, s_max = min(s_vals), max(s_vals)
1264
+ s_rng = s_max - s_min if s_max > s_min else 1.0
1265
+
1266
+ # Normalize ICL scores
1267
+ i_vals = list(icl_scores.values())
1268
+ i_min, i_max = min(i_vals), max(i_vals)
1269
+ i_rng = i_max - i_min if i_max > i_min else 1.0
1270
+
1271
+ blended = []
1272
+ for idx, score in rankings:
1273
+ norm_s = (score - s_min) / s_rng
1274
+ norm_i = (icl_scores.get(idx, 0.5) - i_min) / i_rng
1275
+ combined = strategy_weight * norm_s + icl_weight * norm_i
1276
+ blended.append((idx, combined))
1277
+
1278
+ blended.sort(key=lambda x: x[1], reverse=True)
1279
+ return blended
1280
+
1281
+ except ImportError:
1282
+ return rankings
1283
+ except Exception as e:
1284
+ self.logger.warning(f"ICL blending failed: {e}")
1285
+ return rankings
1286
+
1287
+ def _cold_start_reorder(self, item_manager: ItemStateManager):
1288
+ """LLM-based cold-start instance selection (Phase 3A).
1289
+
1290
+ Based on Bayer et al. (2024) ActiveLLM approach. Before enough
1291
+ annotations exist for classifier training, use LLM to estimate
1292
+ which instances are most informative by finding those where LLM
1293
+ confidence is moderate (on the decision boundary).
1294
+ """
1295
+ try:
1296
+ from potato.ai.llm_active_learning import create_llm_active_learning
1297
+
1298
+ llm = create_llm_active_learning(self.config.llm_config)
1299
+
1300
+ # Sample candidate instances
1301
+ all_ids = list(item_manager.get_instance_ids())
1302
+ unannotated = [
1303
+ iid for iid in all_ids
1304
+ if not item_manager.get_annotators_for_item(iid)
1305
+ ]
1306
+
1307
+ if not unannotated:
1308
+ return
1309
+
1310
+ batch_size = min(self.config.cold_start_batch_size, len(unannotated))
1311
+ candidates = random.sample(unannotated, batch_size)
1312
+
1313
+ instances = []
1314
+ for iid in candidates:
1315
+ item = item_manager.get_item(iid)
1316
+ if item:
1317
+ instances.append({"id": iid, "text": item.get_text()})
1318
+
1319
+ if not instances:
1320
+ return
1321
+
1322
+ # Get LLM predictions
1323
+ schema_name = self.schema_cycler.get_current_schema() if self.schema_cycler else None
1324
+ predictions = llm.predict_instances(
1325
+ instances=instances,
1326
+ annotation_instructions="Rate your confidence in labeling this text.",
1327
+ schema_name=schema_name or "default",
1328
+ label_options=["positive", "negative", "neutral"],
1329
+ )
1330
+
1331
+ # Select instances with moderate confidence (decision boundary)
1332
+ moderate = []
1333
+ other = []
1334
+ for pred in predictions:
1335
+ if 0.4 <= pred.confidence_score <= 0.7:
1336
+ moderate.append((pred.instance_id, pred.confidence_score))
1337
+ else:
1338
+ other.append((pred.instance_id, pred.confidence_score))
1339
+
1340
+ # Moderate-confidence first, then others, interleaved with random
1341
+ reordered = [iid for iid, _ in moderate] + [iid for iid, _ in other]
1342
+
1343
+ # Add remaining unannotated instances not in the sample
1344
+ sampled_set = set(candidates)
1345
+ remaining = [iid for iid in unannotated if iid not in sampled_set]
1346
+ random.shuffle(remaining)
1347
+ reordered.extend(remaining)
1348
+
1349
+ item_manager.reorder_instances(reordered)
1350
+ self.logger.info(f"Cold-start LLM reordering: {len(moderate)} moderate-confidence, "
1351
+ f"{len(other)} other, {len(remaining)} remaining")
1352
+
1353
+ except Exception as e:
1354
+ self.logger.warning(f"Cold-start LLM reordering failed: {e}")
1355
+
1356
+ def _route_annotation(self, instance_id: str, instance_text: str,
1357
+ schema_name: str) -> Dict[str, Any]:
1358
+ """Noise-aware annotation routing (Phase 5D).
1359
+
1360
+ Based on Yuan et al. (2024) NoiseAL approach. Routes instances
1361
+ between LLM auto-labeling and human annotation based on LLM
1362
+ confidence levels.
1363
+
1364
+ Returns:
1365
+ Dict with 'route' ('human'|'auto'), optional 'suggestion',
1366
+ and optional 'auto_label'.
1367
+ """
1368
+ if not self.config.annotation_routing:
1369
+ return {"route": "human"}
1370
+
1371
+ thresholds = self.config.routing_thresholds
1372
+ auto_min = thresholds.get("auto_label_min_confidence", 0.9)
1373
+ suggest_below = thresholds.get("show_suggestion_below", 0.5)
1374
+
1375
+ try:
1376
+ from potato.ai.icl_labeler import get_icl_labeler
1377
+ icl_labeler = get_icl_labeler()
1378
+ if icl_labeler is None or not icl_labeler.has_enough_examples(schema_name):
1379
+ return {"route": "human"}
1380
+
1381
+ prediction = icl_labeler.label_instance(
1382
+ instance_id=instance_id,
1383
+ schema_name=schema_name,
1384
+ instance_text=instance_text,
1385
+ )
1386
+
1387
+ if prediction is None:
1388
+ return {"route": "human"}
1389
+
1390
+ confidence = prediction.confidence_score
1391
+
1392
+ if confidence >= auto_min:
1393
+ # High confidence: auto-label with periodic verification
1394
+ should_verify = random.random() < self.config.verification_sample_rate
1395
+ return {
1396
+ "route": "auto",
1397
+ "auto_label": prediction.predicted_label,
1398
+ "confidence": confidence,
1399
+ "needs_verification": should_verify,
1400
+ }
1401
+ elif confidence < suggest_below:
1402
+ # Low confidence: route to human with LLM suggestion
1403
+ return {
1404
+ "route": "human",
1405
+ "suggestion": prediction.predicted_label,
1406
+ "confidence": confidence,
1407
+ }
1408
+ else:
1409
+ # Medium confidence: route to human (most informative)
1410
+ return {"route": "human"}
1411
+
1412
+ except ImportError:
1413
+ return {"route": "human"}
1414
+ except Exception as e:
1415
+ self.logger.warning(f"Annotation routing failed for {instance_id}: {e}")
1416
+ return {"route": "human"}
1417
+
1418
+ def _calculate_confidence_scores(self, instance_ids: List[str], item_manager: ItemStateManager, schema_name: str) -> List[Tuple[str, float]]:
1419
+ """Calculate confidence scores for instances."""
1420
+ instance_scores = []
1421
+ model = self._models[schema_name]
1422
+
1423
+ for instance_id in instance_ids:
1424
+ item = item_manager.get_item(instance_id)
1425
+ if not item:
1426
+ continue
1427
+
1428
+ text = item.get_text()
1429
+
1430
+ try:
1431
+ # Get prediction probabilities
1432
+ probas = model.predict_proba([text])[0]
1433
+ confidence = np.max(probas)
1434
+ instance_scores.append((instance_id, confidence))
1435
+ except Exception as e:
1436
+ self.logger.warning(f"Error predicting for instance {instance_id}: {e}")
1437
+ # Default to low confidence for failed predictions
1438
+ instance_scores.append((instance_id, 0.1))
1439
+
1440
+ return instance_scores
1441
+
1442
+ def _apply_reordering(self, sorted_instances: List[Tuple[str, float]], item_manager: ItemStateManager):
1443
+ """Apply the new ordering to the item manager."""
1444
+ # Extract instance IDs in new order
1445
+ new_order = [instance_id for instance_id, _ in sorted_instances]
1446
+
1447
+ if not new_order:
1448
+ return
1449
+
1450
+ # Apply random sampling
1451
+ random_count = int(len(new_order) * self.config.random_sample_percent)
1452
+ if random_count > 0 and random_count <= len(new_order):
1453
+ random_instances = random.sample(new_order, random_count)
1454
+ else:
1455
+ random_instances = []
1456
+
1457
+ # Interleave active learning and random instances
1458
+ final_order = []
1459
+ al_idx = 0
1460
+ rand_idx = 0
1461
+
1462
+ while al_idx < len(new_order) or rand_idx < len(random_instances):
1463
+ if al_idx < len(new_order):
1464
+ final_order.append(new_order[al_idx])
1465
+ al_idx += 1
1466
+ if rand_idx < len(random_instances):
1467
+ final_order.append(random_instances[rand_idx])
1468
+ rand_idx += 1
1469
+
1470
+ # Update item manager ordering
1471
+ item_manager.reorder_instances(final_order)
1472
+ self.logger.info(f"Reordered {len(final_order)} instances")
1473
+
1474
+ def check_and_trigger_training(self):
1475
+ """Check if training should be triggered and queue it if needed."""
1476
+ if not self.config.enabled:
1477
+ self.logger.debug("Active learning is disabled")
1478
+ return
1479
+
1480
+ with self._lock:
1481
+ # Count current annotations
1482
+ user_manager = get_user_state_manager()
1483
+ current_annotation_count = sum(
1484
+ len(user_state.get_all_annotations())
1485
+ for user_state in user_manager.get_all_users()
1486
+ )
1487
+
1488
+ self.logger.debug(f"Current annotation count: {current_annotation_count}, last count: {self._last_annotation_count}, update_frequency: {self.config.update_frequency}")
1489
+
1490
+ # Check if we should trigger training
1491
+ if (current_annotation_count - self._last_annotation_count) >= self.config.update_frequency:
1492
+ self._training_queue.put("train")
1493
+ self._last_annotation_count = current_annotation_count
1494
+ self.logger.info(f"Queued active learning training (annotations: {current_annotation_count})")
1495
+ else:
1496
+ self.logger.debug("Not enough new annotations to trigger training")
1497
+
1498
+ def force_training(self):
1499
+ """Force immediate training (for testing purposes)."""
1500
+ if not self.config.enabled:
1501
+ self.logger.debug("Active learning is disabled")
1502
+ return
1503
+
1504
+ self.logger.info("Forcing immediate active learning training")
1505
+ self._training_queue.put("train")
1506
+
1507
+ def get_stats(self) -> Dict[str, Any]:
1508
+ """Get active learning statistics."""
1509
+ with self._lock:
1510
+ stats = {
1511
+ "enabled": self.config.enabled,
1512
+ "training_count": self._training_count,
1513
+ "last_training_time": self._last_training_time,
1514
+ "models_trained": list(self._models.keys()),
1515
+ "current_schema": self.schema_cycler.get_current_schema() if self.schema_cycler else None,
1516
+ "schema_order": self.schema_cycler.get_schema_order() if self.schema_cycler else [],
1517
+ "database_enabled": self.config.database_enabled,
1518
+ "model_persistence_enabled": self.config.model_persistence_enabled,
1519
+ "llm_enabled": self.config.llm_enabled,
1520
+ "query_strategy": self.config.query_strategy,
1521
+ "calibrate_probabilities": self.config.calibrate_probabilities,
1522
+ "cold_start_strategy": self.config.cold_start_strategy,
1523
+ "use_icl_ensemble": self.config.use_icl_ensemble,
1524
+ "annotation_routing": self.config.annotation_routing,
1525
+ }
1526
+
1527
+ # Add training metrics if available
1528
+ if self.database_manager:
1529
+ try:
1530
+ stats["training_history"] = [
1531
+ asdict(metrics) for metrics in self.database_manager.get_training_history()
1532
+ ]
1533
+ except Exception as e:
1534
+ self.logger.warning(f"Failed to get training history: {e}")
1535
+ stats["training_history"] = []
1536
+
1537
+ return stats
1538
+
1539
+ def shutdown(self):
1540
+ """Shutdown the active learning manager."""
1541
+ self._stop_training.set()
1542
+ if self._training_thread and self._training_thread.is_alive():
1543
+ self._training_queue.put(None) # Shutdown signal
1544
+ self._training_thread.join(timeout=5.0)
1545
+ self.logger.info("Active learning manager shutdown complete")
1546
+
1547
+
1548
+ # Global singleton instance
1549
+ ACTIVE_LEARNING_MANAGER = None
1550
+
1551
+
1552
+ def parse_active_learning_config(config_data: Dict[str, Any]) -> Optional[ActiveLearningConfig]:
1553
+ """Build an ``ActiveLearningConfig`` from a Potato project config dict.
1554
+
1555
+ Returns None when active learning is not enabled. Maps the keys under the
1556
+ ``active_learning:`` section onto the dataclass fields (unknown keys are
1557
+ ignored), and defaults ``schema_names`` to the project's labelable
1558
+ annotation schemes when not given.
1559
+ """
1560
+ al_dict = (config_data or {}).get("active_learning", {}) or {}
1561
+ if not al_dict.get("enabled"):
1562
+ return None
1563
+
1564
+ valid_fields = {f.name for f in dataclasses.fields(ActiveLearningConfig)}
1565
+ kwargs = {k: v for k, v in al_dict.items() if k in valid_fields}
1566
+
1567
+ # Honor the nested `active_learning.llm:` block (LLM cold-start / ICL).
1568
+ # The dataclass uses flat fields (llm_enabled / llm_config), so translate.
1569
+ llm_block = al_dict.get("llm")
1570
+ if isinstance(llm_block, dict):
1571
+ kwargs.setdefault("llm_enabled", bool(llm_block.get("enabled", False)))
1572
+ kwargs.setdefault("llm_config", llm_block)
1573
+
1574
+ # YAML parses sequences as lists, but sklearn's vectorizers require a tuple
1575
+ # for ngram_range (e.g. (1, 2)). Coerce it so training doesn't fail.
1576
+ vec_params = kwargs.get("vectorizer_params")
1577
+ if isinstance(vec_params, dict) and isinstance(vec_params.get("ngram_range"), list):
1578
+ vec_params = dict(vec_params)
1579
+ vec_params["ngram_range"] = tuple(vec_params["ngram_range"])
1580
+ kwargs["vectorizer_params"] = vec_params
1581
+
1582
+ # resolution_strategy may arrive as a string; coerce to the enum.
1583
+ rs = kwargs.get("resolution_strategy")
1584
+ if isinstance(rs, str):
1585
+ try:
1586
+ kwargs["resolution_strategy"] = ResolutionStrategy(rs)
1587
+ except ValueError:
1588
+ kwargs.pop("resolution_strategy", None)
1589
+
1590
+ # Default schema_names to the labelable schemes in the project.
1591
+ if not kwargs.get("schema_names"):
1592
+ schemes = config_data.get("annotation_schemes", []) or []
1593
+ kwargs["schema_names"] = [
1594
+ s.get("name") for s in schemes
1595
+ if s.get("name") and s.get("annotation_type") in (
1596
+ "radio", "multiselect", "likert", "select"
1597
+ )
1598
+ ]
1599
+
1600
+ return ActiveLearningConfig(**kwargs)
1601
+
1602
+
1603
+ def init_active_learning_manager(config: ActiveLearningConfig) -> ActiveLearningManager:
1604
+ """Initialize the global active learning manager."""
1605
+ global ACTIVE_LEARNING_MANAGER
1606
+
1607
+ if ACTIVE_LEARNING_MANAGER is None:
1608
+ ACTIVE_LEARNING_MANAGER = ActiveLearningManager(config)
1609
+
1610
+ return ACTIVE_LEARNING_MANAGER
1611
+
1612
+
1613
+ def get_active_learning_manager() -> Optional[ActiveLearningManager]:
1614
+ """Get the global active learning manager."""
1615
+ return ACTIVE_LEARNING_MANAGER
1616
+
1617
+
1618
+ def clear_active_learning_manager():
1619
+ """Clear the global active learning manager (for testing)."""
1620
+ global ACTIVE_LEARNING_MANAGER
1621
+ if ACTIVE_LEARNING_MANAGER:
1622
+ ACTIVE_LEARNING_MANAGER.shutdown()
1623
+ ACTIVE_LEARNING_MANAGER = None
potato/adjudication.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adjudication Module
3
+
4
+ This module provides a comprehensive adjudication system where designated users
5
+ review items with multiple annotations, resolve disagreements, and produce
6
+ gold-standard final decisions.
7
+
8
+ Adjudication is NOT a phase — it's a parallel workflow accessible via a dedicated
9
+ /adjudicate route, available to users with adjudicator privileges. This avoids
10
+ disrupting the existing phase progression system.
11
+
12
+ Key Components:
13
+ - AdjudicationConfig: Configuration dataclass for adjudication settings
14
+ - AdjudicationItem: Represents an item eligible for adjudication with all annotations
15
+ - AdjudicationDecision: Represents an adjudicator's final decision on an item
16
+ - AdjudicationManager: Singleton manager for the adjudication workflow
17
+
18
+ The workflow:
19
+ 1. Annotators complete annotations via /annotate (existing workflow)
20
+ 2. AdjudicationManager monitors annotation counts and agreement
21
+ 3. Items are flagged when criteria are met (min annotations, low agreement)
22
+ 4. Adjudicators review items via /adjudicate and submit decisions
23
+ 5. Final dataset CLI merges unanimous + adjudicated decisions
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import math
29
+ import os
30
+ import threading
31
+ from collections import Counter, defaultdict
32
+ from dataclasses import dataclass, field
33
+ from datetime import datetime
34
+ from typing import Dict, List, Optional, Any, Set
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Singleton instance
39
+ _ADJUDICATION_MANAGER = None
40
+ _ADJUDICATION_LOCK = threading.Lock()
41
+
42
+
43
+ @dataclass
44
+ class AdjudicationConfig:
45
+ """Configuration for adjudication features."""
46
+ enabled: bool = False
47
+ adjudicator_users: List[str] = field(default_factory=list)
48
+
49
+ # Trigger criteria
50
+ min_annotations: int = 2
51
+ require_fully_annotated: bool = False
52
+ agreement_threshold: float = 0.75
53
+ show_all_items: bool = False
54
+
55
+ # Display options
56
+ show_annotator_names: bool = True
57
+ show_timing_data: bool = True
58
+ show_agreement_scores: bool = True
59
+ fast_decision_warning_ms: int = 2000
60
+
61
+ # Adjudicator metadata fields
62
+ require_confidence: bool = True
63
+ require_notes_on_override: bool = False
64
+ error_taxonomy: List[str] = field(default_factory=lambda: [
65
+ "ambiguous_text", "guideline_gap", "annotator_error",
66
+ "edge_case", "subjective_disagreement", "other"
67
+ ])
68
+
69
+ # Similarity (Phase 3, optional)
70
+ similarity_enabled: bool = False
71
+ similarity_model: str = "all-MiniLM-L6-v2"
72
+ similarity_top_k: int = 5
73
+ similarity_precompute: bool = True
74
+
75
+ # Output
76
+ output_subdir: str = "adjudication"
77
+
78
+
79
+ @dataclass
80
+ class AdjudicationItem:
81
+ """Represents an item eligible for adjudication with all annotator data."""
82
+ instance_id: str
83
+ annotations: Dict[str, Dict[str, Any]] # user_id -> {schema: {label: value}}
84
+ span_annotations: Dict[str, List[Dict]] # user_id -> [span_dict, ...]
85
+ behavioral_data: Dict[str, Dict] # user_id -> {total_time_ms, ...}
86
+ agreement_scores: Dict[str, float] # schema_name -> agreement score
87
+ overall_agreement: float
88
+ num_annotators: int
89
+ status: str = "pending" # pending, in_progress, completed, skipped
90
+ assigned_adjudicator: Optional[str] = None
91
+ mace_predictions: Dict[str, Any] = field(default_factory=dict) # schema -> predicted label
92
+
93
+ def to_dict(self) -> Dict[str, Any]:
94
+ """Serialize to dictionary for JSON output."""
95
+ result = {
96
+ "instance_id": self.instance_id,
97
+ "annotations": self.annotations,
98
+ "span_annotations": self.span_annotations,
99
+ "behavioral_data": self.behavioral_data,
100
+ "agreement_scores": self.agreement_scores,
101
+ "overall_agreement": self.overall_agreement,
102
+ "num_annotators": self.num_annotators,
103
+ "status": self.status,
104
+ "assigned_adjudicator": self.assigned_adjudicator,
105
+ }
106
+ if self.mace_predictions:
107
+ result["mace_predictions"] = self.mace_predictions
108
+ return result
109
+
110
+
111
+ @dataclass
112
+ class AdjudicationDecision:
113
+ """Represents an adjudicator's final decision on an item."""
114
+ instance_id: str
115
+ adjudicator_id: str
116
+ timestamp: str # ISO format string
117
+ label_decisions: Dict[str, Any] # schema -> value
118
+ span_decisions: List[Dict] # list of span dicts
119
+ source: Dict[str, str] # schema -> "annotator_X" | "adjudicator" | "merged"
120
+ confidence: str # "high", "medium", "low"
121
+ notes: str
122
+ error_taxonomy: List[str]
123
+ guideline_update_flag: bool = False
124
+ guideline_update_notes: str = ""
125
+ time_spent_ms: int = 0
126
+
127
+ def to_dict(self) -> Dict[str, Any]:
128
+ """Serialize to dictionary for JSON output."""
129
+ return {
130
+ "instance_id": self.instance_id,
131
+ "adjudicator_id": self.adjudicator_id,
132
+ "timestamp": self.timestamp,
133
+ "label_decisions": self.label_decisions,
134
+ "span_decisions": self.span_decisions,
135
+ "source": self.source,
136
+ "confidence": self.confidence,
137
+ "notes": self.notes,
138
+ "error_taxonomy": self.error_taxonomy,
139
+ "guideline_update_flag": self.guideline_update_flag,
140
+ "guideline_update_notes": self.guideline_update_notes,
141
+ "time_spent_ms": self.time_spent_ms,
142
+ }
143
+
144
+ @classmethod
145
+ def from_dict(cls, d: Dict[str, Any]) -> "AdjudicationDecision":
146
+ """Deserialize from dictionary."""
147
+ return cls(
148
+ instance_id=d["instance_id"],
149
+ adjudicator_id=d["adjudicator_id"],
150
+ timestamp=d["timestamp"],
151
+ label_decisions=d.get("label_decisions", {}),
152
+ span_decisions=d.get("span_decisions", []),
153
+ source=d.get("source", {}),
154
+ confidence=d.get("confidence", "medium"),
155
+ notes=d.get("notes", ""),
156
+ error_taxonomy=d.get("error_taxonomy", []),
157
+ guideline_update_flag=d.get("guideline_update_flag", False),
158
+ guideline_update_notes=d.get("guideline_update_notes", ""),
159
+ time_spent_ms=d.get("time_spent_ms", 0),
160
+ )
161
+
162
+
163
+ class AdjudicationManager:
164
+ """
165
+ Manages the adjudication workflow including queue building, agreement
166
+ computation, decision storage, and final dataset generation.
167
+
168
+ Follows the singleton pattern used by QualityControlManager.
169
+ """
170
+
171
+ def __init__(self, config: Dict[str, Any]):
172
+ """
173
+ Initialize the adjudication manager.
174
+
175
+ Args:
176
+ config: The full application configuration dictionary
177
+ """
178
+ self.config = config
179
+ self.logger = logging.getLogger(__name__)
180
+ self._lock = threading.RLock()
181
+
182
+ # Parse configuration
183
+ self.adj_config = self._parse_config(config)
184
+
185
+ # Queue and decisions
186
+ self.queue: Dict[str, AdjudicationItem] = {} # instance_id -> AdjudicationItem
187
+ self.decisions: Dict[str, AdjudicationDecision] = {} # instance_id -> decision
188
+ self._queue_built = False
189
+
190
+ # Load any previously saved decisions
191
+ self._load_decisions()
192
+
193
+ # Initialize similarity engine (Phase 3)
194
+ self.similarity_engine = None
195
+ if self.adj_config.similarity_enabled:
196
+ from potato.similarity import init_similarity_engine
197
+ self.similarity_engine = init_similarity_engine(config, self.adj_config)
198
+ if (self.similarity_engine and self.similarity_engine.enabled
199
+ and self.adj_config.similarity_precompute):
200
+ self._precompute_similarities()
201
+
202
+ self.logger.info(
203
+ f"AdjudicationManager initialized: enabled={self.adj_config.enabled}, "
204
+ f"adjudicators={self.adj_config.adjudicator_users}"
205
+ )
206
+
207
+ def _parse_config(self, config: Dict[str, Any]) -> AdjudicationConfig:
208
+ """Parse adjudication configuration from the main config."""
209
+ adj = AdjudicationConfig()
210
+
211
+ adj_config = config.get("adjudication", {})
212
+ if not adj_config or not adj_config.get("enabled", False):
213
+ return adj
214
+
215
+ adj.enabled = True
216
+ adj.adjudicator_users = adj_config.get("adjudicator_users", [])
217
+ adj.min_annotations = adj_config.get("min_annotations", 2)
218
+ adj.require_fully_annotated = adj_config.get("require_fully_annotated", False)
219
+ adj.agreement_threshold = adj_config.get("agreement_threshold", 0.75)
220
+ adj.show_all_items = adj_config.get("show_all_items", False)
221
+ adj.show_annotator_names = adj_config.get("show_annotator_names", True)
222
+ adj.show_timing_data = adj_config.get("show_timing_data", True)
223
+ adj.show_agreement_scores = adj_config.get("show_agreement_scores", True)
224
+ adj.fast_decision_warning_ms = adj_config.get("fast_decision_warning_ms", 2000)
225
+ adj.require_confidence = adj_config.get("require_confidence", True)
226
+ adj.require_notes_on_override = adj_config.get("require_notes_on_override", False)
227
+
228
+ if "error_taxonomy" in adj_config:
229
+ adj.error_taxonomy = adj_config["error_taxonomy"]
230
+
231
+ # Similarity settings
232
+ sim_config = adj_config.get("similarity", {})
233
+ if sim_config.get("enabled", False):
234
+ adj.similarity_enabled = True
235
+ adj.similarity_model = sim_config.get("model", "all-MiniLM-L6-v2")
236
+ adj.similarity_top_k = sim_config.get("top_k", 5)
237
+ adj.similarity_precompute = sim_config.get("precompute_on_start", True)
238
+
239
+ adj.output_subdir = adj_config.get("output_subdir", "adjudication")
240
+
241
+ return adj
242
+
243
+ def is_adjudicator(self, username: str) -> bool:
244
+ """Check if a user is an authorized adjudicator."""
245
+ if not self.adj_config.enabled:
246
+ return False
247
+ return username in self.adj_config.adjudicator_users
248
+
249
+ def build_queue(self) -> List[AdjudicationItem]:
250
+ """
251
+ Scan all user annotations and build the adjudication queue.
252
+
253
+ Items become eligible when they have enough annotations and
254
+ agreement is below the threshold.
255
+
256
+ Returns:
257
+ List of AdjudicationItem objects
258
+ """
259
+ from potato.user_state_management import get_user_state_manager
260
+ from potato.item_state_management import get_item_state_manager
261
+
262
+ with self._lock:
263
+ usm = get_user_state_manager()
264
+ ism = get_item_state_manager()
265
+
266
+ # Get all annotation schemes from config
267
+ annotation_schemes = self.config.get("annotation_schemes", [])
268
+ scheme_names = [s.get("name", "") for s in annotation_schemes]
269
+
270
+ # Iterate over all items
271
+ for instance_id, item in ism.instance_id_to_instance.items():
272
+ instance_id_str = str(instance_id)
273
+
274
+ # Skip if already decided
275
+ if instance_id_str in self.decisions:
276
+ if instance_id_str not in self.queue:
277
+ continue
278
+ # Mark as completed if decision exists
279
+ self.queue[instance_id_str].status = "completed"
280
+ continue
281
+
282
+ # Get all annotators for this item
283
+ annotators = ism.instance_annotators.get(instance_id, set())
284
+ # Filter out adjudicators from annotator list
285
+ annotators = {
286
+ u for u in annotators
287
+ if u not in self.adj_config.adjudicator_users
288
+ }
289
+
290
+ if len(annotators) < self.adj_config.min_annotations:
291
+ continue
292
+
293
+ # Check if we require fully annotated items
294
+ if self.adj_config.require_fully_annotated:
295
+ max_per_item = ism.max_annotations_per_item
296
+ if max_per_item > 0 and len(annotators) < max_per_item:
297
+ continue
298
+
299
+ # Collect annotations from all annotators
300
+ item_annotations = {}
301
+ item_spans = {}
302
+ item_behavioral = {}
303
+
304
+ for user_id in annotators:
305
+ user_state = usm.get_user_state(user_id)
306
+ if not user_state:
307
+ continue
308
+
309
+ # Get label annotations
310
+ label_annots = user_state.instance_id_to_label_to_value.get(
311
+ instance_id_str, {}
312
+ )
313
+ if label_annots:
314
+ item_annotations[user_id] = self._serialize_labels(label_annots)
315
+
316
+ # Get span annotations
317
+ span_annots = user_state.instance_id_to_span_to_value.get(
318
+ instance_id_str, {}
319
+ )
320
+ if span_annots:
321
+ item_spans[user_id] = self._serialize_spans(span_annots)
322
+
323
+ # Get behavioral data
324
+ bd = user_state.instance_id_to_behavioral_data.get(
325
+ instance_id_str, {}
326
+ )
327
+ if bd:
328
+ item_behavioral[user_id] = self._serialize_behavioral(bd)
329
+
330
+ if not item_annotations and not item_spans:
331
+ continue
332
+
333
+ # Compute agreement scores
334
+ agreement_scores = self._compute_agreement(
335
+ item_annotations, scheme_names
336
+ )
337
+ overall = self._compute_overall_agreement(agreement_scores)
338
+
339
+ # Filter by agreement threshold
340
+ if not self.adj_config.show_all_items:
341
+ if overall >= self.adj_config.agreement_threshold:
342
+ continue
343
+
344
+ # Preserve existing status if already in queue
345
+ existing = self.queue.get(instance_id_str)
346
+ status = existing.status if existing else "pending"
347
+ assigned = existing.assigned_adjudicator if existing else None
348
+
349
+ # Enrich with MACE predictions if available
350
+ mace_preds = {}
351
+ try:
352
+ from potato.mace_manager import get_mace_manager
353
+ mace_mgr = get_mace_manager()
354
+ if mace_mgr and mace_mgr.results:
355
+ for sname in scheme_names:
356
+ pred = mace_mgr.get_prediction(instance_id_str, sname)
357
+ if pred is not None:
358
+ mace_preds[sname] = pred
359
+ except Exception:
360
+ pass # MACE is optional
361
+
362
+ self.queue[instance_id_str] = AdjudicationItem(
363
+ instance_id=instance_id_str,
364
+ annotations=item_annotations,
365
+ span_annotations=item_spans,
366
+ behavioral_data=item_behavioral,
367
+ agreement_scores=agreement_scores,
368
+ overall_agreement=overall,
369
+ num_annotators=len(annotators),
370
+ status=status,
371
+ assigned_adjudicator=assigned,
372
+ mace_predictions=mace_preds,
373
+ )
374
+
375
+ self._queue_built = True
376
+ return list(self.queue.values())
377
+
378
+ def try_enqueue_item(self, instance_id: str) -> bool:
379
+ """
380
+ Evaluate a single item and, if it qualifies, add it to the queue.
381
+
382
+ Called when an overlap-sample item saturates so that low-agreement
383
+ items show up in the adjudication queue without needing a full
384
+ ``build_queue()`` rescan. Returns True if the item ended up in the
385
+ queue, False otherwise.
386
+ """
387
+ if not self.adj_config.enabled:
388
+ return False
389
+
390
+ from potato.user_state_management import get_user_state_manager
391
+ from potato.item_state_management import get_item_state_manager
392
+
393
+ usm = get_user_state_manager()
394
+ ism = get_item_state_manager()
395
+ if usm is None or ism is None:
396
+ return False
397
+
398
+ with self._lock:
399
+ instance_id_str = str(instance_id)
400
+ if instance_id_str in self.decisions:
401
+ return False
402
+ item = ism.instance_id_to_instance.get(instance_id)
403
+ if item is None:
404
+ return False
405
+
406
+ annotators = {
407
+ u for u in ism.instance_annotators.get(instance_id, set())
408
+ if u not in self.adj_config.adjudicator_users
409
+ }
410
+ if len(annotators) < self.adj_config.min_annotations:
411
+ return False
412
+
413
+ scheme_names = [s.get("name", "") for s in self.config.get("annotation_schemes", [])]
414
+ item_annotations: Dict[str, Any] = {}
415
+ item_spans: Dict[str, Any] = {}
416
+ item_behavioral: Dict[str, Any] = {}
417
+ for user_id in annotators:
418
+ ustate = usm.get_user_state(user_id)
419
+ if not ustate:
420
+ continue
421
+ la = ustate.instance_id_to_label_to_value.get(instance_id_str, {})
422
+ if la:
423
+ item_annotations[user_id] = self._serialize_labels(la)
424
+ sa = ustate.instance_id_to_span_to_value.get(instance_id_str, {})
425
+ if sa:
426
+ item_spans[user_id] = self._serialize_spans(sa)
427
+ bd = ustate.instance_id_to_behavioral_data.get(instance_id_str, {})
428
+ if bd:
429
+ item_behavioral[user_id] = self._serialize_behavioral(bd)
430
+
431
+ if not item_annotations and not item_spans:
432
+ return False
433
+
434
+ agreement_scores = self._compute_agreement(item_annotations, scheme_names)
435
+ overall = self._compute_overall_agreement(agreement_scores)
436
+ if not self.adj_config.show_all_items:
437
+ if overall >= self.adj_config.agreement_threshold:
438
+ return False
439
+
440
+ existing = self.queue.get(instance_id_str)
441
+ self.queue[instance_id_str] = AdjudicationItem(
442
+ instance_id=instance_id_str,
443
+ annotations=item_annotations,
444
+ span_annotations=item_spans,
445
+ behavioral_data=item_behavioral,
446
+ agreement_scores=agreement_scores,
447
+ overall_agreement=overall,
448
+ num_annotators=len(annotators),
449
+ status=existing.status if existing else "pending",
450
+ assigned_adjudicator=existing.assigned_adjudicator if existing else None,
451
+ )
452
+ self.logger.info(
453
+ "Auto-routed item %s into adjudication queue (overall agreement=%.3f, "
454
+ "threshold=%.3f, annotators=%d)",
455
+ instance_id_str, overall, self.adj_config.agreement_threshold, len(annotators),
456
+ )
457
+ return True
458
+
459
+ def _serialize_labels(self, label_data: Dict) -> Dict[str, Any]:
460
+ """Convert label annotation data to serializable dict."""
461
+ result = {}
462
+ for key, value in label_data.items():
463
+ # Key might be a Label object or a string
464
+ if hasattr(key, 'get_schema'):
465
+ schema = key.get_schema()
466
+ name = key.get_name()
467
+ if schema not in result:
468
+ result[schema] = {}
469
+ result[schema][name] = value
470
+ elif isinstance(key, str):
471
+ result[key] = value
472
+ else:
473
+ result[str(key)] = value
474
+ return result
475
+
476
+ def _serialize_spans(self, span_data: Dict) -> List[Dict]:
477
+ """Convert span annotation data to serializable list."""
478
+ spans = []
479
+ for key, value in span_data.items():
480
+ if hasattr(key, 'get_schema'):
481
+ spans.append({
482
+ "schema": key.get_schema(),
483
+ "name": key.get_name(),
484
+ "title": key.get_title() if hasattr(key, 'get_title') else "",
485
+ "start": key.get_start(),
486
+ "end": key.get_end(),
487
+ "id": key.get_id(),
488
+ "target_field": key.get_target_field() if hasattr(key, 'get_target_field') else None,
489
+ })
490
+ elif isinstance(value, dict):
491
+ spans.append(value)
492
+ return spans
493
+
494
+ def _serialize_behavioral(self, bd) -> Dict:
495
+ """Convert behavioral data to serializable dict."""
496
+ if hasattr(bd, 'to_dict'):
497
+ return bd.to_dict()
498
+ elif isinstance(bd, dict):
499
+ return bd
500
+ return {}
501
+
502
+ def _compute_agreement(
503
+ self, item_annotations: Dict[str, Dict], scheme_names: List[str]
504
+ ) -> Dict[str, float]:
505
+ """
506
+ Compute per-schema agreement for an item.
507
+
508
+ Uses simple percentage agreement (proportion of annotators who chose
509
+ the most common label). For more sophisticated metrics, simpledorff
510
+ can be used but requires multiple items.
511
+
512
+ Returns:
513
+ Dict mapping schema_name to agreement score (0.0 - 1.0)
514
+ """
515
+ agreement_scores = {}
516
+
517
+ for schema in scheme_names:
518
+ values = []
519
+ for user_id, user_annots in item_annotations.items():
520
+ if schema in user_annots:
521
+ val = user_annots[schema]
522
+ # Normalize to comparable form
523
+ if isinstance(val, dict):
524
+ # Radio stores {label: label} (value is the label string)
525
+ # and multiselect stores {label: value/true}. A label is
526
+ # "selected" when its value is present/truthy. The old
527
+ # filter (v is True / == "true" / == 1) dropped radio's
528
+ # string value, collapsing every annotator to an empty
529
+ # frozenset -> a spurious 1.0 agreement even on total
530
+ # disagreement.
531
+ falsey = (False, None, "", "false", "False", 0, "0")
532
+ selected = frozenset(
533
+ k for k, v in val.items() if v not in falsey
534
+ )
535
+ values.append(selected)
536
+ else:
537
+ values.append(val)
538
+
539
+ if len(values) < 2:
540
+ continue
541
+
542
+ # Compute pairwise agreement (percentage)
543
+ agree_count = 0
544
+ total_pairs = 0
545
+ for i in range(len(values)):
546
+ for j in range(i + 1, len(values)):
547
+ total_pairs += 1
548
+ if values[i] == values[j]:
549
+ agree_count += 1
550
+
551
+ agreement_scores[schema] = (
552
+ agree_count / total_pairs if total_pairs > 0 else 1.0
553
+ )
554
+
555
+ return agreement_scores
556
+
557
+ def _compute_overall_agreement(self, agreement_scores: Dict[str, float]) -> float:
558
+ """Compute overall agreement as the mean of per-schema scores."""
559
+ if not agreement_scores:
560
+ return 1.0
561
+ return sum(agreement_scores.values()) / len(agreement_scores)
562
+
563
+ def get_queue(
564
+ self,
565
+ adjudicator_id: Optional[str] = None,
566
+ filter_status: Optional[str] = None,
567
+ ) -> List[AdjudicationItem]:
568
+ """
569
+ Get the adjudication queue, optionally filtered by status.
570
+
571
+ Args:
572
+ adjudicator_id: Optional adjudicator to filter by assignment
573
+ filter_status: Optional status filter ("pending", "completed", etc.)
574
+
575
+ Returns:
576
+ List of AdjudicationItem objects
577
+ """
578
+ with self._lock:
579
+ if not self._queue_built:
580
+ self.build_queue()
581
+
582
+ items = list(self.queue.values())
583
+
584
+ if filter_status:
585
+ items = [i for i in items if i.status == filter_status]
586
+
587
+ # Sort: pending first, then by agreement (lowest first)
588
+ items.sort(key=lambda x: (
589
+ 0 if x.status == "pending" else 1 if x.status == "in_progress" else 2,
590
+ x.overall_agreement,
591
+ ))
592
+
593
+ return items
594
+
595
+ def get_item(self, instance_id: str) -> Optional[AdjudicationItem]:
596
+ """
597
+ Get full item data for adjudication.
598
+
599
+ Args:
600
+ instance_id: The instance ID to retrieve
601
+
602
+ Returns:
603
+ AdjudicationItem or None if not in queue
604
+ """
605
+ with self._lock:
606
+ if not self._queue_built:
607
+ self.build_queue()
608
+ return self.queue.get(str(instance_id))
609
+
610
+ def get_item_text(self, instance_id: str) -> str:
611
+ """Get the text content for an item."""
612
+ from potato.item_state_management import get_item_state_manager
613
+
614
+ ism = get_item_state_manager()
615
+ item = ism.instance_id_to_instance.get(instance_id)
616
+ if item:
617
+ # Use text_key from config if available
618
+ text_key = self.config.get("item_properties", {}).get("text_key", "text")
619
+ data = item.get_data()
620
+ if isinstance(data, dict) and text_key in data:
621
+ return data[text_key]
622
+ return item.get_text()
623
+ return ""
624
+
625
+ def get_item_data(self, instance_id: str) -> Dict[str, Any]:
626
+ """Get the full raw data for an item."""
627
+ from potato.item_state_management import get_item_state_manager
628
+
629
+ ism = get_item_state_manager()
630
+ item = ism.instance_id_to_instance.get(instance_id)
631
+ if item:
632
+ data = item.get_data()
633
+ if isinstance(data, dict):
634
+ return data
635
+ return {"text": str(data)}
636
+ return {}
637
+
638
+ def get_next_item(self, adjudicator_id: str) -> Optional[AdjudicationItem]:
639
+ """Get the next pending item for an adjudicator."""
640
+ items = self.get_queue(filter_status="pending")
641
+ if items:
642
+ return items[0]
643
+ return None
644
+
645
+ def skip_item(self, instance_id: str, adjudicator_id: str) -> bool:
646
+ """Mark an item as skipped."""
647
+ with self._lock:
648
+ item = self.queue.get(str(instance_id))
649
+ if item:
650
+ item.status = "skipped"
651
+ return True
652
+ return False
653
+
654
+ def submit_decision(self, decision: AdjudicationDecision) -> bool:
655
+ """
656
+ Submit an adjudication decision.
657
+
658
+ Args:
659
+ decision: The AdjudicationDecision to save
660
+
661
+ Returns:
662
+ True if successful
663
+ """
664
+ with self._lock:
665
+ instance_id = str(decision.instance_id)
666
+ self.decisions[instance_id] = decision
667
+
668
+ # Update queue status
669
+ if instance_id in self.queue:
670
+ self.queue[instance_id].status = "completed"
671
+ self.queue[instance_id].assigned_adjudicator = decision.adjudicator_id
672
+
673
+ # Persist to disk
674
+ self._save_decisions()
675
+
676
+ self.logger.info(
677
+ f"Adjudication decision saved for {instance_id} "
678
+ f"by {decision.adjudicator_id}"
679
+ )
680
+ return True
681
+
682
+ def get_stats(self) -> Dict[str, Any]:
683
+ """Get adjudication progress statistics."""
684
+ with self._lock:
685
+ if not self._queue_built:
686
+ self.build_queue()
687
+
688
+ total = len(self.queue)
689
+ completed = sum(
690
+ 1 for i in self.queue.values() if i.status == "completed"
691
+ )
692
+ pending = sum(
693
+ 1 for i in self.queue.values() if i.status == "pending"
694
+ )
695
+ skipped = sum(
696
+ 1 for i in self.queue.values() if i.status == "skipped"
697
+ )
698
+ in_progress = sum(
699
+ 1 for i in self.queue.values() if i.status == "in_progress"
700
+ )
701
+
702
+ avg_agreement = 0.0
703
+ if self.queue:
704
+ avg_agreement = sum(
705
+ i.overall_agreement for i in self.queue.values()
706
+ ) / len(self.queue)
707
+
708
+ # Per-adjudicator stats
709
+ adjudicator_stats = defaultdict(lambda: {"completed": 0, "total_time_ms": 0})
710
+ for decision in self.decisions.values():
711
+ adj_id = decision.adjudicator_id
712
+ adjudicator_stats[adj_id]["completed"] += 1
713
+ adjudicator_stats[adj_id]["total_time_ms"] += decision.time_spent_ms
714
+
715
+ return {
716
+ "total": total,
717
+ "completed": completed,
718
+ "pending": pending,
719
+ "skipped": skipped,
720
+ "in_progress": in_progress,
721
+ "completion_rate": completed / total if total > 0 else 0.0,
722
+ "avg_agreement": avg_agreement,
723
+ "adjudicator_stats": dict(adjudicator_stats),
724
+ }
725
+
726
+ def get_decision(self, instance_id: str) -> Optional[AdjudicationDecision]:
727
+ """Get the decision for an item, if one exists."""
728
+ return self.decisions.get(str(instance_id))
729
+
730
+ # ------------------------------------------------------------------
731
+ # Phase 3: Similarity integration
732
+ # ------------------------------------------------------------------
733
+
734
+ def _precompute_similarities(self) -> None:
735
+ """Precompute embeddings for all items in the item state manager."""
736
+ if not self.similarity_engine or not self.similarity_engine.enabled:
737
+ return
738
+
739
+ from potato.item_state_management import get_item_state_manager
740
+
741
+ try:
742
+ ism = get_item_state_manager()
743
+ item_texts = {}
744
+ for instance_id, item in ism.instance_id_to_instance.items():
745
+ text = self.get_item_text(str(instance_id))
746
+ if text:
747
+ item_texts[str(instance_id)] = text
748
+
749
+ if item_texts:
750
+ count = self.similarity_engine.precompute_embeddings(item_texts)
751
+ self.logger.info(f"Precomputed {count} similarity embeddings")
752
+ except Exception as e:
753
+ self.logger.error(f"Error precomputing similarities: {e}")
754
+
755
+ def get_similar_items(
756
+ self, instance_id: str, include_metadata: bool = True
757
+ ) -> List[Dict[str, Any]]:
758
+ """
759
+ Get similar items for a given instance, enriched with queue metadata.
760
+
761
+ Args:
762
+ instance_id: The reference instance ID
763
+ include_metadata: Whether to include decision/consensus data
764
+
765
+ Returns:
766
+ List of dicts with instance_id, similarity, and optional metadata
767
+ """
768
+ if not self.similarity_engine or not self.similarity_engine.enabled:
769
+ return []
770
+
771
+ similar = self.similarity_engine.find_similar(instance_id)
772
+ results = []
773
+
774
+ for other_id, score in similar:
775
+ entry = {
776
+ "instance_id": other_id,
777
+ "similarity": round(score, 4),
778
+ "text_preview": self.similarity_engine.text_cache.get(
779
+ other_id, ""
780
+ ),
781
+ }
782
+
783
+ if include_metadata:
784
+ queue_item = self.queue.get(other_id)
785
+ decision = self.decisions.get(other_id)
786
+
787
+ entry["in_queue"] = queue_item is not None
788
+ entry["status"] = queue_item.status if queue_item else None
789
+ entry["overall_agreement"] = (
790
+ queue_item.overall_agreement if queue_item else None
791
+ )
792
+
793
+ if decision:
794
+ entry["decision"] = "completed"
795
+ entry["consensus_label"] = None
796
+ else:
797
+ entry["decision"] = None
798
+ if queue_item:
799
+ entry["consensus_label"] = self._get_consensus_label(
800
+ queue_item
801
+ )
802
+ else:
803
+ entry["consensus_label"] = None
804
+
805
+ results.append(entry)
806
+
807
+ return results
808
+
809
+ def _get_consensus_label(self, item: AdjudicationItem) -> Optional[str]:
810
+ """
811
+ Get the majority/consensus label for an item across the first schema.
812
+
813
+ Args:
814
+ item: The AdjudicationItem
815
+
816
+ Returns:
817
+ The most common label value as a string, or None
818
+ """
819
+ if not item.annotations:
820
+ return None
821
+
822
+ # Use the first schema that has values
823
+ for user_annots in item.annotations.values():
824
+ for schema_name in user_annots:
825
+ # Collect all values for this schema
826
+ values = []
827
+ for ua in item.annotations.values():
828
+ val = ua.get(schema_name)
829
+ if val is not None:
830
+ if isinstance(val, dict):
831
+ # Multiselect: use frozenset representation
832
+ selected = sorted(
833
+ k for k, v in val.items()
834
+ if v is True or v == "true" or v == 1
835
+ )
836
+ values.append(", ".join(selected) if selected else str(val))
837
+ else:
838
+ values.append(str(val))
839
+
840
+ if values:
841
+ counter = Counter(values)
842
+ return counter.most_common(1)[0][0]
843
+
844
+ return None
845
+
846
+ # ------------------------------------------------------------------
847
+ # Phase 3: Behavioral signal analysis
848
+ # ------------------------------------------------------------------
849
+
850
+ def get_annotator_signals(
851
+ self, user_id: str, instance_id: str
852
+ ) -> Dict[str, Any]:
853
+ """
854
+ Compute per-annotator quality signals for a specific item.
855
+
856
+ Returns:
857
+ Dict with user_id, instance_id, flags list, and metrics dict
858
+ """
859
+ flags = []
860
+ metrics = {}
861
+
862
+ instance_id = str(instance_id)
863
+ item = self.queue.get(instance_id)
864
+ if not item:
865
+ return {"user_id": user_id, "instance_id": instance_id,
866
+ "flags": [], "metrics": {}}
867
+
868
+ # Get behavioral data for this user on this item
869
+ bd = item.behavioral_data.get(user_id, {})
870
+ if hasattr(bd, 'to_dict'):
871
+ bd = bd.to_dict()
872
+
873
+ total_time = bd.get("total_time_ms", 0)
874
+ metrics["total_time_ms"] = total_time
875
+
876
+ # 1. Speed z-score vs user's typical time
877
+ user_times = self._get_user_times(user_id)
878
+ if len(user_times) >= 3 and total_time > 0:
879
+ mean_time = sum(user_times) / len(user_times)
880
+ std_time = math.sqrt(
881
+ sum((t - mean_time) ** 2 for t in user_times) / len(user_times)
882
+ )
883
+ if std_time > 0:
884
+ z_score = (total_time - mean_time) / std_time
885
+ metrics["speed_z_score"] = round(z_score, 2)
886
+ if z_score < -2.0:
887
+ flags.append({
888
+ "type": "unusually_fast",
889
+ "severity": "high",
890
+ "message": f"Annotation time ({total_time}ms) is {abs(z_score):.1f} std devs below average"
891
+ })
892
+
893
+ # 2. Fast decision warning
894
+ fast_threshold = self.adj_config.fast_decision_warning_ms
895
+ if fast_threshold > 0 and 0 < total_time < fast_threshold:
896
+ flags.append({
897
+ "type": "fast_decision",
898
+ "severity": "medium",
899
+ "message": f"Decision made in {total_time}ms (below {fast_threshold}ms threshold)"
900
+ })
901
+
902
+ # 3. Annotation change count
903
+ raw_changes = bd.get("annotation_changes", [])
904
+ change_count = len(raw_changes) if isinstance(raw_changes, list) else int(raw_changes or 0)
905
+ metrics["annotation_changes"] = change_count
906
+ if change_count > 5:
907
+ flags.append({
908
+ "type": "excessive_changes",
909
+ "severity": "medium",
910
+ "message": f"Made {change_count} annotation changes on this item"
911
+ })
912
+
913
+ # 4. Historical agreement rate with consensus
914
+ agreement_rate = self._compute_user_agreement_rate(user_id)
915
+ if agreement_rate is not None:
916
+ metrics["agreement_rate"] = round(agreement_rate, 3)
917
+ if agreement_rate < 0.4:
918
+ flags.append({
919
+ "type": "low_agreement",
920
+ "severity": "high",
921
+ "message": f"Agreement rate with consensus: {agreement_rate:.0%}"
922
+ })
923
+
924
+ # 5. Similar item consistency
925
+ if self.similarity_engine and self.similarity_engine.enabled:
926
+ inconsistencies = self._check_similar_item_consistency(
927
+ user_id, instance_id
928
+ )
929
+ metrics["similar_item_inconsistencies"] = inconsistencies
930
+ if inconsistencies > 0:
931
+ flags.append({
932
+ "type": "similar_item_inconsistency",
933
+ "severity": "medium",
934
+ "message": f"Different label on {inconsistencies} similar item(s)"
935
+ })
936
+
937
+ return {
938
+ "user_id": user_id,
939
+ "instance_id": instance_id,
940
+ "flags": flags,
941
+ "metrics": metrics,
942
+ }
943
+
944
+ def _get_user_times(self, user_id: str) -> List[float]:
945
+ """Collect all annotation times for a user across queue items."""
946
+ times = []
947
+ for item in self.queue.values():
948
+ bd = item.behavioral_data.get(user_id, {})
949
+ if hasattr(bd, 'to_dict'):
950
+ bd = bd.to_dict()
951
+ t = bd.get("total_time_ms", 0)
952
+ if t > 0:
953
+ times.append(t)
954
+ return times
955
+
956
+ def _compute_user_agreement_rate(self, user_id: str) -> Optional[float]:
957
+ """
958
+ Compute how often a user agrees with the consensus across all items.
959
+
960
+ Returns:
961
+ Float 0-1 or None if insufficient data (needs >= 3 items)
962
+ """
963
+ agree_count = 0
964
+ total_count = 0
965
+
966
+ for item in self.queue.values():
967
+ if user_id not in item.annotations:
968
+ continue
969
+
970
+ consensus = self._get_consensus_label(item)
971
+ if consensus is None:
972
+ continue
973
+
974
+ user_annots = item.annotations[user_id]
975
+ # Check the first schema
976
+ for schema_name, val in user_annots.items():
977
+ if isinstance(val, dict):
978
+ selected = sorted(
979
+ k for k, v in val.items()
980
+ if v is True or v == "true" or v == 1
981
+ )
982
+ user_label = ", ".join(selected) if selected else str(val)
983
+ else:
984
+ user_label = str(val)
985
+
986
+ if user_label == consensus:
987
+ agree_count += 1
988
+ total_count += 1
989
+ break # Only check first schema
990
+
991
+ if total_count < 3:
992
+ return None
993
+
994
+ return agree_count / total_count
995
+
996
+ def _check_similar_item_consistency(
997
+ self, user_id: str, instance_id: str
998
+ ) -> int:
999
+ """
1000
+ Check if user's label on similar items (>0.8 similarity) is consistent.
1001
+
1002
+ Returns:
1003
+ Count of similar items where user's label differs
1004
+ """
1005
+ if not self.similarity_engine:
1006
+ return 0
1007
+
1008
+ similar = self.similarity_engine.find_similar(instance_id)
1009
+ if not similar:
1010
+ return 0
1011
+
1012
+ # Get user's label on the current item
1013
+ item = self.queue.get(instance_id)
1014
+ if not item or user_id not in item.annotations:
1015
+ return 0
1016
+
1017
+ user_annots = item.annotations[user_id]
1018
+ current_label = None
1019
+ current_schema = None
1020
+ for schema_name, val in user_annots.items():
1021
+ current_schema = schema_name
1022
+ if isinstance(val, dict):
1023
+ selected = sorted(
1024
+ k for k, v in val.items()
1025
+ if v is True or v == "true" or v == 1
1026
+ )
1027
+ current_label = ", ".join(selected) if selected else str(val)
1028
+ else:
1029
+ current_label = str(val)
1030
+ break
1031
+
1032
+ if current_label is None:
1033
+ return 0
1034
+
1035
+ inconsistencies = 0
1036
+ for other_id, score in similar:
1037
+ if score < 0.8:
1038
+ break # Results are sorted by score desc
1039
+
1040
+ other_item = self.queue.get(other_id)
1041
+ if not other_item or user_id not in other_item.annotations:
1042
+ continue
1043
+
1044
+ other_annots = other_item.annotations[user_id]
1045
+ other_val = other_annots.get(current_schema)
1046
+ if other_val is None:
1047
+ continue
1048
+
1049
+ if isinstance(other_val, dict):
1050
+ selected = sorted(
1051
+ k for k, v in other_val.items()
1052
+ if v is True or v == "true" or v == 1
1053
+ )
1054
+ other_label = ", ".join(selected) if selected else str(other_val)
1055
+ else:
1056
+ other_label = str(other_val)
1057
+
1058
+ if other_label != current_label:
1059
+ inconsistencies += 1
1060
+
1061
+ return inconsistencies
1062
+
1063
+ def _get_output_dir(self) -> str:
1064
+ """Get the adjudication output directory."""
1065
+ output_dir = self.config.get("output_annotation_dir", "annotation_output")
1066
+ adj_dir = os.path.join(output_dir, self.adj_config.output_subdir)
1067
+ os.makedirs(adj_dir, exist_ok=True)
1068
+ return adj_dir
1069
+
1070
+ def _save_decisions(self) -> None:
1071
+ """Persist all decisions to disk."""
1072
+ try:
1073
+ adj_dir = self._get_output_dir()
1074
+ decisions_file = os.path.join(adj_dir, "decisions.json")
1075
+
1076
+ data = {
1077
+ "decisions": [d.to_dict() for d in self.decisions.values()],
1078
+ "last_updated": datetime.now().isoformat(),
1079
+ }
1080
+
1081
+ with open(decisions_file, "w", encoding="utf-8") as f:
1082
+ json.dump(data, f, indent=2)
1083
+
1084
+ except Exception as e:
1085
+ self.logger.error(f"Failed to save adjudication decisions: {e}")
1086
+
1087
+ def _load_decisions(self) -> None:
1088
+ """Load previously saved decisions from disk."""
1089
+ try:
1090
+ output_dir = self.config.get("output_annotation_dir", "annotation_output")
1091
+ adj_dir = os.path.join(output_dir, self.adj_config.output_subdir)
1092
+ decisions_file = os.path.join(adj_dir, "decisions.json")
1093
+
1094
+ if not os.path.exists(decisions_file):
1095
+ return
1096
+
1097
+ with open(decisions_file, "r", encoding="utf-8") as f:
1098
+ data = json.load(f)
1099
+
1100
+ for d in data.get("decisions", []):
1101
+ decision = AdjudicationDecision.from_dict(d)
1102
+ self.decisions[decision.instance_id] = decision
1103
+
1104
+ self.logger.info(
1105
+ f"Loaded {len(self.decisions)} previous adjudication decisions"
1106
+ )
1107
+
1108
+ except Exception as e:
1109
+ self.logger.warning(f"Failed to load adjudication decisions: {e}")
1110
+
1111
+ def generate_final_dataset(self) -> List[Dict[str, Any]]:
1112
+ """
1113
+ Generate the final dataset by merging unanimous agreements
1114
+ and adjudication decisions.
1115
+
1116
+ Returns:
1117
+ List of item dicts with final labels and provenance
1118
+ """
1119
+ from potato.user_state_management import get_user_state_manager
1120
+ from potato.item_state_management import get_item_state_manager
1121
+
1122
+ usm = get_user_state_manager()
1123
+ ism = get_item_state_manager()
1124
+
1125
+ annotation_schemes = self.config.get("annotation_schemes", [])
1126
+ scheme_names = [s.get("name", "") for s in annotation_schemes]
1127
+ results = []
1128
+
1129
+ for instance_id, item in ism.instance_id_to_instance.items():
1130
+ instance_id_str = str(instance_id)
1131
+ result = {
1132
+ "instance_id": instance_id_str,
1133
+ "item_data": item.get_data() if hasattr(item, 'get_data') else {},
1134
+ }
1135
+
1136
+ # Check if we have an adjudication decision
1137
+ decision = self.decisions.get(instance_id_str)
1138
+ if decision:
1139
+ result["labels"] = decision.label_decisions
1140
+ result["spans"] = decision.span_decisions
1141
+ result["source"] = "adjudicated"
1142
+ result["adjudicator"] = decision.adjudicator_id
1143
+ result["confidence"] = decision.confidence
1144
+ result["provenance"] = decision.source
1145
+ results.append(result)
1146
+ continue
1147
+
1148
+ # Check for unanimous agreement
1149
+ annotators = ism.instance_annotators.get(instance_id, set())
1150
+ annotators = {
1151
+ u for u in annotators
1152
+ if u not in self.adj_config.adjudicator_users
1153
+ }
1154
+
1155
+ if len(annotators) < 2:
1156
+ continue
1157
+
1158
+ # Collect annotations
1159
+ annotations = {}
1160
+ for user_id in annotators:
1161
+ user_state = usm.get_user_state(user_id)
1162
+ if not user_state:
1163
+ continue
1164
+ labels = user_state.instance_id_to_label_to_value.get(
1165
+ instance_id_str, {}
1166
+ )
1167
+ if labels:
1168
+ annotations[user_id] = self._serialize_labels(labels)
1169
+
1170
+ if not annotations:
1171
+ continue
1172
+
1173
+ # Check for unanimity per schema
1174
+ unanimous_labels = {}
1175
+ is_unanimous = True
1176
+ for schema in scheme_names:
1177
+ values = []
1178
+ for user_annots in annotations.values():
1179
+ if schema in user_annots:
1180
+ values.append(json.dumps(user_annots[schema], sort_keys=True))
1181
+
1182
+ if len(values) < 2:
1183
+ continue
1184
+
1185
+ if len(set(values)) == 1:
1186
+ unanimous_labels[schema] = json.loads(values[0])
1187
+ else:
1188
+ is_unanimous = False
1189
+
1190
+ if is_unanimous and unanimous_labels:
1191
+ result["labels"] = unanimous_labels
1192
+ result["source"] = "unanimous"
1193
+ result["num_annotators"] = len(annotators)
1194
+ results.append(result)
1195
+ else:
1196
+ result["labels"] = {}
1197
+ result["source"] = "unresolved"
1198
+ result["num_annotators"] = len(annotators)
1199
+ results.append(result)
1200
+
1201
+ return results
1202
+
1203
+
1204
+ def init_adjudication_manager(config: Dict[str, Any]) -> Optional[AdjudicationManager]:
1205
+ """Initialize the singleton AdjudicationManager."""
1206
+ global _ADJUDICATION_MANAGER
1207
+
1208
+ with _ADJUDICATION_LOCK:
1209
+ if _ADJUDICATION_MANAGER is None:
1210
+ _ADJUDICATION_MANAGER = AdjudicationManager(config)
1211
+
1212
+ return _ADJUDICATION_MANAGER
1213
+
1214
+
1215
+ def get_adjudication_manager() -> Optional[AdjudicationManager]:
1216
+ """Get the singleton AdjudicationManager instance."""
1217
+ return _ADJUDICATION_MANAGER
1218
+
1219
+
1220
+ def clear_adjudication_manager():
1221
+ """Clear the singleton (for testing)."""
1222
+ global _ADJUDICATION_MANAGER
1223
+ with _ADJUDICATION_LOCK:
1224
+ _ADJUDICATION_MANAGER = None
potato/adjudication_export.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adjudication Export CLI
3
+
4
+ Generate final datasets by merging unanimous agreements and adjudication decisions.
5
+
6
+ Usage:
7
+ python -m potato.adjudication_export --config config.yaml --output final_dataset.jsonl
8
+ python -m potato.adjudication_export --config config.yaml --output final.csv --format csv
9
+ python -m potato.adjudication_export --config config.yaml --output final.json --format json
10
+ """
11
+
12
+ import argparse
13
+ import csv
14
+ import json
15
+ import os
16
+ import sys
17
+ import logging
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser(
24
+ description="Export adjudicated dataset from Potato annotation project"
25
+ )
26
+ parser.add_argument(
27
+ "--config", required=True,
28
+ help="Path to the Potato config YAML file"
29
+ )
30
+ parser.add_argument(
31
+ "--output", required=True,
32
+ help="Output file path"
33
+ )
34
+ parser.add_argument(
35
+ "--format", choices=["jsonl", "json", "csv"], default="jsonl",
36
+ help="Output format (default: jsonl)"
37
+ )
38
+ parser.add_argument(
39
+ "--include-unresolved", action="store_true",
40
+ help="Include items without adjudication or consensus"
41
+ )
42
+ parser.add_argument(
43
+ "--verbose", "-v", action="store_true",
44
+ help="Verbose output"
45
+ )
46
+
47
+ args = parser.parse_args()
48
+
49
+ if args.verbose:
50
+ logging.basicConfig(level=logging.DEBUG)
51
+ else:
52
+ logging.basicConfig(level=logging.INFO)
53
+
54
+ # Load config
55
+ from potato.server_utils.config_module import init_config, config
56
+ try:
57
+ init_config(args.config)
58
+ except Exception as e:
59
+ print(f"Error loading config: {e}", file=sys.stderr)
60
+ sys.exit(1)
61
+
62
+ # Initialize state managers
63
+ from potato.item_state_management import init_item_state_manager
64
+ from potato.user_state_management import init_user_state_manager
65
+
66
+ init_user_state_manager(config)
67
+ init_item_state_manager(config)
68
+
69
+ # Load data (this loads items and user annotations from disk)
70
+ # We need a minimal load - just items and user states
71
+ from potato.flask_server import load_instance_data, load_user_data
72
+ load_instance_data(config)
73
+ load_user_data(config)
74
+
75
+ # Initialize adjudication manager
76
+ from potato.adjudication import init_adjudication_manager
77
+ adj_mgr = init_adjudication_manager(config)
78
+
79
+ if not adj_mgr or not adj_mgr.adj_config.enabled:
80
+ print("Adjudication is not enabled in this config.", file=sys.stderr)
81
+ sys.exit(1)
82
+
83
+ # Build queue to compute agreements
84
+ adj_mgr.build_queue()
85
+
86
+ # Generate final dataset
87
+ results = adj_mgr.generate_final_dataset()
88
+
89
+ # Filter unresolved if not requested
90
+ if not args.include_unresolved:
91
+ results = [r for r in results if r.get("source") != "unresolved"]
92
+
93
+ # Write output
94
+ output_path = args.output
95
+ fmt = args.format
96
+
97
+ if fmt == "jsonl":
98
+ with open(output_path, "w") as f:
99
+ for item in results:
100
+ f.write(json.dumps(item) + "\n")
101
+
102
+ elif fmt == "json":
103
+ with open(output_path, "w") as f:
104
+ json.dump(results, f, indent=2)
105
+
106
+ elif fmt == "csv":
107
+ if not results:
108
+ print("No results to export.", file=sys.stderr)
109
+ sys.exit(0)
110
+
111
+ # Flatten for CSV
112
+ fieldnames = set()
113
+ flat_results = []
114
+ for item in results:
115
+ flat = {
116
+ "instance_id": item["instance_id"],
117
+ "source": item.get("source", ""),
118
+ }
119
+ # Flatten labels
120
+ labels = item.get("labels", {})
121
+ for schema, value in labels.items():
122
+ if isinstance(value, dict):
123
+ flat[schema] = json.dumps(value)
124
+ else:
125
+ flat[schema] = value
126
+
127
+ # Add provenance fields
128
+ if "adjudicator" in item:
129
+ flat["adjudicator"] = item["adjudicator"]
130
+ if "confidence" in item:
131
+ flat["confidence"] = item["confidence"]
132
+ if "num_annotators" in item:
133
+ flat["num_annotators"] = item["num_annotators"]
134
+
135
+ fieldnames.update(flat.keys())
136
+ flat_results.append(flat)
137
+
138
+ # Sort fieldnames for consistent output
139
+ fieldnames = sorted(fieldnames)
140
+
141
+ with open(output_path, "w", newline="") as f:
142
+ writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
143
+ writer.writeheader()
144
+ writer.writerows(flat_results)
145
+
146
+ # Summary
147
+ total = len(results)
148
+ unanimous = sum(1 for r in results if r.get("source") == "unanimous")
149
+ adjudicated = sum(1 for r in results if r.get("source") == "adjudicated")
150
+ unresolved = sum(1 for r in results if r.get("source") == "unresolved")
151
+
152
+ print(f"\nExport complete: {output_path}")
153
+ print(f" Total items: {total}")
154
+ print(f" Unanimous: {unanimous}")
155
+ print(f" Adjudicated: {adjudicated}")
156
+ if args.include_unresolved:
157
+ print(f" Unresolved: {unresolved}")
158
+ print(f" Format: {fmt}")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
potato/admin.py ADDED
The diff for this file is too large to render. See raw diff
 
potato/agent_proxy/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Proxy Package
3
+
4
+ Provides agent proxy implementations for live agent interaction during annotation.
5
+ Proxies communicate with AI agent backends (echo, HTTP, OpenAI) and return
6
+ responses to the annotation interface.
7
+
8
+ Usage:
9
+ from potato.agent_proxy import AgentProxyFactory
10
+
11
+ proxy = AgentProxyFactory.create(config)
12
+ context = proxy.start_session("Book a flight to Paris")
13
+ response = proxy.send_message("Hello", context)
14
+ """
15
+
16
+ from .base import AgentMessage, AgentResponse, BaseAgentProxy, AgentProxyFactory
17
+ from .session import (
18
+ AgentSession,
19
+ AgentSessionManager,
20
+ init_agent_session_manager,
21
+ get_agent_session_manager,
22
+ clear_agent_session_manager,
23
+ )
24
+ from .sandbox import SafetySandbox, SandboxViolation
25
+
26
+ # Import proxy implementations to trigger registration
27
+ from . import echo_proxy
28
+ from . import http_proxy
29
+ from . import openai_proxy
30
+ from . import coding_proxy # subprocess_coding + docker_coding
31
+
32
+ __all__ = [
33
+ "AgentMessage",
34
+ "AgentResponse",
35
+ "BaseAgentProxy",
36
+ "AgentProxyFactory",
37
+ "AgentSession",
38
+ "AgentSessionManager",
39
+ "init_agent_session_manager",
40
+ "get_agent_session_manager",
41
+ "clear_agent_session_manager",
42
+ "SafetySandbox",
43
+ "SandboxViolation",
44
+ ]
potato/agent_proxy/base.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Proxy Base Module
3
+
4
+ Provides the abstract base class and data structures for agent proxies,
5
+ plus a factory registry for creating proxy instances from configuration.
6
+
7
+ Agent proxies allow annotators to interact with AI agents live during
8
+ annotation tasks. Each proxy type (echo, http, openai) handles
9
+ communication with a specific kind of agent backend.
10
+ """
11
+
12
+ from abc import ABC, abstractmethod
13
+ from dataclasses import dataclass, field
14
+ from typing import Dict, Any, Optional, List
15
+ import logging
16
+ import time
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class AgentMessage:
23
+ """A single message in an agent conversation."""
24
+ role: str # "user", "agent", "system", "error"
25
+ content: str
26
+ timestamp: float = field(default_factory=time.time)
27
+ metadata: Dict[str, Any] = field(default_factory=dict)
28
+
29
+
30
+ @dataclass
31
+ class AgentResponse:
32
+ """Response from an agent proxy after sending a message."""
33
+ message: AgentMessage
34
+ done: bool = False
35
+ error: Optional[str] = None
36
+
37
+
38
+ class BaseAgentProxy(ABC):
39
+ """
40
+ Abstract base class for agent proxies.
41
+
42
+ Subclasses implement communication with specific agent backends
43
+ (echo for testing, HTTP for generic REST APIs, OpenAI for chat completions).
44
+ """
45
+
46
+ proxy_type: str = ""
47
+
48
+ def __init__(self, config: dict):
49
+ self.config = config
50
+ self._initialize()
51
+
52
+ @abstractmethod
53
+ def _initialize(self):
54
+ """Set up connections, validate config. Called by __init__."""
55
+ pass
56
+
57
+ @abstractmethod
58
+ def start_session(self, task_description: str) -> dict:
59
+ """
60
+ Start a new interaction session.
61
+
62
+ Args:
63
+ task_description: The task the annotator should accomplish with the agent.
64
+
65
+ Returns:
66
+ Proxy-specific session context dict (stored in AgentSession.proxy_context).
67
+ """
68
+ pass
69
+
70
+ @abstractmethod
71
+ def send_message(self, message: str, session_context: dict) -> AgentResponse:
72
+ """
73
+ Send a message to the agent and get a blocking response.
74
+
75
+ Args:
76
+ message: The user's message text.
77
+ session_context: The proxy-specific context from start_session.
78
+
79
+ Returns:
80
+ AgentResponse with the agent's reply.
81
+ """
82
+ pass
83
+
84
+ def end_session(self, session_context: dict):
85
+ """
86
+ Clean up session resources. Override if needed.
87
+
88
+ Args:
89
+ session_context: The proxy-specific context from start_session.
90
+ """
91
+ pass
92
+
93
+
94
+ class AgentProxyFactory:
95
+ """Factory registry for creating agent proxy instances."""
96
+
97
+ _proxies: Dict[str, type] = {}
98
+
99
+ @classmethod
100
+ def register(cls, proxy_type: str, proxy_class: type):
101
+ """Register a proxy type."""
102
+ cls._proxies[proxy_type] = proxy_class
103
+ logger.debug(f"Registered agent proxy type: {proxy_type}")
104
+
105
+ @classmethod
106
+ def create(cls, config: dict) -> BaseAgentProxy:
107
+ """
108
+ Create an agent proxy from configuration.
109
+
110
+ Args:
111
+ config: The full config dict. Reads from config["agent_proxy"].
112
+
113
+ Returns:
114
+ Configured BaseAgentProxy instance.
115
+
116
+ Raises:
117
+ ValueError: If proxy type is unknown or missing.
118
+ """
119
+ agent_config = config.get("agent_proxy", {})
120
+ proxy_type = agent_config.get("type")
121
+
122
+ if not proxy_type:
123
+ raise ValueError("agent_proxy.type is required")
124
+
125
+ if proxy_type not in cls._proxies:
126
+ supported = ", ".join(sorted(cls._proxies.keys()))
127
+ raise ValueError(
128
+ f"Unknown agent proxy type: '{proxy_type}'. "
129
+ f"Supported types: {supported}"
130
+ )
131
+
132
+ proxy_class = cls._proxies[proxy_type]
133
+ return proxy_class(agent_config)
134
+
135
+ @classmethod
136
+ def get_supported_types(cls) -> List[str]:
137
+ """Get list of registered proxy type names."""
138
+ return sorted(cls._proxies.keys())
potato/agent_proxy/coding_proxy.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Coding-agent proxies — LLM plans + sandboxed code execution.
3
+
4
+ Two implementations behind a shared base class:
5
+
6
+ - :class:`SubprocessCodingAgentProxy` (default, ``type: subprocess_coding``)
7
+ runs each Python or shell action in a per-session temp workspace via
8
+ ``subprocess.run`` with a per-step timeout and an output cap. Light
9
+ isolation — suitable for trusted-input research workflows. The
10
+ workspace is sandboxed (separate cwd) but **not** a security boundary;
11
+ malicious code can still touch the host filesystem outside ``cwd``.
12
+
13
+ - :class:`DockerCodingAgentProxy` (``type: docker_coding``) runs each
14
+ action inside an ephemeral Docker container with ``--network=none``,
15
+ ``--memory``, ``--cpus``, ``--read-only`` and a writable workspace
16
+ bind-mounted at ``/work``. Real isolation — survives untrusted code.
17
+ Requires the ``docker`` Python package and a running Docker daemon.
18
+
19
+ Both inherit per-step / per-session / rate-limit enforcement from the
20
+ existing :mod:`potato.agent_proxy.sandbox` framework via the standard
21
+ ``send_message`` flow in ``routes.py:agent_chat_send``.
22
+
23
+ Configuration shape (both proxies):
24
+
25
+ agent_proxy:
26
+ type: subprocess_coding | docker_coding
27
+ llm:
28
+ endpoint_type: ollama
29
+ model: llama3.2:3b
30
+ base_url: http://localhost:11434
31
+ temperature: 0.2
32
+ max_tokens: 800
33
+ execution:
34
+ per_step_timeout: 8 # seconds
35
+ max_output_chars: 4000
36
+ starter_files: {} # {filename: contents} written into workspace
37
+ docker: # only for docker_coding
38
+ image: python:3.11-slim
39
+ memory: 512m
40
+ cpus: 1.0
41
+ network: none # "none" or "bridge"
42
+ sandbox: { max_steps: 20, ... } # standard agent-proxy sandbox knobs
43
+ """
44
+
45
+ from __future__ import annotations
46
+
47
+ import json
48
+ import logging
49
+ import os
50
+ import re
51
+ import shutil
52
+ import subprocess
53
+ import tempfile
54
+ from abc import abstractmethod
55
+ from dataclasses import dataclass
56
+ from typing import Any, Dict, List, Optional
57
+
58
+ from .base import AgentMessage, AgentProxyFactory, AgentResponse, BaseAgentProxy
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ _PLANNER_SYSTEM_PROMPT = (
64
+ "You are an autonomous coding agent. The user will give you a coding task. "
65
+ "Each turn, decide a SINGLE next action and respond with ONLY a JSON object "
66
+ "of the form: {\"thought\": str, \"action\": {\"type\": str, \"code\": str}}. "
67
+ "The 'type' must be one of:\n"
68
+ " - \"python\": run the contents of 'code' as a python script in the workspace\n"
69
+ " - \"shell\": run 'code' as a bash command in the workspace\n"
70
+ " - \"finish\": stop and return your final answer in 'code' (which is then shown "
71
+ "to the user as your conclusion -- no execution happens)\n"
72
+ "Keep each action small and focused. Use 'finish' as soon as the task is done."
73
+ )
74
+
75
+
76
+ @dataclass
77
+ class _ExecResult:
78
+ stdout: str
79
+ stderr: str
80
+ exit_code: Optional[int]
81
+ timed_out: bool = False
82
+ error: Optional[str] = None
83
+
84
+
85
+ class CodingAgentProxy(BaseAgentProxy):
86
+ """Shared planner/executor scaffold for coding agents.
87
+
88
+ Subclasses implement :meth:`_execute` to run an action in their
89
+ chosen sandbox. Each call to :meth:`send_message`:
90
+
91
+ 1. Appends the user message to the running history.
92
+ 2. Asks an LLM (configurable endpoint) for a JSON ``{thought, action}``.
93
+ 3. Hands ``action`` to the subclass for execution.
94
+ 4. Returns a single reply combining thought + tool output.
95
+ """
96
+
97
+ def _initialize(self):
98
+ llm_cfg = self.config.get("llm") or {}
99
+ self.llm_endpoint_type = llm_cfg.get("endpoint_type", "ollama")
100
+ self.llm_model = llm_cfg.get("model")
101
+ self.llm_base_url = llm_cfg.get("base_url")
102
+ self.llm_temperature = llm_cfg.get("temperature", 0.2)
103
+ self.llm_max_tokens = llm_cfg.get("max_tokens", 800)
104
+ # OpenAI-compatible servers (vLLM etc.) ignore the key but the SDK
105
+ # requires a non-empty string. Ollama needs none. Without forwarding
106
+ # this the planner silently failed with "planner_unavailable".
107
+ self.llm_api_key = llm_cfg.get("api_key")
108
+ # Last endpoint init / call error, surfaced to the user instead of
109
+ # an opaque "planner unavailable" message.
110
+ self._llm_error: Optional[str] = None
111
+
112
+ execution_cfg = self.config.get("execution") or {}
113
+ self.per_step_timeout = execution_cfg.get("per_step_timeout", 8)
114
+ self.max_output_chars = execution_cfg.get("max_output_chars", 4000)
115
+ self.starter_files: Dict[str, str] = execution_cfg.get("starter_files", {}) or {}
116
+
117
+ self._llm = None # lazy
118
+
119
+ # ------------------------------------------------------------------
120
+ # LLM lazy-init
121
+ # ------------------------------------------------------------------
122
+
123
+ def _get_llm(self):
124
+ if self._llm is not None:
125
+ return self._llm
126
+ try:
127
+ from potato.ai.ai_endpoint import AIEndpointFactory
128
+
129
+ ai_cfg: Dict[str, Any] = {
130
+ "model": self.llm_model,
131
+ "max_tokens": self.llm_max_tokens,
132
+ "temperature": self.llm_temperature,
133
+ }
134
+ if self.llm_base_url:
135
+ ai_cfg["base_url"] = self.llm_base_url
136
+ # Forward the key for OpenAI-compatible endpoints; vLLM ignores
137
+ # its value but the OpenAI SDK rejects an empty one. Fall back to
138
+ # env then a non-empty placeholder so local servers just work.
139
+ ai_cfg["api_key"] = (
140
+ self.llm_api_key
141
+ or os.environ.get("OPENAI_API_KEY")
142
+ or os.environ.get("ANTHROPIC_API_KEY")
143
+ or "EMPTY"
144
+ )
145
+
146
+ self._llm = AIEndpointFactory.create_endpoint({
147
+ "ai_support": {
148
+ "enabled": True,
149
+ "endpoint_type": self.llm_endpoint_type,
150
+ "ai_config": ai_cfg,
151
+ }
152
+ })
153
+ self._llm_error = None
154
+ except Exception as e:
155
+ logger.warning("CodingAgentProxy: planner LLM init failed: %s", e)
156
+ self._llm_error = f"{type(e).__name__}: {e}"
157
+ self._llm = None
158
+ return self._llm
159
+
160
+ # ------------------------------------------------------------------
161
+ # Lifecycle
162
+ # ------------------------------------------------------------------
163
+
164
+ def start_session(self, task_description: str) -> dict:
165
+ workspace = tempfile.mkdtemp(prefix="potato_coding_agent_")
166
+ for filename, contents in self.starter_files.items():
167
+ target = os.path.join(workspace, filename)
168
+ os.makedirs(os.path.dirname(target) or workspace, exist_ok=True)
169
+ with open(target, "w") as f:
170
+ f.write(contents)
171
+ history = [
172
+ {"role": "system", "content": _PLANNER_SYSTEM_PROMPT},
173
+ {
174
+ "role": "system",
175
+ "content": f"Workspace: {workspace}\nTask: {task_description}",
176
+ },
177
+ ]
178
+ return {
179
+ "workspace": workspace,
180
+ "history": history,
181
+ "step": 0,
182
+ "finished": False,
183
+ }
184
+
185
+ def end_session(self, session_context: dict):
186
+ workspace = session_context.get("workspace") if session_context else None
187
+ if workspace and os.path.isdir(workspace):
188
+ shutil.rmtree(workspace, ignore_errors=True)
189
+
190
+ # ------------------------------------------------------------------
191
+ # Per-turn flow
192
+ # ------------------------------------------------------------------
193
+
194
+ def send_message(self, message: str, session_context: dict) -> AgentResponse:
195
+ if session_context.get("finished"):
196
+ return AgentResponse(
197
+ message=AgentMessage(role="agent", content="(session already finished)"),
198
+ done=True,
199
+ )
200
+
201
+ history: List[Dict[str, str]] = session_context.setdefault("history", [])
202
+ history.append({"role": "user", "content": message})
203
+ session_context["step"] = session_context.get("step", 0) + 1
204
+
205
+ plan = self._plan_next_action(history)
206
+ if plan is None:
207
+ detail = self._llm_error or "no response from planner LLM"
208
+ reply = (
209
+ f"Planner LLM unavailable ({self.llm_endpoint_type}): "
210
+ f"{detail}"
211
+ )
212
+ history.append({"role": "assistant", "content": reply})
213
+ session_context["finished"] = True
214
+ return AgentResponse(
215
+ message=AgentMessage(role="error", content=reply), error="planner_unavailable",
216
+ )
217
+
218
+ thought = plan.get("thought", "")
219
+ action = plan.get("action") or {}
220
+ atype = (action.get("type") or "").strip().lower()
221
+ code = action.get("code") or ""
222
+
223
+ if atype == "finish":
224
+ reply = self._format_finish_reply(thought, code)
225
+ history.append({"role": "assistant", "content": reply})
226
+ session_context["finished"] = True
227
+ return AgentResponse(
228
+ message=AgentMessage(role="agent", content=reply), done=True,
229
+ )
230
+
231
+ if atype not in ("python", "shell"):
232
+ reply = (
233
+ f"{thought}\n\n[invalid action type {atype!r}; try 'python', "
234
+ "'shell', or 'finish']"
235
+ )
236
+ history.append({"role": "assistant", "content": reply})
237
+ return AgentResponse(message=AgentMessage(role="agent", content=reply))
238
+
239
+ result = self._execute(atype, code, session_context)
240
+ reply = self._format_exec_reply(thought, atype, code, result)
241
+ history.append({"role": "assistant", "content": reply})
242
+ return AgentResponse(message=AgentMessage(role="agent", content=reply))
243
+
244
+ # ------------------------------------------------------------------
245
+ # Planning
246
+ # ------------------------------------------------------------------
247
+
248
+ def _plan_next_action(self, history: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
249
+ endpoint = self._get_llm()
250
+ if endpoint is None:
251
+ return None
252
+ try:
253
+ if hasattr(endpoint, "chat_query"):
254
+ raw = endpoint.chat_query(history)
255
+ else:
256
+ flat = "\n".join(f'{m["role"]}: {m["content"]}' for m in history)
257
+ raw = endpoint.query(flat + "\nassistant:", None)
258
+ except Exception as e:
259
+ logger.warning("Planner LLM call failed: %s", e)
260
+ self._llm_error = f"{type(e).__name__}: {e}"
261
+ return None
262
+
263
+ if isinstance(raw, dict):
264
+ return raw # already parsed JSON
265
+ text = str(raw or "").strip()
266
+ if not text:
267
+ return None
268
+ try:
269
+ return json.loads(text)
270
+ except json.JSONDecodeError:
271
+ match = re.search(r"\{.*\}", text, flags=re.DOTALL)
272
+ if match:
273
+ try:
274
+ return json.loads(match.group(0))
275
+ except json.JSONDecodeError:
276
+ pass
277
+ # Last-ditch: treat the whole text as a finish reply.
278
+ return {"thought": "", "action": {"type": "finish", "code": text[:500]}}
279
+
280
+ # ------------------------------------------------------------------
281
+ # Reply formatting
282
+ # ------------------------------------------------------------------
283
+
284
+ def _format_exec_reply(
285
+ self, thought: str, atype: str, code: str, result: _ExecResult
286
+ ) -> str:
287
+ parts: List[str] = []
288
+ if thought:
289
+ parts.append(thought.strip())
290
+ parts.append(f"```{atype}\n{code.strip()}\n```")
291
+ out_block: List[str] = []
292
+ if result.timed_out:
293
+ out_block.append(f"[timeout after {self.per_step_timeout}s]")
294
+ if result.error:
295
+ out_block.append(f"[error: {result.error}]")
296
+ if result.exit_code is not None:
297
+ out_block.append(f"[exit={result.exit_code}]")
298
+ if result.stdout:
299
+ out_block.append("stdout:\n" + self._truncate(result.stdout))
300
+ if result.stderr:
301
+ out_block.append("stderr:\n" + self._truncate(result.stderr))
302
+ if not out_block:
303
+ out_block.append("(no output)")
304
+ parts.append("\n".join(out_block))
305
+ return "\n\n".join(parts)
306
+
307
+ def _format_finish_reply(self, thought: str, final_text: str) -> str:
308
+ parts = []
309
+ if thought:
310
+ parts.append(thought.strip())
311
+ if final_text:
312
+ parts.append(final_text.strip())
313
+ if not parts:
314
+ parts.append("(done)")
315
+ return "\n\n".join(parts)
316
+
317
+ def _truncate(self, text: str) -> str:
318
+ if len(text) <= self.max_output_chars:
319
+ return text
320
+ cut = self.max_output_chars
321
+ return text[:cut] + f"\n[...truncated {len(text) - cut} chars]"
322
+
323
+ # ------------------------------------------------------------------
324
+ # Sandbox-specific execution
325
+ # ------------------------------------------------------------------
326
+
327
+ @abstractmethod
328
+ def _execute(
329
+ self, action_type: str, code: str, session_context: dict
330
+ ) -> _ExecResult:
331
+ """Execute ``code`` (one of 'python' / 'shell') in the sandbox."""
332
+
333
+
334
+ class SubprocessCodingAgentProxy(CodingAgentProxy):
335
+ """Local subprocess-based execution.
336
+
337
+ NOT a security boundary -- the per-step timeout + tempdir cwd is the
338
+ only protection. Use ``DockerCodingAgentProxy`` for untrusted input.
339
+ """
340
+
341
+ proxy_type = "subprocess_coding"
342
+
343
+ def _execute(
344
+ self, action_type: str, code: str, session_context: dict
345
+ ) -> _ExecResult:
346
+ workspace = session_context["workspace"]
347
+ env = self._build_env()
348
+ try:
349
+ if action_type == "python":
350
+ script_path = os.path.join(workspace, "_action.py")
351
+ with open(script_path, "w") as f:
352
+ f.write(code)
353
+ proc = subprocess.run(
354
+ ["python", script_path],
355
+ cwd=workspace,
356
+ env=env,
357
+ capture_output=True,
358
+ text=True,
359
+ timeout=self.per_step_timeout,
360
+ )
361
+ else: # shell
362
+ proc = subprocess.run(
363
+ ["bash", "-c", code],
364
+ cwd=workspace,
365
+ env=env,
366
+ capture_output=True,
367
+ text=True,
368
+ timeout=self.per_step_timeout,
369
+ )
370
+ except subprocess.TimeoutExpired as e:
371
+ return _ExecResult(
372
+ stdout=(e.stdout or "") if isinstance(e.stdout, str) else "",
373
+ stderr=(e.stderr or "") if isinstance(e.stderr, str) else "",
374
+ exit_code=None,
375
+ timed_out=True,
376
+ )
377
+ except FileNotFoundError as e:
378
+ return _ExecResult(stdout="", stderr="", exit_code=None, error=str(e))
379
+ return _ExecResult(
380
+ stdout=proc.stdout or "",
381
+ stderr=proc.stderr or "",
382
+ exit_code=proc.returncode,
383
+ )
384
+
385
+ def _build_env(self) -> Dict[str, str]:
386
+ # Strip env down to a minimal set so subprocess code can't
387
+ # accidentally exfiltrate the host's secrets via env vars.
388
+ keep = {"PATH", "HOME", "LANG", "LC_ALL"}
389
+ return {k: v for k, v in os.environ.items() if k in keep}
390
+
391
+
392
+ class DockerCodingAgentProxy(CodingAgentProxy):
393
+ """Ephemeral-container execution. Real isolation; requires Docker."""
394
+
395
+ proxy_type = "docker_coding"
396
+
397
+ def _initialize(self):
398
+ super()._initialize()
399
+ docker_cfg = self.config.get("docker") or {}
400
+ self.docker_image = docker_cfg.get("image", "python:3.11-slim")
401
+ self.docker_memory = docker_cfg.get("memory", "512m")
402
+ self.docker_cpus = str(docker_cfg.get("cpus", 1.0))
403
+ self.docker_network = docker_cfg.get("network", "none")
404
+ self._docker = None # lazy
405
+ # Sanity check: warn if docker CLI isn't on PATH
406
+ if shutil.which("docker") is None:
407
+ logger.warning(
408
+ "DockerCodingAgentProxy: 'docker' CLI not found on PATH. "
409
+ "Container execution will fail at runtime."
410
+ )
411
+
412
+ def _execute(
413
+ self, action_type: str, code: str, session_context: dict
414
+ ) -> _ExecResult:
415
+ workspace = session_context["workspace"]
416
+ # Materialise the code as a file in the workspace so the container
417
+ # can run it without inline injection through `-c`.
418
+ if action_type == "python":
419
+ target = os.path.join(workspace, "_action.py")
420
+ with open(target, "w") as f:
421
+ f.write(code)
422
+ container_cmd = ["python", "/work/_action.py"]
423
+ else: # shell
424
+ target = os.path.join(workspace, "_action.sh")
425
+ with open(target, "w") as f:
426
+ f.write(code)
427
+ os.chmod(target, 0o755)
428
+ container_cmd = ["bash", "/work/_action.sh"]
429
+
430
+ cmd = [
431
+ "docker", "run", "--rm",
432
+ f"--network={self.docker_network}",
433
+ f"--memory={self.docker_memory}",
434
+ f"--cpus={self.docker_cpus}",
435
+ "--read-only",
436
+ "--tmpfs", "/tmp:exec,size=64m",
437
+ "-v", f"{workspace}:/work",
438
+ "-w", "/work",
439
+ self.docker_image,
440
+ ] + container_cmd
441
+ try:
442
+ proc = subprocess.run(
443
+ cmd,
444
+ capture_output=True,
445
+ text=True,
446
+ timeout=self.per_step_timeout + 5, # docker pull/start overhead
447
+ )
448
+ except subprocess.TimeoutExpired as e:
449
+ return _ExecResult(
450
+ stdout=(e.stdout or "") if isinstance(e.stdout, str) else "",
451
+ stderr=(e.stderr or "") if isinstance(e.stderr, str) else "",
452
+ exit_code=None,
453
+ timed_out=True,
454
+ )
455
+ except FileNotFoundError as e:
456
+ return _ExecResult(stdout="", stderr="", exit_code=None, error=str(e))
457
+ return _ExecResult(
458
+ stdout=proc.stdout or "",
459
+ stderr=proc.stderr or "",
460
+ exit_code=proc.returncode,
461
+ )
462
+
463
+
464
+ # Register both with the factory so configs can refer to them by name.
465
+ AgentProxyFactory.register("subprocess_coding", SubprocessCodingAgentProxy)
466
+ AgentProxyFactory.register("docker_coding", DockerCodingAgentProxy)
potato/agent_proxy/echo_proxy.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Echo Agent Proxy
3
+
4
+ A testing/demo proxy that returns responses from a configurable list.
5
+ Cycles through responses in order, wrapping around when exhausted.
6
+
7
+ Configuration:
8
+ agent_proxy:
9
+ type: echo
10
+ responses:
11
+ - "I understand your request."
12
+ - "Working on it now."
13
+ - "Here's what I found."
14
+ """
15
+
16
+ import logging
17
+
18
+ from .base import BaseAgentProxy, AgentMessage, AgentResponse, AgentProxyFactory
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class EchoProxy(BaseAgentProxy):
24
+ """Test proxy that returns canned responses in order."""
25
+
26
+ proxy_type = "echo"
27
+
28
+ def _initialize(self):
29
+ self.responses = self.config.get("responses", [
30
+ "I understand.",
31
+ "Working on it.",
32
+ "Done!",
33
+ ])
34
+
35
+ def start_session(self, task_description: str) -> dict:
36
+ return {"response_index": 0, "task_description": task_description}
37
+
38
+ def send_message(self, message: str, session_context: dict) -> AgentResponse:
39
+ idx = session_context.get("response_index", 0)
40
+ response_text = self.responses[idx % len(self.responses)]
41
+ session_context["response_index"] = idx + 1
42
+
43
+ return AgentResponse(
44
+ message=AgentMessage(
45
+ role="agent",
46
+ content=response_text,
47
+ )
48
+ )
49
+
50
+ def end_session(self, session_context: dict):
51
+ pass
52
+
53
+
54
+ # Register with factory
55
+ AgentProxyFactory.register("echo", EchoProxy)
potato/agent_proxy/http_proxy.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generic HTTP Agent Proxy
3
+
4
+ POSTs to any REST endpoint with configurable field mapping.
5
+ Supports sending full conversation history and custom headers.
6
+
7
+ Configuration:
8
+ agent_proxy:
9
+ type: http
10
+ url: "http://localhost:8080/chat"
11
+ headers:
12
+ Authorization: "Bearer YOUR_KEY"
13
+ message_key: "message" # key in request body for user message
14
+ response_key: "response" # key in response JSON for agent reply
15
+ session_id_key: "session_id" # key in request/response for session tracking
16
+ send_history: false # whether to send full conversation history
17
+ history_key: "messages" # key for history array in request body
18
+ """
19
+
20
+ import logging
21
+ import uuid
22
+
23
+ import requests
24
+
25
+ from .base import BaseAgentProxy, AgentMessage, AgentResponse, AgentProxyFactory
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class GenericHTTPProxy(BaseAgentProxy):
31
+ """Generic REST API proxy with configurable field mapping."""
32
+
33
+ proxy_type = "http"
34
+
35
+ def _initialize(self):
36
+ self.url = self.config.get("url")
37
+ if not self.url:
38
+ raise ValueError("http proxy requires 'url' in agent_proxy config")
39
+
40
+ self.headers = self.config.get("headers", {})
41
+ self.message_key = self.config.get("message_key", "message")
42
+ self.response_key = self.config.get("response_key", "response")
43
+ self.session_id_key = self.config.get("session_id_key", "session_id")
44
+ self.send_history = self.config.get("send_history", False)
45
+ self.history_key = self.config.get("history_key", "messages")
46
+ self.timeout = self.config.get("sandbox", {}).get(
47
+ "request_timeout_seconds", 60
48
+ )
49
+
50
+ def start_session(self, task_description: str) -> dict:
51
+ return {
52
+ "session_id": str(uuid.uuid4()),
53
+ "task_description": task_description,
54
+ "history": [],
55
+ }
56
+
57
+ def send_message(self, message: str, session_context: dict) -> AgentResponse:
58
+ payload = {
59
+ self.message_key: message,
60
+ self.session_id_key: session_context["session_id"],
61
+ }
62
+
63
+ if self.send_history:
64
+ payload[self.history_key] = session_context.get("history", [])
65
+
66
+ try:
67
+ resp = requests.post(
68
+ self.url,
69
+ json=payload,
70
+ headers=self.headers,
71
+ timeout=self.timeout,
72
+ )
73
+ resp.raise_for_status()
74
+ data = resp.json()
75
+
76
+ response_text = data.get(self.response_key, "")
77
+ if not response_text and isinstance(data, str):
78
+ response_text = data
79
+
80
+ # Update history
81
+ session_context.setdefault("history", []).append(
82
+ {"role": "user", "content": message}
83
+ )
84
+ session_context["history"].append(
85
+ {"role": "agent", "content": response_text}
86
+ )
87
+
88
+ return AgentResponse(
89
+ message=AgentMessage(role="agent", content=str(response_text))
90
+ )
91
+
92
+ except requests.Timeout:
93
+ return AgentResponse(
94
+ message=AgentMessage(role="error", content="Agent request timed out."),
95
+ error="timeout",
96
+ )
97
+ except requests.RequestException as e:
98
+ logger.error(f"HTTP proxy request failed: {e}")
99
+ return AgentResponse(
100
+ message=AgentMessage(
101
+ role="error", content=f"Agent communication error: {e}"
102
+ ),
103
+ error=str(e),
104
+ )
105
+
106
+
107
+ # Register with factory
108
+ AgentProxyFactory.register("http", GenericHTTPProxy)
potato/agent_proxy/openai_proxy.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI Chat Completions Agent Proxy
3
+
4
+ Uses the OpenAI SDK to communicate with chat completion models.
5
+ Maintains conversation history in session context for multi-turn dialogue.
6
+
7
+ Configuration:
8
+ agent_proxy:
9
+ type: openai
10
+ api_key: "${OPENAI_API_KEY}" # or set OPENAI_API_KEY env var
11
+ model: "gpt-4o"
12
+ system_prompt: "You are a helpful travel agent."
13
+ temperature: 0.7
14
+ max_tokens: 1024
15
+ """
16
+
17
+ import logging
18
+ import os
19
+
20
+ from .base import BaseAgentProxy, AgentMessage, AgentResponse, AgentProxyFactory
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class OpenAIChatProxy(BaseAgentProxy):
26
+ """OpenAI Chat Completions proxy."""
27
+
28
+ proxy_type = "openai"
29
+
30
+ def _initialize(self):
31
+ api_key = self.config.get("api_key", "")
32
+ # Support environment variable references like ${OPENAI_API_KEY}
33
+ if api_key.startswith("${") and api_key.endswith("}"):
34
+ env_var = api_key[2:-1]
35
+ api_key = os.environ.get(env_var, "")
36
+
37
+ if not api_key:
38
+ api_key = os.environ.get("OPENAI_API_KEY", "")
39
+
40
+ if not api_key:
41
+ raise ValueError(
42
+ "OpenAI proxy requires api_key in config or OPENAI_API_KEY env var"
43
+ )
44
+
45
+ try:
46
+ import openai
47
+ self.client = openai.OpenAI(api_key=api_key)
48
+ except ImportError:
49
+ raise ImportError(
50
+ "openai package is required for the OpenAI proxy. "
51
+ "Install with: pip install openai"
52
+ )
53
+
54
+ self.model = self.config.get("model", "gpt-4o")
55
+ self.system_prompt = self.config.get("system_prompt", "")
56
+ self.temperature = self.config.get("temperature", 0.7)
57
+ self.max_tokens = self.config.get("max_tokens", 1024)
58
+ self.timeout = self.config.get("sandbox", {}).get(
59
+ "request_timeout_seconds", 60
60
+ )
61
+
62
+ def start_session(self, task_description: str) -> dict:
63
+ messages = []
64
+ if self.system_prompt:
65
+ messages.append({"role": "system", "content": self.system_prompt})
66
+ # Include task description as system context
67
+ messages.append({
68
+ "role": "system",
69
+ "content": f"The user's task: {task_description}",
70
+ })
71
+ return {"messages": messages}
72
+
73
+ def send_message(self, message: str, session_context: dict) -> AgentResponse:
74
+ messages = session_context.get("messages", [])
75
+ messages.append({"role": "user", "content": message})
76
+
77
+ try:
78
+ response = self.client.chat.completions.create(
79
+ model=self.model,
80
+ messages=messages,
81
+ temperature=self.temperature,
82
+ max_tokens=self.max_tokens,
83
+ timeout=self.timeout,
84
+ )
85
+
86
+ content = response.choices[0].message.content or ""
87
+ messages.append({"role": "assistant", "content": content})
88
+ session_context["messages"] = messages
89
+
90
+ return AgentResponse(
91
+ message=AgentMessage(role="agent", content=content)
92
+ )
93
+
94
+ except Exception as e:
95
+ logger.error(f"OpenAI proxy error: {e}")
96
+ return AgentResponse(
97
+ message=AgentMessage(
98
+ role="error", content=f"Agent error: {e}"
99
+ ),
100
+ error=str(e),
101
+ )
102
+
103
+
104
+ # Register with factory
105
+ AgentProxyFactory.register("openai", OpenAIChatProxy)
potato/agent_proxy/sandbox.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Proxy Safety Sandbox
3
+
4
+ Enforces limits on agent interactions: step counts, session timeouts,
5
+ rate limits, and request timeouts. Prevents runaway or abusive sessions.
6
+ """
7
+
8
+ import time
9
+ import threading
10
+ import logging
11
+ from collections import defaultdict
12
+ from typing import Dict, List
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class SandboxViolation(Exception):
18
+ """Raised when a safety limit is exceeded."""
19
+ pass
20
+
21
+
22
+ class SafetySandbox:
23
+ """Enforces safety limits on agent interactions."""
24
+
25
+ def __init__(self, config: dict):
26
+ sandbox_config = config.get("sandbox", {})
27
+ self.max_steps = sandbox_config.get("max_steps", 20)
28
+ self.max_session_seconds = sandbox_config.get("max_session_seconds", 600)
29
+ self.rate_limit_per_minute = sandbox_config.get("rate_limit_per_minute", 10)
30
+ self.request_timeout = sandbox_config.get("request_timeout_seconds", 60)
31
+
32
+ # Sliding window rate limit tracking: user_id -> list of timestamps
33
+ self._rate_windows: Dict[str, List[float]] = defaultdict(list)
34
+ self._lock = threading.Lock()
35
+
36
+ def check_step_limit(self, current_steps: int):
37
+ """Raise SandboxViolation if step limit reached."""
38
+ if current_steps >= self.max_steps:
39
+ raise SandboxViolation(
40
+ f"Step limit reached ({self.max_steps}). "
41
+ f"Please finish the conversation."
42
+ )
43
+
44
+ def check_session_timeout(self, session_start: float):
45
+ """Raise SandboxViolation if session has timed out."""
46
+ elapsed = time.time() - session_start
47
+ if elapsed > self.max_session_seconds:
48
+ raise SandboxViolation(
49
+ f"Session timeout ({self.max_session_seconds}s). "
50
+ f"Please finish the conversation."
51
+ )
52
+
53
+ def check_rate_limit(self, user_id: str):
54
+ """Raise SandboxViolation if user is sending too fast."""
55
+ now = time.time()
56
+ window_start = now - 60.0
57
+
58
+ with self._lock:
59
+ # Remove old entries outside the 1-minute window
60
+ timestamps = self._rate_windows[user_id]
61
+ self._rate_windows[user_id] = [
62
+ t for t in timestamps if t > window_start
63
+ ]
64
+
65
+ if len(self._rate_windows[user_id]) >= self.rate_limit_per_minute:
66
+ raise SandboxViolation(
67
+ f"Rate limit exceeded ({self.rate_limit_per_minute}/min). "
68
+ f"Please wait before sending another message."
69
+ )
70
+
71
+ # Record this request
72
+ self._rate_windows[user_id].append(now)
73
+
74
+ def get_request_timeout(self) -> float:
75
+ """Get the timeout in seconds for proxy HTTP requests."""
76
+ return self.request_timeout
potato/agent_proxy/session.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Session Manager
3
+
4
+ Thread-safe singleton that tracks active agent interaction sessions.
5
+ Each session maps a (user_id, instance_id) pair to an AgentSession
6
+ containing the proxy, conversation history, and step count.
7
+
8
+ Follows the same singleton pattern as ItemStateManager and UserStateManager.
9
+ """
10
+
11
+ import threading
12
+ import time
13
+ import logging
14
+ from dataclasses import dataclass, field
15
+ from typing import Dict, List, Optional, Tuple
16
+
17
+ from .base import AgentMessage, BaseAgentProxy
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class AgentSession:
24
+ """An active agent interaction session."""
25
+ user_id: str
26
+ instance_id: str
27
+ proxy: BaseAgentProxy
28
+ task_description: str
29
+ proxy_context: dict = field(default_factory=dict)
30
+ messages: List[AgentMessage] = field(default_factory=list)
31
+ step_count: int = 0
32
+ started_at: float = field(default_factory=time.time)
33
+ finished: bool = False
34
+
35
+
36
+ class AgentSessionManager:
37
+ """Thread-safe manager for active agent sessions."""
38
+
39
+ def __init__(self, config: dict):
40
+ self.config = config
41
+ self._sessions: Dict[Tuple[str, str], AgentSession] = {}
42
+ self._lock = threading.RLock()
43
+
44
+ def create_session(
45
+ self,
46
+ user_id: str,
47
+ instance_id: str,
48
+ proxy: BaseAgentProxy,
49
+ task_description: str,
50
+ ) -> AgentSession:
51
+ """Create a new session for a user/instance pair."""
52
+ with self._lock:
53
+ key = (user_id, instance_id)
54
+ if key in self._sessions and not self._sessions[key].finished:
55
+ logger.warning(
56
+ f"Session already exists for {key}, returning existing"
57
+ )
58
+ return self._sessions[key]
59
+
60
+ proxy_context = proxy.start_session(task_description)
61
+ session = AgentSession(
62
+ user_id=user_id,
63
+ instance_id=instance_id,
64
+ proxy=proxy,
65
+ task_description=task_description,
66
+ proxy_context=proxy_context,
67
+ )
68
+ self._sessions[key] = session
69
+ logger.debug(f"Created agent session for {key}")
70
+ return session
71
+
72
+ def get_session(
73
+ self, user_id: str, instance_id: str
74
+ ) -> Optional[AgentSession]:
75
+ """Get an active session, or None if not found."""
76
+ with self._lock:
77
+ return self._sessions.get((user_id, instance_id))
78
+
79
+ def remove_session(self, user_id: str, instance_id: str):
80
+ """Remove a session and clean up proxy resources."""
81
+ with self._lock:
82
+ key = (user_id, instance_id)
83
+ session = self._sessions.pop(key, None)
84
+ if session:
85
+ try:
86
+ session.proxy.end_session(session.proxy_context)
87
+ except Exception as e:
88
+ logger.warning(f"Error ending proxy session for {key}: {e}")
89
+ logger.debug(f"Removed agent session for {key}")
90
+
91
+
92
+ # Singleton management
93
+ _AGENT_SESSION_MANAGER: Optional[AgentSessionManager] = None
94
+ _AGENT_SESSION_MANAGER_LOCK = threading.Lock()
95
+
96
+
97
+ def init_agent_session_manager(config: dict) -> AgentSessionManager:
98
+ """Initialize the singleton AgentSessionManager."""
99
+ global _AGENT_SESSION_MANAGER
100
+ if _AGENT_SESSION_MANAGER is None:
101
+ with _AGENT_SESSION_MANAGER_LOCK:
102
+ if _AGENT_SESSION_MANAGER is None:
103
+ _AGENT_SESSION_MANAGER = AgentSessionManager(config)
104
+ return _AGENT_SESSION_MANAGER
105
+
106
+
107
+ def get_agent_session_manager() -> AgentSessionManager:
108
+ """Get the singleton AgentSessionManager."""
109
+ global _AGENT_SESSION_MANAGER
110
+ if _AGENT_SESSION_MANAGER is None:
111
+ raise ValueError("AgentSessionManager has not been initialized yet!")
112
+ return _AGENT_SESSION_MANAGER
113
+
114
+
115
+ def clear_agent_session_manager():
116
+ """Clear the singleton instance (for testing)."""
117
+ global _AGENT_SESSION_MANAGER
118
+ with _AGENT_SESSION_MANAGER_LOCK:
119
+ _AGENT_SESSION_MANAGER = None
potato/agent_runner.py ADDED
@@ -0,0 +1,1008 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Live Agent Runner
3
+
4
+ Manages an AI agent that browses the web via Playwright, controlled by an LLM.
5
+ Annotators can observe, pause, instruct, or take over the agent in real time.
6
+
7
+ The agent loop runs in a background thread with its own asyncio event loop.
8
+ Communication with Flask routes happens through thread-safe state and queues.
9
+ """
10
+
11
+ import asyncio
12
+ import base64
13
+ import json
14
+ import logging
15
+ import os
16
+ import threading
17
+ import time
18
+ import uuid
19
+ from dataclasses import dataclass, field
20
+ from enum import Enum
21
+ from queue import Queue, Empty
22
+ from typing import Any, Callable, Dict, List, Optional
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class AgentState(Enum):
28
+ """States of the agent lifecycle."""
29
+ IDLE = "idle"
30
+ RUNNING = "running"
31
+ PAUSED = "paused"
32
+ TAKEOVER = "takeover"
33
+ COMPLETED = "completed"
34
+ ERROR = "error"
35
+
36
+
37
+ @dataclass
38
+ class AgentStep:
39
+ """A single step in the agent's execution."""
40
+ step_index: int
41
+ screenshot_path: str
42
+ action: Dict[str, Any]
43
+ thought: str
44
+ observation: str
45
+ timestamp: float
46
+ url: str = ""
47
+ viewport: Optional[Dict[str, int]] = None
48
+ coordinates: Optional[Dict[str, int]] = None
49
+ element: Optional[Dict[str, Any]] = None
50
+ annotator_instruction: Optional[str] = None
51
+
52
+ def to_dict(self) -> Dict[str, Any]:
53
+ d = {
54
+ "step_index": self.step_index,
55
+ "screenshot_url": self.screenshot_path,
56
+ "action_type": self.action.get("type", "unknown"),
57
+ "action": self.action,
58
+ "thought": self.thought,
59
+ "observation": self.observation,
60
+ "timestamp": self.timestamp,
61
+ "url": self.url,
62
+ }
63
+ if self.viewport:
64
+ d["viewport"] = self.viewport
65
+ if self.coordinates:
66
+ d["coordinates"] = self.coordinates
67
+ if self.element:
68
+ d["element"] = self.element
69
+ if self.annotator_instruction:
70
+ d["annotator_instruction"] = self.annotator_instruction
71
+ return d
72
+
73
+
74
+ @dataclass
75
+ class AgentConfig:
76
+ """Configuration for the agent runner."""
77
+ max_steps: int = 30
78
+ step_delay: float = 1.0
79
+ viewport_width: int = 1280
80
+ viewport_height: int = 720
81
+ system_prompt: str = ""
82
+ model: str = "claude-sonnet-4-20250514"
83
+ api_key: str = ""
84
+ max_tokens: int = 4096
85
+ temperature: float = 0.3
86
+ endpoint_type: str = "anthropic_vision"
87
+ history_window: int = 5 # Number of recent steps to include in LLM context
88
+ timeout: int = 60 # Per-request timeout in seconds
89
+
90
+ base_url: str = "" # For Ollama: server URL
91
+
92
+ @classmethod
93
+ def from_config(cls, config: Dict[str, Any]) -> "AgentConfig":
94
+ """Create AgentConfig from a live_agent YAML config dict."""
95
+ ai_config = config.get("ai_config", {})
96
+ viewport = config.get("viewport", {})
97
+ endpoint_type = config.get("endpoint_type", "anthropic_vision")
98
+
99
+ # API key: Ollama doesn't need one; OpenAI-compatible servers
100
+ # (e.g. vLLM) ignore it but the SDK requires a non-empty string.
101
+ if endpoint_type == "ollama_vision":
102
+ api_key = ai_config.get("api_key", "")
103
+ default_model = "gemma3:4b"
104
+ elif endpoint_type == "openai_vision":
105
+ api_key = ai_config.get("api_key", os.environ.get("OPENAI_API_KEY", "EMPTY"))
106
+ default_model = "" # must be set explicitly (e.g. served model id)
107
+ else:
108
+ api_key = ai_config.get("api_key", os.environ.get("ANTHROPIC_API_KEY", ""))
109
+ default_model = "claude-sonnet-4-20250514"
110
+
111
+ return cls(
112
+ max_steps=config.get("max_steps", 30),
113
+ step_delay=config.get("step_delay", 1.0),
114
+ viewport_width=viewport.get("width", 1280),
115
+ viewport_height=viewport.get("height", 720),
116
+ system_prompt=config.get("system_prompt", DEFAULT_SYSTEM_PROMPT),
117
+ model=ai_config.get("model", default_model),
118
+ api_key=api_key,
119
+ max_tokens=ai_config.get("max_tokens", 4096),
120
+ temperature=ai_config.get("temperature", 0.3),
121
+ endpoint_type=endpoint_type,
122
+ history_window=config.get("history_window", 5),
123
+ timeout=ai_config.get("timeout", 60),
124
+ base_url=ai_config.get("base_url", "http://localhost:11434"),
125
+ )
126
+
127
+
128
+ DEFAULT_SYSTEM_PROMPT = """You are a web browsing agent. You can see screenshots of web pages and take actions to complete tasks.
129
+
130
+ For each step, analyze the current screenshot and respond with a JSON object:
131
+ {
132
+ "thought": "Your reasoning about what you see and what to do next",
133
+ "action": {
134
+ "type": "click|type|scroll|navigate|wait|done",
135
+ // For click: "x": 100, "y": 200
136
+ // For type: "text": "hello world"
137
+ // For scroll: "direction": "up|down", "amount": 300
138
+ // For navigate: "url": "https://..."
139
+ // For wait: (no extra fields)
140
+ // For done: "summary": "Task completed because..."
141
+ }
142
+ }
143
+
144
+ Always respond with valid JSON only. No markdown, no extra text."""
145
+
146
+
147
+ class AgentRunner:
148
+ """
149
+ Runs an AI agent that browses the web via Playwright.
150
+
151
+ The agent loop:
152
+ 1. Takes a screenshot
153
+ 2. Sends it to the LLM with context/history
154
+ 3. Parses the LLM response for an action
155
+ 4. Executes the action via Playwright
156
+ 5. Emits events to all listeners (for SSE)
157
+ 6. Repeats until done, error, or max_steps
158
+
159
+ Thread-safe control methods allow pause/resume/instruct/takeover.
160
+ """
161
+
162
+ def __init__(self, session_id: str, config: AgentConfig, screenshot_dir: str):
163
+ self.session_id = session_id
164
+ self.config = config
165
+ self.screenshot_dir = screenshot_dir
166
+
167
+ # State
168
+ self._state = AgentState.IDLE
169
+ self._state_lock = threading.Lock()
170
+ self._steps: List[AgentStep] = []
171
+ self._error: Optional[str] = None
172
+
173
+ # Control
174
+ self._pause_event = threading.Event()
175
+ self._pause_event.set() # Not paused initially
176
+ self._stop_flag = threading.Event()
177
+ self._instruction_queue: Queue = Queue()
178
+ self._takeover_actions: Queue = Queue()
179
+
180
+ # Listeners for SSE
181
+ self._listeners: List[Callable] = []
182
+ self._listeners_lock = threading.Lock()
183
+
184
+ # Annotator interactions log
185
+ self._interactions: List[Dict[str, Any]] = []
186
+
187
+ # Playwright session (set during run)
188
+ self._playwright_session = None
189
+ self._llm_client = None
190
+
191
+ # Background thread
192
+ self._thread: Optional[threading.Thread] = None
193
+
194
+ @property
195
+ def state(self) -> AgentState:
196
+ with self._state_lock:
197
+ return self._state
198
+
199
+ @state.setter
200
+ def state(self, new_state: AgentState):
201
+ with self._state_lock:
202
+ old_state = self._state
203
+ self._state = new_state
204
+ self._emit_event("state_change", {
205
+ "old_state": old_state.value,
206
+ "new_state": new_state.value,
207
+ "timestamp": time.time(),
208
+ })
209
+
210
+ @property
211
+ def steps(self) -> List[AgentStep]:
212
+ return list(self._steps)
213
+
214
+ @property
215
+ def step_count(self) -> int:
216
+ return len(self._steps)
217
+
218
+ @property
219
+ def error(self) -> Optional[str]:
220
+ return self._error
221
+
222
+ # --- Control methods (thread-safe) ---
223
+
224
+ def pause(self):
225
+ """Pause the agent loop after the current step completes."""
226
+ if self.state == AgentState.RUNNING:
227
+ self._pause_event.clear()
228
+ self.state = AgentState.PAUSED
229
+ logger.info(f"[{self.session_id}] Agent paused")
230
+
231
+ def resume(self):
232
+ """Resume a paused agent."""
233
+ if self.state == AgentState.PAUSED:
234
+ self.state = AgentState.RUNNING
235
+ self._pause_event.set()
236
+ logger.info(f"[{self.session_id}] Agent resumed")
237
+
238
+ def inject_instruction(self, instruction: str):
239
+ """Send an instruction to the agent (processed at next step)."""
240
+ self._instruction_queue.put(instruction)
241
+ self._interactions.append({
242
+ "type": "instruction",
243
+ "text": instruction,
244
+ "timestamp": time.time(),
245
+ "step_index": self.step_count,
246
+ })
247
+ self._emit_event("instruction_received", {"instruction": instruction})
248
+ logger.info(f"[{self.session_id}] Instruction injected: {instruction[:100]}")
249
+
250
+ def enter_takeover(self):
251
+ """Switch to manual takeover mode."""
252
+ if self.state in (AgentState.RUNNING, AgentState.PAUSED):
253
+ self._pause_event.clear() # Pause the agent loop
254
+ self.state = AgentState.TAKEOVER
255
+ self._interactions.append({
256
+ "type": "takeover_start",
257
+ "timestamp": time.time(),
258
+ "step_index": self.step_count,
259
+ })
260
+ logger.info(f"[{self.session_id}] Takeover mode entered")
261
+
262
+ def exit_takeover(self):
263
+ """Exit manual takeover and resume the agent."""
264
+ if self.state == AgentState.TAKEOVER:
265
+ self._interactions.append({
266
+ "type": "takeover_end",
267
+ "timestamp": time.time(),
268
+ "step_index": self.step_count,
269
+ })
270
+ self.state = AgentState.RUNNING
271
+ self._pause_event.set()
272
+ logger.info(f"[{self.session_id}] Takeover mode exited")
273
+
274
+ def submit_manual_action(self, action: Dict[str, Any]):
275
+ """Submit a manual action during takeover mode."""
276
+ if self.state == AgentState.TAKEOVER:
277
+ self._takeover_actions.put(action)
278
+
279
+ def stop(self):
280
+ """Stop the agent loop."""
281
+ self._stop_flag.set()
282
+ self._pause_event.set() # Unblock if paused
283
+ logger.info(f"[{self.session_id}] Stop requested")
284
+
285
+ # --- Listener management ---
286
+
287
+ def add_listener(self, callback: Callable):
288
+ """Add an SSE listener callback."""
289
+ with self._listeners_lock:
290
+ self._listeners.append(callback)
291
+
292
+ def remove_listener(self, callback: Callable):
293
+ """Remove an SSE listener callback."""
294
+ with self._listeners_lock:
295
+ self._listeners = [l for l in self._listeners if l is not callback]
296
+
297
+ def _emit_event(self, event_type: str, data: Dict[str, Any]):
298
+ """Emit an event to all listeners."""
299
+ event = {"type": event_type, "data": data, "session_id": self.session_id}
300
+ with self._listeners_lock:
301
+ for listener in self._listeners:
302
+ try:
303
+ listener(event)
304
+ except Exception as e:
305
+ logger.warning(f"Listener error: {e}")
306
+
307
+ # --- Main agent loop ---
308
+
309
+ def start(self, task_description: str, start_url: str):
310
+ """Start the agent in a background thread."""
311
+ if self.state != AgentState.IDLE:
312
+ raise RuntimeError(f"Cannot start agent in state {self.state}")
313
+
314
+ self._thread = threading.Thread(
315
+ target=self._run_thread,
316
+ args=(task_description, start_url),
317
+ daemon=True,
318
+ name=f"agent-{self.session_id}",
319
+ )
320
+ self._thread.start()
321
+
322
+ def _run_thread(self, task_description: str, start_url: str):
323
+ """Thread target: runs the async agent loop."""
324
+ loop = asyncio.new_event_loop()
325
+ asyncio.set_event_loop(loop)
326
+ try:
327
+ loop.run_until_complete(self._run_async(task_description, start_url))
328
+ except Exception as e:
329
+ logger.error(f"[{self.session_id}] Agent thread error: {e}")
330
+ self._error = str(e)
331
+ self.state = AgentState.ERROR
332
+ self._emit_event("error", {"message": str(e)})
333
+ finally:
334
+ loop.close()
335
+
336
+ async def _run_async(self, task_description: str, start_url: str):
337
+ """Async agent loop."""
338
+ from potato.web_playwright import PlaywrightSession
339
+
340
+ self.state = AgentState.RUNNING
341
+
342
+ # Initialize Playwright
343
+ self._playwright_session = PlaywrightSession(
344
+ width=self.config.viewport_width,
345
+ height=self.config.viewport_height,
346
+ )
347
+ started = await self._playwright_session.start(start_url)
348
+ if not started:
349
+ raise RuntimeError("Failed to start Playwright browser session")
350
+
351
+ # Initialize LLM client
352
+ self._init_llm_client()
353
+
354
+ self._emit_event("started", {
355
+ "task": task_description,
356
+ "start_url": start_url,
357
+ "max_steps": self.config.max_steps,
358
+ })
359
+
360
+ try:
361
+ for step_index in range(self.config.max_steps):
362
+ # Check stop flag
363
+ if self._stop_flag.is_set():
364
+ logger.info(f"[{self.session_id}] Stopped by user")
365
+ break
366
+
367
+ # Wait if paused (blocks until resume/stop)
368
+ while not self._pause_event.is_set():
369
+ if self._stop_flag.is_set():
370
+ break
371
+ # Handle takeover actions while paused in takeover mode
372
+ if self.state == AgentState.TAKEOVER:
373
+ await self._process_takeover_actions()
374
+ await asyncio.sleep(0.1)
375
+
376
+ if self._stop_flag.is_set():
377
+ break
378
+
379
+ # Check for injected instructions
380
+ instruction = None
381
+ try:
382
+ instruction = self._instruction_queue.get_nowait()
383
+ except Empty:
384
+ pass
385
+
386
+ # Execute one agent step
387
+ step = await self._agent_step(
388
+ step_index, task_description, instruction
389
+ )
390
+ self._steps.append(step)
391
+
392
+ # Check if agent decided it's done
393
+ if step.action.get("type") == "done":
394
+ logger.info(f"[{self.session_id}] Agent completed task")
395
+ break
396
+
397
+ # Step delay
398
+ if self.config.step_delay > 0:
399
+ await asyncio.sleep(self.config.step_delay)
400
+
401
+ self.state = AgentState.COMPLETED
402
+ self._emit_event("complete", {
403
+ "total_steps": len(self._steps),
404
+ "final_url": (await self._playwright_session.get_state()).get("url", ""),
405
+ })
406
+
407
+ finally:
408
+ await self._playwright_session.stop()
409
+ self._playwright_session = None
410
+
411
+ async def _agent_step(
412
+ self,
413
+ step_index: int,
414
+ task_description: str,
415
+ instruction: Optional[str] = None,
416
+ ) -> AgentStep:
417
+ """Execute a single agent step: screenshot → LLM → action → emit."""
418
+
419
+ # 1. Take screenshot
420
+ screenshot_bytes = await self._playwright_session.screenshot()
421
+ if not screenshot_bytes:
422
+ raise RuntimeError("Failed to capture screenshot")
423
+
424
+ screenshot_path = os.path.join(
425
+ self.screenshot_dir, f"step_{step_index:03d}.png"
426
+ )
427
+ os.makedirs(os.path.dirname(screenshot_path), exist_ok=True)
428
+ with open(screenshot_path, "wb") as f:
429
+ f.write(screenshot_bytes)
430
+
431
+ # 2. Get page state
432
+ page_state = await self._playwright_session.get_state()
433
+
434
+ # 3. Emit thinking event
435
+ self._emit_event("thinking", {
436
+ "step_index": step_index,
437
+ "screenshot_url": screenshot_path,
438
+ "url": page_state.get("url", ""),
439
+ })
440
+
441
+ # 4. Build messages and query LLM
442
+ screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
443
+ messages = self._build_llm_messages(
444
+ screenshot_b64, task_description, instruction
445
+ )
446
+ llm_response = self._query_llm(messages)
447
+
448
+ # 5. Parse action from response
449
+ thought, action = self._parse_action(llm_response)
450
+
451
+ # 6. Execute action
452
+ observation = await self._execute_action(action)
453
+
454
+ # 7. Build step
455
+ step = AgentStep(
456
+ step_index=step_index,
457
+ screenshot_path=screenshot_path,
458
+ action=action,
459
+ thought=thought,
460
+ observation=observation,
461
+ timestamp=time.time(),
462
+ url=page_state.get("url", ""),
463
+ viewport=page_state.get("viewport"),
464
+ coordinates=_extract_coordinates(action),
465
+ annotator_instruction=instruction,
466
+ )
467
+
468
+ # 8. Emit step event
469
+ self._emit_event("step", step.to_dict())
470
+
471
+ return step
472
+
473
+ def _build_llm_messages(
474
+ self,
475
+ screenshot_b64: str,
476
+ task_description: str,
477
+ instruction: Optional[str] = None,
478
+ ) -> List[Dict[str, Any]]:
479
+ """Build message list for the LLM vision API."""
480
+ messages = []
481
+
482
+ # System message
483
+ system_prompt = self.config.system_prompt or DEFAULT_SYSTEM_PROMPT
484
+ messages.append({"role": "system", "content": system_prompt})
485
+
486
+ # Task description
487
+ task_msg = f"Task: {task_description}"
488
+ if instruction:
489
+ task_msg += f"\n\nAnnotator instruction: {instruction}"
490
+
491
+ # Include recent step history
492
+ history_steps = self._steps[-self.config.history_window:]
493
+ if history_steps:
494
+ history_parts = []
495
+ for s in history_steps:
496
+ entry = f"Step {s.step_index}: thought='{s.thought}', action={json.dumps(s.action)}, observation='{s.observation}'"
497
+ history_parts.append(entry)
498
+ task_msg += "\n\nRecent history:\n" + "\n".join(history_parts)
499
+
500
+ messages.append({"role": "user", "content": task_msg})
501
+
502
+ # Current screenshot (as a separate user message with image)
503
+ messages.append({
504
+ "role": "user",
505
+ "content": [
506
+ {
507
+ "type": "image",
508
+ "source": {
509
+ "type": "base64",
510
+ "media_type": "image/png",
511
+ "data": screenshot_b64,
512
+ },
513
+ },
514
+ {
515
+ "type": "text",
516
+ "text": f"Current page screenshot (step {len(self._steps)}). What action should I take next?",
517
+ },
518
+ ],
519
+ })
520
+
521
+ return messages
522
+
523
+ def _init_llm_client(self):
524
+ """Initialize the LLM client based on endpoint_type."""
525
+ if self.config.endpoint_type == "anthropic_vision":
526
+ try:
527
+ import anthropic
528
+ except ImportError:
529
+ raise RuntimeError(
530
+ "anthropic package required. Install with: pip install anthropic"
531
+ )
532
+ api_key = self.config.api_key or os.environ.get("ANTHROPIC_API_KEY")
533
+ if not api_key:
534
+ raise RuntimeError(
535
+ "Anthropic API key required. Set in config or ANTHROPIC_API_KEY env var."
536
+ )
537
+ self._llm_client = anthropic.Anthropic(
538
+ api_key=api_key, timeout=self.config.timeout
539
+ )
540
+ elif self.config.endpoint_type == "ollama_vision":
541
+ try:
542
+ import ollama
543
+ except ImportError:
544
+ raise RuntimeError(
545
+ "ollama package required. Install with: pip install ollama"
546
+ )
547
+ host = self.config.base_url or "http://localhost:11434"
548
+ self._llm_client = ollama.Client(
549
+ host=host, timeout=self.config.timeout
550
+ )
551
+ # Verify connectivity
552
+ try:
553
+ self._llm_client.list()
554
+ logger.info(f"Connected to Ollama at {host}, model: {self.config.model}")
555
+ except Exception as e:
556
+ raise RuntimeError(f"Failed to connect to Ollama at {host}: {e}")
557
+ elif self.config.endpoint_type == "openai_vision":
558
+ try:
559
+ from openai import OpenAI
560
+ except ImportError:
561
+ raise RuntimeError(
562
+ "openai package required. Install with: pip install openai"
563
+ )
564
+ base_url = self.config.base_url or "https://api.openai.com/v1"
565
+ self._llm_client = OpenAI(
566
+ base_url=base_url,
567
+ api_key=self.config.api_key or "EMPTY",
568
+ timeout=self.config.timeout,
569
+ )
570
+ try:
571
+ self._llm_client.models.list()
572
+ logger.info(
573
+ f"Connected to OpenAI-compatible endpoint at {base_url}, "
574
+ f"model: {self.config.model}"
575
+ )
576
+ except Exception as e:
577
+ # Non-fatal: some servers gate /models; the chat call will
578
+ # surface a real error if the endpoint is truly unreachable.
579
+ logger.warning(
580
+ f"Could not list models at {base_url} ({e}); continuing."
581
+ )
582
+ else:
583
+ raise RuntimeError(
584
+ f"Unsupported endpoint_type: {self.config.endpoint_type}. "
585
+ f"Supported: 'anthropic_vision', 'ollama_vision', 'openai_vision'."
586
+ )
587
+
588
+ def _query_llm(self, messages: List[Dict[str, Any]]) -> str:
589
+ """Send messages to the LLM and return the text response."""
590
+ if self.config.endpoint_type == "anthropic_vision":
591
+ return self._query_anthropic(messages)
592
+ elif self.config.endpoint_type == "ollama_vision":
593
+ return self._query_ollama(messages)
594
+ elif self.config.endpoint_type == "openai_vision":
595
+ return self._query_openai(messages)
596
+ raise RuntimeError(f"Unsupported endpoint type: {self.config.endpoint_type}")
597
+
598
+ def _query_openai(self, messages: List[Dict[str, Any]]) -> str:
599
+ """Query an OpenAI-compatible vision endpoint (OpenAI, vLLM, etc.).
600
+
601
+ Converts the internal Anthropic-style message blocks into OpenAI
602
+ chat-completions format (image blocks become ``image_url`` data
603
+ URIs). Requests a JSON object response when the server supports it,
604
+ falling back gracefully if it does not.
605
+ """
606
+ oai_messages = []
607
+ for msg in messages:
608
+ role = msg["role"]
609
+ content = msg.get("content", "")
610
+ if isinstance(content, str):
611
+ oai_messages.append({"role": role, "content": content})
612
+ continue
613
+ parts = []
614
+ for block in content:
615
+ if not isinstance(block, dict):
616
+ continue
617
+ if block.get("type") == "text":
618
+ parts.append({"type": "text", "text": block.get("text", "")})
619
+ elif block.get("type") == "image":
620
+ src = block.get("source", {})
621
+ if src.get("type") == "base64":
622
+ media = src.get("media_type", "image/png")
623
+ parts.append({
624
+ "type": "image_url",
625
+ "image_url": {
626
+ "url": f"data:{media};base64,{src['data']}"
627
+ },
628
+ })
629
+ oai_messages.append({"role": role, "content": parts})
630
+
631
+ kwargs = {
632
+ "model": self.config.model,
633
+ "messages": oai_messages,
634
+ "max_tokens": self.config.max_tokens,
635
+ "temperature": self.config.temperature,
636
+ }
637
+
638
+ def _is_rate_limit(exc) -> bool:
639
+ if getattr(exc, "status_code", None) == 429:
640
+ return True
641
+ s = str(exc).lower()
642
+ return ("429" in s or "rate limit" in s or "quota" in s
643
+ or "resource_exhausted" in s)
644
+
645
+ def _create(use_rf: bool):
646
+ if use_rf:
647
+ return self._llm_client.chat.completions.create(
648
+ response_format={"type": "json_object"}, **kwargs)
649
+ return self._llm_client.chat.completions.create(**kwargs)
650
+
651
+ # Transient 429s (per-minute rate/token bursts) are common mid-run
652
+ # even on paid tiers; back off and retry instead of failing the
653
+ # whole agent session.
654
+ backoffs = [5, 15, 30, 30, 30]
655
+ use_rf = True
656
+ attempt = 0
657
+ while True:
658
+ try:
659
+ resp = _create(use_rf)
660
+ break
661
+ except Exception as e:
662
+ if _is_rate_limit(e):
663
+ if attempt >= len(backoffs):
664
+ raise
665
+ wait = backoffs[attempt]
666
+ attempt += 1
667
+ logger.warning(
668
+ f"[{self.session_id}] LLM 429/rate-limited; "
669
+ f"retry {attempt}/{len(backoffs)} in {wait}s"
670
+ )
671
+ self._emit_event("thinking", {
672
+ "text": f"Rate-limited by the model API; "
673
+ f"waiting {wait}s before retrying…"
674
+ })
675
+ time.sleep(wait)
676
+ continue
677
+ if use_rf:
678
+ # Server may not support response_format; drop it once.
679
+ use_rf = False
680
+ continue
681
+ raise
682
+ return resp.choices[0].message.content or ""
683
+
684
+ def _query_anthropic(self, messages: List[Dict[str, Any]]) -> str:
685
+ """Query Anthropic Claude with vision support."""
686
+ # Separate system message
687
+ system = ""
688
+ api_messages = []
689
+ for msg in messages:
690
+ if msg["role"] == "system":
691
+ system = msg["content"]
692
+ else:
693
+ api_messages.append(msg)
694
+
695
+ kwargs = {
696
+ "model": self.config.model,
697
+ "max_tokens": self.config.max_tokens,
698
+ "temperature": self.config.temperature,
699
+ "messages": api_messages,
700
+ }
701
+ if system:
702
+ kwargs["system"] = system
703
+
704
+ response = self._llm_client.messages.create(**kwargs)
705
+ return response.content[0].text
706
+
707
+ def _query_ollama(self, messages: List[Dict[str, Any]]) -> str:
708
+ """Query Ollama vision model.
709
+
710
+ Converts Anthropic-format messages to Ollama format:
711
+ - System messages are prepended to the prompt text
712
+ - Multiple user messages are merged into a single message
713
+ - Content blocks with images use Ollama's 'images' key
714
+ """
715
+ # Extract text and images from Anthropic-format messages
716
+ all_text_parts = []
717
+ all_images = []
718
+ for msg in messages:
719
+ content = msg.get("content", "")
720
+ if msg["role"] == "system":
721
+ if isinstance(content, str) and content:
722
+ all_text_parts.insert(0, content)
723
+ continue
724
+ if isinstance(content, list):
725
+ for block in content:
726
+ if isinstance(block, dict):
727
+ if block.get("type") == "text":
728
+ all_text_parts.append(block["text"])
729
+ elif block.get("type") == "image":
730
+ source = block.get("source", {})
731
+ if source.get("type") == "base64":
732
+ all_images.append(source["data"])
733
+ elif isinstance(content, str) and content:
734
+ all_text_parts.append(content)
735
+
736
+ ollama_msg = {
737
+ "role": "user",
738
+ "content": "\n\n".join(all_text_parts),
739
+ }
740
+ if all_images:
741
+ ollama_msg["images"] = all_images
742
+
743
+ options = {
744
+ "temperature": self.config.temperature,
745
+ "num_predict": self.config.max_tokens,
746
+ }
747
+
748
+ # Use Ollama's format schema to force structured JSON output
749
+ agent_schema = {
750
+ "type": "object",
751
+ "properties": {
752
+ "thought": {"type": "string"},
753
+ "action": {
754
+ "type": "object",
755
+ "properties": {
756
+ "type": {"type": "string"},
757
+ "x": {"type": "integer"},
758
+ "y": {"type": "integer"},
759
+ "text": {"type": "string"},
760
+ "url": {"type": "string"},
761
+ "direction": {"type": "string"},
762
+ "amount": {"type": "integer"},
763
+ "summary": {"type": "string"},
764
+ },
765
+ "required": ["type"],
766
+ },
767
+ },
768
+ "required": ["thought", "action"],
769
+ }
770
+
771
+ response = self._llm_client.chat(
772
+ model=self.config.model,
773
+ messages=[ollama_msg],
774
+ options=options,
775
+ format=agent_schema,
776
+ )
777
+
778
+ # Extract content from response (handle both dict and Pydantic model)
779
+ message = (
780
+ response.get("message")
781
+ if hasattr(response, "get")
782
+ else getattr(response, "message", None)
783
+ )
784
+ if message is None:
785
+ raise RuntimeError("No message in Ollama response")
786
+
787
+ content = (
788
+ message.get("content")
789
+ if hasattr(message, "get")
790
+ else getattr(message, "content", None)
791
+ )
792
+
793
+ # Some models (e.g. qwen3-vl) put responses in 'thinking' field
794
+ # and leave content empty. Extract the agent JSON from thinking.
795
+ if not content:
796
+ thinking = (
797
+ message.get("thinking")
798
+ if hasattr(message, "get")
799
+ else getattr(message, "thinking", None)
800
+ )
801
+ if thinking:
802
+ content = _extract_agent_json(thinking)
803
+
804
+ return content or ""
805
+
806
+ def _parse_action(self, llm_response: str) -> tuple:
807
+ """Parse thought and action from LLM JSON response.
808
+
809
+ Returns:
810
+ (thought, action_dict)
811
+ """
812
+ # Try to extract JSON from response
813
+ text = llm_response.strip()
814
+
815
+ # Handle markdown code blocks
816
+ if "```json" in text:
817
+ import re
818
+ match = re.search(r"```json\s*([\s\S]*?)\s*```", text)
819
+ if match:
820
+ text = match.group(1).strip()
821
+ elif "```" in text:
822
+ import re
823
+ match = re.search(r"```\s*([\s\S]*?)\s*```", text)
824
+ if match:
825
+ text = match.group(1).strip()
826
+
827
+ try:
828
+ parsed = json.loads(text)
829
+ except json.JSONDecodeError:
830
+ logger.warning(f"Failed to parse LLM response as JSON: {text[:200]}")
831
+ return text, {"type": "wait"}
832
+
833
+ thought = parsed.get("thought", "")
834
+ action = parsed.get("action", {"type": "wait"})
835
+
836
+ # Validate action has a type
837
+ if "type" not in action:
838
+ action["type"] = "wait"
839
+
840
+ return thought, action
841
+
842
+ async def _execute_action(self, action: Dict[str, Any]) -> str:
843
+ """Execute an action via Playwright and return observation."""
844
+ action_type = action.get("type", "wait")
845
+ pw = self._playwright_session
846
+
847
+ try:
848
+ if action_type == "click":
849
+ x = int(action.get("x", 0))
850
+ y = int(action.get("y", 0))
851
+ success = await pw.click(x, y)
852
+ return f"Clicked at ({x}, {y})" if success else f"Click failed at ({x}, {y})"
853
+
854
+ elif action_type == "type":
855
+ text = action.get("text", "")
856
+ # Handle control characters via keyboard.press
857
+ if text == "\b":
858
+ success = await pw.page.keyboard.press("Backspace") or True
859
+ return "Pressed Backspace"
860
+ elif text == "\n":
861
+ success = await pw.page.keyboard.press("Enter") or True
862
+ return "Pressed Enter"
863
+ elif text == "\t":
864
+ success = await pw.page.keyboard.press("Tab") or True
865
+ return "Pressed Tab"
866
+ else:
867
+ success = await pw.type_text(text)
868
+ return f"Typed '{text}'" if success else f"Type failed: '{text}'"
869
+
870
+ elif action_type == "scroll":
871
+ direction = action.get("direction", "down")
872
+ amount = int(action.get("amount", 300))
873
+ dy = amount if direction == "down" else -amount
874
+ success = await pw.scroll(0, dy)
875
+ return f"Scrolled {direction} by {amount}px" if success else "Scroll failed"
876
+
877
+ elif action_type == "navigate":
878
+ url = action.get("url", "")
879
+ success = await pw.navigate(url)
880
+ return f"Navigated to {url}" if success else f"Navigation failed: {url}"
881
+
882
+ elif action_type == "wait":
883
+ await asyncio.sleep(1)
884
+ return "Waited 1 second"
885
+
886
+ elif action_type == "done":
887
+ summary = action.get("summary", "Task completed")
888
+ return summary
889
+
890
+ else:
891
+ logger.warning(f"Unknown action type: {action_type}")
892
+ return f"Unknown action: {action_type}"
893
+
894
+ except Exception as e:
895
+ logger.error(f"Action execution error: {e}")
896
+ return f"Error executing {action_type}: {e}"
897
+
898
+ async def _process_takeover_actions(self):
899
+ """Process manual actions submitted during takeover mode."""
900
+ try:
901
+ action = self._takeover_actions.get_nowait()
902
+ except Empty:
903
+ return
904
+
905
+ pw = self._playwright_session
906
+ if not pw:
907
+ return
908
+
909
+ observation = await self._execute_action(action)
910
+
911
+ # Take screenshot after manual action
912
+ screenshot_bytes = await pw.screenshot()
913
+ step_index = len(self._steps)
914
+ screenshot_path = os.path.join(
915
+ self.screenshot_dir, f"step_{step_index:03d}_manual.png"
916
+ )
917
+ if screenshot_bytes:
918
+ with open(screenshot_path, "wb") as f:
919
+ f.write(screenshot_bytes)
920
+
921
+ page_state = await pw.get_state()
922
+
923
+ step = AgentStep(
924
+ step_index=step_index,
925
+ screenshot_path=screenshot_path,
926
+ action={**action, "_manual": True},
927
+ thought="[Manual takeover action]",
928
+ observation=observation,
929
+ timestamp=time.time(),
930
+ url=page_state.get("url", ""),
931
+ viewport=page_state.get("viewport"),
932
+ coordinates=_extract_coordinates(action),
933
+ )
934
+ self._steps.append(step)
935
+ self._emit_event("step", step.to_dict())
936
+
937
+ # --- Trace export ---
938
+
939
+ def get_trace(self) -> Dict[str, Any]:
940
+ """Export the session as a web_agent_trace-compatible dict."""
941
+ return {
942
+ "steps": [s.to_dict() for s in self._steps],
943
+ "task_description": "", # Set by caller
944
+ "session_id": self.session_id,
945
+ "agent_config": {
946
+ "model": self.config.model,
947
+ "endpoint_type": self.config.endpoint_type,
948
+ "max_steps": self.config.max_steps,
949
+ },
950
+ "annotator_interactions": self._interactions,
951
+ "state": self.state.value,
952
+ "total_steps": len(self._steps),
953
+ }
954
+
955
+ def get_state_summary(self) -> Dict[str, Any]:
956
+ """Get a summary of current state for API responses."""
957
+ return {
958
+ "session_id": self.session_id,
959
+ "state": self.state.value,
960
+ "step_count": len(self._steps),
961
+ "error": self._error,
962
+ "has_instructions_pending": not self._instruction_queue.empty(),
963
+ }
964
+
965
+
966
+ def _extract_agent_json(text: str) -> str:
967
+ """Extract the last valid JSON object containing 'thought' or 'action' from text.
968
+
969
+ Some models (qwen3-vl) put their chain-of-thought in the thinking field
970
+ with the actual JSON answer embedded in the text. This function finds
971
+ that JSON, skipping any example/template JSON from the prompt.
972
+ """
973
+ import re
974
+
975
+ # Find all JSON-like blocks (balanced braces)
976
+ candidates = []
977
+ depth = 0
978
+ start = None
979
+ for i, ch in enumerate(text):
980
+ if ch == "{":
981
+ if depth == 0:
982
+ start = i
983
+ depth += 1
984
+ elif ch == "}":
985
+ depth -= 1
986
+ if depth == 0 and start is not None:
987
+ candidates.append(text[start : i + 1])
988
+ start = None
989
+
990
+ # Try each candidate (last first — most likely to be the final answer)
991
+ for candidate in reversed(candidates):
992
+ try:
993
+ parsed = json.loads(candidate)
994
+ if isinstance(parsed, dict) and ("thought" in parsed or "action" in parsed):
995
+ return candidate
996
+ except (json.JSONDecodeError, ValueError):
997
+ continue
998
+
999
+ # Fallback: try greedy regex for any JSON
1000
+ match = re.search(r"\{[^{}]*\}", text)
1001
+ return match.group(0) if match else ""
1002
+
1003
+
1004
+ def _extract_coordinates(action: Dict[str, Any]) -> Optional[Dict[str, int]]:
1005
+ """Extract x, y coordinates from an action if present."""
1006
+ if "x" in action and "y" in action:
1007
+ return {"x": int(action["x"]), "y": int(action["y"])}
1008
+ return None
potato/agent_runner_manager.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Runner Session Manager
3
+
4
+ Singleton that manages active AgentRunner sessions.
5
+ Keyed by "{user_id}:{instance_id}" for per-user, per-instance isolation.
6
+ Includes TTL-based cleanup and max concurrent session limits.
7
+ """
8
+
9
+ import atexit
10
+ import logging
11
+ import threading
12
+ import time
13
+ from typing import Dict, Optional
14
+
15
+ from potato.agent_runner import AgentConfig, AgentRunner, AgentState
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Default limits
20
+ DEFAULT_MAX_SESSIONS = 10
21
+ DEFAULT_SESSION_TTL = 3600 # 1 hour
22
+
23
+
24
+ class AgentRunnerManager:
25
+ """
26
+ Manages active AgentRunner sessions with lifecycle control.
27
+
28
+ Thread-safe singleton. Sessions are keyed by "{user_id}:{instance_id}".
29
+ """
30
+
31
+ _instance = None
32
+ _lock = threading.Lock()
33
+
34
+ def __init__(
35
+ self,
36
+ max_sessions: int = DEFAULT_MAX_SESSIONS,
37
+ session_ttl: int = DEFAULT_SESSION_TTL,
38
+ ):
39
+ self._sessions: Dict[str, AgentRunner] = {}
40
+ self._session_created: Dict[str, float] = {}
41
+ self._session_meta: Dict[str, Dict] = {}
42
+ self._lock = threading.Lock()
43
+ self.max_sessions = max_sessions
44
+ self.session_ttl = session_ttl
45
+
46
+ # Start cleanup thread
47
+ self._cleanup_stop = threading.Event()
48
+ self._cleanup_thread = threading.Thread(
49
+ target=self._cleanup_loop, daemon=True, name="agent-cleanup"
50
+ )
51
+ self._cleanup_thread.start()
52
+
53
+ @classmethod
54
+ def get_instance(cls, **kwargs) -> "AgentRunnerManager":
55
+ """Get or create the singleton instance."""
56
+ if cls._instance is None:
57
+ with cls._lock:
58
+ if cls._instance is None:
59
+ cls._instance = cls(**kwargs)
60
+ return cls._instance
61
+
62
+ @classmethod
63
+ def clear_instance(cls):
64
+ """Clear the singleton (for testing)."""
65
+ with cls._lock:
66
+ if cls._instance is not None:
67
+ cls._instance.shutdown()
68
+ cls._instance = None
69
+
70
+ def create_session(
71
+ self,
72
+ user_id: str,
73
+ instance_id: str,
74
+ config: AgentConfig,
75
+ screenshot_dir: str,
76
+ ) -> AgentRunner:
77
+ """
78
+ Create a new agent session.
79
+
80
+ Args:
81
+ user_id: Annotator user ID
82
+ instance_id: Annotation instance ID
83
+ config: Agent configuration
84
+ screenshot_dir: Directory to store screenshots
85
+
86
+ Returns:
87
+ AgentRunner instance
88
+
89
+ Raises:
90
+ RuntimeError: If max sessions reached or session already exists
91
+ """
92
+ session_key = f"{user_id}:{instance_id}"
93
+
94
+ with self._lock:
95
+ # Clean up expired sessions first
96
+ self._cleanup_expired_locked()
97
+
98
+ # Check for existing active session
99
+ if session_key in self._sessions:
100
+ existing = self._sessions[session_key]
101
+ if existing.state in (AgentState.RUNNING, AgentState.PAUSED, AgentState.TAKEOVER):
102
+ raise RuntimeError(
103
+ f"Active session already exists for {session_key}. "
104
+ f"Stop it first."
105
+ )
106
+ # Old completed/error session — remove it
107
+ del self._sessions[session_key]
108
+ del self._session_created[session_key]
109
+ if session_key in self._session_meta:
110
+ del self._session_meta[session_key]
111
+
112
+ # Check capacity
113
+ active_count = sum(
114
+ 1
115
+ for s in self._sessions.values()
116
+ if s.state in (AgentState.RUNNING, AgentState.PAUSED, AgentState.TAKEOVER)
117
+ )
118
+ if active_count >= self.max_sessions:
119
+ raise RuntimeError(
120
+ f"Maximum concurrent sessions ({self.max_sessions}) reached"
121
+ )
122
+
123
+ import uuid
124
+ session_id = str(uuid.uuid4())[:12]
125
+ runner = AgentRunner(session_id, config, screenshot_dir)
126
+
127
+ self._sessions[session_key] = runner
128
+ self._session_created[session_key] = time.time()
129
+ self._session_meta[session_key] = {
130
+ "user_id": user_id,
131
+ "instance_id": instance_id,
132
+ "session_id": session_id,
133
+ }
134
+
135
+ logger.info(
136
+ f"Created agent session {session_id} for {session_key}"
137
+ )
138
+ return runner
139
+
140
+ def get_session(self, session_id: str) -> Optional[AgentRunner]:
141
+ """Get a session by its session_id."""
142
+ with self._lock:
143
+ for runner in self._sessions.values():
144
+ if runner.session_id == session_id:
145
+ return runner
146
+ return None
147
+
148
+ def get_session_by_key(self, user_id: str, instance_id: str) -> Optional[AgentRunner]:
149
+ """Get a session by user_id and instance_id."""
150
+ session_key = f"{user_id}:{instance_id}"
151
+ with self._lock:
152
+ return self._sessions.get(session_key)
153
+
154
+ def remove_session(self, session_id: str):
155
+ """Remove a session by session_id."""
156
+ with self._lock:
157
+ key_to_remove = None
158
+ for key, runner in self._sessions.items():
159
+ if runner.session_id == session_id:
160
+ key_to_remove = key
161
+ break
162
+ if key_to_remove:
163
+ runner = self._sessions.pop(key_to_remove)
164
+ self._session_created.pop(key_to_remove, None)
165
+ self._session_meta.pop(key_to_remove, None)
166
+ runner.stop()
167
+ logger.info(f"Removed agent session {session_id}")
168
+
169
+ def list_sessions(self) -> list:
170
+ """List all active sessions."""
171
+ with self._lock:
172
+ result = []
173
+ for key, runner in self._sessions.items():
174
+ meta = self._session_meta.get(key, {})
175
+ result.append({
176
+ "session_id": runner.session_id,
177
+ "user_id": meta.get("user_id"),
178
+ "instance_id": meta.get("instance_id"),
179
+ "state": runner.state.value,
180
+ "step_count": runner.step_count,
181
+ "created": self._session_created.get(key),
182
+ })
183
+ return result
184
+
185
+ def _cleanup_expired_locked(self):
186
+ """Remove expired sessions. Must be called with self._lock held."""
187
+ now = time.time()
188
+ expired_keys = []
189
+ for key, created_at in self._session_created.items():
190
+ if now - created_at > self.session_ttl:
191
+ runner = self._sessions.get(key)
192
+ if runner and runner.state in (AgentState.COMPLETED, AgentState.ERROR, AgentState.IDLE):
193
+ expired_keys.append(key)
194
+ elif runner and now - created_at > self.session_ttl * 2:
195
+ # Force-stop sessions that have been running too long
196
+ runner.stop()
197
+ expired_keys.append(key)
198
+
199
+ for key in expired_keys:
200
+ self._sessions.pop(key, None)
201
+ self._session_created.pop(key, None)
202
+ self._session_meta.pop(key, None)
203
+ logger.info(f"Cleaned up expired session: {key}")
204
+
205
+ def _cleanup_loop(self):
206
+ """Background cleanup thread."""
207
+ while not self._cleanup_stop.is_set():
208
+ self._cleanup_stop.wait(60) # Check every 60 seconds
209
+ if self._cleanup_stop.is_set():
210
+ break
211
+ with self._lock:
212
+ self._cleanup_expired_locked()
213
+
214
+ def shutdown(self):
215
+ """Stop all sessions and cleanup thread."""
216
+ self._cleanup_stop.set()
217
+ with self._lock:
218
+ for key, runner in self._sessions.items():
219
+ try:
220
+ runner.stop()
221
+ except Exception as e:
222
+ logger.warning(f"Error stopping session {key}: {e}")
223
+ self._sessions.clear()
224
+ self._session_created.clear()
225
+ self._session_meta.clear()
226
+ logger.info("AgentRunnerManager shut down")
potato/agreement.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inter-Annotator Agreement Calculation Module
3
+
4
+ This module provides functionality for calculating inter-annotator agreement metrics,
5
+ including Krippendorff's alpha, Cohen's kappa (pairwise), and Fleiss' kappa
6
+ (N raters), from annotation data. It supports both rating agreement (interval
7
+ metric) and skip agreement (nominal metric) calculations.
8
+
9
+ The module processes annotation files in JSON format and outputs agreement statistics
10
+ along with a CSV file containing the processed annotation data.
11
+ """
12
+
13
+ import argparse
14
+ from itertools import combinations
15
+ import simpledorff
16
+ from simpledorff.metrics import *
17
+ import ujson
18
+ import pandas as pd
19
+
20
+ from collections import defaultdict
21
+ import numpy as np
22
+
23
+
24
+ def get_nans(shape):
25
+ """
26
+ Create a numpy array filled with NaN values.
27
+
28
+ Args:
29
+ shape: The shape of the array to create
30
+
31
+ Returns:
32
+ numpy.ndarray: Array filled with NaN values
33
+ """
34
+ ar = np.empty(shape)
35
+ ar[:] = np.NaN
36
+ return ar
37
+
38
+
39
+ def cohen_kappa_pairwise(reliability_df):
40
+ """
41
+ Compute Cohen's kappa for every pair of annotators and return aggregate stats.
42
+
43
+ Cohen's kappa is defined for exactly two raters. With N>2 raters we compute
44
+ kappa for each pair on the items they both rated, then return the mean and the
45
+ per-pair breakdown. Pairs that share fewer than 2 items are skipped.
46
+
47
+ Args:
48
+ reliability_df: long-format DataFrame with columns
49
+ unit (item id), annotator (user), annotation (label value).
50
+
51
+ Returns:
52
+ dict with keys: mean_kappa (float | None), pairs (list of
53
+ {annotator_a, annotator_b, kappa, n_items}), n_pairs_evaluated,
54
+ n_pairs_skipped.
55
+ """
56
+ from sklearn.metrics import cohen_kappa_score
57
+
58
+ annotators = sorted(reliability_df["annotator"].unique())
59
+ pairs = []
60
+ skipped = 0
61
+
62
+ for a, b in combinations(annotators, 2):
63
+ a_rows = reliability_df[reliability_df["annotator"] == a].set_index("unit")["annotation"]
64
+ b_rows = reliability_df[reliability_df["annotator"] == b].set_index("unit")["annotation"]
65
+ shared = a_rows.index.intersection(b_rows.index)
66
+ if len(shared) < 2:
67
+ skipped += 1
68
+ continue
69
+
70
+ y_a = a_rows.loc[shared].astype(str).tolist()
71
+ y_b = b_rows.loc[shared].astype(str).tolist()
72
+ try:
73
+ kappa = float(cohen_kappa_score(y_a, y_b))
74
+ except Exception:
75
+ skipped += 1
76
+ continue
77
+ pairs.append({
78
+ "annotator_a": a,
79
+ "annotator_b": b,
80
+ "kappa": round(kappa, 4),
81
+ "n_items": int(len(shared)),
82
+ })
83
+
84
+ mean_kappa = (sum(p["kappa"] for p in pairs) / len(pairs)) if pairs else None
85
+ return {
86
+ "mean_kappa": round(mean_kappa, 4) if mean_kappa is not None else None,
87
+ "pairs": pairs,
88
+ "n_pairs_evaluated": len(pairs),
89
+ "n_pairs_skipped": skipped,
90
+ }
91
+
92
+
93
+ def fleiss_kappa(reliability_df):
94
+ """
95
+ Compute Fleiss' kappa for N raters over a categorical label set.
96
+
97
+ Fleiss' kappa assumes the same number of ratings per item but tolerates
98
+ different rater identities per item. Items with fewer than 2 ratings are
99
+ dropped; the remaining items are padded by repeating their available
100
+ ratings up to the per-item rater count (`n_raters = max ratings per item`).
101
+ When per-item rater counts vary widely the metric is approximate; we report
102
+ `n_raters` and `n_items_evaluated` so the caller can judge.
103
+
104
+ Args:
105
+ reliability_df: long-format DataFrame with columns
106
+ unit (item id), annotator (user), annotation (label value).
107
+
108
+ Returns:
109
+ dict with keys: kappa (float | None), n_items_evaluated (int),
110
+ n_raters (int), n_categories (int), interpretation (str).
111
+ """
112
+ if reliability_df.empty:
113
+ return {"kappa": None, "n_items_evaluated": 0, "n_raters": 0,
114
+ "n_categories": 0, "interpretation": "No data"}
115
+
116
+ df = reliability_df.copy()
117
+ df["annotation"] = df["annotation"].astype(str)
118
+
119
+ counts_by_item = df.groupby(["unit", "annotation"]).size().unstack(fill_value=0)
120
+ items_with_ratings = counts_by_item.sum(axis=1)
121
+ counts_by_item = counts_by_item.loc[items_with_ratings >= 2]
122
+
123
+ if counts_by_item.empty:
124
+ return {"kappa": None, "n_items_evaluated": 0, "n_raters": 0,
125
+ "n_categories": int(df["annotation"].nunique()),
126
+ "interpretation": "No items with >=2 raters"}
127
+
128
+ n_raters = int(counts_by_item.sum(axis=1).max())
129
+ n_items = int(counts_by_item.shape[0])
130
+ n_categories = int(counts_by_item.shape[1])
131
+
132
+ matrix = counts_by_item.to_numpy(dtype=float)
133
+ row_sums = matrix.sum(axis=1, keepdims=True)
134
+ row_sums[row_sums == 0] = 1.0
135
+ matrix = matrix * (n_raters / row_sums)
136
+
137
+ p_j = matrix.sum(axis=0) / (n_items * n_raters)
138
+ if n_raters < 2:
139
+ return {"kappa": None, "n_items_evaluated": n_items, "n_raters": n_raters,
140
+ "n_categories": n_categories,
141
+ "interpretation": "Need >=2 raters per item"}
142
+ p_i = (np.sum(matrix ** 2, axis=1) - n_raters) / (n_raters * (n_raters - 1))
143
+ p_bar = float(p_i.mean())
144
+ p_e = float(np.sum(p_j ** 2))
145
+
146
+ if p_e >= 1.0:
147
+ kappa = 1.0 if p_bar >= 1.0 else 0.0
148
+ else:
149
+ kappa = (p_bar - p_e) / (1 - p_e)
150
+
151
+ return {
152
+ "kappa": round(float(kappa), 4),
153
+ "n_items_evaluated": n_items,
154
+ "n_raters": n_raters,
155
+ "n_categories": n_categories,
156
+ "interpretation": interpret_kappa(kappa),
157
+ }
158
+
159
+
160
+ def interpret_kappa(kappa):
161
+ """Landis & Koch (1977) interpretation bands for kappa-family metrics."""
162
+ if kappa is None:
163
+ return "No agreement computable"
164
+ if kappa < 0:
165
+ return "Worse than chance"
166
+ if kappa < 0.21:
167
+ return "Slight"
168
+ if kappa < 0.41:
169
+ return "Fair"
170
+ if kappa < 0.61:
171
+ return "Moderate"
172
+ if kappa < 0.81:
173
+ return "Substantial"
174
+ return "Almost perfect"
175
+
176
+
177
+ def flatten(annotations):
178
+ """
179
+ Flatten annotation data structure for processing.
180
+
181
+ Converts a list of annotation dictionaries into a format where each
182
+ annotation is a dictionary mapping user IDs to their labels.
183
+
184
+ Args:
185
+ annotations: List of annotation dictionaries
186
+
187
+ Returns:
188
+ list: Flattened annotation data structure
189
+
190
+ Example:
191
+ Input: [{"user": "user1", "label": "positive"}, {"user": "user2", "label": "negative"}]
192
+ Output: [{"user1": "positive", "user2": "negative"}]
193
+ """
194
+ return [{a["user"]: a["label"] for a in ann} for ann in annotations]
195
+
196
+
197
+ def main(args):
198
+ """
199
+ Main function for calculating inter-annotator agreement.
200
+
201
+ This function processes annotation data from a JSON file, calculates
202
+ Krippendorff's alpha for both rating agreement and skip agreement,
203
+ and outputs the results along with a CSV file of the processed data.
204
+
205
+ Args:
206
+ args: Command line arguments containing file paths
207
+
208
+ Side Effects:
209
+ - Reads annotation data from input file
210
+ - Prints agreement statistics to console
211
+ - Writes processed data to output CSV file
212
+
213
+ The function processes the first 385 annotations by default and handles
214
+ missing annotations and skipped items appropriately.
215
+ """
216
+ # Load annotation data from JSON file
217
+ with open(args.file, "r") as f:
218
+ annotations = [ujson.loads(line)["annotations"] for line in f]
219
+
220
+ # Extract unique user IDs from all annotations
221
+ users = set([a["user"] for ann in annotations for a in ann])
222
+ annotations = flatten(annotations)
223
+
224
+ # Limit to first 385 annotations (configurable limit)
225
+ annotations = annotations[:385]
226
+
227
+ # Create data matrix for agreement calculation
228
+ # Each row represents a user, each column represents an annotation
229
+ # -1 values indicate skipped annotations, NaN indicates missing annotations
230
+ data = [
231
+ [np.nan if user not in a or int(a[user]) == -1 else int(a[user]) for a in annotations]
232
+ for user in users
233
+ ]
234
+
235
+ # Create skip data matrix (boolean indicating if annotation was skipped)
236
+ skip_data = [
237
+ [np.nan if user not in a else int(a[user]) < 0 for a in annotations] for user in users
238
+ ]
239
+
240
+ # Calculate statistics for each user
241
+ labeled = ~np.isnan(data)
242
+ skipped = [
243
+ [False if user not in a else int(a[user]) < 0 for a in annotations] for user in users
244
+ ]
245
+
246
+ # Print summary statistics
247
+ print("calculating over:")
248
+ for user, skip in zip(labeled, skipped):
249
+ print("labeled:", sum(user))
250
+ print("skipped:", sum(skip))
251
+
252
+ # Count instances where all users provided annotations
253
+ print(np.all(labeled, axis=0).sum())
254
+
255
+ # Calculate and print Krippendorff's alpha for rating agreement
256
+ # Uses interval metric for continuous rating scales
257
+ print("rating agreement:")
258
+ print(simpledorff.calculate_krippendorffs_alpha(pd.DataFrame(data),metric_fn=interval_metric))
259
+
260
+ # Calculate and print Krippendorff's alpha for skip agreement
261
+ # Uses nominal metric for binary skip/no-skip decisions
262
+ print("skip agreement:")
263
+ print(simpledorff.calculate_krippendorffs_alpha(pd.DataFrame(data),metric_fn=nominal_metric))
264
+
265
+ # Write processed data to CSV file
266
+ with open(args.outfile, "w") as f:
267
+ for row in zip(*data):
268
+ f.write(",".join([str(a) for a in row]) + "\n")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ # Set up command line argument parsing
273
+ parser = argparse.ArgumentParser(
274
+ description="Calculate Krippendorf's alpha from given JSON file of annotations"
275
+ )
276
+ parser.add_argument("file", help="path to JSON file")
277
+ parser.add_argument("outfile", help="write path to CSV")
278
+ main(parser.parse_args())
potato/ai/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ai_help_wrapper import generate_ai_help_html
potato/ai/ai_cache.py ADDED
@@ -0,0 +1,1473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ import logging
4
+ import os
5
+ from typing import Any, Dict, Union
6
+ import requests
7
+ from tqdm import tqdm
8
+ import time
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ import threading
11
+ from builtins import open
12
+ from potato.server_utils.config_module import config
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ from potato.item_state_management import get_item_state_manager
17
+ from potato.ai.ai_endpoint import (
18
+ AIEndpointFactory,
19
+ Annotation_Type,
20
+ AnnotationInput,
21
+ ImageData,
22
+ VisualAnnotationInput,
23
+ ModelCapabilities,
24
+ )
25
+ from potato.ai.ollama_endpoint import OllamaEndpoint
26
+ from potato.ai.openrouter_endpoint import OpenRouterEndpoint
27
+ from potato.ai.ai_prompt import ModelManager, get_ai_prompt
28
+
29
+
30
+ AICACHEMANAGER = None
31
+
32
+
33
+ def _get_scheme_field(annotation_id: int, field: str, default=None):
34
+ """Safely get a field from an annotation scheme with a clear error message."""
35
+ schemes = config.get("annotation_schemes", [])
36
+ if annotation_id >= len(schemes):
37
+ raise ValueError(
38
+ f"AI cache: annotation_id {annotation_id} out of range "
39
+ f"(only {len(schemes)} scheme(s) configured)"
40
+ )
41
+ scheme = schemes[annotation_id]
42
+ if default is not None:
43
+ return scheme.get(field, default)
44
+ if field not in scheme:
45
+ scheme_name = scheme.get("name", f"index {annotation_id}")
46
+ scheme_type = scheme.get("annotation_type", "unknown")
47
+ raise ValueError(
48
+ f"AI cache: annotation scheme '{scheme_name}' (type '{scheme_type}') "
49
+ f"missing required field '{field}'"
50
+ )
51
+ return scheme[field]
52
+
53
+
54
+ def _get_instance_text(instance_id: int) -> str:
55
+ """Get the text content from an instance using the configured text_key."""
56
+ item = get_item_state_manager().items()[instance_id]
57
+ item_data = item.get_data()
58
+
59
+ # Get the configured text_key
60
+ text_key = config.get("item_properties", {}).get("text_key", "text")
61
+
62
+ # Try the configured text_key first
63
+ if text_key in item_data:
64
+ return item_data[text_key]
65
+
66
+ # Fall back to common keys
67
+ for key in ['text', 'content', 'message']:
68
+ if key in item_data:
69
+ return item_data[key]
70
+
71
+ # Last resort: return any string value
72
+ for value in item_data.values():
73
+ if isinstance(value, str):
74
+ return value
75
+
76
+ return str(item_data)
77
+
78
+ def _is_image_url(text: str) -> bool:
79
+ """Check if text appears to be an image URL."""
80
+ if not isinstance(text, str):
81
+ return False
82
+ text_lower = text.lower()
83
+ # Check for image extensions
84
+ image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp']
85
+ if any(ext in text_lower for ext in image_extensions):
86
+ return True
87
+ # Check for common image hosting services
88
+ image_hosts = ['unsplash.com', 'imgur.com', 'flickr.com', 'picsum.photos']
89
+ if any(host in text_lower for host in image_hosts):
90
+ return True
91
+ # Check if URL starts with http and might be an image
92
+ if text_lower.startswith(('http://', 'https://')) and 'image' in text_lower:
93
+ return True
94
+ return False
95
+
96
+ def _get_image_data_from_url(url: str) -> ImageData:
97
+ """Download image from URL and return as ImageData.
98
+
99
+ Includes SSRF protection to prevent fetching from private/internal IPs.
100
+ """
101
+ import base64
102
+ import ipaddress
103
+ import socket
104
+ from urllib.parse import urlparse
105
+
106
+ # SSRF protection: validate URL scheme and resolve hostname
107
+ try:
108
+ parsed = urlparse(url)
109
+ if parsed.scheme not in ('http', 'https'):
110
+ logger.warning(f"Blocked non-HTTP image URL: {url[:100]}")
111
+ return None
112
+
113
+ hostname = parsed.hostname
114
+ if hostname:
115
+ addr_info = socket.getaddrinfo(hostname, None)
116
+ for info in addr_info:
117
+ ip_str = info[4][0]
118
+ try:
119
+ ip = ipaddress.ip_address(ip_str)
120
+ if ip.is_private or ip.is_loopback or ip.is_link_local:
121
+ logger.warning(
122
+ f"Blocked image URL resolving to private IP: "
123
+ f"{hostname} -> {ip_str}"
124
+ )
125
+ return None
126
+ except ValueError:
127
+ pass
128
+ except Exception as e:
129
+ logger.warning(f"Failed to validate image URL {url[:100]}: {e}")
130
+ return None
131
+
132
+ try:
133
+ response = requests.get(url, timeout=30)
134
+ response.raise_for_status()
135
+ b64_data = base64.b64encode(response.content).decode('utf-8')
136
+ # Determine mime type from content-type header or URL
137
+ content_type = response.headers.get('content-type', 'image/jpeg')
138
+ return ImageData(source='base64', data=b64_data, mime_type=content_type)
139
+ except Exception as e:
140
+ logger.error(f"Failed to download image from {url}: {e}")
141
+ return None
142
+
143
+ def init_ai_cache_manager():
144
+ global AICACHEMANAGER
145
+ if AICACHEMANAGER is None:
146
+ AICACHEMANAGER = AiCacheManager()
147
+
148
+ return AICACHEMANAGER
149
+
150
+ def get_ai_cache_manager():
151
+ """Get the AI cache manager instance. Returns None if not initialized (AI support disabled)."""
152
+ global AICACHEMANAGER
153
+ return AICACHEMANAGER
154
+
155
+ def clear_ai_cache_manager():
156
+ """Clear the AI cache manager singleton. Used for testing."""
157
+ global AICACHEMANAGER
158
+ AICACHEMANAGER = None
159
+
160
+ class AiCacheManager:
161
+ def __init__(self):
162
+ ai_support = config["ai_support"]
163
+ if not ai_support["enabled"]:
164
+ return
165
+ cache_config = ai_support.get("cache_config", {})
166
+ ai_config = ai_support.get("ai_config", {})
167
+ include = ai_config.get("include") or {}
168
+ special_include = include.get("special_include", None)
169
+ self.include_all = include.get("all", False)
170
+ self.special_includes = {}
171
+
172
+ self.model_manager = ModelManager()
173
+ self.model_manager.load_models_module()
174
+
175
+ if special_include:
176
+ for page_key, page_value in special_include.items():
177
+ # Convert string keys to integers for easier lookup
178
+ page_index = int(page_key)
179
+ self.special_includes[page_index] = {}
180
+ for annotation_id, annotation_types in page_value.items():
181
+ annotation_id_int = int(annotation_id)
182
+ self.special_includes[page_index][annotation_id_int] = annotation_types
183
+
184
+ # Disk cache configuration.
185
+ # F-028: tolerate a partial/absent ai_cache config (e.g. AI support
186
+ # enabled for ICL with no disk_cache block) instead of crashing boot
187
+ # with KeyError: 'disk_cache'.
188
+ disk_cache_cfg = cache_config.get("disk_cache", {}) if isinstance(cache_config, dict) else {}
189
+ self.disk_cache_enabled = disk_cache_cfg.get("enabled", False)
190
+
191
+ disk_cache_path = disk_cache_cfg.get("path")
192
+ if self.disk_cache_enabled and not disk_cache_path:
193
+ raise Exception("You have enable disk cache, but you did not specific the path!")
194
+ self.disk_persistence_path = disk_cache_path
195
+
196
+ # Validate cache path stays within task directory
197
+ if self.disk_persistence_path:
198
+ task_dir = os.path.abspath(config.get("task_dir", "."))
199
+ cache_abs = os.path.abspath(
200
+ os.path.join(task_dir, self.disk_persistence_path)
201
+ if not os.path.isabs(self.disk_persistence_path)
202
+ else self.disk_persistence_path
203
+ )
204
+ if not cache_abs.startswith(task_dir + os.sep) and cache_abs != task_dir:
205
+ raise ValueError(
206
+ f"Cache path '{self.disk_persistence_path}' resolves to "
207
+ f"'{cache_abs}' which is outside the task directory "
208
+ f"'{task_dir}'. Path traversal is not allowed."
209
+ )
210
+
211
+ # Prefetch configuration — clamp to sane ranges.
212
+ # F-028: default to no prefetch when the prefetch block is absent
213
+ # (e.g. cache_config: {enabled: false}) instead of KeyError on boot.
214
+ prefetch_cfg = cache_config.get("prefetch", {}) if isinstance(cache_config, dict) else {}
215
+ self.warm_up_page_count = max(0, min(int(prefetch_cfg.get("warm_up_page_count", 0)), 10000))
216
+ self.prefetch_page_count_on_next = max(0, min(int(prefetch_cfg.get("on_next", 0)), 10000))
217
+ self.prefetch_page_count_on_prev = max(0, min(int(prefetch_cfg.get("on_prev", 0)), 10000))
218
+
219
+ # Option highlighting configuration
220
+ option_highlighting = ai_support.get("option_highlighting", {})
221
+ self.option_highlighting_enabled = option_highlighting.get("enabled", False)
222
+ self.option_highlighting_top_k = option_highlighting.get("top_k", 3)
223
+ self.option_highlighting_dim_opacity = option_highlighting.get("dim_opacity", 0.4)
224
+ self.option_highlighting_auto_apply = option_highlighting.get("auto_apply", True)
225
+ self.option_highlighting_schemas = option_highlighting.get("schemas", None) # None means all
226
+ # Prefetch count for option highlighting — clamp to sane range
227
+ self.option_highlighting_prefetch_count = max(0, min(
228
+ int(option_highlighting.get("prefetch_count", 20)), 10000
229
+ ))
230
+
231
+ # Threading
232
+ self.in_progress = {}
233
+ self.lock = threading.RLock()
234
+ self.executor = ThreadPoolExecutor(max_workers=20)
235
+
236
+ AIEndpointFactory.register_endpoint("ollama", OllamaEndpoint)
237
+ AIEndpointFactory.register_endpoint("open_router", OpenRouterEndpoint)
238
+
239
+ # Register visual AI endpoints
240
+ try:
241
+ from potato.ai.yolo_endpoint import YOLOEndpoint
242
+ AIEndpointFactory.register_endpoint("yolo", YOLOEndpoint)
243
+ except ImportError:
244
+ logger.debug("YOLO endpoint not available (ultralytics not installed)")
245
+
246
+ try:
247
+ from potato.ai.ollama_vision_endpoint import OllamaVisionEndpoint
248
+ AIEndpointFactory.register_endpoint("ollama_vision", OllamaVisionEndpoint)
249
+ except ImportError:
250
+ logger.debug("Ollama Vision endpoint not available")
251
+
252
+ try:
253
+ from potato.ai.openai_vision_endpoint import OpenAIVisionEndpoint
254
+ AIEndpointFactory.register_endpoint("openai_vision", OpenAIVisionEndpoint)
255
+ except ImportError:
256
+ logger.debug("OpenAI Vision endpoint not available")
257
+
258
+ try:
259
+ from potato.ai.anthropic_vision_endpoint import AnthropicVisionEndpoint
260
+ AIEndpointFactory.register_endpoint("anthropic_vision", AnthropicVisionEndpoint)
261
+ except ImportError:
262
+ logger.debug("Anthropic Vision endpoint not available")
263
+
264
+ # Degrade gracefully if the AI backend (e.g. a local Ollama/vLLM server)
265
+ # is unreachable at boot: log a warning and serve the task with AI
266
+ # support disabled rather than aborting server startup.
267
+ try:
268
+ self.ai_endpoint = AIEndpointFactory.create_endpoint(config)
269
+ except Exception as e:
270
+ logger.warning(
271
+ "AI endpoint unavailable at startup (%s). Continuing with AI "
272
+ "support disabled. Check that your AI backend is running.", e
273
+ )
274
+ self.ai_endpoint = None
275
+
276
+ # Create visual endpoint if different from main endpoint
277
+ self.visual_endpoint = None
278
+ visual_endpoint_type = config.get("ai_support", {}).get("visual_endpoint_type")
279
+ if visual_endpoint_type and visual_endpoint_type != config.get("ai_support", {}).get("endpoint_type"):
280
+ visual_config = {
281
+ "ai_support": {
282
+ "enabled": True,
283
+ "endpoint_type": visual_endpoint_type,
284
+ "ai_config": config.get("ai_support", {}).get("visual_ai_config", config.get("ai_support", {}).get("ai_config", {}))
285
+ }
286
+ }
287
+ try:
288
+ self.visual_endpoint = AIEndpointFactory.create_endpoint(visual_config)
289
+ except Exception as e:
290
+ logger.warning(
291
+ "Visual AI endpoint unavailable at startup (%s). Continuing "
292
+ "without visual AI support.", e
293
+ )
294
+ self.visual_endpoint = None
295
+
296
+ annotation_scheme = config.get("annotation_schemes")
297
+ self.annotations = []
298
+ for scheme in annotation_scheme:
299
+ self.annotations.append(scheme)
300
+
301
+ # Check if main endpoint supports vision
302
+ self.endpoint_supports_vision = hasattr(self.ai_endpoint, 'query_with_image')
303
+ logger.info(f"AI endpoint supports vision: {self.endpoint_supports_vision}")
304
+
305
+ # Initialize cache
306
+ if self.disk_cache_enabled:
307
+ self.load_cache_from_disk()
308
+ self.start_warmup()
309
+
310
+ def _validate_assistant_compatibility(
311
+ self, instance_id: int, annotation_id: int, ai_assistant: str
312
+ ) -> tuple:
313
+ """
314
+ Validate that the AI assistant is compatible with the input type and model capabilities.
315
+
316
+ Args:
317
+ instance_id: The instance/item index
318
+ annotation_id: The annotation scheme index
319
+ ai_assistant: Type of assistance ('hint', 'keyword', 'rationale', 'detection', etc.)
320
+
321
+ Returns:
322
+ Tuple of (is_valid: bool, error_message: str)
323
+ If valid, error_message is empty string.
324
+ """
325
+ try:
326
+ text = _get_instance_text(instance_id)
327
+ is_image = _is_image_url(text)
328
+
329
+ # Determine which endpoint to use
330
+ if is_image and self.visual_endpoint:
331
+ endpoint = self.visual_endpoint
332
+ elif is_image and self.endpoint_supports_vision:
333
+ endpoint = self.ai_endpoint
334
+ else:
335
+ endpoint = self.ai_endpoint
336
+
337
+ # Get capabilities from endpoint
338
+ capabilities = getattr(endpoint, 'CAPABILITIES', None)
339
+
340
+ if capabilities is None:
341
+ # No capabilities declared - allow all (backward compatibility)
342
+ logger.debug(f"Endpoint {type(endpoint).__name__} has no CAPABILITIES, allowing {ai_assistant}")
343
+ return True, ""
344
+
345
+ # Check if the assistant type is supported
346
+ if not capabilities.supports_assistant(ai_assistant, is_image):
347
+ input_type = "image" if is_image else "text"
348
+ return False, (
349
+ f"Model {type(endpoint).__name__} does not support '{ai_assistant}' "
350
+ f"for {input_type} content"
351
+ )
352
+
353
+ return True, ""
354
+
355
+ except Exception as e:
356
+ logger.warning(f"Error validating assistant compatibility: {e}")
357
+ # On validation error, allow the request (fail open for now)
358
+ return True, ""
359
+
360
+ def get_endpoint_capabilities(self, for_image: bool = False) -> ModelCapabilities:
361
+ """
362
+ Get the capabilities of the appropriate endpoint for the given input type.
363
+
364
+ Args:
365
+ for_image: Whether the input is an image
366
+
367
+ Returns:
368
+ ModelCapabilities instance, or a default permissive one if not declared
369
+ """
370
+ if for_image and self.visual_endpoint:
371
+ endpoint = self.visual_endpoint
372
+ elif for_image and self.endpoint_supports_vision:
373
+ endpoint = self.ai_endpoint
374
+ else:
375
+ endpoint = self.ai_endpoint
376
+
377
+ capabilities = getattr(endpoint, 'CAPABILITIES', None)
378
+ if capabilities is None:
379
+ # Return permissive defaults for backward compatibility
380
+ return ModelCapabilities(
381
+ text_generation=True,
382
+ vision_input=for_image,
383
+ bounding_box_output=False,
384
+ text_classification=True,
385
+ image_classification=for_image,
386
+ rationale_generation=True,
387
+ keyword_extraction=not for_image,
388
+ )
389
+ return capabilities
390
+
391
+ def _get_ai_with_vision_support(self, text: str, prompt: str, output_format) -> str:
392
+ """
393
+ Get AI response, using vision if text is an image URL and endpoint supports it.
394
+ """
395
+ # Check if we should use vision
396
+ if self.endpoint_supports_vision and _is_image_url(text):
397
+ logger.debug(f"Using vision query for image URL: {text[:50]}...")
398
+ image_data = _get_image_data_from_url(text)
399
+ if image_data:
400
+ try:
401
+ return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
402
+ except Exception as e:
403
+ logger.error(f"Vision query failed: {e}")
404
+ # Fall back to text query
405
+
406
+ # Fall back to regular text query
407
+ return self.ai_endpoint.query(prompt, output_format)
408
+
409
+ def start_warmup(self):
410
+ self.start_prefetch(0, self.warm_up_page_count)
411
+
412
+ # Also prefetch option highlights if enabled
413
+ if self.option_highlighting_enabled:
414
+ self.start_option_highlight_prefetch(0, self.warm_up_page_count)
415
+
416
+ total = len(self.in_progress)
417
+ desc = "Preloading the AI"
418
+
419
+ progress_bar = tqdm(total=total, desc=desc, unit="item")
420
+
421
+ def count_completed():
422
+ return total - len(self.in_progress)
423
+
424
+ prev_done = 0
425
+ while self.in_progress:
426
+ current_done = count_completed()
427
+ progress_bar.update(current_done - prev_done)
428
+ prev_done = current_done
429
+ time.sleep(0.2)
430
+
431
+ final_done = count_completed()
432
+ if final_done > prev_done:
433
+ progress_bar.update(final_done - prev_done)
434
+
435
+ progress_bar.close()
436
+
437
+ def load_disk_cache_data(self, file_path: str) -> Dict[str, Any]:
438
+ """loads the cache JSON from disk and returns a dictionary of stringified keys to values."""
439
+ try:
440
+ with open(file_path, 'r', encoding='utf-8') as f:
441
+ return json.load(f)
442
+ except Exception as e:
443
+ logger.error(f"Error loading disk cache: {e}")
444
+ return {}
445
+
446
+ def load_cache_from_disk(self):
447
+ """Initializes disk cache file if it doesn't exist."""
448
+ if not self.disk_cache_enabled or not self.disk_persistence_path:
449
+ return
450
+
451
+ if os.path.exists(self.disk_persistence_path):
452
+ data = self.load_disk_cache_data(self.disk_persistence_path)
453
+ logger.info(f"Disk cache initialized with {len(data)} items")
454
+ else:
455
+ try:
456
+ # Create parent directory if it doesn't exist
457
+ os.makedirs(os.path.dirname(self.disk_persistence_path), exist_ok=True)
458
+ with open(self.disk_persistence_path, 'w', encoding='utf-8') as file:
459
+ json.dump({}, file)
460
+ logger.info(f"Initialized empty disk cache at {self.disk_persistence_path}")
461
+ except Exception as e:
462
+ logger.error(f"Failed to create disk cache: {e}")
463
+
464
+ def save_cache_to_disk(self, key, value):
465
+ """saves a single key-value pair to disk cache using atomic write."""
466
+ if not self.disk_cache_enabled or not self.disk_persistence_path:
467
+ return
468
+
469
+ try:
470
+ os.makedirs(os.path.dirname(self.disk_persistence_path), exist_ok=True)
471
+
472
+ # Load existing disk data first
473
+ existing_disk_data = {}
474
+ if os.path.exists(self.disk_persistence_path):
475
+ existing_disk_data = self.load_disk_cache_data(self.disk_persistence_path)
476
+
477
+ # Add the new key-value pair
478
+ existing_disk_data[str(key)] = value
479
+
480
+ temp_path = self.disk_persistence_path + ".tmp"
481
+ with open(temp_path, 'w', encoding='utf-8') as f:
482
+ json.dump(existing_disk_data, f, indent=2, ensure_ascii=False)
483
+ os.rename(temp_path, self.disk_persistence_path)
484
+ except Exception as e:
485
+ logger.error(f"Error saving cache to disk: {e}")
486
+
487
+ def add_to_cache(self, key, value):
488
+ """inserts a key-value into the disk cache."""
489
+ with self.lock:
490
+ if self.disk_cache_enabled:
491
+ self.save_cache_to_disk(key, value)
492
+
493
+ def get_from_cache(self, key):
494
+ """Tries to retrieve the item from disk cache."""
495
+ with self.lock:
496
+ # Try disk cache
497
+ if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
498
+ try:
499
+ disk_data = self.load_disk_cache_data(self.disk_persistence_path)
500
+ key_str = str(key)
501
+ if key_str in disk_data:
502
+ return disk_data[key_str]
503
+ except Exception as e:
504
+ logger.error(f"Error reading from disk: {e}")
505
+ return None
506
+
507
+ def generate_likert(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
508
+ from string import Template
509
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
510
+ description = _get_scheme_field(annotation_id, "description")
511
+ text = _get_instance_text(instance_id)
512
+ min_label = _get_scheme_field(annotation_id, "min_label")
513
+ max_label = _get_scheme_field(annotation_id, "max_label")
514
+ size = _get_scheme_field(annotation_id, "size")
515
+
516
+ ai_prompt = get_ai_prompt()
517
+ output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
518
+
519
+ # Check if we should use vision endpoint for image-based content
520
+ if self.endpoint_supports_vision and _is_image_url(text):
521
+ logger.debug(f"Using vision for likert {ai_assistant} on image: {text[:50]}...")
522
+ image_data = _get_image_data_from_url(text)
523
+ if image_data:
524
+ # Build vision-specific prompts based on ai_assistant type
525
+ if ai_assistant == "hint":
526
+ prompt = f"""Look at this image and help with the following annotation task:
527
+
528
+ Task: {description}
529
+ Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
530
+
531
+ Please analyze the image and suggest an appropriate rating with a brief explanation.
532
+ Respond in JSON format: {{"hint": "<explanation>", "suggestive_choice": "<rating label>"}}"""
533
+ elif ai_assistant == "rationale":
534
+ prompt = f"""Look at this image and explain the reasoning for different rating choices:
535
+
536
+ Task: {description}
537
+ Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
538
+
539
+ For each possible rating, explain what visual evidence in the image would support that rating.
540
+ Respond in JSON format: {{"rationales": [{{"label": "<rating>", "reasoning": "<explanation>"}}]}}"""
541
+ elif ai_assistant == "keyword":
542
+ prompt = f"""Look at this image and identify visual features relevant to the rating task:
543
+
544
+ Task: {description}
545
+ Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
546
+
547
+ Identify key visual elements that would influence the rating.
548
+ Respond in JSON format: {{"keywords": ["<visual_feature_1>", "<visual_feature_2>"]}}"""
549
+ else:
550
+ prompt = f"Analyze this image for: {description}"
551
+
552
+ try:
553
+ return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
554
+ except Exception as e:
555
+ logger.error(f"Vision query failed for likert {ai_assistant}: {e}")
556
+
557
+ # Fall back to standard text-based generation
558
+ data = AnnotationInput(
559
+ ai_assistant=ai_assistant,
560
+ annotation_type=annotation_type,
561
+ text=text,
562
+ description=description,
563
+ min_label=min_label,
564
+ max_label=max_label,
565
+ size=size
566
+ )
567
+ res = self.ai_endpoint.get_ai(data, output_format)
568
+ return res
569
+
570
+ def generate_multiselect(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
571
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
572
+ description = _get_scheme_field(annotation_id, "description")
573
+ labels = _get_scheme_field(annotation_id, "labels")
574
+ text = _get_instance_text(instance_id)
575
+
576
+ ai_prompt = get_ai_prompt()
577
+ output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
578
+
579
+ # Check if we should use vision endpoint for image-based content
580
+ if self.endpoint_supports_vision and _is_image_url(text):
581
+ logger.debug(f"Using vision for multiselect {ai_assistant} on image: {text[:50]}...")
582
+ image_data = _get_image_data_from_url(text)
583
+ if image_data:
584
+ # Format labels for the prompt
585
+ label_names = [l.get('name', l) if isinstance(l, dict) else l for l in labels]
586
+ labels_str = ', '.join(f'"{name}"' for name in label_names)
587
+
588
+ # Build vision-specific prompts based on ai_assistant type
589
+ if ai_assistant == "hint":
590
+ prompt = f"""Look at this image and help with the following annotation task:
591
+
592
+ Task: {description}
593
+ Available options (select all that apply): {labels_str}
594
+
595
+ Please analyze the image and suggest which options apply.
596
+ Respond in JSON format: {{"hint": "<explanation>", "suggestive_choices": ["<option1>", "<option2>"]}}"""
597
+ elif ai_assistant == "rationale":
598
+ prompt = f"""Look at this image and explain the reasoning for each option:
599
+
600
+ Task: {description}
601
+ Available options: {labels_str}
602
+
603
+ For each option, explain what visual evidence supports or contradicts it.
604
+ Respond in JSON format: {{"rationales": [{{"label": "<option>", "reasoning": "<explanation>"}}]}}"""
605
+ elif ai_assistant == "keyword":
606
+ prompt = f"""Look at this image and identify visual features for each option:
607
+
608
+ Task: {description}
609
+ Available options: {labels_str}
610
+
611
+ For each option, identify visual cues that indicate its presence.
612
+ Respond in JSON format: {{"label_keywords": [{{"label": "<option>", "keywords": ["<feature1>", "<feature2>"]}}]}}"""
613
+ else:
614
+ prompt = f"Analyze this image for: {description}. Options: {labels_str}"
615
+
616
+ try:
617
+ return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
618
+ except Exception as e:
619
+ logger.error(f"Vision query failed for multiselect {ai_assistant}: {e}")
620
+
621
+ # Fall back to standard text-based generation
622
+ data = AnnotationInput(
623
+ ai_assistant=ai_assistant,
624
+ annotation_type=annotation_type,
625
+ text=text,
626
+ description=description,
627
+ labels=labels
628
+ )
629
+ res = self.ai_endpoint.get_ai(data, output_format)
630
+ return res
631
+
632
+ def generate_radio(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
633
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
634
+ description = _get_scheme_field(annotation_id, "description")
635
+ text = _get_instance_text(instance_id)
636
+ labels = _get_scheme_field(annotation_id, "labels")
637
+
638
+ ai_prompt = get_ai_prompt()
639
+ output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
640
+
641
+ # Check if we should use vision endpoint for image-based content
642
+ if self.endpoint_supports_vision and _is_image_url(text):
643
+ logger.debug(f"Using vision for radio {ai_assistant} on image: {text[:50]}...")
644
+ image_data = _get_image_data_from_url(text)
645
+ if image_data:
646
+ # Format labels for the prompt
647
+ label_names = [l.get('name', l) if isinstance(l, dict) else l for l in labels]
648
+ labels_str = ', '.join(f'"{name}"' for name in label_names)
649
+
650
+ # Build vision-specific prompts based on ai_assistant type
651
+ if ai_assistant == "hint":
652
+ prompt = f"""Look at this image and help with the following annotation task:
653
+
654
+ Task: {description}
655
+ Available options: {labels_str}
656
+
657
+ Please analyze the image and suggest the most appropriate option.
658
+ Respond in JSON format: {{"hint": "<explanation>", "suggestive_choice": "<selected option>"}}"""
659
+ elif ai_assistant == "rationale":
660
+ prompt = f"""Look at this image and explain the reasoning for each option:
661
+
662
+ Task: {description}
663
+ Available options: {labels_str}
664
+
665
+ For each option, explain what visual evidence in the image supports or contradicts it.
666
+ Respond in JSON format: {{"rationales": [{{"label": "<option>", "reasoning": "<explanation>"}}]}}"""
667
+ elif ai_assistant == "keyword":
668
+ prompt = f"""Look at this image and identify visual features for each option:
669
+
670
+ Task: {description}
671
+ Available options: {labels_str}
672
+
673
+ For each option, identify visual cues that would indicate its presence.
674
+ Respond in JSON format: {{"label_keywords": [{{"label": "<option>", "keywords": ["<feature1>", "<feature2>"]}}]}}"""
675
+ else:
676
+ prompt = f"Analyze this image for: {description}. Options: {labels_str}"
677
+
678
+ try:
679
+ return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
680
+ except Exception as e:
681
+ logger.error(f"Vision query failed for radio {ai_assistant}: {e}")
682
+
683
+ # Fall back to standard text-based generation
684
+ data = AnnotationInput(
685
+ ai_assistant=ai_assistant,
686
+ annotation_type=annotation_type,
687
+ text=text,
688
+ description=description,
689
+ labels=labels
690
+ )
691
+ res = self.ai_endpoint.get_ai(data, output_format)
692
+ return res
693
+
694
+ def generate_number(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
695
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
696
+ description = _get_scheme_field(annotation_id, "description")
697
+ text = _get_instance_text(instance_id)
698
+
699
+ data = AnnotationInput(
700
+ ai_assistant=ai_assistant,
701
+ annotation_type=annotation_type,
702
+ text=text,
703
+ description=description,
704
+ )
705
+ ai_prompt = get_ai_prompt();
706
+ output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
707
+ res = self.ai_endpoint.get_ai(data, output_format)
708
+ return res
709
+
710
+ def generate_select(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
711
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
712
+ description = _get_scheme_field(annotation_id, "description")
713
+ labels = _get_scheme_field(annotation_id, "labels")
714
+ text = _get_instance_text(instance_id)
715
+
716
+
717
+ data = AnnotationInput(
718
+ ai_assistant=ai_assistant,
719
+ annotation_type=annotation_type,
720
+ text=text,
721
+ description=description,
722
+ labels=labels
723
+ )
724
+ ai_prompt = get_ai_prompt();
725
+ output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
726
+ res = self.ai_endpoint.get_ai(data, output_format)
727
+ return res
728
+
729
+ def generate_slider(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
730
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
731
+ description = _get_scheme_field(annotation_id, "description")
732
+ min_value = _get_scheme_field(annotation_id, "min_value")
733
+ max_value = _get_scheme_field(annotation_id, "max_value")
734
+ step = _get_scheme_field(annotation_id, "step", default=1)
735
+ text = _get_instance_text(instance_id)
736
+
737
+ data = AnnotationInput(
738
+ ai_assistant=ai_assistant,
739
+ annotation_type=annotation_type,
740
+ text=text,
741
+ description=description,
742
+ min_value=min_value,
743
+ max_value=max_value,
744
+ step=step
745
+ )
746
+
747
+ ai_prompt = get_ai_prompt();
748
+ output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
749
+ res = self.ai_endpoint.get_ai(data, output_format)
750
+ return res
751
+
752
+ def generate_span(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
753
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
754
+ description = _get_scheme_field(annotation_id, "description")
755
+ labels = _get_scheme_field(annotation_id, "labels")
756
+ text = _get_instance_text(instance_id)
757
+
758
+ data = AnnotationInput(
759
+ ai_assistant=ai_assistant,
760
+ annotation_type=annotation_type,
761
+ text=text,
762
+ description=description,
763
+ labels=labels
764
+ )
765
+ ai_prompt = get_ai_prompt();
766
+ logger.debug(f"Generating span annotation with labels: {labels}")
767
+ output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
768
+ res = self.ai_endpoint.get_ai(data, output_format)
769
+ return res
770
+
771
+ def generate_textbox(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
772
+ logger.debug(f"Generating textbox for annotation_id: {annotation_id}")
773
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
774
+ description = _get_scheme_field(annotation_id, "description")
775
+ text = _get_instance_text(instance_id)
776
+
777
+ data = AnnotationInput(
778
+ ai_assistant=ai_assistant,
779
+ annotation_type=annotation_type,
780
+ text=text,
781
+ description=description,
782
+ )
783
+ ai_prompt = get_ai_prompt();
784
+ output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
785
+ res = self.ai_endpoint.get_ai(data, output_format)
786
+ return res
787
+
788
+ def generate_image_annotation(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
789
+ """Generate AI assistance for image annotation tasks.
790
+
791
+ Args:
792
+ instance_id: The instance/item index
793
+ annotation_id: The annotation scheme index
794
+ ai_assistant: Type of assistance ('detection', 'classification', 'hint', 'pre_annotate', etc.)
795
+
796
+ Returns:
797
+ Dict with AI suggestions (detections, classifications, hints, etc.)
798
+ """
799
+ logger.debug(f"Generating image annotation for instance={instance_id}, annotation={annotation_id}, assistant={ai_assistant}")
800
+
801
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
802
+ description = _get_scheme_field(annotation_id, "description", default="")
803
+ labels = _get_scheme_field(annotation_id, "labels", default=[])
804
+
805
+ # Extract label names if labels are dicts
806
+ if labels and isinstance(labels[0], dict):
807
+ labels = [l.get("name", str(l)) for l in labels]
808
+
809
+ # Get image URL from item data
810
+ item_data = get_item_state_manager().items()[instance_id].get_data()
811
+ image_url = self._extract_image_url(item_data)
812
+
813
+ if not image_url:
814
+ return {"error": "No image URL found in instance data"}
815
+
816
+ # Determine which endpoint to use
817
+ endpoint = self._get_visual_endpoint()
818
+ if not endpoint:
819
+ return {"error": "No visual AI endpoint configured"}
820
+
821
+ # Check if endpoint supports visual queries
822
+ if not hasattr(endpoint, 'query_with_image'):
823
+ # Fall back to text-based hint
824
+ return self._generate_text_hint_for_visual(instance_id, annotation_id, ai_assistant)
825
+
826
+ # Prepare image data
827
+ image_data = self._prepare_image_data(image_url)
828
+
829
+ # Get confidence threshold from config
830
+ confidence_threshold = _get_scheme_field(annotation_id, "ai_support", default={}).get(
831
+ "confidence_threshold", 0.5
832
+ )
833
+
834
+ # Build VisualAnnotationInput
835
+ data = VisualAnnotationInput(
836
+ ai_assistant=ai_assistant,
837
+ annotation_type=annotation_type,
838
+ task_type=ai_assistant, # detection, classification, hint, etc.
839
+ image_data=image_data,
840
+ description=description,
841
+ labels=labels,
842
+ confidence_threshold=confidence_threshold
843
+ )
844
+
845
+ # Get output format from prompt config
846
+ ai_prompt = get_ai_prompt()
847
+ prompt_config = ai_prompt.get(annotation_type, {}).get(ai_assistant, {})
848
+ output_format_name = prompt_config.get("output_format", "visual_detection")
849
+ output_format = self.model_manager.get_model_class_by_name(output_format_name)
850
+
851
+ # Query the visual endpoint
852
+ result = endpoint.get_visual_ai(data, output_format)
853
+ return result
854
+
855
+ def generate_video_annotation(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
856
+ """Generate AI assistance for video annotation tasks.
857
+
858
+ Args:
859
+ instance_id: The instance/item index
860
+ annotation_id: The annotation scheme index
861
+ ai_assistant: Type of assistance ('scene_detection', 'frame_classification', etc.)
862
+
863
+ Returns:
864
+ Dict with AI suggestions (segments, keyframes, etc.)
865
+ """
866
+ logger.debug(f"Generating video annotation for instance={instance_id}, annotation={annotation_id}, assistant={ai_assistant}")
867
+
868
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type")
869
+ description = _get_scheme_field(annotation_id, "description", default="")
870
+ labels = _get_scheme_field(annotation_id, "labels", default=[])
871
+
872
+ # Extract label names if labels are dicts
873
+ if labels and isinstance(labels[0], dict):
874
+ labels = [l.get("name", str(l)) for l in labels]
875
+
876
+ # Get video URL from item data
877
+ item_data = get_item_state_manager().items()[instance_id].get_data()
878
+ video_url = self._extract_video_url(item_data)
879
+
880
+ if not video_url:
881
+ return {"error": "No video URL found in instance data"}
882
+
883
+ # Determine which endpoint to use
884
+ endpoint = self._get_visual_endpoint()
885
+ if not endpoint:
886
+ return {"error": "No visual AI endpoint configured"}
887
+
888
+ # Check if endpoint supports visual queries
889
+ if not hasattr(endpoint, 'query_with_image'):
890
+ return self._generate_text_hint_for_visual(instance_id, annotation_id, ai_assistant)
891
+
892
+ # Extract video frames
893
+ try:
894
+ frames = endpoint.extract_video_frames(video_url)
895
+ video_metadata = endpoint.get_video_metadata(video_url)
896
+ except Exception as e:
897
+ logger.error(f"Failed to extract video frames: {e}")
898
+ return {"error": f"Failed to process video: {str(e)}"}
899
+
900
+ # Build VisualAnnotationInput
901
+ data = VisualAnnotationInput(
902
+ ai_assistant=ai_assistant,
903
+ annotation_type=annotation_type,
904
+ task_type=ai_assistant,
905
+ image_data=frames, # List of frame images
906
+ description=description,
907
+ labels=labels,
908
+ video_metadata=video_metadata
909
+ )
910
+
911
+ # Get output format
912
+ ai_prompt = get_ai_prompt()
913
+ prompt_config = ai_prompt.get(annotation_type, {}).get(ai_assistant, {})
914
+ output_format_name = prompt_config.get("output_format", "video_scene_detection")
915
+ output_format = self.model_manager.get_model_class_by_name(output_format_name)
916
+
917
+ # Query the visual endpoint
918
+ result = endpoint.get_visual_ai(data, output_format)
919
+ return result
920
+
921
+ def _get_visual_endpoint(self):
922
+ """Get the appropriate endpoint for visual tasks."""
923
+ # Use dedicated visual endpoint if configured
924
+ if self.visual_endpoint:
925
+ return self.visual_endpoint
926
+
927
+ # Check if main endpoint supports vision
928
+ if hasattr(self.ai_endpoint, 'query_with_image'):
929
+ return self.ai_endpoint
930
+
931
+ # Try to find a visual endpoint from registered types
932
+ visual_types = ['yolo', 'ollama_vision', 'openai_vision', 'anthropic_vision']
933
+ for vtype in visual_types:
934
+ if vtype in AIEndpointFactory._endpoints:
935
+ try:
936
+ visual_config = {
937
+ "ai_support": {
938
+ "enabled": True,
939
+ "endpoint_type": vtype,
940
+ "ai_config": config.get("ai_support", {}).get("ai_config", {})
941
+ }
942
+ }
943
+ return AIEndpointFactory.create_endpoint(visual_config)
944
+ except Exception as e:
945
+ logger.debug(f"Could not create {vtype} endpoint: {e}")
946
+ continue
947
+
948
+ return None
949
+
950
+ def _extract_image_url(self, item_data: Dict) -> str:
951
+ """Extract image URL from item data.
952
+
953
+ Looks for common field names that might contain image URLs.
954
+ """
955
+ # Common field names for images
956
+ image_fields = ['image', 'image_url', 'img', 'img_url', 'url', 'path', 'file', 'src']
957
+
958
+ for field in image_fields:
959
+ if field in item_data:
960
+ value = item_data[field]
961
+ if isinstance(value, str) and (
962
+ value.startswith(('http://', 'https://', '/')) or
963
+ value.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))
964
+ ):
965
+ return value
966
+
967
+ # Check 'text' field for URL (common in simple configs)
968
+ if 'text' in item_data:
969
+ text = item_data['text']
970
+ if isinstance(text, str) and (
971
+ text.startswith(('http://', 'https://')) and
972
+ any(ext in text.lower() for ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp'])
973
+ ):
974
+ return text
975
+
976
+ return None
977
+
978
+ def _extract_video_url(self, item_data: Dict) -> str:
979
+ """Extract video URL from item data."""
980
+ # Common field names for videos
981
+ video_fields = ['video', 'video_url', 'url', 'path', 'file', 'src', 'media']
982
+
983
+ for field in video_fields:
984
+ if field in item_data:
985
+ value = item_data[field]
986
+ if isinstance(value, str) and (
987
+ value.startswith(('http://', 'https://', '/')) or
988
+ value.endswith(('.mp4', '.webm', '.ogg', '.avi', '.mov'))
989
+ ):
990
+ return value
991
+
992
+ # Check 'text' field for URL
993
+ if 'text' in item_data:
994
+ text = item_data['text']
995
+ if isinstance(text, str) and (
996
+ text.startswith(('http://', 'https://')) and
997
+ any(ext in text.lower() for ext in ['.mp4', '.webm', '.ogg', '.avi', '.mov'])
998
+ ):
999
+ return text
1000
+
1001
+ return None
1002
+
1003
+ def _prepare_image_data(self, image_url: str) -> ImageData:
1004
+ """Prepare ImageData from URL or path."""
1005
+ if image_url.startswith(('http://', 'https://')):
1006
+ return ImageData(source="url", data=image_url)
1007
+ else:
1008
+ # Local file path - encode as base64
1009
+ from potato.ai.visual_ai_endpoint import BaseVisualAIEndpoint
1010
+ return BaseVisualAIEndpoint.encode_image_to_base64(image_url)
1011
+
1012
+ def _generate_text_hint_for_visual(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
1013
+ """Generate text-based hint when visual endpoint is not available."""
1014
+ description = config["annotation_schemes"][annotation_id].get("description", "")
1015
+ labels = config["annotation_schemes"][annotation_id].get("labels", [])
1016
+
1017
+ if labels and isinstance(labels[0], dict):
1018
+ labels = [l.get("name", str(l)) for l in labels]
1019
+
1020
+ return {
1021
+ "hint": f"Review the {'image' if 'image' in config['annotation_schemes'][annotation_id]['annotation_type'] else 'video'} carefully. "
1022
+ f"Look for: {', '.join(labels) if labels else 'relevant content'}. "
1023
+ f"Task: {description}",
1024
+ "suggestive_choice": ""
1025
+ }
1026
+
1027
+ def is_option_highlighting_enabled_for_scheme(self, annotation_id: int) -> bool:
1028
+ """Check if option highlighting is enabled for a specific annotation scheme."""
1029
+ if not self.option_highlighting_enabled:
1030
+ return False
1031
+
1032
+ scheme = config["annotation_schemes"][annotation_id]
1033
+ annotation_type = scheme.get("annotation_type", "")
1034
+ scheme_name = scheme.get("name", "")
1035
+
1036
+ # Only applicable to discrete option types
1037
+ discrete_types = ["radio", "multiselect", "likert", "select"]
1038
+ if annotation_type not in discrete_types:
1039
+ return False
1040
+
1041
+ # Check if schemas filter is set
1042
+ if self.option_highlighting_schemas is not None:
1043
+ if scheme_name not in self.option_highlighting_schemas:
1044
+ return False
1045
+
1046
+ return True
1047
+
1048
+ def get_option_highlighting_config(self) -> Dict:
1049
+ """Get the option highlighting configuration for the frontend."""
1050
+ return {
1051
+ "enabled": self.option_highlighting_enabled,
1052
+ "top_k": self.option_highlighting_top_k,
1053
+ "dim_opacity": self.option_highlighting_dim_opacity,
1054
+ "auto_apply": self.option_highlighting_auto_apply,
1055
+ "schemas": self.option_highlighting_schemas,
1056
+ "prefetch_count": self.option_highlighting_prefetch_count,
1057
+ }
1058
+
1059
+ def generate_option_highlights(self, instance_id: int, annotation_id: int) -> Dict:
1060
+ """Generate option highlighting suggestions for an annotation.
1061
+
1062
+ Args:
1063
+ instance_id: The instance/item index
1064
+ annotation_id: The annotation scheme index
1065
+
1066
+ Returns:
1067
+ Dict with highlighted options and configuration:
1068
+ {
1069
+ "highlighted": ["option1", "option2"],
1070
+ "top_k": 3,
1071
+ "confidence": 0.85
1072
+ }
1073
+ """
1074
+ from string import Template
1075
+
1076
+ if not self.is_option_highlighting_enabled_for_scheme(annotation_id):
1077
+ return {"error": "Option highlighting not enabled for this scheme"}
1078
+
1079
+ annotation_type = _get_scheme_field(annotation_id, "annotation_type", default="")
1080
+ description = _get_scheme_field(annotation_id, "description", default="")
1081
+ labels = _get_scheme_field(annotation_id, "labels", default=[])
1082
+
1083
+ # Extract label names
1084
+ if labels and isinstance(labels[0], dict):
1085
+ label_names = [l.get("name", str(l)) for l in labels]
1086
+ else:
1087
+ label_names = [str(l) for l in labels]
1088
+
1089
+ # For likert scales, generate label names from min/max labels
1090
+ if annotation_type == "likert":
1091
+ size = scheme.get("size", 5)
1092
+ min_label = scheme.get("min_label", "1")
1093
+ max_label = scheme.get("max_label", str(size))
1094
+ label_names = [f"{i+1} ({min_label if i == 0 else max_label if i == size-1 else ''})" for i in range(size)]
1095
+ # Clean up empty parentheses
1096
+ label_names = [l.replace(" ()", "") for l in label_names]
1097
+
1098
+ text = _get_instance_text(instance_id)
1099
+ top_k = min(self.option_highlighting_top_k, len(label_names))
1100
+
1101
+ # Get prompt template
1102
+ ai_prompt = get_ai_prompt()
1103
+ prompt_config = ai_prompt.get("option_highlight", {}).get("option_highlight", {})
1104
+
1105
+ if not prompt_config:
1106
+ return {"error": "Option highlight prompt not configured"}
1107
+
1108
+ prompt_template = prompt_config.get("prompt", "")
1109
+ output_format_name = prompt_config.get("output_format", "option_highlight")
1110
+ output_format = self.model_manager.get_model_class_by_name(output_format_name)
1111
+
1112
+ # Build the prompt with clear delimiters to mitigate prompt injection.
1113
+ # The user content is wrapped in XML-style tags so the LLM can
1114
+ # distinguish between instructions and untrusted data.
1115
+ delimited_text = (
1116
+ f"<user_content>\n{text}\n</user_content>"
1117
+ )
1118
+ template = Template(prompt_template)
1119
+ prompt = template.safe_substitute(
1120
+ text=delimited_text,
1121
+ description=description,
1122
+ labels=", ".join(label_names),
1123
+ top_k=top_k
1124
+ )
1125
+
1126
+ # Query the AI endpoint
1127
+ try:
1128
+ result = self.ai_endpoint.query(prompt, output_format)
1129
+ logger.debug(f"Option highlight raw result: {result}")
1130
+
1131
+ # Parse the result
1132
+ if isinstance(result, str):
1133
+ import json as json_module
1134
+ try:
1135
+ # Try to parse JSON from the response
1136
+ result = json_module.loads(result)
1137
+ except json_module.JSONDecodeError:
1138
+ # Try to extract JSON from markdown code block
1139
+ if "```json" in result:
1140
+ json_start = result.find("```json") + 7
1141
+ json_end = result.find("```", json_start)
1142
+ result = json_module.loads(result[json_start:json_end].strip())
1143
+ elif "```" in result:
1144
+ json_start = result.find("```") + 3
1145
+ json_end = result.find("```", json_start)
1146
+ result = json_module.loads(result[json_start:json_end].strip())
1147
+ else:
1148
+ return {"error": f"Could not parse response: {result[:100]}"}
1149
+
1150
+ highlighted = result.get("highlighted_options", [])
1151
+ confidence = result.get("confidence", None)
1152
+
1153
+ # Validate highlighted options against available labels
1154
+ valid_highlighted = [opt for opt in highlighted if opt in label_names]
1155
+
1156
+ return {
1157
+ "highlighted": valid_highlighted[:top_k],
1158
+ "top_k": top_k,
1159
+ "confidence": confidence
1160
+ }
1161
+
1162
+ except Exception as e:
1163
+ logger.error(f"Error generating option highlights: {e}")
1164
+ return {"error": str(e)}
1165
+
1166
+ def get_option_highlights(self, instance_id: int, annotation_id: int) -> Dict:
1167
+ """Get option highlights from cache or generate them.
1168
+
1169
+ Args:
1170
+ instance_id: The instance/item index
1171
+ annotation_id: The annotation scheme index
1172
+
1173
+ Returns:
1174
+ Dict with highlighted options
1175
+ """
1176
+ key = (instance_id, annotation_id, "option_highlight")
1177
+
1178
+ # Try cache first
1179
+ if self.disk_cache_enabled:
1180
+ cached = self.get_from_cache(key)
1181
+ if cached is not None:
1182
+ logger.debug(f"Option highlight cache hit for {key}")
1183
+ return cached
1184
+
1185
+ # Generate
1186
+ result = self.generate_option_highlights(instance_id, annotation_id)
1187
+
1188
+ # Cache if successful
1189
+ if "error" not in result and self.disk_cache_enabled:
1190
+ self.add_to_cache(key, result)
1191
+
1192
+ return result
1193
+
1194
+ def start_option_highlight_prefetch(self, page_id: int, prefetch_amount: int = None):
1195
+ """Prefetch option highlights for upcoming items.
1196
+
1197
+ Args:
1198
+ page_id: Current page/instance index
1199
+ prefetch_amount: Number of items to prefetch (uses config default if None)
1200
+ """
1201
+ if not self.option_highlighting_enabled or not self.disk_cache_enabled:
1202
+ return
1203
+
1204
+ if prefetch_amount is None:
1205
+ prefetch_amount = self.option_highlighting_prefetch_count
1206
+
1207
+ ism = get_item_state_manager()
1208
+ with self.lock:
1209
+ # Calculate range
1210
+ if prefetch_amount >= 0:
1211
+ start_idx = page_id
1212
+ end_idx = min(start_idx + prefetch_amount, len(ism.items()))
1213
+ else:
1214
+ start_idx = max(page_id + prefetch_amount, 0)
1215
+ end_idx = page_id
1216
+
1217
+ keys = []
1218
+ for i in range(start_idx, end_idx):
1219
+ for annotation_id, scheme in enumerate(config["annotation_schemes"]):
1220
+ if self.is_option_highlighting_enabled_for_scheme(annotation_id):
1221
+ key = (i, annotation_id, "option_highlight")
1222
+ # Check if not already cached or in progress
1223
+ if self.get_from_cache(key) is None and key not in self.in_progress:
1224
+ keys.append(key)
1225
+
1226
+ # Submit prefetch jobs
1227
+ for key in keys:
1228
+ instance_id, annotation_id, _ = key
1229
+ future = self.executor.submit(self.generate_option_highlights, instance_id, annotation_id)
1230
+ self.in_progress[key] = future
1231
+
1232
+ def callback(fut, cache_key=key):
1233
+ with self.lock:
1234
+ try:
1235
+ result = fut.result()
1236
+ if "error" not in result:
1237
+ self.add_to_cache(cache_key, result)
1238
+ except Exception as e:
1239
+ logger.error(f"Option highlight prefetch failed for {cache_key}: {e}")
1240
+ self.in_progress.pop(cache_key, None)
1241
+
1242
+ future.add_done_callback(callback)
1243
+
1244
+ if keys:
1245
+ logger.debug(f"Started option highlight prefetch for {len(keys)} items")
1246
+
1247
+ def get_include_all(self):
1248
+ return self.include_all
1249
+
1250
+ def get_special_include(self, page_number_int, annotation_id_int):
1251
+ logger.debug(f"get_special_include: page={page_number_int}, annotation_id={annotation_id_int}")
1252
+ if not self.special_includes.get(page_number_int):
1253
+ return None
1254
+ elif not self.special_includes.get(page_number_int).get(annotation_id_int):
1255
+ return None
1256
+ return self.special_includes.get(page_number_int).get(annotation_id_int)
1257
+
1258
+ def start_prefetch(self, page_id, prefetch_amount):
1259
+ """Prefetches a fixed number of upcoming items to warm the cache."""
1260
+ if not config.get("ai_support", {}).get("enabled") or not self.disk_cache_enabled:
1261
+ return
1262
+
1263
+ ism = get_item_state_manager()
1264
+ with self.lock:
1265
+ # Calculate range bounds
1266
+ if prefetch_amount >= 0:
1267
+ start_idx = page_id
1268
+ end_idx = min(start_idx + prefetch_amount, len(ism.items()))
1269
+ else:
1270
+ start_idx = max(page_id - prefetch_amount, 0)
1271
+ end_idx = page_id
1272
+
1273
+ logger.debug(f"Prefetch range: start_idx={start_idx}, end_idx={end_idx}")
1274
+ keys = []
1275
+
1276
+ for i in range(start_idx, end_idx):
1277
+ # Check if this page should be included
1278
+ if not self.should_include_page(i):
1279
+ continue
1280
+
1281
+ # Process each annotation scheme for this page
1282
+ for annotation_id, scheme in enumerate(config["annotation_schemes"]):
1283
+ if not self.should_include_scheme(i, annotation_id):
1284
+ continue
1285
+
1286
+ annotation_type = scheme["annotation_type"]
1287
+ ai_prompt = get_ai_prompt()
1288
+
1289
+
1290
+ if not ai_prompt[annotation_type]:
1291
+ raise Exception(f"{annotation_type} is not defined in ai_prompt")
1292
+
1293
+ # Generate keys for this page/scheme combination
1294
+ scheme_keys = self.get_keys_for_scheme(i, annotation_type, annotation_id, ai_prompt)
1295
+ keys.extend(scheme_keys)
1296
+
1297
+ if keys:
1298
+ self.prefetch(keys)
1299
+
1300
+ def should_include_page(self, page_index):
1301
+ """Determine if a page should be included based on include_all and special_includes."""
1302
+ if self.include_all:
1303
+ return True
1304
+ return page_index in self.special_includes
1305
+
1306
+ def should_include_scheme(self, page_index, annotation_id):
1307
+ """Determine if a scheme should be included for a given page."""
1308
+ if self.include_all:
1309
+ return True
1310
+
1311
+ # Check if page is in special_includes and scheme is specified
1312
+ if page_index in self.special_includes:
1313
+ page_includes = self.special_includes[page_index]
1314
+ # Handle both list and dict formats for page_includes
1315
+ if isinstance(page_includes, dict):
1316
+ return annotation_id in page_includes
1317
+ elif isinstance(page_includes, list):
1318
+ return annotation_id in page_includes
1319
+
1320
+ return False
1321
+
1322
+ def get_keys_for_scheme(self, page_index, annotation_type, annotation_id, ai_prompt):
1323
+ """Get all keys for a specific page combination."""
1324
+ keys = []
1325
+
1326
+ # Check if this page/annotation has specific overrides in special_includes
1327
+ if (page_index in self.special_includes and
1328
+ isinstance(self.special_includes[page_index], dict) and
1329
+ annotation_id in self.special_includes[page_index]):
1330
+
1331
+ # Use special_includes (overrides include_all setting)
1332
+ specified_keys = self.special_includes[page_index][annotation_id]
1333
+ for key in specified_keys:
1334
+ keys.append((page_index, annotation_id, key))
1335
+ elif self.include_all:
1336
+ # No specific override, so include all available keys for this annotation type
1337
+ for key in ai_prompt[annotation_type]:
1338
+ keys.append((page_index, annotation_id, key))
1339
+ # If include_all is False and no special_include entry, return empty keys
1340
+
1341
+ return keys
1342
+
1343
+ def prefetch(self, keys: list):
1344
+ """checks if keys are already cached and asynchronously generates missing ones"""
1345
+ with self.lock:
1346
+ for key in keys:
1347
+ if self.get_from_cache(key) is None and key not in self.in_progress:
1348
+ # i, annotation_id, annotation_type, ai_prompt
1349
+ instance_id, annotation_id, ai_assistant = key
1350
+
1351
+ future = self.executor.submit(self.compute_help, instance_id, annotation_id, ai_assistant)
1352
+ self.in_progress[key] = future
1353
+ def callback(fut, cache_key=key):
1354
+ with self.lock:
1355
+ try:
1356
+ result = fut.result()
1357
+ self.add_to_cache(cache_key, result)
1358
+ except Exception as e:
1359
+ logger.error(f"Prefetch failed for key {cache_key}: {e}")
1360
+ self.in_progress.pop(cache_key, None)
1361
+
1362
+ future.add_done_callback(callback)
1363
+
1364
+ def get_ai_help(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
1365
+ """retrieves AI help either from cache, waits for in-progress, or computes on-demand."""
1366
+ key = (instance_id, annotation_id, ai_assistant)
1367
+
1368
+ # Check if caching is enabled for this help type
1369
+ if not self.disk_cache_enabled:
1370
+ return self.compute_help(instance_id, annotation_id, ai_assistant)
1371
+
1372
+ # Try to get from cache if caching is enabled
1373
+ cached_value = self.get_from_cache(key)
1374
+ if cached_value is not None:
1375
+ logger.debug(f"Cache hit for key: {key}")
1376
+ return cached_value
1377
+
1378
+ with self.lock:
1379
+ if key in self.in_progress:
1380
+ future = self.in_progress[key]
1381
+ else:
1382
+ future = self.executor.submit(self.compute_help, instance_id, annotation_id, ai_assistant)
1383
+ self.in_progress[key] = future
1384
+ try:
1385
+ result = future.result(timeout=60)
1386
+ # Don't cache error responses
1387
+ is_error_response = (
1388
+ isinstance(result, str) and
1389
+ (result.startswith("Unable to generate") or
1390
+ result.startswith("Error:") or
1391
+ "error" in result.lower()[:50])
1392
+ )
1393
+ if self.disk_cache_enabled and not is_error_response:
1394
+ self.add_to_cache(key, result)
1395
+ elif is_error_response:
1396
+ logger.warning(f"Not caching error response for key {key}: {result[:100]}")
1397
+ with self.lock:
1398
+ self.in_progress.pop(key, None)
1399
+ return result
1400
+ except Exception as e:
1401
+ logger.error(f"Error computing help for key {key}: {e}")
1402
+ with self.lock:
1403
+ self.in_progress.pop(key, None)
1404
+ return f"Error: {str(e)}"
1405
+
1406
+ def compute_help(self, instance_id: int, annotation_id: int, ai_assistant: str):
1407
+ # Validate that the assistant type is compatible with the model and input
1408
+ is_valid, error_message = self._validate_assistant_compatibility(
1409
+ instance_id, annotation_id, ai_assistant
1410
+ )
1411
+ if not is_valid:
1412
+ logger.warning(f"Assistant compatibility check failed: {error_message}")
1413
+ return {"error": error_message}
1414
+
1415
+ annotation_type_str = config["annotation_schemes"][annotation_id]["annotation_type"]
1416
+ annotation_type = Annotation_Type(annotation_type_str)
1417
+ if annotation_type == Annotation_Type.LIKERT:
1418
+ return self.generate_likert(instance_id, annotation_id, ai_assistant)
1419
+ elif annotation_type == Annotation_Type.RADIO:
1420
+ return self.generate_radio(instance_id, annotation_id, ai_assistant)
1421
+ elif annotation_type == Annotation_Type.MULTISELECT:
1422
+ return self.generate_multiselect(instance_id, annotation_id, ai_assistant)
1423
+ elif annotation_type == Annotation_Type.NUMBER:
1424
+ return self.generate_number(instance_id, annotation_id, ai_assistant)
1425
+ elif annotation_type == Annotation_Type.SELECT:
1426
+ return self.generate_select(instance_id, annotation_id, ai_assistant)
1427
+ elif annotation_type == Annotation_Type.SLIDER:
1428
+ return self.generate_slider(instance_id, annotation_id, ai_assistant)
1429
+ elif annotation_type == Annotation_Type.SPAN:
1430
+ return self.generate_span(instance_id, annotation_id, ai_assistant)
1431
+ elif annotation_type == Annotation_Type.TEXTBOX:
1432
+ return self.generate_textbox(instance_id, annotation_id, ai_assistant)
1433
+ elif annotation_type == Annotation_Type.IMAGE_ANNOTATION:
1434
+ return self.generate_image_annotation(instance_id, annotation_id, ai_assistant)
1435
+ elif annotation_type == Annotation_Type.VIDEO_ANNOTATION:
1436
+ return self.generate_video_annotation(instance_id, annotation_id, ai_assistant)
1437
+ else:
1438
+ raise ValueError(f"Unknown annotation type: {annotation_type}")
1439
+
1440
+ def get_cache_stats(self) -> Dict[str, int]:
1441
+ """returns statistics on disk cache and in-progress cache entries."""
1442
+ with self.lock:
1443
+ disk_count = 0
1444
+ if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
1445
+ try:
1446
+ disk_data = self.load_disk_cache_data(self.disk_persistence_path)
1447
+ disk_count = len(disk_data)
1448
+ except:
1449
+ pass
1450
+
1451
+ return {
1452
+ 'disk_cache_enabled': self.disk_cache_enabled,
1453
+ 'cached_items_disk': disk_count,
1454
+ 'in_progress_items': len(self.in_progress)
1455
+ }
1456
+
1457
+ def clear_cache(self):
1458
+ """clears disk cache and cancels any ongoing generation."""
1459
+ with self.lock:
1460
+ for future in self.in_progress.values():
1461
+ future.cancel()
1462
+ self.in_progress.clear()
1463
+
1464
+ if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
1465
+ try:
1466
+ os.remove(self.disk_persistence_path)
1467
+ logger.info("Disk cache file removed")
1468
+ except Exception as e:
1469
+ logger.error(f"Error removing disk cache file: {e}")
1470
+ logger.info("Cache cleared")
1471
+
1472
+
1473
+
potato/ai/ai_endpoint.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Unified AI endpoint interface for various LLM providers.
4
+
5
+ This module provides a common interface for interacting with different LLM providers
6
+ including OpenAI, Anthropic, Hugging Face, Ollama, and VLLM endpoints.
7
+ """
8
+
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+ import logging
12
+ from abc import ABC, abstractmethod
13
+ import os
14
+ from typing import Dict, Any, Optional, List, Type, Union
15
+ import json
16
+ from string import Template
17
+
18
+ from pydantic import BaseModel
19
+
20
+ from .ai_prompt import get_ai_prompt
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class Annotation_Type(Enum):
26
+ RADIO = "radio"
27
+ LIKERT = "likert"
28
+ NUMBER = "number"
29
+ TEXTBOX = "text"
30
+ MULTISELECT = "multiselect"
31
+ SPAN = "span"
32
+ SELECT = "select"
33
+ SLIDER = "slider"
34
+ IMAGE_ANNOTATION = "image_annotation"
35
+ VIDEO_ANNOTATION = "video_annotation"
36
+
37
+
38
+ @dataclass
39
+ class ImageData:
40
+ """Data structure for image input to visual AI endpoints."""
41
+ source: str # 'url' | 'base64'
42
+ data: str # The URL or base64-encoded image data
43
+ width: Optional[int] = None
44
+ height: Optional[int] = None
45
+ mime_type: Optional[str] = None # e.g., 'image/jpeg', 'image/png'
46
+
47
+
48
+ @dataclass
49
+ class VisualAnnotationInput:
50
+ """Input data structure for visual annotation AI assistance."""
51
+ ai_assistant: str # 'detection', 'classification', 'hint', 'pre_annotate', etc.
52
+ annotation_type: str # 'image_annotation' | 'video_annotation'
53
+ task_type: str # Specific task: 'detection', 'classification', 'scene_detection', etc.
54
+ image_data: Union[ImageData, List[ImageData]] # Single image or list of frames
55
+ description: str # Task description from annotation scheme
56
+ labels: Optional[List[str]] = None # Available labels for the task
57
+ video_metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # fps, duration for video
58
+ region: Optional[Dict[str, float]] = None # Selected region for classification (x, y, width, height)
59
+ confidence_threshold: float = 0.5 # Minimum confidence for detections
60
+
61
+
62
+ @dataclass
63
+ class AnnotationInput:
64
+ ai_assistant: str
65
+ annotation_type: Annotation_Type
66
+ text: str
67
+ description: str
68
+ min_label: Optional[str] = ""
69
+ max_label: Optional[str] = ""
70
+ size: Optional[int] = -1
71
+ labels: Optional[List[str]] = None
72
+ min_value: Optional[int] = -1
73
+ max_value: Optional[int] = -1
74
+ step: Optional[int] = -1
75
+
76
+
77
+ @dataclass
78
+ class ModelCapabilities:
79
+ """
80
+ Declares what operations an AI endpoint can perform.
81
+
82
+ This dataclass is used to define the capabilities of different AI endpoints,
83
+ enabling the system to automatically filter AI assistant buttons and validate
84
+ requests based on what each model can actually do.
85
+
86
+ Attributes:
87
+ text_generation: Can generate text (hints, rationales, descriptions)
88
+ vision_input: Can process images as input
89
+ bounding_box_output: Can output precise coordinate detections
90
+ text_classification: Can classify text into categories
91
+ image_classification: Can classify images into categories
92
+ rationale_generation: Can generate explanations/rationales for labels
93
+ keyword_extraction: Can extract keywords from text (not applicable to images)
94
+ """
95
+ text_generation: bool = False
96
+ vision_input: bool = False
97
+ bounding_box_output: bool = False
98
+ text_classification: bool = False
99
+ image_classification: bool = False
100
+ rationale_generation: bool = False
101
+ keyword_extraction: bool = False
102
+
103
+ def supports_assistant(self, assistant_type: str, has_image_input: bool = False) -> bool:
104
+ """
105
+ Check if model supports a specific AI assistant type.
106
+
107
+ Args:
108
+ assistant_type: The type of AI assistant ('hint', 'keyword', 'rationale',
109
+ 'detection', 'pre_annotate', 'classification')
110
+ has_image_input: Whether the current content is an image
111
+
112
+ Returns:
113
+ True if the model supports this assistant type for the given input type
114
+ """
115
+ if assistant_type == "hint":
116
+ # Hints require text generation; for images, also need vision
117
+ if has_image_input:
118
+ return self.text_generation and self.vision_input
119
+ return self.text_generation
120
+
121
+ elif assistant_type == "keyword":
122
+ # Keywords require keyword extraction AND text input (not images)
123
+ # Keyword highlighting doesn't make sense for images
124
+ return self.keyword_extraction and not has_image_input
125
+
126
+ elif assistant_type == "rationale":
127
+ # Rationales require rationale generation; for images, also need vision
128
+ if has_image_input:
129
+ return self.rationale_generation and self.vision_input
130
+ return self.rationale_generation
131
+
132
+ elif assistant_type in ("detection", "detect", "pre_annotate"):
133
+ # Detection requires vision and bounding box output
134
+ return self.bounding_box_output and self.vision_input
135
+
136
+ elif assistant_type == "classification":
137
+ # Classification depends on input type
138
+ if has_image_input:
139
+ return self.image_classification and self.vision_input
140
+ return self.text_classification
141
+
142
+ # Unknown assistant type - default to False for safety
143
+ return False
144
+
145
+ def get_supported_assistants(self, has_image_input: bool = False) -> List[str]:
146
+ """
147
+ Get list of assistant types supported for the given input type.
148
+
149
+ Args:
150
+ has_image_input: Whether the current content is an image
151
+
152
+ Returns:
153
+ List of supported assistant type names
154
+ """
155
+ all_types = ["hint", "keyword", "rationale", "detection", "pre_annotate", "classification"]
156
+ return [t for t in all_types if self.supports_assistant(t, has_image_input)]
157
+
158
+
159
+ class AIEndpointError(Exception):
160
+ """Base exception for AI endpoint errors."""
161
+ pass
162
+
163
+
164
+ class AIEndpointConfigError(AIEndpointError):
165
+ """Exception raised for configuration errors."""
166
+ pass
167
+
168
+
169
+ class AIEndpointRequestError(AIEndpointError):
170
+ """Exception raised for request/API errors."""
171
+ pass
172
+
173
+
174
+ class BaseAIEndpoint(ABC):
175
+ """
176
+ Abstract base class for AI endpoints.
177
+
178
+ All AI endpoint implementations should inherit from this class
179
+ and implement the required methods.
180
+ """
181
+
182
+ def __init__(self, config: Dict[str, Any]):
183
+ """
184
+ Initialize the AI endpoint with configuration.
185
+
186
+ Args:
187
+ config: Configuration dictionary containing endpoint-specific settings
188
+ """
189
+ self.config = config
190
+ self.description = config.get("description", "")
191
+ self.annotation_type = config.get("annotation_type", "")
192
+ self.ai_config = config.get("ai_config", {})
193
+
194
+ # Model configuration
195
+ self.model = self.ai_config.get("model", self._get_default_model())
196
+ self.max_tokens = self.ai_config.get("max_tokens", 100)
197
+ self.temperature = self.ai_config.get("temperature", 0.1)
198
+
199
+ # prompt
200
+ self.prompts = get_ai_prompt()
201
+
202
+ # Initialize the client
203
+ self._initialize_client()
204
+
205
+ @abstractmethod
206
+ def _initialize_client(self) -> None:
207
+ """Initialize the client for the specific AI provider."""
208
+ pass
209
+
210
+ @abstractmethod
211
+ def _get_default_model(self) -> str:
212
+ """Get the default model name for this provider."""
213
+ pass
214
+
215
+ @abstractmethod
216
+ def query(self, prompt: str, output_format: Type[BaseModel]):
217
+ """
218
+ Send a query to the AI model and return the response.
219
+
220
+ Args:
221
+ prompt: The prompt to send to the model
222
+
223
+ Returns:
224
+ The model's response as a string
225
+
226
+ Raises:
227
+ AIEndpointRequestError: If the request fails
228
+ """
229
+ pass
230
+
231
+ def chat_query_with_image(
232
+ self,
233
+ messages: List[Dict[str, Any]],
234
+ images: Optional[List["ImageData"]] = None,
235
+ ) -> str:
236
+ """
237
+ Send a multi-turn chat with interleaved images to the AI model.
238
+
239
+ Messages may contain content blocks (text + image) instead of plain strings.
240
+ Used by the live agent runner for vision-based agent loops.
241
+
242
+ Default implementation raises NotImplementedError — only vision-capable
243
+ endpoints should override this.
244
+
245
+ Args:
246
+ messages: List of message dicts. 'content' may be a string or a list
247
+ of content blocks (e.g., {"type": "text", "text": "..."} or
248
+ {"type": "image", "source": {...}}).
249
+ images: Optional list of ImageData to include (alternative to inline images).
250
+
251
+ Returns:
252
+ The model's response as a plain text string.
253
+
254
+ Raises:
255
+ NotImplementedError: If the endpoint doesn't support vision.
256
+ """
257
+ raise NotImplementedError(
258
+ f"{self.__class__.__name__} does not support chat_query_with_image. "
259
+ f"Use a vision-capable endpoint (e.g., anthropic_vision)."
260
+ )
261
+
262
+ def chat_query(self, messages: List[Dict[str, str]]) -> str:
263
+ """
264
+ Send a multi-turn chat conversation to the AI model.
265
+
266
+ Default implementation flattens messages into a single prompt and calls query().
267
+ Subclasses should override with native multi-turn support.
268
+
269
+ Args:
270
+ messages: List of message dicts with 'role' and 'content' keys.
271
+ Roles: 'system', 'user', 'assistant'
272
+
273
+ Returns:
274
+ The model's response as a plain text string.
275
+ """
276
+ # Flatten messages into a single prompt
277
+ parts = []
278
+ for msg in messages:
279
+ role = msg.get("role", "user")
280
+ content = msg.get("content", "")
281
+ if role == "system":
282
+ parts.append(f"System: {content}")
283
+ elif role == "assistant":
284
+ parts.append(f"Assistant: {content}")
285
+ else:
286
+ parts.append(f"User: {content}")
287
+ prompt = "\n\n".join(parts) + "\n\nAssistant:"
288
+
289
+ try:
290
+ result = self.query(prompt)
291
+ # query() may return parsed JSON or a string; ensure we return a string
292
+ if isinstance(result, dict):
293
+ return result.get("response", result.get("content", str(result)))
294
+ return str(result)
295
+ except Exception as e:
296
+ raise AIEndpointRequestError(f"Chat query failed: {e}")
297
+
298
+ def parseStringToJson(self, response_content: str) -> str:
299
+ """
300
+ Parse structured output from any LLM response, with robust fallbacks.
301
+
302
+ Handles common issues across all endpoint types (ollama, vllm, openai, etc.):
303
+ 1. Clean JSON responses -> direct parse
304
+ 2. JSON wrapped in markdown code blocks (```json ... ```)
305
+ 3. JSON embedded in surrounding prose text
306
+ 4. Truncated JSON (max_tokens exceeded) -> extract complete key-value pairs
307
+ 5. Plain text with no JSON structure -> return as {"response": text}
308
+
309
+ Args:
310
+ response_content: Raw response string from the LLM
311
+
312
+ Returns:
313
+ Parsed dict or the raw string if JSON extraction succeeds
314
+
315
+ Raises:
316
+ ValueError: Only if response is completely empty
317
+ """
318
+ import re
319
+
320
+ # Handle empty or None content
321
+ if not response_content:
322
+ raise ValueError("Empty response content received from AI endpoint")
323
+
324
+ # If it's already a dict, return it
325
+ if isinstance(response_content, dict):
326
+ return response_content
327
+
328
+ # Convert to string if needed
329
+ content_str = str(response_content).strip()
330
+ if not content_str:
331
+ raise ValueError("Empty response content received from AI endpoint")
332
+
333
+ # Strategy 0: Strip thinking/reasoning blocks that wrap the actual output
334
+ # Many models (qwen3, deepseek, etc.) produce <think>...</think> blocks
335
+ cleaned = content_str
336
+ for tag in ['think', 'thinking', 'thought', 'inner_monologue']:
337
+ cleaned = re.sub(
338
+ rf'<{tag}>[\s\S]*?</{tag}>\s*',
339
+ '', cleaned, flags=re.IGNORECASE
340
+ ).strip()
341
+ if cleaned and cleaned != content_str:
342
+ content_str = cleaned
343
+
344
+ # Strategy 1: Try direct JSON parse
345
+ try:
346
+ return json.loads(content_str)
347
+ except json.JSONDecodeError:
348
+ pass
349
+
350
+ # Strategy 2: Extract from markdown code blocks
351
+ for pattern in [
352
+ r'```json\s*([\s\S]*?)\s*```',
353
+ r'```\s*([\s\S]*?)\s*```',
354
+ ]:
355
+ match = re.search(pattern, content_str)
356
+ if match:
357
+ try:
358
+ return json.loads(match.group(1).strip())
359
+ except json.JSONDecodeError:
360
+ pass
361
+
362
+ # Strategy 3: Extract from tool call format
363
+ # Some models return: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
364
+ # or function_call blocks
365
+ for pattern in [
366
+ r'<tool_call>\s*([\s\S]*?)\s*</tool_call>',
367
+ r'<function_call>\s*([\s\S]*?)\s*</function_call>',
368
+ r'<output>\s*([\s\S]*?)\s*</output>',
369
+ r'<result>\s*([\s\S]*?)\s*</result>',
370
+ r'<answer>\s*([\s\S]*?)\s*</answer>',
371
+ ]:
372
+ match = re.search(pattern, content_str, re.IGNORECASE)
373
+ if match:
374
+ try:
375
+ parsed = json.loads(match.group(1).strip())
376
+ # If it's a tool call wrapper, extract the arguments
377
+ if isinstance(parsed, dict) and 'arguments' in parsed:
378
+ return parsed['arguments']
379
+ return parsed
380
+ except json.JSONDecodeError:
381
+ pass
382
+
383
+ # Strategy 4: Find a JSON object anywhere in the text
384
+ # Greedy match for the outermost { ... }
385
+ match = re.search(r'\{[\s\S]*\}', content_str)
386
+ if match:
387
+ try:
388
+ return json.loads(match.group(0))
389
+ except json.JSONDecodeError:
390
+ pass
391
+
392
+ # Strategy 4: Salvage truncated JSON
393
+ # Extract complete "key": "value" and "key": number pairs
394
+ salvaged = self._salvage_key_value_pairs(content_str)
395
+ if salvaged:
396
+ logger.warning(
397
+ f"Salvaged {len(salvaged)} fields from truncated/malformed response"
398
+ )
399
+ return salvaged
400
+
401
+ # Strategy 6: Parse XML-style output
402
+ # Some models (especially larger ones) produce XML like:
403
+ # <label>joy</label><confidence>90</confidence>
404
+ # or <response><label>joy</label></response>
405
+ xml_result = self._parse_xml_to_dict(content_str)
406
+ if xml_result:
407
+ logger.info(
408
+ f"Parsed {len(xml_result)} fields from XML-style response"
409
+ )
410
+ return xml_result
411
+
412
+ # Strategy 7: Return raw text wrapped in a dict
413
+ logger.warning(
414
+ f"Could not parse JSON or XML from response ({len(content_str)} chars), "
415
+ f"returning as raw text"
416
+ )
417
+ return {"response": content_str}
418
+
419
+ @staticmethod
420
+ def _salvage_key_value_pairs(text: str) -> Optional[dict]:
421
+ """Extract key-value pairs from truncated or malformed JSON.
422
+
423
+ Handles cases where max_tokens cuts off a response mid-field, e.g.:
424
+ {"label": "joy", "confidence": 90, "reasoning": "The text expres...
425
+
426
+ Returns:
427
+ Dict of extracted key-value pairs, or None if nothing found.
428
+ """
429
+ import re
430
+ result = {}
431
+
432
+ # Extract "key": "value" pairs (string values)
433
+ for match in re.finditer(r'"(\w+)"\s*:\s*"([^"]*)"', text):
434
+ result[match.group(1)] = match.group(2)
435
+
436
+ # Extract "key": number pairs
437
+ for match in re.finditer(r'"(\w+)"\s*:\s*(-?\d+(?:\.\d+)?)\b', text):
438
+ key = match.group(1)
439
+ if key not in result:
440
+ try:
441
+ val = float(match.group(2))
442
+ result[key] = int(val) if val == int(val) else val
443
+ except ValueError:
444
+ pass
445
+
446
+ # Extract "key": true/false/null
447
+ for match in re.finditer(r'"(\w+)"\s*:\s*(true|false|null)\b', text):
448
+ key = match.group(1)
449
+ if key not in result:
450
+ val_str = match.group(2)
451
+ result[key] = (
452
+ True if val_str == 'true'
453
+ else False if val_str == 'false'
454
+ else None
455
+ )
456
+
457
+ return result if result else None
458
+
459
+ @staticmethod
460
+ def _parse_xml_to_dict(text: str) -> Optional[dict]:
461
+ """Extract key-value pairs from XML-style LLM output.
462
+
463
+ Handles patterns like:
464
+ - <label>joy</label><confidence>90</confidence>
465
+ - <response><label>joy</label></response>
466
+ - Mixed XML with text: "The emotion is <label>joy</label>"
467
+
468
+ Returns:
469
+ Dict of tag->content pairs, or None if no XML tags found.
470
+ """
471
+ import re
472
+ result = {}
473
+
474
+ # Find all <tag>content</tag> pairs (non-nested simple tags)
475
+ for match in re.finditer(
476
+ r'<(\w+)>([^<]*)</\1>', text, re.IGNORECASE
477
+ ):
478
+ tag = match.group(1).lower()
479
+ value = match.group(2).strip()
480
+
481
+ # Skip wrapper tags that contain other tags
482
+ if tag in ('response', 'output', 'result', 'answer', 'root'):
483
+ continue
484
+
485
+ # Try to parse numeric values
486
+ try:
487
+ if '.' in value:
488
+ result[tag] = float(value)
489
+ else:
490
+ result[tag] = int(value)
491
+ except ValueError:
492
+ # Boolean
493
+ if value.lower() in ('true', 'false'):
494
+ result[tag] = value.lower() == 'true'
495
+ else:
496
+ result[tag] = value
497
+
498
+ return result if result else None
499
+
500
+ def get_ai(self, data: AnnotationInput, output_format) -> str:
501
+ """
502
+ Get a hint for annotating the given text.
503
+
504
+ Args:
505
+ text: The text to get a hint for
506
+
507
+ Returns:
508
+ A helpful hint for annotation
509
+ """
510
+
511
+ try:
512
+ # Check if annotation type exists (comparing string against enum values)
513
+ valid_types = [e.value for e in Annotation_Type]
514
+ if data.annotation_type not in valid_types:
515
+ logger.warning(f"Annotation type '{data.annotation_type}' not found")
516
+ return "Unable to generate suggestion - annotation type not configured"
517
+
518
+ # Check if ai_assistant exists
519
+ ai_prompt = get_ai_prompt()
520
+ if data.ai_assistant not in ai_prompt[data.annotation_type]:
521
+ logger.warning(f"'ai_assistant' not found for {data.annotation_type}")
522
+ return "Unable to generate suggestion - prompt not configured"
523
+
524
+ template_str = self.prompts.get(data.annotation_type).get(data.ai_assistant).get("prompt")
525
+ template = Template(template_str)
526
+ prompt = template.substitute(
527
+ text=data.text,
528
+ description=data.description,
529
+ min_label=data.min_label,
530
+ max_label=data.max_label,
531
+ size=data.size,
532
+ labels=data.labels,
533
+ min_value=data.min_value,
534
+ max_value=data.max_value,
535
+ step=data.step
536
+ )
537
+ return self.query(prompt, output_format)
538
+ except Exception as e:
539
+ logger.error(f"[get_ai] AnnotationInput: {data}")
540
+ logger.error(f"[get_ai] Error for {data.annotation_type}/{data.ai_assistant}: {type(e).__name__}: {e}")
541
+ import traceback
542
+ logger.error(f"[get_ai] Traceback:\n{traceback.format_exc()}")
543
+ return "Unable to generate hint at this time."
544
+
545
+ def health_check(self) -> bool:
546
+ """
547
+ Check if the AI endpoint is healthy and accessible.
548
+
549
+ Returns:
550
+ True if the endpoint is healthy, False otherwise
551
+ """
552
+ try:
553
+ # Simple test query
554
+ test_response = self.query("Hello")
555
+ return bool(test_response and test_response.strip())
556
+ except Exception as e:
557
+ logger.error(f"Health check failed: {e}")
558
+ return False
559
+
560
+
561
+ class AIEndpointFactory:
562
+ """
563
+ Factory class for creating AI endpoint instances.
564
+ """
565
+
566
+ _endpoints = { }
567
+
568
+ @classmethod
569
+ def register_endpoint(cls, endpoint_type: str, endpoint_class: type):
570
+ """Register a new endpoint type."""
571
+ cls._endpoints[endpoint_type] = endpoint_class
572
+
573
+ @classmethod
574
+ def create_endpoint(cls, config: Dict[str, Any]) -> Optional[BaseAIEndpoint]:
575
+ """
576
+ Create an AI endpoint instance based on configuration.
577
+
578
+ Args:
579
+ config: Configuration dictionary containing ai_support settings
580
+
581
+ Returns:
582
+ An AI endpoint instance or None if AI support is disabled
583
+
584
+ Raises:
585
+ AIEndpointConfigError: If the configuration is invalid
586
+ """
587
+ if not config.get("ai_support", {}).get("enabled", False):
588
+ return None
589
+
590
+ ai_support = config["ai_support"]
591
+ endpoint_type = ai_support.get("endpoint_type")
592
+
593
+ if not endpoint_type:
594
+ raise AIEndpointConfigError("endpoint_type is required when ai_support is enabled")
595
+
596
+ if endpoint_type not in cls._endpoints:
597
+ raise AIEndpointConfigError(f"Unknown endpoint type: {endpoint_type}")
598
+
599
+ # Prepare endpoint configuration
600
+ endpoint_config = {
601
+ "ai_config": ai_support.get("ai_config", {})
602
+ }
603
+
604
+ try:
605
+ endpoint_class = cls._endpoints[endpoint_type]
606
+ return endpoint_class(endpoint_config)
607
+ except Exception as e:
608
+ raise AIEndpointConfigError(f"Failed to create {endpoint_type} endpoint: {e}")
609
+
610
+
611
+ # Legacy function for backward compatibility
612
+ def get_ai_endpoint(config: dict):
613
+ """
614
+ Get an AI endpoint instance (legacy function).
615
+
616
+ This function is maintained for backward compatibility.
617
+ New code should use AIEndpointFactory.create_endpoint().
618
+ """
619
+ return AIEndpointFactory.create_endpoint(config)
620
+
621
+
622
+ # Register built-in endpoints
623
+ try:
624
+ from .ollama_endpoint import OllamaEndpoint
625
+ AIEndpointFactory.register_endpoint("ollama", OllamaEndpoint)
626
+ except ImportError:
627
+ logger.debug("Ollama endpoint not available")
628
+
629
+ try:
630
+ from .openai_endpoint import OpenAIEndpoint
631
+ AIEndpointFactory.register_endpoint("openai", OpenAIEndpoint)
632
+ except ImportError:
633
+ logger.debug("OpenAI endpoint not available")
634
+
635
+ try:
636
+ from .huggingface_endpoint import HuggingfaceEndpoint
637
+ AIEndpointFactory.register_endpoint("huggingface", HuggingfaceEndpoint)
638
+ except ImportError:
639
+ logger.debug("Hugging Face endpoint not available")
640
+
641
+ try:
642
+ from .gemini_endpoint import GeminiEndpoint
643
+ AIEndpointFactory.register_endpoint("gemini", GeminiEndpoint)
644
+ except ImportError:
645
+ logger.debug("Gemini endpoint not available")
646
+
647
+ try:
648
+ from .anthropic_endpoint import AnthropicEndpoint
649
+ AIEndpointFactory.register_endpoint("anthropic", AnthropicEndpoint)
650
+ except ImportError:
651
+ logger.debug("Anthropic endpoint not available")
652
+
653
+ try:
654
+ from .vllm_endpoint import VLLMEndpoint
655
+ AIEndpointFactory.register_endpoint("vllm", VLLMEndpoint)
656
+ except ImportError:
657
+ logger.debug("VLLM endpoint not available")
658
+
659
+ # Register visual AI endpoints
660
+ try:
661
+ from .yolo_endpoint import YOLOEndpoint
662
+ AIEndpointFactory.register_endpoint("yolo", YOLOEndpoint)
663
+ except ImportError:
664
+ logger.debug("YOLO endpoint not available (ultralytics not installed)")
665
+
666
+ try:
667
+ from .ollama_vision_endpoint import OllamaVisionEndpoint
668
+ AIEndpointFactory.register_endpoint("ollama_vision", OllamaVisionEndpoint)
669
+ except ImportError:
670
+ logger.debug("Ollama Vision endpoint not available")
671
+
672
+ try:
673
+ from .openai_vision_endpoint import OpenAIVisionEndpoint
674
+ AIEndpointFactory.register_endpoint("openai_vision", OpenAIVisionEndpoint)
675
+ except ImportError:
676
+ logger.debug("OpenAI Vision endpoint not available")
677
+
678
+ try:
679
+ from .anthropic_vision_endpoint import AnthropicVisionEndpoint
680
+ AIEndpointFactory.register_endpoint("anthropic_vision", AnthropicVisionEndpoint)
681
+ except ImportError:
682
+ logger.debug("Anthropic Vision endpoint not available")
683
+
684
+ try:
685
+ from .openrouter_endpoint import OpenRouterEndpoint
686
+ AIEndpointFactory.register_endpoint("openrouter", OpenRouterEndpoint)
687
+ except ImportError:
688
+ logger.debug("OpenRouter endpoint not available")
potato/ai/ai_help_wrapper.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import render_template_string
2
+ from typing import Optional, Dict, Any, List
3
+ from potato.ai.ai_cache import get_ai_cache_manager, _is_image_url, _get_instance_text
4
+ from potato.ai.ai_prompt import get_ai_prompt
5
+ from potato.server_utils.config_module import config
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Global instance
11
+ DYNAMICAIHELP = None
12
+
13
+ def init_dynamic_ai_help():
14
+ import logging
15
+ logger = logging.getLogger(__name__)
16
+ logger.info(f"[init_dynamic_ai_help] Called. ai_support.enabled={config.get('ai_support', {}).get('enabled', False)}")
17
+
18
+ if not config["ai_support"]["enabled"]:
19
+ logger.info("[init_dynamic_ai_help] AI support disabled, returning")
20
+ return
21
+ global DYNAMICAIHELP
22
+ if DYNAMICAIHELP is None:
23
+ DYNAMICAIHELP = DynamicAIHelp()
24
+ logger.info(f"[init_dynamic_ai_help] Created DYNAMICAIHELP instance: {id(DYNAMICAIHELP)}")
25
+ else:
26
+ logger.info(f"[init_dynamic_ai_help] DYNAMICAIHELP already exists: {id(DYNAMICAIHELP)}")
27
+
28
+ return DYNAMICAIHELP
29
+
30
+ def get_dynamic_ai_help():
31
+ global DYNAMICAIHELP
32
+ return DYNAMICAIHELP
33
+
34
+ class DynamicAIHelp:
35
+ def __init__(self):
36
+ self.template = """
37
+ {% if ai_assistant %}
38
+ {{ ai_assistant | safe }}
39
+ {% elif error_message %}
40
+ <span class="error">{{ error_message }}</span>
41
+ {% endif %}
42
+ """
43
+
44
+ def get_empty_wrapper(self):
45
+ return f'<div class="ai-help none"><div class="tooltip"></div></div>'
46
+
47
+ def generate_ai_assistant(self, ai_prompts, annotation_type, ai_assistant):
48
+ str_html = f'<div class="{ai_assistant} ai-assistant-containter">'
49
+ img_url = ai_prompts[annotation_type].get(ai_assistant).get("img")
50
+ if img_url:
51
+ # Use empty alt since the button already has a text label
52
+ str_html += f'<span class="ai-assistant-img"><img src="{img_url}" alt=""></span>'
53
+ name = ai_prompts[annotation_type].get(ai_assistant).get("name", ai_assistant.capitalize())
54
+ str_html += f'<span>{name}</span>'
55
+ str_html += "</div>"
56
+ return str_html
57
+
58
+ def _filter_assistants_by_capability(
59
+ self, ai_cache_manager, assistant_keys: List[str], is_image_content: bool
60
+ ) -> List[str]:
61
+ """
62
+ Filter assistant types based on model capabilities.
63
+
64
+ Args:
65
+ ai_cache_manager: The AI cache manager instance
66
+ assistant_keys: List of assistant type keys to filter
67
+ is_image_content: Whether the current content is an image
68
+
69
+ Returns:
70
+ Filtered list of assistant keys that the model supports
71
+ """
72
+ # Get capabilities from the cache manager
73
+ capabilities = ai_cache_manager.get_endpoint_capabilities(for_image=is_image_content)
74
+
75
+ filtered_keys = []
76
+ for key in assistant_keys:
77
+ if capabilities.supports_assistant(key, is_image_content):
78
+ filtered_keys.append(key)
79
+ else:
80
+ logger.debug(
81
+ f"[get_ai_help_data] Skipping '{key}' button - "
82
+ f"not supported for {'image' if is_image_content else 'text'} content"
83
+ )
84
+
85
+ return filtered_keys
86
+
87
+ def get_ai_help_data(self, instance: int, annotation_id: int, annotation_type: str) -> Dict[str, Any]:
88
+ """Get current AI help configuration with the new prompt structure"""
89
+ try:
90
+ context = {
91
+ 'ai_assistant': None,
92
+ 'error_message': None,
93
+ }
94
+ ai_prompts = get_ai_prompt()
95
+ logger.debug(f"[get_ai_help_data] ai_prompts keys: {list(ai_prompts.keys()) if ai_prompts else 'None'}")
96
+
97
+ if not ai_prompts:
98
+ context["error_message"] = f'No AI prompt configured'
99
+ logger.debug("[get_ai_help_data] No AI prompts configured")
100
+ return context
101
+ elif annotation_type not in ai_prompts:
102
+ context["error_message"] = f'annotation type {annotation_type} does not exist in ai_prompts'
103
+ logger.debug(f"[get_ai_help_data] annotation type {annotation_type} not in prompts")
104
+ return context
105
+
106
+ ai_cache_manager = get_ai_cache_manager()
107
+ logger.debug(f"[get_ai_help_data] ai_cache_manager: {ai_cache_manager is not None}")
108
+
109
+ if ai_cache_manager is None:
110
+ context["error_message"] = "AI cache manager not initialized"
111
+ logger.debug("[get_ai_help_data] AI cache manager is None")
112
+ return context
113
+
114
+ ai_assistant_html_parts = []
115
+
116
+ # Determine if content is an image for capability-based filtering
117
+ is_image_content = False
118
+ try:
119
+ text = _get_instance_text(instance)
120
+ is_image_content = _is_image_url(text)
121
+ if is_image_content:
122
+ logger.debug(f"[get_ai_help_data] Content is an image URL")
123
+ except Exception as e:
124
+ logger.debug(f"[get_ai_help_data] Could not determine if content is image: {e}")
125
+
126
+ # Check if user specified specific assistant types
127
+ special_include_types = ai_cache_manager.get_special_include(instance, annotation_id)
128
+ logger.debug(f"[get_ai_help_data] special_include_types: {special_include_types}")
129
+
130
+ if special_include_types:
131
+ # Generate HTML for specific included keys
132
+ logger.debug(f"[get_ai_help_data] Using special include types: {special_include_types}")
133
+ # Filter by capability
134
+ valid_keys = [k for k in special_include_types if k in ai_prompts[annotation_type]]
135
+ filtered_keys = self._filter_assistants_by_capability(
136
+ ai_cache_manager, valid_keys, is_image_content
137
+ )
138
+ for key in filtered_keys:
139
+ ai_assistant_html_parts.append(self.generate_ai_assistant(ai_prompts, annotation_type, key))
140
+
141
+ elif ai_cache_manager.get_include_all():
142
+ # Generate HTML for all keys in the annotation type
143
+ all_keys = list(ai_prompts[annotation_type].keys())
144
+ logger.debug(f"[get_ai_help_data] include_all=True, available keys: {all_keys}")
145
+ # Filter by capability
146
+ filtered_keys = self._filter_assistants_by_capability(
147
+ ai_cache_manager, all_keys, is_image_content
148
+ )
149
+ logger.debug(f"[get_ai_help_data] After capability filter: {filtered_keys}")
150
+ for key in filtered_keys:
151
+ ai_assistant_html_parts.append(self.generate_ai_assistant(ai_prompts, annotation_type, key))
152
+ else:
153
+ logger.debug("[get_ai_help_data] No special includes and include_all=False")
154
+
155
+ # Combine all HTML parts
156
+ ai_assistant_html = '<span>|</span>'.join(ai_assistant_html_parts) if ai_assistant_html_parts else None
157
+ logger.debug(f"[get_ai_help_data] ai_assistant_html_parts count: {len(ai_assistant_html_parts)}")
158
+
159
+ if ai_assistant_html:
160
+ context['ai_assistant'] = ai_assistant_html
161
+
162
+ logger.debug(f"[get_ai_help_data] Final context: ai_assistant={'set' if context['ai_assistant'] else 'None'}, error={'set' if context['error_message'] else 'None'}")
163
+ return context
164
+ except Exception as e:
165
+ logger.error(f"[get_ai_help_data] Exception: {e}", exc_info=True)
166
+ return {
167
+ 'ai_assistant': None,
168
+ 'error_message': f'Error loading AI help: {str(e)}',
169
+ }
170
+
171
+ def render(self, instance: int, annotation_id: int, annotation_type) -> str:
172
+ """Render AI help HTML with current data"""
173
+ context = self.get_ai_help_data(instance, annotation_id, annotation_type)
174
+ context.update({
175
+ 'instance': instance,
176
+ 'annotation_id': annotation_id
177
+ })
178
+ return render_template_string(self.template, **context)
179
+
180
+ def generate_ai_help_html(instance: int, annotation_id: int, annotation_type: str) -> Optional[str]:
181
+ """
182
+ Generates dynamic AI help HTML using template rendering.
183
+ Now works with the new prompt structure: {annotation_type: {prompt: ..., outputformat: ...}}
184
+ """
185
+ import logging
186
+ logger = logging.getLogger(__name__)
187
+
188
+ if DYNAMICAIHELP is None:
189
+ logger.debug("[generate_ai_help_html] DYNAMICAIHELP is None - AI support not enabled")
190
+ return "" # AI support not enabled
191
+
192
+ result = DYNAMICAIHELP.render(instance, annotation_id, annotation_type)
193
+ logger.debug(f"[generate_ai_help_html] Rendered result: '{result[:100] if result else 'empty'}...'")
194
+ return result
195
+
196
+ def get_ai_wrapper():
197
+ import logging
198
+ logger = logging.getLogger(__name__)
199
+ helper = get_dynamic_ai_help()
200
+ logger.debug(f"[get_ai_wrapper] DYNAMICAIHELP is {'set' if helper else 'None'}")
201
+ result = helper.get_empty_wrapper() if helper else ""
202
+ logger.debug(f"[get_ai_wrapper] Returning: '{result[:50] if result else 'empty'}...'")
203
+ return result
potato/ai/ai_prompt.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional, Type
6
+ from pydantic import BaseModel
7
+ from potato.server_utils.config_module import config
8
+ ANNOTATIONS = None
9
+
10
+ class ModelManager:
11
+ def __init__(self):
12
+ self.models_module = None
13
+
14
+ def load_models_module(self):
15
+ """Load the models module if not already loaded"""
16
+ if self.models_module is None:
17
+ # absolute pathing
18
+ module_path = config.get("ai_support").get("model_module")
19
+ if module_path:
20
+ file_path = Path(module_path)
21
+ if not file_path.exists():
22
+ raise FileNotFoundError(f"Model module file not found: {file_path}")
23
+ module_name = file_path.stem
24
+
25
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
26
+ self.models_module = importlib.util.module_from_spec(spec)
27
+ spec.loader.exec_module(self.models_module)
28
+
29
+ else:
30
+ default_path = Path(__file__).resolve().parent / "prompt" / "models_module.py"
31
+
32
+ if not default_path.exists():
33
+ raise FileNotFoundError(f"Default model module file not found: {default_path}")
34
+
35
+ module_name = default_path.stem
36
+
37
+ spec = importlib.util.spec_from_file_location(module_name, default_path)
38
+ self.models_module = importlib.util.module_from_spec(spec)
39
+ spec.loader.exec_module(self.models_module)
40
+
41
+ return self.models_module
42
+
43
+ def get_model_class_by_name(self, name: str) -> Optional[Type[BaseModel]]:
44
+ """
45
+ Return a Pydantic model class based on the provided name.
46
+ """
47
+ models_module = self.load_models_module()
48
+ return models_module.CLASS_REGISTRY.get(name)
49
+
50
+
51
+ def init_ai_prompt(config):
52
+ global ANNOTATIONS
53
+ if not config["ai_support"]["enabled"]:
54
+ return
55
+ try:
56
+ annotation_paths = config.get("ai_support", {}).get("annotation_path")
57
+
58
+ ANNOTATIONS = {}
59
+
60
+ if annotation_paths:
61
+ # Load files from specified paths
62
+ for key, path in annotation_paths.items():
63
+ if path and os.path.exists(path):
64
+ with open(path, "r", encoding="utf-8") as f:
65
+ ANNOTATIONS[key] = json.load(f)
66
+ else:
67
+ raise Exception(f"File path for annotations does not exist: {path}")
68
+ else:
69
+ # Load all JSON files from default directory (parent/prompt)
70
+ default_path = Path(__file__).resolve().parent / "prompt"
71
+
72
+ if default_path.exists() and default_path.is_dir():
73
+ # Find all JSON files in the directory
74
+ json_files = list(default_path.glob("*.json"))
75
+
76
+ if not json_files:
77
+ raise Exception(f"No JSON files found in default directory: {default_path}")
78
+
79
+ # Load each JSON file, using filename (without extension) as key
80
+ for file_path in json_files:
81
+ key = file_path.stem
82
+ with open(file_path, "r", encoding="utf-8") as f:
83
+ ANNOTATIONS[key] = json.load(f)
84
+ else:
85
+ raise Exception(f"Default annotation directory does not exist: {default_path}")
86
+
87
+ except json.JSONDecodeError as e:
88
+ raise ValueError(f"Invalid JSON in annotation file: {e}")
89
+ except Exception as e:
90
+ raise RuntimeError(f"Unexpected error loading AI prompt: {e}")
91
+
92
+ def get_ai_prompt():
93
+ global ANNOTATIONS
94
+ return ANNOTATIONS
potato/ai/anthropic_endpoint.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anthropic AI endpoint implementation.
3
+
4
+ This module provides integration with Anthropic's Claude API for LLM inference.
5
+ """
6
+
7
+ from typing import Dict, List
8
+ import anthropic
9
+ from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError, ModelCapabilities
10
+
11
+ DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
12
+ DEFAULT_HINT_PROMPT = '''
13
+ You are assisting a user with an annotation task.
14
+ The annotation instruction is : {description}
15
+ The annotation task type is: {annotation_type}
16
+ The sentence (or item) to annotate is : {text}
17
+ Your goal is to generate a short, helpful hint that guides the annotator in how to think about the input — **without providing the answer**.
18
+
19
+ The hint should:
20
+ - Highlight key aspects of the input relevant to the task
21
+ - Encourage thoughtful reasoning or observation
22
+ - Point to subtle features (tone, wording, structure, implication) that matter for the annotation
23
+ - Be specific and informative, not vague or generic
24
+ '''
25
+
26
+ DEFAULT_KEYWORD_PROMPT = '''
27
+ You are assisting a user with an annotation task.
28
+ The annotation instruction is : {description}
29
+ The annotation task type is: {annotation_type}
30
+ The sentence (or item) to annotate is : {text}
31
+ Your goal is : Print out just a sequence of keywords, not sentences, in the text that most relate to the task. Do not explain your answer. Do not print out the entire text. If no part of the text relates to the task, print the empty string.
32
+ '''
33
+
34
+ class AnthropicEndpoint(BaseAIEndpoint):
35
+ """Anthropic Claude endpoint for cloud-based LLM inference."""
36
+
37
+ # Capabilities declaration for text-based Anthropic Claude models
38
+ CAPABILITIES = ModelCapabilities(
39
+ text_generation=True,
40
+ vision_input=False,
41
+ bounding_box_output=False,
42
+ text_classification=True,
43
+ image_classification=False,
44
+ rationale_generation=True,
45
+ keyword_extraction=True,
46
+ )
47
+
48
+ def _initialize_client(self) -> None:
49
+ """Initialize the Anthropic client."""
50
+ api_key = self.ai_config.get("api_key", "")
51
+ if not api_key:
52
+ raise AIEndpointRequestError("Anthropic API key is required")
53
+
54
+ # Default timeout of 30 seconds, configurable via ai_config
55
+ timeout = self.ai_config.get("timeout", 30)
56
+ self.client = anthropic.Anthropic(api_key=api_key, timeout=timeout)
57
+
58
+ def _get_default_model(self) -> str:
59
+ """Get the default Anthropic model."""
60
+ return DEFAULT_MODEL
61
+
62
+ def _get_default_hint_prompt(self) -> str:
63
+ """Get the default hint prompt for Anthropic."""
64
+ return DEFAULT_HINT_PROMPT
65
+
66
+ def _get_default_keyword_prompt(self) -> str:
67
+ """Get the default keyword prompt for Anthropic."""
68
+ return DEFAULT_KEYWORD_PROMPT
69
+
70
+ def query(self, prompt: str) -> str:
71
+ """
72
+ Send a query to Anthropic Claude and return the response.
73
+
74
+ Args:
75
+ prompt: The prompt to send to the model
76
+
77
+ Returns:
78
+ The model's response as a string
79
+
80
+ Raises:
81
+ AIEndpointRequestError: If the request fails
82
+ """
83
+ try:
84
+ response = self.client.messages.create(
85
+ model=self.model,
86
+ max_tokens=self.max_tokens,
87
+ temperature=self.temperature,
88
+ messages=[{"role": "user", "content": prompt}]
89
+ )
90
+ return response.content[0].text
91
+ except Exception as e:
92
+ raise AIEndpointRequestError(f"Anthropic request failed: {e}")
93
+
94
+ def chat_query(self, messages: List[Dict[str, str]]) -> str:
95
+ """Send a multi-turn chat to Anthropic using native messages API."""
96
+ try:
97
+ # Extract system message if present
98
+ system_text = ""
99
+ chat_messages = []
100
+ for msg in messages:
101
+ if msg["role"] == "system":
102
+ system_text = msg["content"]
103
+ else:
104
+ chat_messages.append({"role": msg["role"], "content": msg["content"]})
105
+
106
+ kwargs = {
107
+ "model": self.model,
108
+ "max_tokens": self.max_tokens,
109
+ "temperature": self.temperature,
110
+ "messages": chat_messages,
111
+ }
112
+ if system_text:
113
+ kwargs["system"] = system_text
114
+
115
+ response = self.client.messages.create(**kwargs)
116
+ return response.content[0].text
117
+ except Exception as e:
118
+ raise AIEndpointRequestError(f"Anthropic chat request failed: {e}")
potato/ai/anthropic_vision_endpoint.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anthropic Vision AI Endpoint
3
+
4
+ This module provides integration with Anthropic's Claude models for visual
5
+ analysis using the image content block format.
6
+ """
7
+
8
+ import base64
9
+ import logging
10
+ from typing import Any, Dict, List, Type, Union
11
+
12
+ from pydantic import BaseModel
13
+
14
+ from .ai_endpoint import AIEndpointRequestError, ImageData, ModelCapabilities
15
+ from .visual_ai_endpoint import BaseVisualAIEndpoint
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ DEFAULT_MODEL = "claude-sonnet-4-20250514"
20
+
21
+ # Supported image types for Claude
22
+ SUPPORTED_MEDIA_TYPES = [
23
+ "image/jpeg",
24
+ "image/png",
25
+ "image/gif",
26
+ "image/webp"
27
+ ]
28
+
29
+
30
+ class AnthropicVisionEndpoint(BaseVisualAIEndpoint):
31
+ """
32
+ Anthropic Vision endpoint for Claude models with vision capabilities.
33
+
34
+ Uses the image content block format for multimodal inputs.
35
+
36
+ Configuration options:
37
+ - model: Model to use (default: claude-sonnet-4-20250514)
38
+ - api_key: Anthropic API key (can also use ANTHROPIC_API_KEY env var)
39
+ - max_tokens: Maximum response tokens (default: 1024)
40
+ - temperature: Sampling temperature (default: 0.1)
41
+ """
42
+
43
+ # Capabilities declaration for Anthropic Claude vision models
44
+ # Claude models can understand images and generate detailed reasoning but bboxes are approximate
45
+ CAPABILITIES = ModelCapabilities(
46
+ text_generation=True,
47
+ vision_input=True,
48
+ bounding_box_output=False, # Claude bboxes are approximate, not precise
49
+ text_classification=True,
50
+ image_classification=True,
51
+ rationale_generation=True,
52
+ keyword_extraction=False, # Keywords don't apply to images
53
+ )
54
+
55
+ def _initialize_client(self) -> None:
56
+ """Initialize the Anthropic client."""
57
+ try:
58
+ import anthropic
59
+ except ImportError:
60
+ raise AIEndpointRequestError(
61
+ "anthropic package is required. Install it with: pip install anthropic"
62
+ )
63
+
64
+ import os
65
+
66
+ api_key = self.ai_config.get("api_key") or os.environ.get("ANTHROPIC_API_KEY")
67
+ if not api_key:
68
+ raise AIEndpointRequestError(
69
+ "Anthropic API key is required. Set it in config or ANTHROPIC_API_KEY env var."
70
+ )
71
+
72
+ timeout = self.ai_config.get("timeout", 60)
73
+
74
+ self.client = anthropic.Anthropic(api_key=api_key, timeout=timeout)
75
+ logger.info(f"Anthropic Vision client initialized with model: {self.model}")
76
+
77
+ def _get_default_model(self) -> str:
78
+ """Get the default Anthropic model."""
79
+ return DEFAULT_MODEL
80
+
81
+ def query(self, prompt: str, output_format: Type[BaseModel]) -> Any:
82
+ """
83
+ Standard text query without images.
84
+
85
+ Args:
86
+ prompt: Text prompt
87
+ output_format: Pydantic model for structured output
88
+
89
+ Returns:
90
+ Parsed response
91
+ """
92
+ try:
93
+ # Add JSON instruction to prompt
94
+ json_prompt = f"""{prompt}
95
+
96
+ Please respond with valid JSON matching this schema:
97
+ {output_format.model_json_schema()}"""
98
+
99
+ response = self.client.messages.create(
100
+ model=self.model,
101
+ max_tokens=self.max_tokens,
102
+ messages=[{"role": "user", "content": json_prompt}],
103
+ )
104
+
105
+ content = response.content[0].text
106
+ return self.parseStringToJson(content)
107
+
108
+ except Exception as e:
109
+ raise AIEndpointRequestError(f"Anthropic query failed: {e}")
110
+
111
+ def query_with_image(
112
+ self,
113
+ prompt: str,
114
+ image_data: Union[ImageData, List[ImageData]],
115
+ output_format: Type[BaseModel]
116
+ ) -> Any:
117
+ """
118
+ Send a query with image(s) to Claude vision model.
119
+
120
+ Args:
121
+ prompt: Text prompt describing what to analyze
122
+ image_data: Single ImageData or list of ImageData
123
+ output_format: Pydantic model for structured output
124
+
125
+ Returns:
126
+ Parsed response according to output_format
127
+
128
+ Raises:
129
+ AIEndpointRequestError: If the request fails
130
+ """
131
+ try:
132
+ # Prepare images
133
+ images = [image_data] if isinstance(image_data, ImageData) else image_data
134
+
135
+ # Build content array with images first, then text
136
+ content = []
137
+
138
+ for img in images:
139
+ image_block = self._build_image_block(img)
140
+ content.append(image_block)
141
+
142
+ # Add JSON instruction to prompt
143
+ json_prompt = f"""{prompt}
144
+
145
+ Please respond with valid JSON matching this schema:
146
+ {output_format.model_json_schema()}
147
+
148
+ Only return the JSON object, no other text."""
149
+
150
+ content.append({"type": "text", "text": json_prompt})
151
+
152
+ # Make request
153
+ response = self.client.messages.create(
154
+ model=self.model,
155
+ max_tokens=self.max_tokens,
156
+ messages=[{"role": "user", "content": content}],
157
+ )
158
+
159
+ response_content = response.content[0].text
160
+ logger.debug(f"Anthropic vision response: {response_content[:500] if response_content else 'empty'}")
161
+
162
+ return self.parseStringToJson(response_content)
163
+
164
+ except AIEndpointRequestError:
165
+ raise
166
+ except Exception as e:
167
+ logger.error(f"Anthropic vision query failed: {e}")
168
+ import traceback
169
+ logger.error(traceback.format_exc())
170
+ raise AIEndpointRequestError(f"Anthropic vision query failed: {e}")
171
+
172
+ def chat_query_with_image(
173
+ self,
174
+ messages: List[Dict[str, Any]],
175
+ images: Any = None,
176
+ ) -> str:
177
+ """
178
+ Multi-turn chat with interleaved images for vision-based agent loops.
179
+
180
+ Messages may have 'content' as a string (text only) or a list of
181
+ content blocks (text + image dicts in Anthropic format).
182
+
183
+ Args:
184
+ messages: List of message dicts with 'role' and 'content'.
185
+ images: Unused (images are inline in messages).
186
+
187
+ Returns:
188
+ The model's response as a plain text string.
189
+ """
190
+ try:
191
+ system = ""
192
+ api_messages = []
193
+
194
+ for msg in messages:
195
+ if msg["role"] == "system":
196
+ system = msg["content"] if isinstance(msg["content"], str) else str(msg["content"])
197
+ else:
198
+ api_messages.append({
199
+ "role": msg["role"],
200
+ "content": msg["content"],
201
+ })
202
+
203
+ kwargs = {
204
+ "model": self.model,
205
+ "max_tokens": self.max_tokens,
206
+ "temperature": self.temperature,
207
+ "messages": api_messages,
208
+ }
209
+ if system:
210
+ kwargs["system"] = system
211
+
212
+ response = self.client.messages.create(**kwargs)
213
+ return response.content[0].text
214
+
215
+ except Exception as e:
216
+ logger.error(f"Anthropic vision chat query failed: {e}")
217
+ raise AIEndpointRequestError(f"Anthropic vision chat query failed: {e}")
218
+
219
+ def _build_image_block(self, image_data: ImageData) -> Dict[str, Any]:
220
+ """
221
+ Build image content block for Anthropic API.
222
+
223
+ Args:
224
+ image_data: ImageData object
225
+
226
+ Returns:
227
+ Dict with type: "image" and source content
228
+ """
229
+ if image_data.source == "url":
230
+ # Claude supports URL sources directly
231
+ return {
232
+ "type": "image",
233
+ "source": {
234
+ "type": "url",
235
+ "url": image_data.data
236
+ }
237
+ }
238
+
239
+ elif image_data.source == "base64":
240
+ # Determine media type
241
+ media_type = image_data.mime_type or "image/jpeg"
242
+
243
+ # Validate media type
244
+ if media_type not in SUPPORTED_MEDIA_TYPES:
245
+ logger.warning(f"Media type {media_type} may not be supported. Using image/jpeg.")
246
+ media_type = "image/jpeg"
247
+
248
+ return {
249
+ "type": "image",
250
+ "source": {
251
+ "type": "base64",
252
+ "media_type": media_type,
253
+ "data": image_data.data
254
+ }
255
+ }
256
+
257
+ else:
258
+ raise AIEndpointRequestError(f"Unknown image source: {image_data.source}")
259
+
260
+ def analyze_image(
261
+ self,
262
+ image_path_or_url: str,
263
+ prompt: str,
264
+ output_format: Type[BaseModel] = None
265
+ ) -> Any:
266
+ """
267
+ Convenience method for analyzing a single image.
268
+
269
+ Args:
270
+ image_path_or_url: Path to image file or URL
271
+ prompt: Analysis prompt
272
+ output_format: Optional output format model
273
+
274
+ Returns:
275
+ Analysis result
276
+ """
277
+ # Prepare image data
278
+ if image_path_or_url.startswith(("http://", "https://")):
279
+ # Claude can use URLs directly
280
+ image_data = self.create_url_image_data(image_path_or_url)
281
+ else:
282
+ image_data = self.encode_image_to_base64(image_path_or_url)
283
+
284
+ # Use a generic format if not specified
285
+ if output_format is None:
286
+ from .prompt.models_module import GeneralHintFormat
287
+ output_format = GeneralHintFormat
288
+
289
+ return self.query_with_image(prompt, image_data, output_format)
290
+
291
+ def detect_objects(
292
+ self,
293
+ image_path_or_url: str,
294
+ labels: List[str] = None
295
+ ) -> Dict[str, Any]:
296
+ """
297
+ Detect objects in an image and return bounding boxes.
298
+
299
+ Args:
300
+ image_path_or_url: Path to image file or URL
301
+ labels: Optional list of labels to detect
302
+
303
+ Returns:
304
+ Dict with detections list
305
+ """
306
+ from .prompt.models_module import VisualDetectionFormat
307
+
308
+ labels_str = ", ".join(labels) if labels else "all visible objects"
309
+
310
+ prompt = f"""Analyze this image and detect objects. For each object, provide:
311
+ 1. The label (from: {labels_str})
312
+ 2. A bounding box with normalized coordinates (0-1 range)
313
+ 3. Confidence score (0-1)
314
+
315
+ Return a JSON object with this exact structure:
316
+ {{
317
+ "detections": [
318
+ {{
319
+ "label": "object_name",
320
+ "bbox": {{"x": 0.1, "y": 0.2, "width": 0.3, "height": 0.4}},
321
+ "confidence": 0.95
322
+ }}
323
+ ]
324
+ }}
325
+
326
+ Important:
327
+ - Coordinates are normalized (0-1) where x,y is the top-left corner
328
+ - x increases left to right, y increases top to bottom
329
+ - width and height are also normalized (0-1)
330
+ - Only include objects you can clearly identify
331
+ - Estimate bounding boxes as accurately as possible"""
332
+
333
+ # Prepare image
334
+ if image_path_or_url.startswith(("http://", "https://")):
335
+ image_data = self.create_url_image_data(image_path_or_url)
336
+ else:
337
+ image_data = self.encode_image_to_base64(image_path_or_url)
338
+
339
+ return self.query_with_image(prompt, image_data, VisualDetectionFormat)
340
+
341
+ def get_annotation_hint(
342
+ self,
343
+ image_path_or_url: str,
344
+ task_description: str,
345
+ labels: List[str]
346
+ ) -> Dict[str, Any]:
347
+ """
348
+ Get a hint for annotating an image without revealing exact locations.
349
+
350
+ Args:
351
+ image_path_or_url: Path to image file or URL
352
+ task_description: Description of the annotation task
353
+ labels: Available labels
354
+
355
+ Returns:
356
+ Dict with hint text and optional suggested label
357
+ """
358
+ labels_str = ", ".join(labels)
359
+
360
+ prompt = f"""You are helping an annotator with this task: {task_description}
361
+
362
+ Available labels: {labels_str}
363
+
364
+ Provide a helpful hint that guides the annotator without giving away the exact answer.
365
+ The hint should:
366
+ 1. Point out relevant features to consider
367
+ 2. Suggest what to look for
368
+ 3. Not explicitly state the answer or exact locations
369
+
370
+ Return JSON:
371
+ {{
372
+ "hint": "Your helpful hint here",
373
+ "suggested_focus": "What area or aspect to focus on"
374
+ }}"""
375
+
376
+ # Prepare image
377
+ if image_path_or_url.startswith(("http://", "https://")):
378
+ image_data = self.create_url_image_data(image_path_or_url)
379
+ else:
380
+ image_data = self.encode_image_to_base64(image_path_or_url)
381
+
382
+ class HintFormat(BaseModel):
383
+ hint: str
384
+ suggested_focus: str
385
+
386
+ return self.query_with_image(prompt, image_data, HintFormat)
387
+
388
+ def health_check(self) -> bool:
389
+ """
390
+ Check if the Anthropic API is accessible.
391
+
392
+ Returns:
393
+ True if API is reachable, False otherwise
394
+ """
395
+ try:
396
+ # Simple test message
397
+ self.client.messages.create(
398
+ model=self.model,
399
+ max_tokens=10,
400
+ messages=[{"role": "user", "content": "Hello"}]
401
+ )
402
+ return True
403
+ except Exception as e:
404
+ logger.error(f"Anthropic health check failed: {e}")
405
+ return False
potato/ai/gemini_endpoint.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Google Gemini AI endpoint implementation.
3
+
4
+ This module provides integration with Google's Gemini API for LLM inference.
5
+ """
6
+
7
+ from google import genai
8
+ from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError
9
+
10
+ DEFAULT_MODEL = "gemini-2.0-flash-exp"
11
+
12
+
13
+ class GeminiEndpoint(BaseAIEndpoint):
14
+ """Google Gemini endpoint for cloud-based LLM inference."""
15
+
16
+ def _initialize_client(self) -> None:
17
+ """Initialize the Gemini client."""
18
+ api_key = self.ai_config.get("api_key", "")
19
+ if not api_key:
20
+ raise AIEndpointRequestError("Gemini API key is required")
21
+
22
+ # Default timeout of 30 seconds, configurable via ai_config
23
+ timeout = self.ai_config.get("timeout", 30)
24
+ self.client = genai.Client(
25
+ api_key=api_key,
26
+ http_options={'timeout': timeout}
27
+ )
28
+
29
+ def _get_default_model(self) -> str:
30
+ """Get the default Gemini model."""
31
+ return DEFAULT_MODEL
32
+
33
+ def query(self, prompt: str, prompt_format: dict) -> str:
34
+ """
35
+ Send a query to Gemini and return the response.
36
+
37
+ Args:
38
+ prompt: The prompt to send to the model
39
+
40
+ Returns:
41
+ The model's response as a string
42
+
43
+ Raises:
44
+ AIEndpointRequestError: If the request fails
45
+ """
46
+ try:
47
+ response = self.client.models.generate_content(
48
+ model=self.model,
49
+ contents=prompt,
50
+ generation_config={
51
+ 'max_output_tokens': self.max_tokens,
52
+ 'temperature': self.temperature,
53
+ 'response_schema': prompt_format.model_json_schema(),
54
+ }
55
+ )
56
+ return response.text
57
+ except Exception as e:
58
+ raise AIEndpointRequestError(f"Gemini request failed: {e}")
potato/ai/huggingface_endpoint.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face AI endpoint implementation.
3
+
4
+ This module provides integration with Hugging Face's Inference API for LLM inference.
5
+ """
6
+
7
+ from huggingface_hub import InferenceClient
8
+ from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError
9
+
10
+ DEFAULT_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
11
+
12
+ class HuggingfaceEndpoint(BaseAIEndpoint):
13
+ """Hugging Face endpoint for cloud-based LLM inference."""
14
+
15
+ def _initialize_client(self) -> None:
16
+ """Initialize the Hugging Face client."""
17
+ api_key = self.ai_config.get("api_key", "")
18
+ if not api_key:
19
+ raise AIEndpointRequestError("Hugging Face API key is required")
20
+
21
+ # Default timeout of 30 seconds, configurable via ai_config
22
+ timeout = self.ai_config.get("timeout", 30)
23
+ self.client = InferenceClient(
24
+ model=self.model,
25
+ token=api_key,
26
+ timeout=timeout
27
+ )
28
+
29
+ def _get_default_model(self) -> str:
30
+ """Get the default Hugging Face model."""
31
+ return DEFAULT_MODEL
32
+
33
+ def query(self, prompt: str, output_format: dict) -> str:
34
+ """
35
+ Send a query to Hugging Face and return the response.
36
+
37
+ Args:
38
+ prompt: The prompt to send to the model
39
+
40
+ Returns:
41
+ The model's response as a string
42
+
43
+ Raises:
44
+ AIEndpointRequestError: If the request fails
45
+ """
46
+ try:
47
+ response = self.client.chat_completion(
48
+ messages=[{"role": "user", "content": prompt}],
49
+ max_tokens=self.max_tokens,
50
+ temperature=self.temperature,
51
+ response_format= {
52
+ "type": "json_schema",
53
+ "json_schema": {
54
+ "name": "output_format",
55
+ "schema": output_format.model_json_schema(),
56
+ "strict": True,
57
+ }
58
+ }
59
+ )
60
+ return response.choices[0].message.content
61
+ except Exception as e:
62
+ raise AIEndpointRequestError(f"Hugging Face request failed: {e}")
potato/ai/icl_labeler.py ADDED
@@ -0,0 +1,1110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In-Context Learning (ICL) Labeler Module
3
+
4
+ This module provides AI-assisted labeling using high-confidence human annotations
5
+ as in-context examples to prompt an LLM to label remaining data.
6
+
7
+ Key features:
8
+ - Identifies high-confidence examples where annotators agree
9
+ - Uses examples as in-context demonstrations for LLM labeling
10
+ - Tracks LLM confidence scores on predictions
11
+ - Routes subset of LLM labels to humans for verification
12
+ - Calculates and reports LLM accuracy based on verification
13
+ """
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ import random
19
+ import threading
20
+ import time
21
+ from collections import Counter, defaultdict
22
+ from dataclasses import dataclass, field, asdict
23
+ from datetime import datetime
24
+ from typing import Dict, List, Optional, Any, Tuple, Set
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class HighConfidenceExample:
31
+ """A human-annotated example suitable for in-context learning."""
32
+ instance_id: str
33
+ text: str
34
+ schema_name: str
35
+ label: str
36
+ agreement_score: float # Proportion of annotators who chose this label
37
+ annotator_count: int
38
+ timestamp: datetime = field(default_factory=datetime.now)
39
+
40
+ def to_dict(self) -> Dict[str, Any]:
41
+ """Serialize to dictionary."""
42
+ return {
43
+ 'instance_id': self.instance_id,
44
+ 'text': self.text,
45
+ 'schema_name': self.schema_name,
46
+ 'label': self.label,
47
+ 'agreement_score': self.agreement_score,
48
+ 'annotator_count': self.annotator_count,
49
+ 'timestamp': self.timestamp.isoformat()
50
+ }
51
+
52
+ @classmethod
53
+ def from_dict(cls, data: Dict[str, Any]) -> 'HighConfidenceExample':
54
+ """Deserialize from dictionary."""
55
+ timestamp = data.get('timestamp')
56
+ if isinstance(timestamp, str):
57
+ timestamp = datetime.fromisoformat(timestamp)
58
+ elif timestamp is None:
59
+ timestamp = datetime.now()
60
+
61
+ return cls(
62
+ instance_id=data['instance_id'],
63
+ text=data['text'],
64
+ schema_name=data['schema_name'],
65
+ label=data['label'],
66
+ agreement_score=data['agreement_score'],
67
+ annotator_count=data['annotator_count'],
68
+ timestamp=timestamp
69
+ )
70
+
71
+
72
+ @dataclass
73
+ class ICLPrediction:
74
+ """Record of an LLM prediction using in-context learning."""
75
+ instance_id: str
76
+ schema_name: str
77
+ predicted_label: str
78
+ confidence_score: float # 0.0-1.0
79
+ timestamp: datetime = field(default_factory=datetime.now)
80
+
81
+ # In-context examples used
82
+ example_instance_ids: List[str] = field(default_factory=list)
83
+
84
+ # Verification tracking
85
+ verification_status: str = 'pending' # 'pending', 'verified_correct', 'verified_incorrect'
86
+ verified_by: Optional[str] = None
87
+ verified_at: Optional[datetime] = None
88
+ human_label: Optional[str] = None # Human's label if verified
89
+
90
+ # LLM metadata
91
+ model_name: str = ""
92
+ reasoning: str = ""
93
+
94
+ def to_dict(self) -> Dict[str, Any]:
95
+ """Serialize to dictionary."""
96
+ return {
97
+ 'instance_id': self.instance_id,
98
+ 'schema_name': self.schema_name,
99
+ 'predicted_label': self.predicted_label,
100
+ 'confidence_score': self.confidence_score,
101
+ 'timestamp': self.timestamp.isoformat(),
102
+ 'example_instance_ids': self.example_instance_ids,
103
+ 'verification_status': self.verification_status,
104
+ 'verified_by': self.verified_by,
105
+ 'verified_at': self.verified_at.isoformat() if self.verified_at else None,
106
+ 'human_label': self.human_label,
107
+ 'model_name': self.model_name,
108
+ 'reasoning': self.reasoning
109
+ }
110
+
111
+ @classmethod
112
+ def from_dict(cls, data: Dict[str, Any]) -> 'ICLPrediction':
113
+ """Deserialize from dictionary."""
114
+ timestamp = data.get('timestamp')
115
+ if isinstance(timestamp, str):
116
+ timestamp = datetime.fromisoformat(timestamp)
117
+ elif timestamp is None:
118
+ timestamp = datetime.now()
119
+
120
+ verified_at = data.get('verified_at')
121
+ if isinstance(verified_at, str):
122
+ verified_at = datetime.fromisoformat(verified_at)
123
+
124
+ return cls(
125
+ instance_id=data['instance_id'],
126
+ schema_name=data['schema_name'],
127
+ predicted_label=data['predicted_label'],
128
+ confidence_score=data['confidence_score'],
129
+ timestamp=timestamp,
130
+ example_instance_ids=data.get('example_instance_ids', []),
131
+ verification_status=data.get('verification_status', 'pending'),
132
+ verified_by=data.get('verified_by'),
133
+ verified_at=verified_at,
134
+ human_label=data.get('human_label'),
135
+ model_name=data.get('model_name', ''),
136
+ reasoning=data.get('reasoning', '')
137
+ )
138
+
139
+
140
+ class ICLLabeler:
141
+ """
142
+ Manages in-context learning based labeling using high-confidence human annotations.
143
+
144
+ Workflow:
145
+ 1. Monitors annotation progress for high-confidence examples
146
+ 2. Periodically refreshes pool of high-confidence examples
147
+ 3. Uses examples to prompt LLM for labeling unlabeled instances
148
+ 4. Routes some LLM-labeled instances for human verification (blind)
149
+ 5. Tracks accuracy metrics
150
+ """
151
+
152
+ _instance = None
153
+ _lock = threading.RLock()
154
+
155
+ def __new__(cls, *args, **kwargs):
156
+ """Singleton pattern."""
157
+ if cls._instance is None:
158
+ with cls._lock:
159
+ if cls._instance is None:
160
+ cls._instance = super().__new__(cls)
161
+ cls._instance._initialized = False
162
+ return cls._instance
163
+
164
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
165
+ """
166
+ Initialize the ICLLabeler.
167
+
168
+ Args:
169
+ config: Configuration dictionary with settings
170
+ """
171
+ if self._initialized:
172
+ return
173
+
174
+ self.config = config or {}
175
+ self._ai_endpoint = None
176
+
177
+ # Get ICL labeling config
178
+ icl_config = self.config.get('icl_labeling', {})
179
+
180
+ # Example selection config
181
+ example_config = icl_config.get('example_selection', {})
182
+ self.min_agreement_threshold = example_config.get('min_agreement_threshold', 0.8)
183
+ self.min_annotators_per_instance = example_config.get('min_annotators_per_instance', 2)
184
+ self.max_examples_per_schema = example_config.get('max_examples_per_schema', 10)
185
+ self.example_refresh_interval = example_config.get('refresh_interval_seconds', 300)
186
+
187
+ # LLM labeling config
188
+ llm_config = icl_config.get('llm_labeling', {})
189
+ self.batch_size = llm_config.get('batch_size', 20)
190
+ self.trigger_threshold = llm_config.get('trigger_threshold', 5)
191
+ self.confidence_threshold = llm_config.get('confidence_threshold', 0.7)
192
+ self.batch_interval = llm_config.get('batch_interval_seconds', 600)
193
+
194
+ # Limits to prevent labeling entire dataset at once
195
+ # This allows iterative improvement - verify accuracy before labeling more
196
+ self.max_total_labels = llm_config.get('max_total_labels', None) # Max instances to label total
197
+ self.max_unlabeled_ratio = llm_config.get('max_unlabeled_ratio', 0.5) # Max % of unlabeled to label
198
+ self.pause_on_low_accuracy = llm_config.get('pause_on_low_accuracy', True)
199
+ self.min_accuracy_threshold = llm_config.get('min_accuracy_threshold', 0.7) # Pause if accuracy below
200
+
201
+ # Verification config
202
+ verification_config = icl_config.get('verification', {})
203
+ self.verification_enabled = verification_config.get('enabled', True)
204
+ self.verification_sample_rate = verification_config.get('sample_rate', 0.2)
205
+ self.verification_strategy = verification_config.get('selection_strategy', 'low_confidence')
206
+
207
+ # Persistence config
208
+ persistence_config = icl_config.get('persistence', {})
209
+ self.predictions_file = persistence_config.get('predictions_file', 'icl_predictions.json')
210
+
211
+ # State
212
+ self.schema_to_examples: Dict[str, List[HighConfidenceExample]] = {}
213
+ self.predictions: Dict[str, Dict[str, ICLPrediction]] = {} # instance_id -> schema -> prediction
214
+ self.verification_queue: List[Tuple[str, str]] = [] # [(instance_id, schema_name), ...]
215
+ self.labeled_instance_ids: Set[str] = set() # Instances labeled by LLM
216
+
217
+ self.last_example_refresh: Optional[datetime] = None
218
+ self.last_batch_run: Optional[datetime] = None
219
+
220
+ # Background worker
221
+ self._worker_thread: Optional[threading.Thread] = None
222
+ self._stop_worker = threading.Event()
223
+
224
+ self._initialized = True
225
+ logger.info("ICLLabeler initialized")
226
+
227
+ def _get_ai_endpoint(self):
228
+ """Get or create AI endpoint from config (reuses ai_support config)."""
229
+ if self._ai_endpoint is None:
230
+ from potato.ai.ai_endpoint import AIEndpointFactory
231
+ self._ai_endpoint = AIEndpointFactory.create_endpoint(self.config)
232
+ return self._ai_endpoint
233
+
234
+ def _get_annotation_schemes(self) -> List[Dict[str, Any]]:
235
+ """Get annotation schemes from config."""
236
+ return self.config.get('annotation_schemes', [])
237
+
238
+ def _get_text_key(self) -> str:
239
+ """Get the text key from item_properties."""
240
+ return self.config.get('item_properties', {}).get('text_key', 'text')
241
+
242
+ # === High-Confidence Example Collection ===
243
+
244
+ def refresh_high_confidence_examples(self) -> Dict[str, List[HighConfidenceExample]]:
245
+ """
246
+ Scan annotations and identify high-confidence examples.
247
+
248
+ Returns:
249
+ Dictionary mapping schema name to list of high-confidence examples
250
+ """
251
+ from potato.flask_server import get_users, get_user_state, get_item_state_manager
252
+
253
+ with self._lock:
254
+ new_examples: Dict[str, List[HighConfidenceExample]] = defaultdict(list)
255
+
256
+ try:
257
+ ism = get_item_state_manager()
258
+ if ism is None:
259
+ logger.warning("ItemStateManager not available")
260
+ return new_examples
261
+
262
+ text_key = self._get_text_key()
263
+ schemas = self._get_annotation_schemes()
264
+ schema_names = [s.get('name') for s in schemas if s.get('name')]
265
+
266
+ # Collect all annotations per instance
267
+ instance_annotations: Dict[str, Dict[str, List[Tuple[str, Any]]]] = defaultdict(
268
+ lambda: defaultdict(list)
269
+ ) # instance_id -> schema_name -> [(user_id, value), ...]
270
+
271
+ for username in get_users():
272
+ user_state = get_user_state(username)
273
+ if not user_state:
274
+ continue
275
+
276
+ all_annotations = user_state.get_all_annotations()
277
+ for instance_id, instance_data in all_annotations.items():
278
+ if 'labels' not in instance_data:
279
+ continue
280
+
281
+ for label, value in instance_data['labels'].items():
282
+ schema_name = label.get_schema() if hasattr(label, 'get_schema') else str(label)
283
+ if schema_name in schema_names:
284
+ instance_annotations[instance_id][schema_name].append((username, value))
285
+
286
+ # Find high-confidence examples
287
+ for instance_id, schema_data in instance_annotations.items():
288
+ for schema_name, annotations in schema_data.items():
289
+ annotator_count = len(annotations)
290
+
291
+ if annotator_count < self.min_annotators_per_instance:
292
+ continue
293
+
294
+ # Count votes per label
295
+ label_counts = Counter(value for _, value in annotations)
296
+ most_common_label, most_common_count = label_counts.most_common(1)[0]
297
+
298
+ # Calculate agreement
299
+ agreement_score = most_common_count / annotator_count
300
+
301
+ if agreement_score >= self.min_agreement_threshold:
302
+ # Get instance text
303
+ item = ism.get_item(instance_id)
304
+ instance_data = item.get_data() if item else None
305
+ if instance_data is None:
306
+ continue
307
+
308
+ text = instance_data.get(text_key, '')
309
+ if not text:
310
+ continue
311
+
312
+ example = HighConfidenceExample(
313
+ instance_id=instance_id,
314
+ text=text,
315
+ schema_name=schema_name,
316
+ label=str(most_common_label),
317
+ agreement_score=agreement_score,
318
+ annotator_count=annotator_count
319
+ )
320
+ new_examples[schema_name].append(example)
321
+
322
+ # Select examples using coverage-based selection (CoverICL-inspired)
323
+ # or fall back to agreement-score sorting
324
+ for schema_name in new_examples:
325
+ candidates = new_examples[schema_name]
326
+ if len(candidates) > self.max_examples_per_schema:
327
+ selected = self._select_diverse_examples(
328
+ candidates, self.max_examples_per_schema
329
+ )
330
+ new_examples[schema_name] = selected
331
+ else:
332
+ new_examples[schema_name].sort(
333
+ key=lambda x: x.agreement_score, reverse=True
334
+ )
335
+
336
+ self.schema_to_examples = dict(new_examples)
337
+ self.last_example_refresh = datetime.now()
338
+
339
+ total_examples = sum(len(examples) for examples in new_examples.values())
340
+ logger.info(f"Refreshed high-confidence examples: {total_examples} examples across {len(new_examples)} schemas")
341
+
342
+ except Exception as e:
343
+ logger.error(f"Error refreshing examples: {e}")
344
+
345
+ return self.schema_to_examples
346
+
347
+ def get_examples_for_schema(self, schema_name: str) -> List[HighConfidenceExample]:
348
+ """Get high-confidence examples for a specific schema."""
349
+ return self.schema_to_examples.get(schema_name, [])
350
+
351
+ def has_enough_examples(self, schema_name: str) -> bool:
352
+ """Check if we have enough examples to start labeling."""
353
+ return len(self.get_examples_for_schema(schema_name)) >= self.trigger_threshold
354
+
355
+ def _select_diverse_examples(
356
+ self,
357
+ candidates: List[HighConfidenceExample],
358
+ k: int,
359
+ ) -> List[HighConfidenceExample]:
360
+ """CoverICL-inspired coverage-based example selection.
361
+
362
+ Uses greedy facility location to select examples that maximize
363
+ coverage of the instance embedding space, ensuring diverse and
364
+ representative ICL demonstrations.
365
+
366
+ Inspired by Mavromatis et al. (2024) CoverICL. Falls back to
367
+ agreement-score sorting if vectorization fails.
368
+
369
+ Args:
370
+ candidates: Pool of high-confidence examples
371
+ k: Number of examples to select
372
+
373
+ Returns:
374
+ List of selected examples maximizing coverage
375
+ """
376
+ if len(candidates) <= k:
377
+ return candidates
378
+
379
+ try:
380
+ from sklearn.feature_extraction.text import TfidfVectorizer
381
+ from sklearn.metrics.pairwise import cosine_distances
382
+
383
+ texts = [c.text for c in candidates]
384
+ vectorizer = TfidfVectorizer(max_features=5000)
385
+ features = vectorizer.fit_transform(texts).toarray()
386
+
387
+ # Greedy facility location: iteratively pick the candidate that
388
+ # maximizes the minimum distance to already-selected examples
389
+ n = len(candidates)
390
+ selected_indices = []
391
+
392
+ # Start with the candidate that has highest agreement score
393
+ # (quality-weighted seed)
394
+ agreements = [c.agreement_score for c in candidates]
395
+ first = int(max(range(n), key=lambda i: agreements[i]))
396
+ selected_indices.append(first)
397
+
398
+ dist_matrix = cosine_distances(features)
399
+
400
+ for _ in range(k - 1):
401
+ # For each unselected candidate, compute min distance to selected set
402
+ best_idx = -1
403
+ best_score = -1.0
404
+
405
+ for i in range(n):
406
+ if i in selected_indices:
407
+ continue
408
+ min_dist = min(dist_matrix[i][j] for j in selected_indices)
409
+ # Weight by agreement score for quality-aware selection
410
+ score = min_dist * candidates[i].agreement_score
411
+ if score > best_score:
412
+ best_score = score
413
+ best_idx = i
414
+
415
+ if best_idx >= 0:
416
+ selected_indices.append(best_idx)
417
+ else:
418
+ break
419
+
420
+ selected = [candidates[i] for i in selected_indices]
421
+ logger.debug(f"CoverICL selection: {len(selected)} diverse examples from {n} candidates")
422
+ return selected
423
+
424
+ except Exception as e:
425
+ logger.warning(f"Coverage-based selection failed, using agreement sorting: {e}")
426
+ candidates.sort(key=lambda x: x.agreement_score, reverse=True)
427
+ return candidates[:k]
428
+
429
+ # === LLM Labeling ===
430
+
431
+ def label_instance(
432
+ self,
433
+ instance_id: str,
434
+ schema_name: str,
435
+ instance_text: str
436
+ ) -> Optional[ICLPrediction]:
437
+ """
438
+ Label a single instance using in-context learning.
439
+
440
+ Args:
441
+ instance_id: The instance to label
442
+ schema_name: The annotation schema to use
443
+ instance_text: The text to label
444
+
445
+ Returns:
446
+ ICLPrediction if successful, None otherwise
447
+ """
448
+ from potato.ai.icl_prompt_builder import ICLPromptBuilder
449
+
450
+ examples = self.get_examples_for_schema(schema_name)
451
+ if not examples:
452
+ logger.warning(f"No examples available for schema {schema_name}")
453
+ return None
454
+
455
+ # Get schema info
456
+ schemas = self._get_annotation_schemes()
457
+ schema_info = next((s for s in schemas if s.get('name') == schema_name), None)
458
+ if not schema_info:
459
+ logger.warning(f"Schema {schema_name} not found in config")
460
+ return None
461
+
462
+ endpoint = self._get_ai_endpoint()
463
+ if endpoint is None:
464
+ logger.warning("AI endpoint not available")
465
+ return None
466
+
467
+ try:
468
+ # Build prompt
469
+ prompt_builder = ICLPromptBuilder()
470
+ prompt = prompt_builder.build_prompt(
471
+ schema=schema_info,
472
+ examples=examples,
473
+ target_text=instance_text
474
+ )
475
+
476
+ # Query LLM
477
+ from pydantic import BaseModel
478
+
479
+ class ICLResponse(BaseModel):
480
+ label: str
481
+ confidence: float
482
+ reasoning: str = ""
483
+
484
+ response = endpoint.query(prompt, ICLResponse)
485
+
486
+ # Parse response
487
+ if isinstance(response, str):
488
+ response_data = json.loads(response)
489
+ elif hasattr(response, 'model_dump'):
490
+ response_data = response.model_dump()
491
+ else:
492
+ response_data = response
493
+
494
+ predicted_label = response_data.get('label', '')
495
+ confidence = float(response_data.get('confidence', 0.5))
496
+ reasoning = response_data.get('reasoning', '')
497
+
498
+ # Validate label against schema
499
+ valid_labels = self._get_valid_labels(schema_info)
500
+ if valid_labels and predicted_label not in valid_labels:
501
+ # Try fuzzy matching
502
+ predicted_label = self._fuzzy_match_label(predicted_label, valid_labels)
503
+ if predicted_label is None:
504
+ logger.warning(f"LLM returned invalid label for {instance_id}")
505
+ return None
506
+
507
+ # Create prediction
508
+ prediction = ICLPrediction(
509
+ instance_id=instance_id,
510
+ schema_name=schema_name,
511
+ predicted_label=predicted_label,
512
+ confidence_score=min(1.0, max(0.0, confidence)),
513
+ example_instance_ids=[e.instance_id for e in examples],
514
+ model_name=endpoint.model if hasattr(endpoint, 'model') else '',
515
+ reasoning=reasoning
516
+ )
517
+
518
+ # Store prediction
519
+ with self._lock:
520
+ if instance_id not in self.predictions:
521
+ self.predictions[instance_id] = {}
522
+ self.predictions[instance_id][schema_name] = prediction
523
+ self.labeled_instance_ids.add(instance_id)
524
+
525
+ # Maybe add to verification queue
526
+ if self.verification_enabled and random.random() < self.verification_sample_rate:
527
+ self.verification_queue.append((instance_id, schema_name))
528
+
529
+ logger.debug(f"Labeled {instance_id} with {predicted_label} (confidence: {confidence:.2f})")
530
+ return prediction
531
+
532
+ except Exception as e:
533
+ logger.error(f"Error labeling instance {instance_id}: {e}")
534
+ return None
535
+
536
+ def _get_valid_labels(self, schema_info: Dict[str, Any]) -> List[str]:
537
+ """Extract valid labels from schema info."""
538
+ labels = schema_info.get('labels', [])
539
+ valid_labels = []
540
+ for label in labels:
541
+ if isinstance(label, str):
542
+ valid_labels.append(label)
543
+ elif isinstance(label, dict):
544
+ valid_labels.append(label.get('name', str(label)))
545
+ return valid_labels
546
+
547
+ def _fuzzy_match_label(self, predicted: str, valid_labels: List[str]) -> Optional[str]:
548
+ """Try to match predicted label to a valid label."""
549
+ predicted_lower = predicted.lower().strip()
550
+ for label in valid_labels:
551
+ if label.lower().strip() == predicted_lower:
552
+ return label
553
+ return None
554
+
555
+ def should_pause_labeling(self) -> Tuple[bool, str]:
556
+ """
557
+ Check if labeling should be paused based on limits and accuracy.
558
+
559
+ Returns:
560
+ Tuple of (should_pause, reason)
561
+ """
562
+ # Check if max total labels reached
563
+ if self.max_total_labels is not None:
564
+ current_count = len(self.labeled_instance_ids)
565
+ if current_count >= self.max_total_labels:
566
+ return True, f"Reached max_total_labels limit ({self.max_total_labels})"
567
+
568
+ # Check accuracy threshold
569
+ if self.pause_on_low_accuracy:
570
+ metrics = self.get_accuracy_metrics()
571
+ total_verified = metrics.get('total_verified', 0)
572
+ accuracy = metrics.get('accuracy')
573
+
574
+ # Only check accuracy if we have enough verifications
575
+ min_verifications = 10
576
+ if total_verified >= min_verifications and accuracy is not None:
577
+ if accuracy < self.min_accuracy_threshold:
578
+ return True, f"Accuracy ({accuracy:.1%}) below threshold ({self.min_accuracy_threshold:.1%})"
579
+
580
+ return False, ""
581
+
582
+ def get_remaining_label_capacity(self) -> int:
583
+ """
584
+ Get how many more instances can be labeled.
585
+
586
+ Returns:
587
+ Number of instances that can still be labeled, or -1 for unlimited
588
+ """
589
+ from potato.item_state_management import get_item_state_manager
590
+ from potato.user_state_management import get_user_state_manager
591
+
592
+ try:
593
+ ism = get_item_state_manager()
594
+ except ValueError:
595
+ # ISM not initialized yet
596
+ return 0
597
+ if ism is None:
598
+ return 0
599
+
600
+ # Count unlabeled instances (not labeled by humans or LLM)
601
+ unlabeled_count = 0
602
+ usm = get_user_state_manager()
603
+ for instance_id in ism.instance_id_ordering:
604
+ if instance_id in self.labeled_instance_ids:
605
+ continue
606
+
607
+ has_human_annotation = False
608
+ for username in usm.get_all_users():
609
+ user_state = usm.get_user_state(username)
610
+ if user_state:
611
+ all_annotations = user_state.get_all_annotations()
612
+ if instance_id in all_annotations:
613
+ has_human_annotation = True
614
+ break
615
+
616
+ if not has_human_annotation:
617
+ unlabeled_count += 1
618
+
619
+ current_llm_labels = len(self.labeled_instance_ids)
620
+
621
+ # Calculate max based on ratio
622
+ max_from_ratio = int(unlabeled_count * self.max_unlabeled_ratio)
623
+
624
+ # Calculate max based on total limit
625
+ if self.max_total_labels is not None:
626
+ max_from_total = self.max_total_labels - current_llm_labels
627
+ return min(max_from_ratio, max_from_total)
628
+
629
+ return max_from_ratio
630
+
631
+ def batch_label_instances(self, schema_name: str) -> List[ICLPrediction]:
632
+ """
633
+ Label multiple unlabeled instances for a schema.
634
+
635
+ Respects configured limits to prevent labeling entire dataset at once.
636
+
637
+ Returns:
638
+ List of successful predictions
639
+ """
640
+ from potato.flask_server import get_item_state_manager, get_users, get_user_state
641
+
642
+ # Check if we should pause labeling
643
+ should_pause, reason = self.should_pause_labeling()
644
+ if should_pause:
645
+ logger.info(f"Labeling paused: {reason}")
646
+ return []
647
+
648
+ if not self.has_enough_examples(schema_name):
649
+ logger.info(f"Not enough examples for schema {schema_name}")
650
+ return []
651
+
652
+ ism = get_item_state_manager()
653
+ if ism is None:
654
+ return []
655
+
656
+ # Check remaining capacity
657
+ remaining_capacity = self.get_remaining_label_capacity()
658
+ if remaining_capacity <= 0:
659
+ logger.info("No remaining label capacity")
660
+ return []
661
+
662
+ # Limit batch size to remaining capacity
663
+ effective_batch_size = min(self.batch_size, remaining_capacity)
664
+
665
+ text_key = self._get_text_key()
666
+ predictions = []
667
+
668
+ # Find unlabeled instances
669
+ unlabeled_ids = []
670
+ for instance_id in ism.instance_id_ordering:
671
+ # Skip if already labeled by LLM
672
+ if instance_id in self.labeled_instance_ids:
673
+ continue
674
+
675
+ # Skip if already annotated by humans
676
+ has_human_annotation = False
677
+ for username in get_users():
678
+ user_state = get_user_state(username)
679
+ if user_state:
680
+ all_annotations = user_state.get_all_annotations()
681
+ if instance_id in all_annotations:
682
+ has_human_annotation = True
683
+ break
684
+
685
+ if not has_human_annotation:
686
+ unlabeled_ids.append(instance_id)
687
+
688
+ if len(unlabeled_ids) >= effective_batch_size:
689
+ break
690
+
691
+ # Label instances
692
+ for instance_id in unlabeled_ids:
693
+ item = ism.get_item(instance_id)
694
+ instance_data = item.get_data() if item else None
695
+ if instance_data is None:
696
+ continue
697
+
698
+ text = instance_data.get(text_key, '')
699
+ if not text:
700
+ continue
701
+
702
+ prediction = self.label_instance(instance_id, schema_name, text)
703
+ if prediction and prediction.confidence_score >= self.confidence_threshold:
704
+ predictions.append(prediction)
705
+
706
+ self.last_batch_run = datetime.now()
707
+ logger.info(f"Batch labeled {len(predictions)} instances for schema {schema_name}")
708
+
709
+ return predictions
710
+
711
+ # === Verification Workflow ===
712
+
713
+ def get_pending_verifications(self, count: int = 1) -> List[Tuple[str, str]]:
714
+ """
715
+ Get instances pending human verification.
716
+
717
+ Args:
718
+ count: Number of verification tasks to return
719
+
720
+ Returns:
721
+ List of (instance_id, schema_name) tuples
722
+ """
723
+ with self._lock:
724
+ if self.verification_strategy == 'low_confidence':
725
+ # Sort by confidence ascending
726
+ pending = [
727
+ (inst_id, schema)
728
+ for inst_id, schema in self.verification_queue
729
+ if (inst_id in self.predictions and
730
+ schema in self.predictions[inst_id] and
731
+ self.predictions[inst_id][schema].verification_status == 'pending')
732
+ ]
733
+ pending.sort(
734
+ key=lambda x: self.predictions[x[0]][x[1]].confidence_score
735
+ )
736
+ return pending[:count]
737
+
738
+ elif self.verification_strategy == 'random':
739
+ pending = [
740
+ (inst_id, schema)
741
+ for inst_id, schema in self.verification_queue
742
+ if (inst_id in self.predictions and
743
+ schema in self.predictions[inst_id] and
744
+ self.predictions[inst_id][schema].verification_status == 'pending')
745
+ ]
746
+ random.shuffle(pending)
747
+ return pending[:count]
748
+
749
+ else: # mixed
750
+ pending = [
751
+ (inst_id, schema)
752
+ for inst_id, schema in self.verification_queue
753
+ if (inst_id in self.predictions and
754
+ schema in self.predictions[inst_id] and
755
+ self.predictions[inst_id][schema].verification_status == 'pending')
756
+ ]
757
+ # 50% low confidence, 50% random
758
+ pending.sort(
759
+ key=lambda x: self.predictions[x[0]][x[1]].confidence_score
760
+ )
761
+ half = count // 2
762
+ low_conf = pending[:half]
763
+ rest = pending[half:]
764
+ random.shuffle(rest)
765
+ return low_conf + rest[:count - half]
766
+
767
+ def record_verification(
768
+ self,
769
+ instance_id: str,
770
+ schema_name: str,
771
+ human_label: str,
772
+ verified_by: str
773
+ ) -> bool:
774
+ """
775
+ Record human verification of an LLM prediction.
776
+
777
+ Args:
778
+ instance_id: The verified instance
779
+ schema_name: The schema verified
780
+ human_label: The human's label
781
+ verified_by: Username of verifier
782
+
783
+ Returns:
784
+ True if verification recorded successfully
785
+ """
786
+ with self._lock:
787
+ if instance_id not in self.predictions:
788
+ logger.warning(f"No prediction found for instance {instance_id}")
789
+ return False
790
+
791
+ if schema_name not in self.predictions[instance_id]:
792
+ logger.warning(f"No prediction found for schema {schema_name}")
793
+ return False
794
+
795
+ prediction = self.predictions[instance_id][schema_name]
796
+ prediction.human_label = human_label
797
+ prediction.verified_by = verified_by
798
+ prediction.verified_at = datetime.now()
799
+
800
+ if prediction.predicted_label == human_label:
801
+ prediction.verification_status = 'verified_correct'
802
+ else:
803
+ prediction.verification_status = 'verified_incorrect'
804
+
805
+ # Remove from verification queue
806
+ try:
807
+ self.verification_queue.remove((instance_id, schema_name))
808
+ except ValueError:
809
+ pass
810
+
811
+ logger.info(
812
+ f"Verification recorded for {instance_id}: "
813
+ f"predicted={prediction.predicted_label}, human={human_label}, "
814
+ f"status={prediction.verification_status}"
815
+ )
816
+
817
+ return True
818
+
819
+ # === Accuracy Tracking ===
820
+
821
+ def get_accuracy_metrics(self, schema_name: Optional[str] = None) -> Dict[str, Any]:
822
+ """
823
+ Calculate accuracy metrics from verified predictions.
824
+
825
+ Args:
826
+ schema_name: Optional schema to filter by
827
+
828
+ Returns:
829
+ Dictionary with accuracy metrics
830
+ """
831
+ with self._lock:
832
+ verified_correct = 0
833
+ verified_incorrect = 0
834
+ pending = 0
835
+ total_predictions = 0
836
+
837
+ confidence_correct = []
838
+ confidence_incorrect = []
839
+
840
+ for inst_id, schemas in self.predictions.items():
841
+ for s_name, prediction in schemas.items():
842
+ if schema_name and s_name != schema_name:
843
+ continue
844
+
845
+ total_predictions += 1
846
+
847
+ if prediction.verification_status == 'verified_correct':
848
+ verified_correct += 1
849
+ confidence_correct.append(prediction.confidence_score)
850
+ elif prediction.verification_status == 'verified_incorrect':
851
+ verified_incorrect += 1
852
+ confidence_incorrect.append(prediction.confidence_score)
853
+ else:
854
+ pending += 1
855
+
856
+ total_verified = verified_correct + verified_incorrect
857
+ accuracy = verified_correct / total_verified if total_verified > 0 else None
858
+
859
+ avg_confidence_correct = (
860
+ sum(confidence_correct) / len(confidence_correct)
861
+ if confidence_correct else None
862
+ )
863
+ avg_confidence_incorrect = (
864
+ sum(confidence_incorrect) / len(confidence_incorrect)
865
+ if confidence_incorrect else None
866
+ )
867
+
868
+ return {
869
+ 'total_predictions': total_predictions,
870
+ 'verified_correct': verified_correct,
871
+ 'verified_incorrect': verified_incorrect,
872
+ 'pending_verification': pending,
873
+ 'total_verified': total_verified,
874
+ 'accuracy': accuracy,
875
+ 'avg_confidence_correct': avg_confidence_correct,
876
+ 'avg_confidence_incorrect': avg_confidence_incorrect,
877
+ 'schema_name': schema_name
878
+ }
879
+
880
+ def get_status(self) -> Dict[str, Any]:
881
+ """Get overall ICL labeler status."""
882
+ with self._lock:
883
+ total_examples = sum(len(ex) for ex in self.schema_to_examples.values())
884
+ examples_by_schema = {
885
+ schema: len(examples)
886
+ for schema, examples in self.schema_to_examples.items()
887
+ }
888
+
889
+ # Check labeling status
890
+ should_pause, pause_reason = self.should_pause_labeling()
891
+ remaining_capacity = self.get_remaining_label_capacity()
892
+
893
+ return {
894
+ 'enabled': self.config.get('icl_labeling', {}).get('enabled', False),
895
+ 'total_examples': total_examples,
896
+ 'examples_by_schema': examples_by_schema,
897
+ 'total_predictions': sum(
898
+ len(schemas) for schemas in self.predictions.values()
899
+ ),
900
+ 'labeled_instances': len(self.labeled_instance_ids),
901
+ 'verification_queue_size': len(self.verification_queue),
902
+ 'last_example_refresh': (
903
+ self.last_example_refresh.isoformat()
904
+ if self.last_example_refresh else None
905
+ ),
906
+ 'last_batch_run': (
907
+ self.last_batch_run.isoformat()
908
+ if self.last_batch_run else None
909
+ ),
910
+ 'worker_running': (
911
+ self._worker_thread is not None and
912
+ self._worker_thread.is_alive()
913
+ ),
914
+ 'accuracy_metrics': self.get_accuracy_metrics(),
915
+ # Labeling limits status
916
+ 'labeling_paused': should_pause,
917
+ 'pause_reason': pause_reason,
918
+ 'remaining_label_capacity': remaining_capacity,
919
+ 'max_total_labels': self.max_total_labels,
920
+ 'max_unlabeled_ratio': self.max_unlabeled_ratio,
921
+ 'min_accuracy_threshold': self.min_accuracy_threshold
922
+ }
923
+
924
+ # === Background Worker ===
925
+
926
+ def start_background_worker(self) -> None:
927
+ """Start the background worker thread."""
928
+ if self._worker_thread is not None and self._worker_thread.is_alive():
929
+ logger.warning("Background worker already running")
930
+ return
931
+
932
+ self._stop_worker.clear()
933
+ self._worker_thread = threading.Thread(
934
+ target=self._worker_loop,
935
+ name="ICLLabelerWorker",
936
+ daemon=True
937
+ )
938
+ self._worker_thread.start()
939
+ logger.info("Started ICL labeler background worker")
940
+
941
+ def stop_background_worker(self) -> None:
942
+ """Stop the background worker thread."""
943
+ if self._worker_thread is None:
944
+ return
945
+
946
+ self._stop_worker.set()
947
+ self._worker_thread.join(timeout=5.0)
948
+ self._worker_thread = None
949
+ logger.info("Stopped ICL labeler background worker")
950
+
951
+ def _worker_loop(self) -> None:
952
+ """Main loop for the background worker."""
953
+ logger.info(
954
+ f"ICL background worker started, "
955
+ f"example_refresh={self.example_refresh_interval}s, "
956
+ f"batch_interval={self.batch_interval}s"
957
+ )
958
+
959
+ last_example_refresh = 0
960
+ last_batch = 0
961
+
962
+ while not self._stop_worker.is_set():
963
+ try:
964
+ current_time = time.time()
965
+
966
+ # Refresh examples periodically
967
+ if current_time - last_example_refresh >= self.example_refresh_interval:
968
+ self.refresh_high_confidence_examples()
969
+ last_example_refresh = current_time
970
+
971
+ # Run batch labeling periodically
972
+ if current_time - last_batch >= self.batch_interval:
973
+ schemas = self._get_annotation_schemes()
974
+ for schema in schemas:
975
+ schema_name = schema.get('name')
976
+ if schema_name and self.has_enough_examples(schema_name):
977
+ predictions = self.batch_label_instances(schema_name)
978
+ if predictions:
979
+ self.save_state()
980
+ last_batch = current_time
981
+
982
+ except Exception as e:
983
+ logger.error(f"ICL background worker error: {e}")
984
+
985
+ # Wait for next interval or stop signal
986
+ self._stop_worker.wait(min(self.example_refresh_interval, self.batch_interval) / 2)
987
+
988
+ # === Persistence ===
989
+
990
+ def save_state(self) -> None:
991
+ """Save current state to disk."""
992
+ task_dir = self.config.get('output_annotation_dir', '')
993
+ if not task_dir:
994
+ return
995
+
996
+ filepath = os.path.join(task_dir, self.predictions_file)
997
+
998
+ try:
999
+ with self._lock:
1000
+ state = {
1001
+ 'predictions': {
1002
+ inst_id: {
1003
+ schema: pred.to_dict()
1004
+ for schema, pred in schemas.items()
1005
+ }
1006
+ for inst_id, schemas in self.predictions.items()
1007
+ },
1008
+ 'examples': {
1009
+ schema: [ex.to_dict() for ex in examples]
1010
+ for schema, examples in self.schema_to_examples.items()
1011
+ },
1012
+ 'verification_queue': self.verification_queue,
1013
+ 'labeled_instance_ids': list(self.labeled_instance_ids),
1014
+ 'last_example_refresh': (
1015
+ self.last_example_refresh.isoformat()
1016
+ if self.last_example_refresh else None
1017
+ ),
1018
+ 'last_batch_run': (
1019
+ self.last_batch_run.isoformat()
1020
+ if self.last_batch_run else None
1021
+ )
1022
+ }
1023
+
1024
+ # Atomic write
1025
+ temp_path = filepath + '.tmp'
1026
+ with open(temp_path, 'w') as f:
1027
+ json.dump(state, f, indent=2)
1028
+ os.replace(temp_path, filepath)
1029
+
1030
+ logger.debug(f"Saved ICL state to {filepath}")
1031
+
1032
+ except Exception as e:
1033
+ logger.error(f"Error saving ICL state: {e}")
1034
+
1035
+ def load_state(self) -> None:
1036
+ """Load state from disk."""
1037
+ task_dir = self.config.get('output_annotation_dir', '')
1038
+ if not task_dir:
1039
+ return
1040
+
1041
+ filepath = os.path.join(task_dir, self.predictions_file)
1042
+
1043
+ if not os.path.exists(filepath):
1044
+ return
1045
+
1046
+ try:
1047
+ with open(filepath, 'r') as f:
1048
+ state = json.load(f)
1049
+
1050
+ with self._lock:
1051
+ # Load predictions
1052
+ self.predictions = {}
1053
+ for inst_id, schemas in state.get('predictions', {}).items():
1054
+ self.predictions[inst_id] = {
1055
+ schema: ICLPrediction.from_dict(pred_data)
1056
+ for schema, pred_data in schemas.items()
1057
+ }
1058
+
1059
+ # Load examples
1060
+ self.schema_to_examples = {}
1061
+ for schema, examples in state.get('examples', {}).items():
1062
+ self.schema_to_examples[schema] = [
1063
+ HighConfidenceExample.from_dict(ex) for ex in examples
1064
+ ]
1065
+
1066
+ # Load other state
1067
+ self.verification_queue = [
1068
+ tuple(item) for item in state.get('verification_queue', [])
1069
+ ]
1070
+ self.labeled_instance_ids = set(state.get('labeled_instance_ids', []))
1071
+
1072
+ if state.get('last_example_refresh'):
1073
+ self.last_example_refresh = datetime.fromisoformat(
1074
+ state['last_example_refresh']
1075
+ )
1076
+ if state.get('last_batch_run'):
1077
+ self.last_batch_run = datetime.fromisoformat(
1078
+ state['last_batch_run']
1079
+ )
1080
+
1081
+ logger.info(f"Loaded ICL state from {filepath}")
1082
+
1083
+ except Exception as e:
1084
+ logger.error(f"Error loading ICL state: {e}")
1085
+
1086
+
1087
+ # Module-level singleton access
1088
+ _icl_labeler: Optional[ICLLabeler] = None
1089
+
1090
+
1091
+ def init_icl_labeler(config: Dict[str, Any]) -> ICLLabeler:
1092
+ """Initialize the global ICL labeler."""
1093
+ global _icl_labeler
1094
+ _icl_labeler = ICLLabeler(config)
1095
+ _icl_labeler.load_state()
1096
+ return _icl_labeler
1097
+
1098
+
1099
+ def get_icl_labeler() -> Optional[ICLLabeler]:
1100
+ """Get the global ICL labeler instance."""
1101
+ return _icl_labeler
1102
+
1103
+
1104
+ def clear_icl_labeler() -> None:
1105
+ """Clear the global ICL labeler (for testing)."""
1106
+ global _icl_labeler
1107
+ if _icl_labeler:
1108
+ _icl_labeler.stop_background_worker()
1109
+ _icl_labeler = None
1110
+ ICLLabeler._instance = None
potato/ai/icl_prompt_builder.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In-Context Learning Prompt Builder
3
+
4
+ This module builds effective prompts for in-context learning based labeling.
5
+ It formats high-confidence examples and target instances into prompts that
6
+ elicit accurate label predictions with confidence scores.
7
+ """
8
+
9
+ import json
10
+ import logging
11
+ import re
12
+ from typing import Dict, List, Any, Tuple, Optional, TYPE_CHECKING
13
+
14
+ if TYPE_CHECKING:
15
+ from potato.ai.icl_labeler import HighConfidenceExample
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ICLPromptBuilder:
21
+ """
22
+ Builds effective prompts for in-context learning.
23
+
24
+ The prompt structure:
25
+ 1. System instructions explaining the task
26
+ 2. Schema description and available labels
27
+ 3. High-confidence examples with their labels
28
+ 4. Target text to label
29
+ 5. Output format instructions (JSON with label, confidence, reasoning)
30
+ """
31
+
32
+ def __init__(self, max_example_length: int = 500, max_target_length: int = 1000):
33
+ """
34
+ Initialize the prompt builder.
35
+
36
+ Args:
37
+ max_example_length: Maximum characters per example text
38
+ max_target_length: Maximum characters for target text
39
+ """
40
+ self.max_example_length = max_example_length
41
+ self.max_target_length = max_target_length
42
+
43
+ def build_prompt(
44
+ self,
45
+ schema: Dict[str, Any],
46
+ examples: List['HighConfidenceExample'],
47
+ target_text: str
48
+ ) -> str:
49
+ """
50
+ Build a complete ICL prompt.
51
+
52
+ Args:
53
+ schema: Annotation schema dictionary with name, description, labels
54
+ examples: List of high-confidence examples
55
+ target_text: The text to be labeled
56
+
57
+ Returns:
58
+ Complete prompt string
59
+ """
60
+ parts = []
61
+
62
+ # System instructions
63
+ parts.append(self._build_system_prompt(schema))
64
+
65
+ # Examples section
66
+ if examples:
67
+ parts.append("\n## Examples\n")
68
+ parts.append("Here are examples of correctly labeled texts:\n")
69
+ for i, example in enumerate(examples, 1):
70
+ parts.append(self._format_example(example, i))
71
+
72
+ # Target text section
73
+ parts.append("\n## Your Task\n")
74
+ parts.append("Now label the following text:\n")
75
+ parts.append(f'Text: "{self._truncate_text(target_text, self.max_target_length)}"\n')
76
+
77
+ # Output format instructions
78
+ parts.append(self._build_output_instructions(schema))
79
+
80
+ return "\n".join(parts)
81
+
82
+ def _build_system_prompt(self, schema: Dict[str, Any]) -> str:
83
+ """Build the system/instruction portion of the prompt."""
84
+ schema_name = schema.get('name', 'unknown')
85
+ description = schema.get('description', 'Label the text according to the schema.')
86
+ labels = self._get_labels_from_schema(schema)
87
+ annotation_type = schema.get('annotation_type', 'radio')
88
+
89
+ prompt = f"""You are an expert annotation assistant. Your task is to label text according to a specific annotation schema.
90
+
91
+ ## Schema: {schema_name}
92
+
93
+ **Description:** {description}
94
+
95
+ **Available Labels:** {', '.join(labels)}
96
+ """
97
+
98
+ # Add type-specific instructions
99
+ if annotation_type == 'radio':
100
+ prompt += "\n**Task Type:** Single-choice classification. Select exactly ONE label.\n"
101
+ elif annotation_type == 'multiselect':
102
+ prompt += "\n**Task Type:** Multi-label classification. Select ALL applicable labels.\n"
103
+ elif annotation_type == 'likert':
104
+ prompt += "\n**Task Type:** Rating scale. Choose the most appropriate rating.\n"
105
+
106
+ return prompt
107
+
108
+ def _format_example(self, example: 'HighConfidenceExample', index: int) -> str:
109
+ """Format a single example for the prompt."""
110
+ truncated_text = self._truncate_text(example.text, self.max_example_length)
111
+
112
+ return f"""
113
+ ### Example {index}
114
+ Text: "{truncated_text}"
115
+ Label: **{example.label}**
116
+ (Agreement: {example.agreement_score:.0%} from {example.annotator_count} annotators)
117
+ """
118
+
119
+ def _build_output_instructions(self, schema: Dict[str, Any]) -> str:
120
+ """Build instructions for the expected output format."""
121
+ labels = self._get_labels_from_schema(schema)
122
+ labels_json = json.dumps(labels)
123
+
124
+ return f"""
125
+ ## Output Format
126
+
127
+ Respond with a JSON object containing:
128
+ - `label`: Your chosen label (must be one of: {labels_json})
129
+ - `confidence`: Your confidence score from 0.0 to 1.0
130
+ - 1.0 = Absolutely certain
131
+ - 0.7-0.9 = High confidence
132
+ - 0.5-0.7 = Moderate confidence
133
+ - 0.3-0.5 = Low confidence
134
+ - 0.0-0.3 = Very uncertain
135
+ - `reasoning`: Brief explanation for your choice (1-2 sentences)
136
+
137
+ **Important:**
138
+ - Only use labels from the provided list
139
+ - Be honest about your confidence - reflect your actual certainty, not a fixed value
140
+ - Base your decision on the examples and schema description
141
+ - Use the full 0.0–1.0 range: reserve 0.9+ for near-certain cases, use 0.4–0.6 when genuinely unsure
142
+
143
+ Example response (the confidence value here is illustrative only — yours should reflect actual certainty):
144
+ ```json
145
+ {{"label": "example_label", "confidence": 0.72, "reasoning": "The text shows clear indicators of..."}}
146
+ ```
147
+
148
+ Now provide your response as JSON:
149
+ """
150
+
151
+ def _get_labels_from_schema(self, schema: Dict[str, Any]) -> List[str]:
152
+ """Extract label names from schema definition."""
153
+ labels = schema.get('labels', [])
154
+ result = []
155
+ for label in labels:
156
+ if isinstance(label, str):
157
+ result.append(label)
158
+ elif isinstance(label, dict):
159
+ result.append(label.get('name', str(label)))
160
+ return result
161
+
162
+ def _truncate_text(self, text: str, max_length: int) -> str:
163
+ """Truncate text to max length, preserving word boundaries."""
164
+ if len(text) <= max_length:
165
+ return text
166
+
167
+ truncated = text[:max_length]
168
+ # Try to break at word boundary
169
+ last_space = truncated.rfind(' ')
170
+ if last_space > max_length * 0.8:
171
+ truncated = truncated[:last_space]
172
+
173
+ return truncated + "..."
174
+
175
+ def parse_response(
176
+ self,
177
+ response: str,
178
+ schema: Dict[str, Any]
179
+ ) -> Tuple[Optional[str], float, str]:
180
+ """
181
+ Parse the LLM response to extract label, confidence, and reasoning.
182
+
183
+ Args:
184
+ response: Raw response from LLM
185
+ schema: Schema for validation
186
+
187
+ Returns:
188
+ Tuple of (label, confidence, reasoning) or (None, 0.0, "") on failure
189
+ """
190
+ try:
191
+ # Try to parse as JSON directly
192
+ data = self._extract_json(response)
193
+ if data:
194
+ label = data.get('label', '')
195
+ confidence = float(data.get('confidence', 0.5))
196
+ reasoning = data.get('reasoning', '')
197
+
198
+ # Validate label
199
+ valid_labels = self._get_labels_from_schema(schema)
200
+ if label in valid_labels:
201
+ return label, min(1.0, max(0.0, confidence)), reasoning
202
+
203
+ # Try fuzzy matching
204
+ matched = self._fuzzy_match_label(label, valid_labels)
205
+ if matched:
206
+ return matched, min(1.0, max(0.0, confidence)), reasoning
207
+
208
+ # Fallback: try to extract label from text
209
+ return self._extract_label_from_text(response, schema)
210
+
211
+ except Exception as e:
212
+ logger.warning(f"Error parsing response: {e}")
213
+ return None, 0.0, ""
214
+
215
+ def _extract_json(self, text: str) -> Optional[Dict[str, Any]]:
216
+ """Extract JSON from text, handling markdown code blocks."""
217
+ # Try direct parse
218
+ try:
219
+ return json.loads(text)
220
+ except json.JSONDecodeError:
221
+ pass
222
+
223
+ # Try to find JSON in code blocks
224
+ json_patterns = [
225
+ r'```json\s*(.*?)\s*```',
226
+ r'```\s*(.*?)\s*```',
227
+ r'\{[^{}]*"label"[^{}]*\}'
228
+ ]
229
+
230
+ for pattern in json_patterns:
231
+ match = re.search(pattern, text, re.DOTALL)
232
+ if match:
233
+ try:
234
+ json_str = match.group(1) if match.lastindex else match.group(0)
235
+ return json.loads(json_str)
236
+ except (json.JSONDecodeError, IndexError):
237
+ continue
238
+
239
+ return None
240
+
241
+ def _fuzzy_match_label(self, label: str, valid_labels: List[str]) -> Optional[str]:
242
+ """Try to match label with case-insensitive comparison."""
243
+ label_lower = label.lower().strip()
244
+ for valid in valid_labels:
245
+ if valid.lower().strip() == label_lower:
246
+ return valid
247
+ return None
248
+
249
+ def _extract_label_from_text(
250
+ self,
251
+ text: str,
252
+ schema: Dict[str, Any]
253
+ ) -> Tuple[Optional[str], float, str]:
254
+ """Fallback: try to extract label directly from text."""
255
+ valid_labels = self._get_labels_from_schema(schema)
256
+ text_lower = text.lower()
257
+
258
+ for label in valid_labels:
259
+ # Look for label mentioned in text
260
+ if label.lower() in text_lower:
261
+ return label, 0.5, "Extracted from response text (low confidence)"
262
+
263
+ return None, 0.0, ""
264
+
265
+
266
+ class MultiSelectPromptBuilder(ICLPromptBuilder):
267
+ """
268
+ Specialized prompt builder for multi-select (multi-label) tasks.
269
+ """
270
+
271
+ def _build_output_instructions(self, schema: Dict[str, Any]) -> str:
272
+ """Build output instructions for multi-select."""
273
+ labels = self._get_labels_from_schema(schema)
274
+ labels_json = json.dumps(labels)
275
+
276
+ return f"""
277
+ ## Output Format
278
+
279
+ Respond with a JSON object containing:
280
+ - `labels`: Array of selected labels (from: {labels_json})
281
+ - `confidence`: Your overall confidence score from 0.0 to 1.0 — reflect actual certainty, not a fixed value
282
+ - `reasoning`: Brief explanation for your choices
283
+
284
+ Example response (the confidence value here is illustrative only — yours should reflect actual certainty):
285
+ ```json
286
+ {{"labels": ["label1", "label2"], "confidence": 0.65, "reasoning": "The text exhibits both..."}}
287
+ ```
288
+
289
+ Now provide your response as JSON:
290
+ """
291
+
292
+ def parse_response(
293
+ self,
294
+ response: str,
295
+ schema: Dict[str, Any]
296
+ ) -> Tuple[Optional[List[str]], float, str]:
297
+ """Parse multi-select response."""
298
+ try:
299
+ data = self._extract_json(response)
300
+ if data:
301
+ labels = data.get('labels', [])
302
+ confidence = float(data.get('confidence', 0.5))
303
+ reasoning = data.get('reasoning', '')
304
+
305
+ valid_labels = self._get_labels_from_schema(schema)
306
+ validated = [l for l in labels if l in valid_labels]
307
+
308
+ if validated:
309
+ return validated, min(1.0, max(0.0, confidence)), reasoning
310
+
311
+ return None, 0.0, ""
312
+
313
+ except Exception as e:
314
+ logger.warning(f"Error parsing multi-select response: {e}")
315
+ return None, 0.0, ""
potato/ai/judge.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-as-Judge service for human-alignment.
3
+
4
+ Produces a judge verdict (label + confidence + reasoning) for an annotation
5
+ instance, given the schema (labels + description + an editable rubric) and,
6
+ optionally, few-shot examples drawn from high-agreement human labels. The
7
+ verdicts are compared against human labels elsewhere
8
+ (``potato/server_utils/judge_alignment.py``) to measure and calibrate
9
+ human↔judge agreement (Cohen's κ).
10
+
11
+ This deliberately does NOT reuse ``ICLLabeler`` as the judge — ICL auto-labels
12
+ from inter-annotator agreement, which would leak the gold labels we are trying
13
+ to measure the judge against. We only borrow ICLLabeler's *example selection*
14
+ for few-shot calibration, and we always exclude the instance being judged from
15
+ its own example set.
16
+
17
+ The judge call goes through the same ``AIEndpointFactory`` / ``BaseAIEndpoint``
18
+ machinery as every other AI feature (mirrors ``icl_labeler.label_instance``),
19
+ so it works with any configured provider.
20
+ """
21
+
22
+ import hashlib
23
+ import json
24
+ import logging
25
+ from dataclasses import dataclass, field
26
+ from typing import Any, Dict, List, Optional
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class JudgePrediction:
33
+ """A single LLM-judge verdict for one instance + schema."""
34
+
35
+ instance_id: str
36
+ schema_name: str
37
+ predicted_label: str
38
+ confidence: float # 0.0–1.0
39
+ reasoning: str = ""
40
+ model_name: str = ""
41
+ prompt_version: str = ""
42
+ examples_used: List[str] = field(default_factory=list)
43
+
44
+ def to_dict(self) -> Dict[str, Any]:
45
+ return {
46
+ "instance_id": self.instance_id,
47
+ "schema_name": self.schema_name,
48
+ "predicted_label": self.predicted_label,
49
+ "confidence": self.confidence,
50
+ "reasoning": self.reasoning,
51
+ "model_name": self.model_name,
52
+ "prompt_version": self.prompt_version,
53
+ "examples_used": self.examples_used,
54
+ }
55
+
56
+ @classmethod
57
+ def from_dict(cls, data: Dict[str, Any]) -> "JudgePrediction":
58
+ return cls(
59
+ instance_id=data["instance_id"],
60
+ schema_name=data["schema_name"],
61
+ predicted_label=data.get("predicted_label", ""),
62
+ confidence=float(data.get("confidence", 0.0)),
63
+ reasoning=data.get("reasoning", ""),
64
+ model_name=data.get("model_name", ""),
65
+ prompt_version=data.get("prompt_version", ""),
66
+ examples_used=data.get("examples_used", []),
67
+ )
68
+
69
+
70
+ def extract_labels(schema_info: Dict[str, Any]) -> List[str]:
71
+ """Return the allowed label names for a categorical schema.
72
+
73
+ Supports ``radio``/``select``/``multiselect`` (``labels`` list of
74
+ str|dict) and ``likert`` (1..size). Returns ``[]`` for unsupported types.
75
+ """
76
+ atype = schema_info.get("annotation_type", "")
77
+ if atype == "likert":
78
+ size = int(schema_info.get("size", 5))
79
+ return [str(i) for i in range(1, size + 1)]
80
+ labels = schema_info.get("labels", [])
81
+ out = []
82
+ for lab in labels:
83
+ if isinstance(lab, dict):
84
+ out.append(str(lab.get("name", "")))
85
+ else:
86
+ out.append(str(lab))
87
+ return [x for x in out if x]
88
+
89
+
90
+ def compute_prompt_version(rubric: str, schema_name: str, few_shot: bool) -> str:
91
+ """Stable short hash identifying this judge configuration.
92
+
93
+ Editing the rubric (or toggling few-shot) yields a new version so the admin
94
+ report can track κ across prompt versions.
95
+ """
96
+ basis = f"{schema_name}␟{int(bool(few_shot))}␟{rubric or ''}"
97
+ return "v_" + hashlib.sha1(basis.encode("utf-8")).hexdigest()[:10]
98
+
99
+
100
+ class JudgeService:
101
+ """Builds judge prompts and queries the configured AI endpoint."""
102
+
103
+ def __init__(self, config: Dict[str, Any]):
104
+ self.config = config or {}
105
+ self.judge_config = self.config.get("judge_alignment", {}) or {}
106
+ self._endpoint = None
107
+ self._endpoint_initialized = False
108
+
109
+ # ----- endpoint -------------------------------------------------------
110
+
111
+ def _get_endpoint(self):
112
+ if not self._endpoint_initialized:
113
+ self._endpoint_initialized = True
114
+ try:
115
+ from potato.ai.ai_endpoint import AIEndpointFactory
116
+ # The judge endpoint config lives under judge_alignment, but
117
+ # fall back to the task's ai_support so a single endpoint can
118
+ # serve both. Shape mirrors what AIEndpointFactory expects.
119
+ ai_support = self.judge_config.get("ai_support") or self.config.get("ai_support")
120
+ if not ai_support:
121
+ logger.warning("Judge: no ai_support / judge_alignment.ai_support configured")
122
+ return None
123
+ self._endpoint = AIEndpointFactory.create_endpoint({"ai_support": ai_support})
124
+ except Exception as e:
125
+ logger.error(f"Judge: failed to create endpoint: {e}")
126
+ self._endpoint = None
127
+ return self._endpoint
128
+
129
+ # ----- prompt ---------------------------------------------------------
130
+
131
+ def _schema_judge_config(self, schema_name: str) -> Dict[str, Any]:
132
+ per_schema = self.judge_config.get("schemas", {}) or {}
133
+ return per_schema.get(schema_name, {}) or {}
134
+
135
+ def get_rubric(self, schema_info: Dict[str, Any]) -> str:
136
+ """Editable rubric for a schema; falls back to its description."""
137
+ sc = self._schema_judge_config(schema_info.get("name", ""))
138
+ return sc.get("rubric") or schema_info.get("description", "") or ""
139
+
140
+ def build_prompt(
141
+ self,
142
+ schema_info: Dict[str, Any],
143
+ instance_text: str,
144
+ few_shot_examples: Optional[List[Dict[str, str]]] = None,
145
+ ) -> str:
146
+ """Compose the judge prompt.
147
+
148
+ few_shot_examples: list of {"text": ..., "label": ...} gold exemplars
149
+ (already excluding the target instance).
150
+ """
151
+ labels = extract_labels(schema_info)
152
+ rubric = self.get_rubric(schema_info)
153
+ parts = [
154
+ "You are an expert evaluator acting as an impartial judge.",
155
+ "Assign exactly one label to the item below, following the rubric.",
156
+ "",
157
+ f"Task: {schema_info.get('description', '')}".rstrip(),
158
+ f"Rubric: {rubric}".rstrip(),
159
+ "Allowed labels: " + ", ".join(labels) if labels else "",
160
+ ]
161
+ if few_shot_examples:
162
+ parts.append("\nExamples (item → correct label):")
163
+ for ex in few_shot_examples:
164
+ parts.append(f"- {_truncate(ex.get('text', ''))} → {ex.get('label', '')}")
165
+ parts.append("\nItem to judge:")
166
+ parts.append(_truncate(instance_text, 4000))
167
+ parts.append(
168
+ '\nRespond as JSON: {"label": <one of the allowed labels>, '
169
+ '"confidence": <0.0-1.0>, "reasoning": <one sentence>}.'
170
+ )
171
+ return "\n".join(p for p in parts if p != "")
172
+
173
+ # ----- judging --------------------------------------------------------
174
+
175
+ def judge_instance(
176
+ self,
177
+ instance_id: str,
178
+ schema_info: Dict[str, Any],
179
+ instance_text: str,
180
+ few_shot_examples: Optional[List[Dict[str, str]]] = None,
181
+ prompt_version: Optional[str] = None,
182
+ ) -> Optional[JudgePrediction]:
183
+ """Query the judge for one instance. Returns None on failure."""
184
+ endpoint = self._get_endpoint()
185
+ if endpoint is None:
186
+ return None
187
+
188
+ schema_name = schema_info.get("name", "")
189
+ valid_labels = extract_labels(schema_info)
190
+ rubric = self.get_rubric(schema_info)
191
+ if prompt_version is None:
192
+ prompt_version = compute_prompt_version(
193
+ rubric, schema_name, bool(few_shot_examples)
194
+ )
195
+
196
+ prompt = self.build_prompt(schema_info, instance_text, few_shot_examples)
197
+
198
+ try:
199
+ from pydantic import BaseModel
200
+
201
+ class JudgeVerdict(BaseModel):
202
+ label: str
203
+ confidence: float = 0.5
204
+ reasoning: str = ""
205
+
206
+ response = endpoint.query(prompt, JudgeVerdict)
207
+ if isinstance(response, str):
208
+ data = json.loads(response)
209
+ elif hasattr(response, "model_dump"):
210
+ data = response.model_dump()
211
+ elif hasattr(response, "dict"):
212
+ data = response.dict()
213
+ else:
214
+ data = response or {}
215
+ except Exception as e:
216
+ logger.error(f"Judge: query/parse failed for {instance_id}/{schema_name}: {e}")
217
+ return None
218
+
219
+ predicted = str(data.get("label", "")).strip()
220
+ try:
221
+ confidence = float(data.get("confidence", 0.5))
222
+ except (TypeError, ValueError):
223
+ confidence = 0.5
224
+ confidence = min(1.0, max(0.0, confidence))
225
+ reasoning = str(data.get("reasoning", ""))
226
+
227
+ if valid_labels and predicted not in valid_labels:
228
+ matched = _fuzzy_match_label(predicted, valid_labels)
229
+ if matched is None:
230
+ logger.warning(
231
+ f"Judge: invalid label '{predicted}' for {instance_id}/{schema_name}"
232
+ )
233
+ return None
234
+ predicted = matched
235
+
236
+ return JudgePrediction(
237
+ instance_id=instance_id,
238
+ schema_name=schema_name,
239
+ predicted_label=predicted,
240
+ confidence=confidence,
241
+ reasoning=reasoning,
242
+ model_name=getattr(endpoint, "model", ""),
243
+ prompt_version=prompt_version,
244
+ examples_used=[e.get("id", "") for e in (few_shot_examples or []) if e.get("id")],
245
+ )
246
+
247
+
248
+ def _truncate(text: str, limit: int = 300) -> str:
249
+ text = str(text or "")
250
+ return text if len(text) <= limit else text[:limit] + "…"
251
+
252
+
253
+ def _fuzzy_match_label(predicted: str, valid_labels: List[str]) -> Optional[str]:
254
+ """Case-insensitive / prefix match a model label to an allowed label."""
255
+ if not predicted:
256
+ return None
257
+ low = predicted.lower().strip()
258
+ for lab in valid_labels:
259
+ if lab.lower() == low:
260
+ return lab
261
+ for lab in valid_labels:
262
+ ll = lab.lower()
263
+ if ll in low or low in ll:
264
+ return lab
265
+ return None
potato/ai/llm_active_learning.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Integration for Active Learning
3
+
4
+ This module provides LLM-based active learning capabilities using VLLM endpoints.
5
+ It implements confidence-based instance selection and prediction using large language
6
+ models, with support for multiple confidence elicitation methods:
7
+
8
+ - **logprobs**: Extract token-level log probabilities from VLLM/OpenAI-compatible
9
+ endpoints for calibrated confidence scores.
10
+ - **verbalized**: Ask the LLM to self-report confidence on a 1-10 scale (default).
11
+ - **consistency**: Query the same instance N times with temperature > 0 and use
12
+ agreement rate as confidence (works with any endpoint).
13
+
14
+ References:
15
+ Tian et al. (2023) "Just Ask for Calibration: Strategies for Eliciting
16
+ Calibrated Confidence Scores from Language Models Fine-Tuned with Human
17
+ Feedback." EMNLP 2023.
18
+
19
+ Xiong et al. (2024) "Can LLMs Express Their Uncertainty? An Empirical
20
+ Evaluation of Confidence Elicitation in LLMs." ICLR 2024.
21
+ """
22
+
23
+ import logging
24
+ import math
25
+ import time
26
+ import json
27
+ import requests
28
+ from collections import Counter
29
+ from typing import Dict, List, Optional, Tuple, Any
30
+ from dataclasses import dataclass, field
31
+ from concurrent.futures import ThreadPoolExecutor, as_completed
32
+ import numpy as np
33
+
34
+ from potato.active_learning_manager import TrainingMetrics
35
+
36
+
37
+ def _loads_lenient(content: str):
38
+ """json.loads that tolerates markdown code fences and surrounding prose.
39
+
40
+ Many models (e.g. Gemma on vLLM) wrap JSON in ```json ... ``` fences even
41
+ when response_format=json_object is requested, which breaks a naive
42
+ json.loads(). Strip fences and, failing that, extract the first {...} block.
43
+ """
44
+ import re
45
+ if content is None:
46
+ raise json.JSONDecodeError("empty content", "", 0)
47
+ s = content.strip()
48
+ # Strip a leading ```json / ``` fence and trailing ```
49
+ s = re.sub(r"^```(?:json|JSON)?\s*", "", s)
50
+ s = re.sub(r"\s*```$", "", s).strip()
51
+ try:
52
+ return json.loads(s)
53
+ except json.JSONDecodeError:
54
+ # Fall back to the first balanced-looking {...} object in the text.
55
+ m = re.search(r"\{.*\}", s, re.DOTALL)
56
+ if m:
57
+ return json.loads(m.group(0))
58
+ raise
59
+
60
+
61
+ @dataclass
62
+ class LLMPrediction:
63
+ """Result of an LLM prediction."""
64
+ instance_id: str
65
+ predicted_label: str
66
+ confidence_score: float
67
+ raw_response: str
68
+ error_message: Optional[str] = None
69
+ confidence_method: str = "verbalized"
70
+
71
+
72
+ @dataclass
73
+ class LLMConfig:
74
+ """Configuration for LLM integration."""
75
+ endpoint_url: str
76
+ model_name: str
77
+ max_tokens: int = 512
78
+ temperature: float = 0.1
79
+ timeout: int = 30
80
+ batch_size: int = 10
81
+ retry_attempts: int = 3
82
+ retry_delay: float = 1.0
83
+ max_instances_per_request: int = 5
84
+ confidence_method: str = "verbalized" # logprobs | verbalized | consistency
85
+ consistency_samples: int = 3
86
+
87
+
88
+ class LLMActiveLearning:
89
+ """
90
+ LLM-based active learning implementation.
91
+
92
+ This class provides methods for:
93
+ - Querying LLMs for predictions and confidence scores
94
+ - Batch processing of instances
95
+ - Error handling and retry logic
96
+ - Integration with the active learning pipeline
97
+ """
98
+
99
+ def __init__(self, config: LLMConfig):
100
+ self.config = config
101
+ self.logger = logging.getLogger(__name__)
102
+ self.session = requests.Session()
103
+
104
+ # Configure session
105
+ self.session.timeout = config.timeout
106
+
107
+ # Test connection on initialization
108
+ self._test_connection()
109
+
110
+ def _test_connection(self):
111
+ """Test the connection to the LLM endpoint."""
112
+ try:
113
+ test_payload = {
114
+ "model": self.config.model_name,
115
+ "messages": [{"role": "user", "content": "Hello"}],
116
+ "max_tokens": 10,
117
+ "temperature": 0.1
118
+ }
119
+
120
+ response = self.session.post(
121
+ self.config.endpoint_url,
122
+ json=test_payload,
123
+ timeout=5
124
+ )
125
+
126
+ if response.status_code == 200:
127
+ self.logger.info(f"Successfully connected to LLM endpoint: {self.config.endpoint_url}")
128
+ else:
129
+ self.logger.warning(f"LLM endpoint returned status {response.status_code}: {response.text}")
130
+
131
+ except Exception as e:
132
+ self.logger.error(f"Failed to connect to LLM endpoint: {e}")
133
+ # Don't raise - allow fallback to traditional methods
134
+
135
+ def predict_instances(self, instances: List[Dict[str, Any]],
136
+ annotation_instructions: str,
137
+ schema_name: str,
138
+ label_options: List[str]) -> List[LLMPrediction]:
139
+ """
140
+ Predict labels and confidence scores for instances using LLM.
141
+
142
+ Args:
143
+ instances: List of instances to predict
144
+ annotation_instructions: Instructions for the annotation task
145
+ schema_name: Name of the annotation schema
146
+ label_options: Available label options
147
+
148
+ Returns:
149
+ List of LLM predictions with confidence scores
150
+ """
151
+ if not instances:
152
+ return []
153
+
154
+ self.logger.info(f"Starting LLM prediction for {len(instances)} instances")
155
+
156
+ # Create prompts for each instance
157
+ prompts = self._create_prompts(instances, annotation_instructions, schema_name, label_options)
158
+
159
+ # Process in batches
160
+ all_predictions = []
161
+
162
+ for i in range(0, len(prompts), self.config.batch_size):
163
+ batch_prompts = prompts[i:i + self.config.batch_size]
164
+ batch_instances = instances[i:i + self.config.batch_size]
165
+
166
+ batch_predictions = self._process_batch(batch_prompts, batch_instances)
167
+ all_predictions.extend(batch_predictions)
168
+
169
+ # Small delay between batches to avoid overwhelming the endpoint
170
+ if i + self.config.batch_size < len(prompts):
171
+ time.sleep(0.1)
172
+
173
+ self.logger.info(f"Completed LLM prediction for {len(all_predictions)} instances")
174
+ return all_predictions
175
+
176
+ def _create_prompts(self, instances: List[Dict[str, Any]],
177
+ annotation_instructions: str,
178
+ schema_name: str,
179
+ label_options: List[str]) -> List[str]:
180
+ """Create prompts for LLM prediction."""
181
+ prompts = []
182
+
183
+ # Create the base prompt template
184
+ base_prompt = self._create_base_prompt(annotation_instructions, schema_name, label_options)
185
+
186
+ for instance in instances:
187
+ # Extract text content
188
+ text_content = self._extract_text_content(instance)
189
+
190
+ # Create instance-specific prompt
191
+ prompt = f"{base_prompt}\n\nText to annotate:\n{text_content}\n\nPlease provide your prediction and confidence score."
192
+
193
+ prompts.append(prompt)
194
+
195
+ return prompts
196
+
197
+ def _create_base_prompt(self, annotation_instructions: str,
198
+ schema_name: str,
199
+ label_options: List[str]) -> str:
200
+ """Create the base prompt for LLM prediction."""
201
+ prompt = f"""You are an expert annotator for a text classification task.
202
+
203
+ Task: {annotation_instructions}
204
+
205
+ Schema: {schema_name}
206
+
207
+ Available labels: {', '.join(label_options)}
208
+
209
+ For each text, please:
210
+ 1. Analyze the text carefully
211
+ 2. Choose the most appropriate label from the available options
212
+ 3. Provide a confidence score from 1 to 10 (where 1 = very uncertain, 10 = very confident)
213
+
214
+ Please respond in the following JSON format:
215
+ {{
216
+ "label": "chosen_label",
217
+ "confidence": confidence_score,
218
+ "reasoning": "brief explanation of your choice"
219
+ }}
220
+
221
+ Example response:
222
+ {{
223
+ "label": "{label_options[0] if label_options else 'example'}",
224
+ "confidence": 8,
225
+ "reasoning": "The text clearly expresses positive sentiment based on the language used."
226
+ }}"""
227
+
228
+ return prompt
229
+
230
+ def _extract_text_content(self, instance: Dict[str, Any]) -> str:
231
+ """Extract text content from an instance."""
232
+ # Try common text field names
233
+ text_fields = ['text', 'content', 'message', 'sentence', 'document']
234
+
235
+ for field in text_fields:
236
+ if field in instance:
237
+ content = instance[field]
238
+ if isinstance(content, str):
239
+ return content
240
+ elif isinstance(content, dict):
241
+ # Handle nested text fields
242
+ for nested_field in text_fields:
243
+ if nested_field in content:
244
+ return str(content[nested_field])
245
+
246
+ # Fallback: convert the entire instance to string
247
+ return str(instance)
248
+
249
+ def _process_batch(self, prompts: List[str], instances: List[Dict[str, Any]]) -> List[LLMPrediction]:
250
+ """Process a batch of prompts."""
251
+ predictions = []
252
+
253
+ # Use ThreadPoolExecutor for parallel processing within the batch
254
+ with ThreadPoolExecutor(max_workers=min(len(prompts), 5)) as executor:
255
+ future_to_index = {
256
+ executor.submit(self._predict_single, prompt, instances[i]): i
257
+ for i, prompt in enumerate(prompts)
258
+ }
259
+
260
+ for future in as_completed(future_to_index):
261
+ index = future_to_index[future]
262
+ try:
263
+ prediction = future.result()
264
+ predictions.append(prediction)
265
+ except Exception as e:
266
+ self.logger.error(f"Error processing instance {index}: {e}")
267
+ # Create error prediction
268
+ error_prediction = LLMPrediction(
269
+ instance_id=instances[index].get('id', f'instance_{index}'),
270
+ predicted_label='',
271
+ confidence_score=0.1,
272
+ raw_response='',
273
+ error_message=str(e)
274
+ )
275
+ predictions.append(error_prediction)
276
+
277
+ return predictions
278
+
279
+ def _predict_single(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
280
+ """Make a single prediction using the LLM.
281
+
282
+ Dispatches to the appropriate confidence method:
283
+ - logprobs: Extract token-level log probabilities
284
+ - consistency: Query N times, use agreement rate
285
+ - verbalized (default): Parse self-reported confidence from JSON
286
+ """
287
+ method = self.config.confidence_method
288
+
289
+ if method == "consistency":
290
+ return self._predict_consistency(prompt, instance)
291
+ elif method == "logprobs":
292
+ return self._predict_with_logprobs(prompt, instance)
293
+ else:
294
+ return self._predict_verbalized(prompt, instance)
295
+
296
+ def _predict_verbalized(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
297
+ """Original verbalized confidence method (1-10 scale)."""
298
+ instance_id = instance.get('id', 'unknown')
299
+
300
+ for attempt in range(self.config.retry_attempts):
301
+ try:
302
+ payload = {
303
+ "model": self.config.model_name,
304
+ "messages": [
305
+ {"role": "system", "content": "You are a helpful assistant that provides structured JSON responses."},
306
+ {"role": "user", "content": prompt}
307
+ ],
308
+ "max_tokens": self.config.max_tokens,
309
+ "temperature": self.config.temperature,
310
+ "response_format": {"type": "json_object"}
311
+ }
312
+
313
+ response = self.session.post(
314
+ self.config.endpoint_url,
315
+ json=payload,
316
+ timeout=self.config.timeout
317
+ )
318
+
319
+ if response.status_code == 200:
320
+ result = response.json()
321
+
322
+ if 'choices' in result and len(result['choices']) > 0:
323
+ content = result['choices'][0]['message']['content']
324
+
325
+ try:
326
+ parsed_response = _loads_lenient(content)
327
+ predicted_label = parsed_response.get('label', '')
328
+ confidence_score = parsed_response.get('confidence', 1)
329
+
330
+ if not isinstance(confidence_score, (int, float)):
331
+ confidence_score = 1
332
+ else:
333
+ confidence_score = max(1, min(10, confidence_score)) / 10.0
334
+
335
+ return LLMPrediction(
336
+ instance_id=instance_id,
337
+ predicted_label=predicted_label,
338
+ confidence_score=confidence_score,
339
+ raw_response=content,
340
+ confidence_method="verbalized"
341
+ )
342
+
343
+ except json.JSONDecodeError as e:
344
+ self.logger.warning(f"Failed to parse JSON response for instance {instance_id}: {e}")
345
+ return self._extract_from_raw_response(content, instance_id)
346
+
347
+ else:
348
+ raise Exception(f"Invalid response format: {result}")
349
+
350
+ else:
351
+ raise Exception(f"HTTP {response.status_code}: {response.text}")
352
+
353
+ except Exception as e:
354
+ self.logger.warning(f"Attempt {attempt + 1} failed for instance {instance_id}: {e}")
355
+
356
+ if attempt < self.config.retry_attempts - 1:
357
+ time.sleep(self.config.retry_delay * (attempt + 1))
358
+ else:
359
+ return LLMPrediction(
360
+ instance_id=instance_id,
361
+ predicted_label='',
362
+ confidence_score=0.1,
363
+ raw_response='',
364
+ error_message=f"All attempts failed: {e}",
365
+ confidence_method="verbalized"
366
+ )
367
+
368
+ return LLMPrediction(
369
+ instance_id=instance_id,
370
+ predicted_label='',
371
+ confidence_score=0.1,
372
+ raw_response='',
373
+ error_message="Unknown error",
374
+ confidence_method="verbalized"
375
+ )
376
+
377
+ def _predict_with_logprobs(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
378
+ """Extract confidence from token-level log probabilities.
379
+
380
+ Requests logprobs=True from VLLM/OpenAI-compatible endpoints and
381
+ computes confidence as exp(mean_logprob) over the label tokens.
382
+ Falls back to verbalized confidence if logprobs unavailable.
383
+ """
384
+ instance_id = instance.get('id', 'unknown')
385
+
386
+ for attempt in range(self.config.retry_attempts):
387
+ try:
388
+ payload = {
389
+ "model": self.config.model_name,
390
+ "messages": [
391
+ {"role": "system", "content": "You are a helpful assistant that provides structured JSON responses."},
392
+ {"role": "user", "content": prompt}
393
+ ],
394
+ "max_tokens": self.config.max_tokens,
395
+ "temperature": self.config.temperature,
396
+ "response_format": {"type": "json_object"},
397
+ "logprobs": True,
398
+ "top_logprobs": 5,
399
+ }
400
+
401
+ response = self.session.post(
402
+ self.config.endpoint_url,
403
+ json=payload,
404
+ timeout=self.config.timeout
405
+ )
406
+
407
+ if response.status_code == 200:
408
+ result = response.json()
409
+
410
+ if 'choices' not in result or len(result['choices']) == 0:
411
+ raise Exception(f"Invalid response format: {result}")
412
+
413
+ choice = result['choices'][0]
414
+ content = choice['message']['content']
415
+
416
+ # Parse label from JSON content
417
+ try:
418
+ parsed = _loads_lenient(content)
419
+ except json.JSONDecodeError:
420
+ return self._extract_from_raw_response(content, instance_id)
421
+
422
+ predicted_label = parsed.get('label', '')
423
+
424
+ # Try to extract logprobs
425
+ logprobs_data = choice.get('logprobs', {})
426
+ token_logprobs = logprobs_data.get('content', [])
427
+
428
+ if token_logprobs:
429
+ # Compute mean logprob across all tokens
430
+ log_probs = [
431
+ t['logprob'] for t in token_logprobs
432
+ if 'logprob' in t and t['logprob'] is not None
433
+ ]
434
+ if log_probs:
435
+ mean_logprob = sum(log_probs) / len(log_probs)
436
+ confidence_score = min(1.0, max(0.0, math.exp(mean_logprob)))
437
+ else:
438
+ # No valid logprobs, fall back to verbalized
439
+ confidence_score = parsed.get('confidence', 5)
440
+ if isinstance(confidence_score, (int, float)):
441
+ confidence_score = max(1, min(10, confidence_score)) / 10.0
442
+ else:
443
+ confidence_score = 0.5
444
+ else:
445
+ # Endpoint didn't return logprobs, fall back to verbalized
446
+ self.logger.debug(f"No logprobs returned for {instance_id}, using verbalized")
447
+ confidence_score = parsed.get('confidence', 5)
448
+ if isinstance(confidence_score, (int, float)):
449
+ confidence_score = max(1, min(10, confidence_score)) / 10.0
450
+ else:
451
+ confidence_score = 0.5
452
+
453
+ return LLMPrediction(
454
+ instance_id=instance_id,
455
+ predicted_label=predicted_label,
456
+ confidence_score=confidence_score,
457
+ raw_response=content,
458
+ confidence_method="logprobs" if token_logprobs else "verbalized"
459
+ )
460
+
461
+ else:
462
+ raise Exception(f"HTTP {response.status_code}: {response.text}")
463
+
464
+ except Exception as e:
465
+ self.logger.warning(f"Logprobs attempt {attempt + 1} failed for {instance_id}: {e}")
466
+ if attempt < self.config.retry_attempts - 1:
467
+ time.sleep(self.config.retry_delay * (attempt + 1))
468
+ else:
469
+ return LLMPrediction(
470
+ instance_id=instance_id,
471
+ predicted_label='',
472
+ confidence_score=0.1,
473
+ raw_response='',
474
+ error_message=f"All logprob attempts failed: {e}",
475
+ confidence_method="logprobs"
476
+ )
477
+
478
+ return LLMPrediction(
479
+ instance_id=instance_id,
480
+ predicted_label='',
481
+ confidence_score=0.1,
482
+ raw_response='',
483
+ error_message="Unknown error",
484
+ confidence_method="logprobs"
485
+ )
486
+
487
+ def _predict_consistency(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
488
+ """Consistency-based confidence: query N times, use agreement rate.
489
+
490
+ Works with any endpoint (including Anthropic, Ollama) that doesn't
491
+ support logprobs. Confidence = fraction of samples that agree on
492
+ the most common label.
493
+ """
494
+ instance_id = instance.get('id', 'unknown')
495
+ n_samples = self.config.consistency_samples
496
+
497
+ labels = []
498
+ raw_responses = []
499
+
500
+ for _ in range(n_samples):
501
+ try:
502
+ payload = {
503
+ "model": self.config.model_name,
504
+ "messages": [
505
+ {"role": "system", "content": "You are a helpful assistant that provides structured JSON responses."},
506
+ {"role": "user", "content": prompt}
507
+ ],
508
+ "max_tokens": self.config.max_tokens,
509
+ "temperature": max(0.5, self.config.temperature), # Need some randomness
510
+ "response_format": {"type": "json_object"}
511
+ }
512
+
513
+ response = self.session.post(
514
+ self.config.endpoint_url,
515
+ json=payload,
516
+ timeout=self.config.timeout
517
+ )
518
+
519
+ if response.status_code == 200:
520
+ result = response.json()
521
+ if 'choices' in result and len(result['choices']) > 0:
522
+ content = result['choices'][0]['message']['content']
523
+ raw_responses.append(content)
524
+ try:
525
+ parsed = _loads_lenient(content)
526
+ labels.append(parsed.get('label', ''))
527
+ except json.JSONDecodeError:
528
+ pass
529
+
530
+ except Exception as e:
531
+ self.logger.debug(f"Consistency sample failed for {instance_id}: {e}")
532
+
533
+ if not labels:
534
+ return LLMPrediction(
535
+ instance_id=instance_id,
536
+ predicted_label='',
537
+ confidence_score=0.1,
538
+ raw_response='',
539
+ error_message="All consistency samples failed",
540
+ confidence_method="consistency"
541
+ )
542
+
543
+ # Most common label
544
+ label_counts = Counter(labels)
545
+ predicted_label, count = label_counts.most_common(1)[0]
546
+ confidence_score = count / len(labels)
547
+
548
+ return LLMPrediction(
549
+ instance_id=instance_id,
550
+ predicted_label=predicted_label,
551
+ confidence_score=confidence_score,
552
+ raw_response=raw_responses[0] if raw_responses else '',
553
+ confidence_method="consistency"
554
+ )
555
+
556
+ def _extract_from_raw_response(self, raw_response: str, instance_id: str) -> LLMPrediction:
557
+ """Extract prediction from raw response when JSON parsing fails."""
558
+ try:
559
+ # Try to find label and confidence in the raw text
560
+ lines = raw_response.lower().split('\n')
561
+
562
+ predicted_label = ''
563
+ confidence_score = 0.1
564
+
565
+ for line in lines:
566
+ if 'label' in line and ':' in line:
567
+ label_part = line.split(':', 1)[1].strip().strip('"\'')
568
+ if label_part:
569
+ predicted_label = label_part
570
+
571
+ if 'confidence' in line and ':' in line:
572
+ conf_part = line.split(':', 1)[1].strip()
573
+ try:
574
+ conf_value = float(conf_part)
575
+ confidence_score = max(0.1, min(1.0, conf_value / 10.0))
576
+ except ValueError:
577
+ pass
578
+
579
+ return LLMPrediction(
580
+ instance_id=instance_id,
581
+ predicted_label=predicted_label,
582
+ confidence_score=confidence_score,
583
+ raw_response=raw_response
584
+ )
585
+
586
+ except Exception as e:
587
+ self.logger.error(f"Failed to extract from raw response for instance {instance_id}: {e}")
588
+ return LLMPrediction(
589
+ instance_id=instance_id,
590
+ predicted_label='',
591
+ confidence_score=0.1,
592
+ raw_response=raw_response,
593
+ error_message=f"Failed to extract prediction: {e}"
594
+ )
595
+
596
+ def calculate_confidence_distribution(self, predictions: List[LLMPrediction]) -> Dict[str, float]:
597
+ """Calculate confidence score distribution from predictions."""
598
+ if not predictions:
599
+ return {}
600
+
601
+ # Filter out predictions with errors
602
+ valid_predictions = [p for p in predictions if p.error_message is None]
603
+
604
+ if not valid_predictions:
605
+ return {}
606
+
607
+ confidence_scores = [p.confidence_score for p in valid_predictions]
608
+
609
+ # Create histogram bins
610
+ bins = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
611
+ hist, _ = np.histogram(confidence_scores, bins=bins)
612
+
613
+ # Convert to percentages
614
+ total = len(confidence_scores)
615
+ distribution = {}
616
+ for i, count in enumerate(hist):
617
+ bin_label = f"{bins[i]:.1f}-{bins[i+1]:.1f}"
618
+ distribution[bin_label] = (count / total) * 100 if total > 0 else 0
619
+
620
+ return distribution
621
+
622
+ def get_prediction_stats(self, predictions: List[LLMPrediction]) -> Dict[str, Any]:
623
+ """Get statistics about the predictions."""
624
+ if not predictions:
625
+ return {
626
+ "total_predictions": 0,
627
+ "successful_predictions": 0,
628
+ "error_rate": 0.0,
629
+ "average_confidence": 0.0,
630
+ "confidence_distribution": {}
631
+ }
632
+
633
+ total = len(predictions)
634
+ successful = len([p for p in predictions if p.error_message is None])
635
+ error_rate = (total - successful) / total if total > 0 else 0.0
636
+
637
+ valid_predictions = [p for p in predictions if p.error_message is None]
638
+ average_confidence = np.mean([p.confidence_score for p in valid_predictions]) if valid_predictions else 0.0
639
+
640
+ confidence_distribution = self.calculate_confidence_distribution(predictions)
641
+
642
+ return {
643
+ "total_predictions": total,
644
+ "successful_predictions": successful,
645
+ "error_rate": error_rate,
646
+ "average_confidence": average_confidence,
647
+ "confidence_distribution": confidence_distribution
648
+ }
649
+
650
+
651
+ class MockLLMActiveLearning(LLMActiveLearning):
652
+ """
653
+ Mock LLM implementation for testing and development.
654
+
655
+ This class provides realistic mock responses for testing active learning
656
+ without requiring an actual LLM endpoint.
657
+ """
658
+
659
+ def __init__(self, config: LLMConfig):
660
+ super().__init__(config)
661
+ self.logger.info("Using Mock LLM for active learning")
662
+
663
+ # Mock response patterns
664
+ self._mock_responses = [
665
+ {"label": "positive", "confidence": 8, "reasoning": "Clear positive sentiment"},
666
+ {"label": "negative", "confidence": 7, "reasoning": "Negative tone detected"},
667
+ {"label": "neutral", "confidence": 6, "reasoning": "Balanced perspective"},
668
+ {"label": "positive", "confidence": 9, "reasoning": "Very positive language"},
669
+ {"label": "negative", "confidence": 5, "reasoning": "Somewhat negative"},
670
+ {"label": "neutral", "confidence": 4, "reasoning": "Mixed signals"},
671
+ {"label": "positive", "confidence": 3, "reasoning": "Uncertain positive"},
672
+ {"label": "negative", "confidence": 8, "reasoning": "Clearly negative"},
673
+ {"label": "neutral", "confidence": 7, "reasoning": "Neutral stance"},
674
+ {"label": "positive", "confidence": 6, "reasoning": "Moderately positive"}
675
+ ]
676
+ self._response_index = 0
677
+
678
+ def _test_connection(self):
679
+ """Mock connection test."""
680
+ self.logger.info("Mock LLM connection test successful")
681
+
682
+ def _predict_single(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
683
+ """Make a mock prediction."""
684
+ instance_id = instance.get('id', 'unknown')
685
+
686
+ # Simulate processing time
687
+ time.sleep(0.1)
688
+
689
+ # Get next mock response
690
+ mock_response = self._mock_responses[self._response_index % len(self._mock_responses)]
691
+ self._response_index += 1
692
+
693
+ # Add some randomness to confidence scores
694
+ confidence_variation = np.random.normal(0, 0.1)
695
+ confidence_score = max(0.1, min(1.0, (mock_response['confidence'] / 10.0) + confidence_variation))
696
+
697
+ return LLMPrediction(
698
+ instance_id=instance_id,
699
+ predicted_label=mock_response['label'],
700
+ confidence_score=confidence_score,
701
+ raw_response=json.dumps(mock_response)
702
+ )
703
+
704
+
705
+ def create_llm_active_learning(config: Dict[str, Any]) -> LLMActiveLearning:
706
+ """
707
+ Factory function to create LLM active learning instance.
708
+
709
+ Args:
710
+ config: LLM configuration dictionary
711
+
712
+ Returns:
713
+ LLMActiveLearning: Configured LLM active learning instance
714
+ """
715
+ llm_config = LLMConfig(
716
+ endpoint_url=config.get('endpoint_url', ''),
717
+ model_name=config.get('model_name', ''),
718
+ max_tokens=config.get('max_tokens', 512),
719
+ temperature=config.get('temperature', 0.1),
720
+ timeout=config.get('timeout', 30),
721
+ batch_size=config.get('batch_size', 10),
722
+ retry_attempts=config.get('retry_attempts', 3),
723
+ retry_delay=config.get('retry_delay', 1.0),
724
+ max_instances_per_request=config.get('max_instances_per_request', 5),
725
+ confidence_method=config.get('confidence_method', 'verbalized'),
726
+ consistency_samples=config.get('consistency_samples', 3),
727
+ )
728
+
729
+ # Use mock implementation for testing or when endpoint is not available
730
+ if config.get('use_mock', False) or not llm_config.endpoint_url:
731
+ return MockLLMActiveLearning(llm_config)
732
+ else:
733
+ return LLMActiveLearning(llm_config)
potato/ai/ollama_endpoint.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ollama AI endpoint implementation.
3
+
4
+ This module provides integration with Ollama for local LLM inference.
5
+ """
6
+
7
+ import json
8
+ from typing import Dict, List, Optional, Type
9
+ import ollama
10
+ from pydantic import BaseModel
11
+ from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError, ModelCapabilities
12
+ import re
13
+
14
+
15
+ DEFAULT_MODEL = "llama3.2"
16
+
17
+
18
+ class OllamaEndpoint(BaseAIEndpoint):
19
+ """Ollama endpoint for local LLM inference."""
20
+
21
+ # Capabilities declaration for text-based Ollama models
22
+ CAPABILITIES = ModelCapabilities(
23
+ text_generation=True,
24
+ vision_input=False,
25
+ bounding_box_output=False,
26
+ text_classification=True,
27
+ image_classification=False,
28
+ rationale_generation=True,
29
+ keyword_extraction=True,
30
+ )
31
+
32
+ def _initialize_client(self) -> None:
33
+ """Initialize the Ollama client."""
34
+ # Default timeout of 60 seconds for local inference (can be slower)
35
+ timeout = self.ai_config.get("timeout", 60)
36
+ host = self.ai_config.get("base_url", "http://localhost:11434")
37
+
38
+ # Create client with timeout
39
+ self.client = ollama.Client(host=host, timeout=timeout)
40
+
41
+ # Check if Ollama is available
42
+ try:
43
+ self.client.list()
44
+ except Exception as e:
45
+ raise AIEndpointRequestError(f"Failed to connect to Ollama: {e}")
46
+
47
+ def _get_default_model(self) -> str:
48
+ """Get the default Ollama model."""
49
+ return DEFAULT_MODEL
50
+
51
+ def query(self, prompt: str, output_format: Type[BaseModel]) -> str:
52
+ """
53
+ Send a query to Ollama and return the response.
54
+
55
+ Args:
56
+ prompt: The prompt to send to the model
57
+
58
+ Returns:
59
+ The model's response as a string
60
+
61
+ Raises:
62
+ AIEndpointRequestError: If the request fails
63
+ """
64
+ import logging
65
+ logger = logging.getLogger(__name__)
66
+
67
+ try:
68
+ logger.debug(f"[Ollama] Querying model: {self.model}")
69
+ logger.debug(f"[Ollama] Prompt (first 200 chars): {prompt[:200]}...")
70
+
71
+ options = {
72
+ 'temperature': self.temperature,
73
+ 'num_predict': self.max_tokens
74
+ }
75
+
76
+ # Think mode: configurable via ai_config['think'], defaults to False
77
+ # for fast structured output (thinking wastes tokens on JSON tasks)
78
+ think = self.ai_config.get('think', False)
79
+
80
+ response = self.client.chat(
81
+ model=self.model,
82
+ messages=[{'role': 'user', 'content': prompt}],
83
+ options=options,
84
+ format=output_format.model_json_schema(),
85
+ think=think,
86
+ )
87
+
88
+ # Log full response structure for debugging
89
+ logger.debug(f"[Ollama] Response type: {type(response)}")
90
+ logger.debug(f"[Ollama] Full response: {response}")
91
+
92
+ # Get the message object - handle both dict and object access
93
+ message = response.get('message') if hasattr(response, 'get') else getattr(response, 'message', None)
94
+ if message is None:
95
+ raise AIEndpointRequestError("No message in Ollama response")
96
+
97
+ # Get content - handle both dict and object access
98
+ content = message.get('content') if hasattr(message, 'get') else getattr(message, 'content', None)
99
+
100
+ # Some models put response in 'thinking' field - check that too
101
+ if not content and hasattr(message, 'thinking') and message.thinking:
102
+ logger.warning("[Ollama] Content empty but thinking field has data - model may need think=False")
103
+ # Try to extract JSON from thinking field as fallback
104
+ thinking_text = message.thinking
105
+ # Look for JSON in the thinking text
106
+ import re
107
+ json_match = re.search(r'\{[^{}]*\}', thinking_text)
108
+ if json_match:
109
+ content = json_match.group(0)
110
+ logger.debug(f"[Ollama] Extracted JSON from thinking: {content}")
111
+
112
+ logger.debug(f"[Ollama] Content type: {type(content)}")
113
+ logger.debug(f"[Ollama] Content value: {repr(content)[:200] if content else 'EMPTY'}")
114
+
115
+ # If content is already a dict (structured output), return it directly
116
+ if isinstance(content, dict):
117
+ logger.debug("[Ollama] Content is already a dict, returning directly")
118
+ return content
119
+
120
+ # Parse response using the base class's robust parser
121
+ # (handles truncated JSON, markdown blocks, plain text, etc.)
122
+ if content:
123
+ logger.debug(f"[Ollama] Response content (first 500 chars): {str(content)[:500]}")
124
+ return self.parseStringToJson(content)
125
+ else:
126
+ raise AIEndpointRequestError("Empty content from Ollama - try a different model or disable thinking mode")
127
+ except Exception as e:
128
+ logger.error(f"[Ollama] Request failed: {e}")
129
+ raise AIEndpointRequestError(f"Ollama request failed: {e}")
130
+
131
+ def chat_query(self, messages: List[Dict[str, str]]) -> str:
132
+ """Send a multi-turn chat to Ollama using native chat API."""
133
+ import logging
134
+ logger = logging.getLogger(__name__)
135
+
136
+ try:
137
+ options = {
138
+ 'temperature': self.temperature,
139
+ 'num_predict': self.max_tokens,
140
+ }
141
+
142
+ think = self.ai_config.get('think', False)
143
+ response = self.client.chat(
144
+ model=self.model,
145
+ messages=messages,
146
+ options=options,
147
+ think=think,
148
+ )
149
+
150
+ message = response.get('message') if hasattr(response, 'get') else getattr(response, 'message', None)
151
+ if message is None:
152
+ raise AIEndpointRequestError("No message in Ollama chat response")
153
+
154
+ content = message.get('content') if hasattr(message, 'get') else getattr(message, 'content', None)
155
+ return content or ""
156
+ except Exception as e:
157
+ logger.error(f"[Ollama] Chat request failed: {e}")
158
+ raise AIEndpointRequestError(f"Ollama chat request failed: {e}")
159
+
160
+
potato/ai/ollama_vision_endpoint.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ollama Vision AI Endpoint
3
+
4
+ This module provides integration with Ollama vision models for local
5
+ visual AI inference. Supports LLaVA, Llama 3.2 Vision, BakLLaVA, and Qwen-VL models.
6
+ """
7
+
8
+ import base64
9
+ import json
10
+ import logging
11
+ from typing import Any, Dict, List, Type, Union
12
+
13
+ from pydantic import BaseModel
14
+
15
+ from .ai_endpoint import AIEndpointRequestError, ImageData, ModelCapabilities
16
+ from .visual_ai_endpoint import BaseVisualAIEndpoint
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Default vision model
21
+ DEFAULT_MODEL = "llava:latest"
22
+
23
+ # Models known to support vision (used only for warning suppression, not to restrict usage)
24
+ # Any Ollama model can be used - this list just prevents "may not support vision" warnings
25
+ # for models we know are vision-capable
26
+ VISION_MODELS = [
27
+ # LLaVA family
28
+ "llava",
29
+ "llava-llama3",
30
+ "llava-phi3",
31
+ "bakllava",
32
+ # Llama Vision
33
+ "llama3.2-vision",
34
+ # Qwen Vision-Language
35
+ "qwen2.5-vl",
36
+ "qwen2-vl",
37
+ "qwen3-vl",
38
+ # Other vision models
39
+ "moondream",
40
+ "minicpm-v",
41
+ "gemma3", # Gemma 3 has vision capabilities
42
+ "gemma4", # Gemma 4 family (e.g. gemma4:e4b) supports vision
43
+ ]
44
+
45
+
46
+ class OllamaVisionEndpoint(BaseVisualAIEndpoint):
47
+ """
48
+ Ollama Vision endpoint for multimodal local inference.
49
+
50
+ Supports vision-capable models like LLaVA, Llama 3.2 Vision, BakLLaVA.
51
+ Images are sent as base64 in the 'images' field.
52
+
53
+ Configuration options:
54
+ - model: Vision model to use (default: llava:latest)
55
+ - base_url: Ollama server URL (default: http://localhost:11434)
56
+ - timeout: Request timeout in seconds (default: 120)
57
+ - max_tokens: Maximum response tokens (default: 500)
58
+ - temperature: Sampling temperature (default: 0.1)
59
+ """
60
+
61
+ # Capabilities declaration for vision-capable Ollama models (LLaVA, Qwen-VL, etc.)
62
+ # Note: VLLMs can generate text about images but cannot do precise bounding box detection
63
+ # Keyword extraction is disabled because it doesn't apply to image content
64
+ CAPABILITIES = ModelCapabilities(
65
+ text_generation=True,
66
+ vision_input=True,
67
+ bounding_box_output=False, # VLLMs are not reliable for precise bbox coordinates
68
+ text_classification=True,
69
+ image_classification=True,
70
+ rationale_generation=True,
71
+ keyword_extraction=False, # Keywords don't apply to images
72
+ )
73
+
74
+ def _initialize_client(self) -> None:
75
+ """Initialize the Ollama client."""
76
+ try:
77
+ import ollama
78
+ except ImportError:
79
+ raise AIEndpointRequestError(
80
+ "ollama package is required. Install it with: pip install ollama"
81
+ )
82
+
83
+ timeout = self.ai_config.get("timeout", 120) # Vision models can be slower
84
+ host = self.ai_config.get("base_url", "http://localhost:11434")
85
+
86
+ self.client = ollama.Client(host=host, timeout=timeout)
87
+
88
+ # Verify connection and model availability
89
+ try:
90
+ models = self.client.list()
91
+ logger.info(f"Connected to Ollama at {host}")
92
+
93
+ # Check if the specified model is a known vision model
94
+ # This is just informational - any model can be used
95
+ model_lower = self.model.lower()
96
+ is_known_vision_model = any(vm in model_lower for vm in VISION_MODELS)
97
+ if not is_known_vision_model:
98
+ logger.info(
99
+ f"Model '{self.model}' not in known vision models list. "
100
+ f"This is fine if it supports vision - proceeding anyway."
101
+ )
102
+
103
+ except Exception as e:
104
+ raise AIEndpointRequestError(f"Failed to connect to Ollama: {e}")
105
+
106
+ def _get_default_model(self) -> str:
107
+ """Get the default vision model."""
108
+ return DEFAULT_MODEL
109
+
110
+ def query(self, prompt: str, output_format: Type[BaseModel]) -> Any:
111
+ """
112
+ Standard text query (falls back to text-only mode).
113
+
114
+ For vision tasks, use query_with_image() instead.
115
+ """
116
+ try:
117
+ options = {
118
+ 'temperature': self.temperature,
119
+ 'num_predict': self.max_tokens
120
+ }
121
+
122
+ response = self.client.chat(
123
+ model=self.model,
124
+ messages=[{'role': 'user', 'content': prompt}],
125
+ options=options,
126
+ format=output_format.model_json_schema(),
127
+ think=False,
128
+ )
129
+
130
+ message = response.get('message') if hasattr(response, 'get') else getattr(response, 'message', None)
131
+ if message is None:
132
+ raise AIEndpointRequestError("No message in Ollama response")
133
+
134
+ content = message.get('content') if hasattr(message, 'get') else getattr(message, 'content', None)
135
+
136
+ if isinstance(content, dict):
137
+ return content
138
+
139
+ if content:
140
+ return self.parseStringToJson(content)
141
+ else:
142
+ raise AIEndpointRequestError("Empty content from Ollama")
143
+
144
+ except Exception as e:
145
+ raise AIEndpointRequestError(f"Ollama query failed: {e}")
146
+
147
+ def query_with_image(
148
+ self,
149
+ prompt: str,
150
+ image_data: Union[ImageData, List[ImageData]],
151
+ output_format: Type[BaseModel]
152
+ ) -> Any:
153
+ """
154
+ Send a query with image(s) to Ollama vision model.
155
+
156
+ Args:
157
+ prompt: Text prompt describing what to analyze
158
+ image_data: Single ImageData or list of ImageData
159
+ output_format: Pydantic model for structured output
160
+
161
+ Returns:
162
+ Parsed response according to output_format
163
+
164
+ Raises:
165
+ AIEndpointRequestError: If the request fails
166
+ """
167
+ try:
168
+ # Prepare images
169
+ images = [image_data] if isinstance(image_data, ImageData) else image_data
170
+
171
+ # Convert to base64 if needed
172
+ image_base64_list = []
173
+ for img in images:
174
+ b64_data = self._get_base64_image(img)
175
+ image_base64_list.append(b64_data)
176
+
177
+ # Build message with images
178
+ options = {
179
+ 'temperature': self.temperature,
180
+ 'num_predict': self.max_tokens
181
+ }
182
+
183
+ # Ollama expects images as a list of base64 strings
184
+ response = self.client.chat(
185
+ model=self.model,
186
+ messages=[{
187
+ 'role': 'user',
188
+ 'content': prompt,
189
+ 'images': image_base64_list
190
+ }],
191
+ options=options,
192
+ format=output_format.model_json_schema(),
193
+ think=False,
194
+ )
195
+
196
+ logger.debug(f"Ollama vision response type: {type(response)}")
197
+
198
+ # Extract content from response
199
+ message = response.get('message') if hasattr(response, 'get') else getattr(response, 'message', None)
200
+ if message is None:
201
+ raise AIEndpointRequestError("No message in Ollama vision response")
202
+
203
+ content = message.get('content') if hasattr(message, 'get') else getattr(message, 'content', None)
204
+
205
+ logger.debug(f"Ollama vision content type: {type(content)}")
206
+
207
+ # Parse response
208
+ if isinstance(content, dict):
209
+ return content
210
+
211
+ if content:
212
+ return self.parseStringToJson(content)
213
+ else:
214
+ raise AIEndpointRequestError("Empty content from Ollama vision model")
215
+
216
+ except AIEndpointRequestError:
217
+ raise
218
+ except Exception as e:
219
+ logger.error(f"Ollama vision query failed: {e}")
220
+ import traceback
221
+ logger.error(traceback.format_exc())
222
+ raise AIEndpointRequestError(f"Ollama vision query failed: {e}")
223
+
224
+ def _get_base64_image(self, image_data: ImageData) -> str:
225
+ """
226
+ Get base64-encoded image data.
227
+
228
+ Args:
229
+ image_data: ImageData object
230
+
231
+ Returns:
232
+ Base64-encoded image string (without data URL prefix)
233
+ """
234
+ if image_data.source == "base64":
235
+ # Already base64, just return the data
236
+ return image_data.data
237
+
238
+ elif image_data.source == "url":
239
+ # Download and convert to base64
240
+ downloaded = self.download_image_to_base64(image_data.data)
241
+ return downloaded.data
242
+
243
+ else:
244
+ raise AIEndpointRequestError(f"Unknown image source: {image_data.source}")
245
+
246
+ def analyze_image(
247
+ self,
248
+ image_path_or_url: str,
249
+ prompt: str,
250
+ output_format: Type[BaseModel] = None
251
+ ) -> Any:
252
+ """
253
+ Convenience method for analyzing a single image.
254
+
255
+ Args:
256
+ image_path_or_url: Path to image file or URL
257
+ prompt: Analysis prompt
258
+ output_format: Optional output format model
259
+
260
+ Returns:
261
+ Analysis result
262
+ """
263
+ # Prepare image data
264
+ if image_path_or_url.startswith(("http://", "https://")):
265
+ image_data = self.download_image_to_base64(image_path_or_url)
266
+ else:
267
+ image_data = self.encode_image_to_base64(image_path_or_url)
268
+
269
+ # Use a generic format if not specified
270
+ if output_format is None:
271
+ from .prompt.models_module import GeneralHintFormat
272
+ output_format = GeneralHintFormat
273
+
274
+ return self.query_with_image(prompt, image_data, output_format)
275
+
276
+ def describe_image(self, image_path_or_url: str) -> str:
277
+ """
278
+ Get a natural language description of an image.
279
+
280
+ Args:
281
+ image_path_or_url: Path to image file or URL
282
+
283
+ Returns:
284
+ Text description of the image
285
+ """
286
+ # Use a simple model that returns text
287
+ class DescriptionFormat(BaseModel):
288
+ description: str
289
+
290
+ result = self.analyze_image(
291
+ image_path_or_url,
292
+ "Describe this image in detail. What objects, people, or scenes do you see?",
293
+ DescriptionFormat
294
+ )
295
+
296
+ if isinstance(result, dict) and "description" in result:
297
+ return result["description"]
298
+ return str(result)
299
+
300
+ def health_check(self) -> bool:
301
+ """
302
+ Check if the Ollama vision model is available.
303
+
304
+ Returns:
305
+ True if model is ready, False otherwise
306
+ """
307
+ try:
308
+ # Try to list models
309
+ self.client.list()
310
+ return True
311
+ except Exception as e:
312
+ logger.error(f"Ollama vision health check failed: {e}")
313
+ return False
potato/ai/openai_endpoint.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI AI endpoint implementation.
3
+
4
+ This module provides integration with OpenAI's API for LLM inference.
5
+ """
6
+
7
+ import os
8
+ from typing import Dict, List
9
+ from openai import OpenAI
10
+ from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError, ModelCapabilities
11
+
12
+ DEFAULT_MODEL = "gpt-4o-mini"
13
+
14
+
15
+ class OpenAIEndpoint(BaseAIEndpoint):
16
+ """OpenAI endpoint for cloud-based LLM inference."""
17
+
18
+ # Capabilities declaration for text-based OpenAI models
19
+ CAPABILITIES = ModelCapabilities(
20
+ text_generation=True,
21
+ vision_input=False,
22
+ bounding_box_output=False,
23
+ text_classification=True,
24
+ image_classification=False,
25
+ rationale_generation=True,
26
+ keyword_extraction=True,
27
+ )
28
+
29
+ def _initialize_client(self) -> None:
30
+ """Initialize the OpenAI client."""
31
+ # OpenAI-compatible servers (vLLM, llama.cpp, etc.) ignore the key
32
+ # but the SDK rejects an empty string, so accept a placeholder.
33
+ api_key = self.ai_config.get("api_key") or os.environ.get(
34
+ "OPENAI_API_KEY", ""
35
+ )
36
+ base_url = self.ai_config.get("base_url")
37
+ if not api_key:
38
+ if base_url:
39
+ api_key = "EMPTY" # non-empty placeholder for local servers
40
+ else:
41
+ raise AIEndpointRequestError("OpenAI API key is required")
42
+
43
+ # Default timeout of 30 seconds, configurable via ai_config
44
+ timeout = self.ai_config.get("timeout", 30)
45
+ client_kwargs = {"api_key": api_key, "timeout": timeout}
46
+ # Honor a custom base_url so this endpoint can target any
47
+ # OpenAI-compatible server (previously ignored -> always hit
48
+ # api.openai.com even when a local base_url was configured).
49
+ if base_url:
50
+ client_kwargs["base_url"] = base_url
51
+ self.client = OpenAI(**client_kwargs)
52
+
53
+ def _get_default_model(self) -> str:
54
+ """Get the default OpenAI model."""
55
+ return DEFAULT_MODEL
56
+
57
+ def query(self, prompt: str, output_format: dict) -> str:
58
+ """
59
+ Send a query to OpenAI and return the response.
60
+
61
+ Args:
62
+ prompt: The prompt to send to the model
63
+
64
+ Returns:
65
+ The model's response as a string
66
+
67
+ Raises:
68
+ AIEndpointRequestError: If the request fails
69
+ """
70
+ try:
71
+ response = self.client.chat.completions.create(
72
+ model=self.model,
73
+ messages=[{"role": "user", "content": prompt}],
74
+ max_tokens=self.max_tokens,
75
+ temperature=self.temperature,
76
+ text_format=output_format.model_json_schema(),
77
+ )
78
+ return response.choices[0].message.content
79
+ except Exception as e:
80
+ raise AIEndpointRequestError(f"OpenAI request failed: {e}")
81
+
82
+ def chat_query(self, messages: List[Dict[str, str]]) -> str:
83
+ """Send a multi-turn chat to OpenAI using native messages API."""
84
+ try:
85
+ response = self.client.chat.completions.create(
86
+ model=self.model,
87
+ messages=messages,
88
+ max_tokens=self.max_tokens,
89
+ temperature=self.temperature,
90
+ )
91
+ return response.choices[0].message.content
92
+ except Exception as e:
93
+ raise AIEndpointRequestError(f"OpenAI chat request failed: {e}")
94
+
potato/ai/openai_vision_endpoint.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI Vision AI Endpoint
3
+
4
+ This module provides integration with OpenAI's vision models (GPT-4o, GPT-4o-mini)
5
+ for visual analysis and annotation assistance.
6
+ """
7
+
8
+ import base64
9
+ import logging
10
+ from typing import Any, Dict, List, Type, Union
11
+
12
+ from pydantic import BaseModel
13
+
14
+ from .ai_endpoint import AIEndpointRequestError, ImageData, ModelCapabilities
15
+ from .visual_ai_endpoint import BaseVisualAIEndpoint
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ DEFAULT_MODEL = "gpt-4o"
20
+
21
+
22
+ class OpenAIVisionEndpoint(BaseVisualAIEndpoint):
23
+ """
24
+ OpenAI Vision endpoint for GPT-4o and GPT-4o-mini vision capabilities.
25
+
26
+ Supports both URL and base64 image inputs using the image_url content type.
27
+
28
+ Configuration options:
29
+ - model: Model to use (gpt-4o, gpt-4o-mini) (default: gpt-4o)
30
+ - api_key: OpenAI API key (can also use OPENAI_API_KEY env var)
31
+ - max_tokens: Maximum response tokens (default: 1000)
32
+ - temperature: Sampling temperature (default: 0.1)
33
+ - detail: Image detail level - 'low', 'high', or 'auto' (default: auto)
34
+ """
35
+
36
+ # Capabilities declaration for OpenAI vision models (GPT-4o, GPT-4o-mini)
37
+ # These models can understand images and generate text but bounding boxes are approximate
38
+ CAPABILITIES = ModelCapabilities(
39
+ text_generation=True,
40
+ vision_input=True,
41
+ bounding_box_output=False, # GPT-4V bboxes are approximate, not precise
42
+ text_classification=True,
43
+ image_classification=True,
44
+ rationale_generation=True,
45
+ keyword_extraction=False, # Keywords don't apply to images
46
+ )
47
+
48
+ def _initialize_client(self) -> None:
49
+ """Initialize the OpenAI client."""
50
+ try:
51
+ import openai
52
+ except ImportError:
53
+ raise AIEndpointRequestError(
54
+ "openai package is required. Install it with: pip install openai"
55
+ )
56
+
57
+ import os
58
+
59
+ api_key = self.ai_config.get("api_key") or os.environ.get("OPENAI_API_KEY")
60
+ if not api_key:
61
+ raise AIEndpointRequestError(
62
+ "OpenAI API key is required. Set it in config or OPENAI_API_KEY env var."
63
+ )
64
+
65
+ timeout = self.ai_config.get("timeout", 60)
66
+ self.detail = self.ai_config.get("detail", "auto")
67
+
68
+ self.client = openai.OpenAI(api_key=api_key, timeout=timeout)
69
+ logger.info(f"OpenAI Vision client initialized with model: {self.model}")
70
+
71
+ def _get_default_model(self) -> str:
72
+ """Get the default OpenAI vision model."""
73
+ return DEFAULT_MODEL
74
+
75
+ def query(self, prompt: str, output_format: Type[BaseModel]) -> Any:
76
+ """
77
+ Standard text query without images.
78
+
79
+ Args:
80
+ prompt: Text prompt
81
+ output_format: Pydantic model for structured output
82
+
83
+ Returns:
84
+ Parsed response
85
+ """
86
+ try:
87
+ response = self.client.chat.completions.create(
88
+ model=self.model,
89
+ messages=[{"role": "user", "content": prompt}],
90
+ max_tokens=self.max_tokens,
91
+ temperature=self.temperature,
92
+ response_format={"type": "json_object"},
93
+ )
94
+
95
+ content = response.choices[0].message.content
96
+ return self.parseStringToJson(content)
97
+
98
+ except Exception as e:
99
+ raise AIEndpointRequestError(f"OpenAI query failed: {e}")
100
+
101
+ def query_with_image(
102
+ self,
103
+ prompt: str,
104
+ image_data: Union[ImageData, List[ImageData]],
105
+ output_format: Type[BaseModel]
106
+ ) -> Any:
107
+ """
108
+ Send a query with image(s) to OpenAI vision model.
109
+
110
+ Args:
111
+ prompt: Text prompt describing what to analyze
112
+ image_data: Single ImageData or list of ImageData
113
+ output_format: Pydantic model for structured output
114
+
115
+ Returns:
116
+ Parsed response according to output_format
117
+
118
+ Raises:
119
+ AIEndpointRequestError: If the request fails
120
+ """
121
+ try:
122
+ # Prepare images
123
+ images = [image_data] if isinstance(image_data, ImageData) else image_data
124
+
125
+ # Build content array with text and images
126
+ content = [{"type": "text", "text": prompt}]
127
+
128
+ for img in images:
129
+ image_content = self._build_image_content(img)
130
+ content.append(image_content)
131
+
132
+ # Make request
133
+ response = self.client.chat.completions.create(
134
+ model=self.model,
135
+ messages=[{"role": "user", "content": content}],
136
+ max_tokens=self.max_tokens,
137
+ temperature=self.temperature,
138
+ response_format={"type": "json_object"},
139
+ )
140
+
141
+ response_content = response.choices[0].message.content
142
+ logger.debug(f"OpenAI vision response: {response_content[:500] if response_content else 'empty'}")
143
+
144
+ return self.parseStringToJson(response_content)
145
+
146
+ except AIEndpointRequestError:
147
+ raise
148
+ except Exception as e:
149
+ logger.error(f"OpenAI vision query failed: {e}")
150
+ import traceback
151
+ logger.error(traceback.format_exc())
152
+ raise AIEndpointRequestError(f"OpenAI vision query failed: {e}")
153
+
154
+ def _build_image_content(self, image_data: ImageData) -> Dict[str, Any]:
155
+ """
156
+ Build image content block for OpenAI API.
157
+
158
+ Args:
159
+ image_data: ImageData object
160
+
161
+ Returns:
162
+ Dict with type: "image_url" and image_url content
163
+ """
164
+ if image_data.source == "url":
165
+ # Direct URL reference
166
+ return {
167
+ "type": "image_url",
168
+ "image_url": {
169
+ "url": image_data.data,
170
+ "detail": self.detail
171
+ }
172
+ }
173
+
174
+ elif image_data.source == "base64":
175
+ # Data URL format
176
+ mime_type = image_data.mime_type or "image/jpeg"
177
+ data_url = f"data:{mime_type};base64,{image_data.data}"
178
+
179
+ return {
180
+ "type": "image_url",
181
+ "image_url": {
182
+ "url": data_url,
183
+ "detail": self.detail
184
+ }
185
+ }
186
+
187
+ else:
188
+ raise AIEndpointRequestError(f"Unknown image source: {image_data.source}")
189
+
190
+ def analyze_image(
191
+ self,
192
+ image_path_or_url: str,
193
+ prompt: str,
194
+ output_format: Type[BaseModel] = None
195
+ ) -> Any:
196
+ """
197
+ Convenience method for analyzing a single image.
198
+
199
+ Args:
200
+ image_path_or_url: Path to image file or URL
201
+ prompt: Analysis prompt
202
+ output_format: Optional output format model
203
+
204
+ Returns:
205
+ Analysis result
206
+ """
207
+ # Prepare image data - use URL directly if possible
208
+ if image_path_or_url.startswith(("http://", "https://")):
209
+ image_data = self.create_url_image_data(image_path_or_url)
210
+ else:
211
+ image_data = self.encode_image_to_base64(image_path_or_url)
212
+
213
+ # Use a generic format if not specified
214
+ if output_format is None:
215
+ from .prompt.models_module import GeneralHintFormat
216
+ output_format = GeneralHintFormat
217
+
218
+ return self.query_with_image(prompt, image_data, output_format)
219
+
220
+ def detect_objects(
221
+ self,
222
+ image_path_or_url: str,
223
+ labels: List[str] = None
224
+ ) -> Dict[str, Any]:
225
+ """
226
+ Detect objects in an image and return bounding boxes.
227
+
228
+ Args:
229
+ image_path_or_url: Path to image file or URL
230
+ labels: Optional list of labels to detect
231
+
232
+ Returns:
233
+ Dict with detections list
234
+ """
235
+ from .prompt.models_module import VisualDetectionFormat
236
+
237
+ labels_str = ", ".join(labels) if labels else "all visible objects"
238
+
239
+ prompt = f"""Analyze this image and detect objects. For each object, provide:
240
+ 1. The label (from: {labels_str})
241
+ 2. A bounding box with normalized coordinates (0-1 range)
242
+ 3. Confidence score (0-1)
243
+
244
+ Return JSON with this structure:
245
+ {{
246
+ "detections": [
247
+ {{
248
+ "label": "object_name",
249
+ "bbox": {{"x": 0.1, "y": 0.2, "width": 0.3, "height": 0.4}},
250
+ "confidence": 0.95
251
+ }}
252
+ ]
253
+ }}
254
+
255
+ Coordinates are normalized (0-1) where x,y is the top-left corner.
256
+ Only include objects you can clearly identify with confidence > 0.5."""
257
+
258
+ # Prepare image
259
+ if image_path_or_url.startswith(("http://", "https://")):
260
+ image_data = self.create_url_image_data(image_path_or_url)
261
+ else:
262
+ image_data = self.encode_image_to_base64(image_path_or_url)
263
+
264
+ return self.query_with_image(prompt, image_data, VisualDetectionFormat)
265
+
266
+ def describe_region(
267
+ self,
268
+ image_path_or_url: str,
269
+ region: Dict[str, float],
270
+ labels: List[str] = None
271
+ ) -> Dict[str, Any]:
272
+ """
273
+ Describe or classify a specific region in an image.
274
+
275
+ Args:
276
+ image_path_or_url: Path to image file or URL
277
+ region: Dict with x, y, width, height (normalized 0-1)
278
+ labels: Optional list of possible labels
279
+
280
+ Returns:
281
+ Classification result with suggested label and confidence
282
+ """
283
+ labels_str = ", ".join(labels) if labels else "any appropriate category"
284
+
285
+ prompt = f"""Look at the region marked in this image:
286
+ - Region: x={region['x']:.2f}, y={region['y']:.2f}, width={region['width']:.2f}, height={region['height']:.2f}
287
+ (Coordinates are normalized 0-1, where 0,0 is top-left)
288
+
289
+ Classify what you see in this region from these options: {labels_str}
290
+
291
+ Return JSON:
292
+ {{
293
+ "suggested_label": "label_name",
294
+ "confidence": 0.85,
295
+ "reasoning": "Brief explanation"
296
+ }}"""
297
+
298
+ # Prepare image
299
+ if image_path_or_url.startswith(("http://", "https://")):
300
+ image_data = self.create_url_image_data(image_path_or_url)
301
+ else:
302
+ image_data = self.encode_image_to_base64(image_path_or_url)
303
+
304
+ class RegionClassificationFormat(BaseModel):
305
+ suggested_label: str
306
+ confidence: float
307
+ reasoning: str
308
+
309
+ return self.query_with_image(prompt, image_data, RegionClassificationFormat)
310
+
311
+ def health_check(self) -> bool:
312
+ """
313
+ Check if the OpenAI API is accessible.
314
+
315
+ Returns:
316
+ True if API is reachable, False otherwise
317
+ """
318
+ try:
319
+ # Simple models list check
320
+ self.client.models.list()
321
+ return True
322
+ except Exception as e:
323
+ logger.error(f"OpenAI health check failed: {e}")
324
+ return False
potato/ai/openrouter_endpoint.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenRouter endpoint implementation.
3
+ This module provides integration with OpenRouter's API for LLM inference.
4
+ """
5
+ import requests
6
+ from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError
7
+
8
+ DEFAULT_MODEL = "openai/gpt-4o-mini"
9
+
10
+ class OpenRouterEndpoint(BaseAIEndpoint):
11
+ """OpenRouter endpoint for cloud-based LLM inference."""
12
+
13
+ # Models that support structured output
14
+ STRUCTURED_OUTPUT_MODELS = {
15
+ "openai/gpt-4o",
16
+ "openai/gpt-4o-mini",
17
+ "openai/gpt-4-turbo",
18
+ "anthropic/claude-3-5-sonnet",
19
+ "deepseek/deepseek-r1:free"
20
+ }
21
+
22
+ def _initialize_client(self) -> None:
23
+ """Initialize the OpenAI client."""
24
+ api_key = self.ai_config.get("api_key", "")
25
+ if not api_key:
26
+ raise AIEndpointRequestError("OpenRouter API key is required")
27
+
28
+ def _get_default_model(self) -> str:
29
+ """Get the default OpenAI model."""
30
+ return DEFAULT_MODEL
31
+
32
+ def supports_structured_output(self) -> bool:
33
+ """Check if the current model supports structured output."""
34
+ model = self.model or DEFAULT_MODEL
35
+ return any(model.startswith(prefix.split('/')[0]) or model in self.STRUCTURED_OUTPUT_MODELS
36
+ for prefix in self.STRUCTURED_OUTPUT_MODELS)
37
+
38
+ def query(self, prompt: str, output_format: dict) -> str:
39
+ """
40
+ Send a query to OpenRouter and return the response.
41
+
42
+ Args:
43
+ prompt: The prompt to send to the model (as messages list or string)
44
+ output_format: Pydantic model for structured output
45
+
46
+ Returns:
47
+ The model's response as a string
48
+
49
+ Raises:
50
+ AIEndpointRequestError: If the request fails
51
+ """
52
+ try:
53
+ url = "https://openrouter.ai/api/v1/chat/completions"
54
+ headers = {
55
+ "Authorization": f"Bearer {self.ai_config.get('api_key')}",
56
+ "Content-Type": "application/json"
57
+ }
58
+
59
+ messages = [{"role": "user", "content": prompt}]
60
+ schema = output_format.model_json_schema()
61
+
62
+ body = {
63
+ "model": self.model or DEFAULT_MODEL,
64
+ "max_tokens": self.max_tokens,
65
+ "temperature": self.temperature,
66
+ }
67
+
68
+ # Handle structured output based on model support
69
+ if self.supports_structured_output():
70
+ body["messages"] = messages
71
+ body["response_format"] = {
72
+ "type": "json_schema",
73
+ "json_schema": {
74
+ "name": "response",
75
+ "schema": schema,
76
+ "strict": True
77
+ }
78
+ }
79
+ else:
80
+ # If model does not support structured format, just send raw prompt
81
+ body["messages"] = messages
82
+
83
+ r = requests.post(url, headers=headers, json=body)
84
+
85
+ if r.status_code >= 400:
86
+ raise AIEndpointRequestError(f"OpenRouter error {r.status_code}: {r.text}")
87
+
88
+ data = r.json()
89
+ if self.supports_structured_output():
90
+ return self.parseStringToJson(data["choices"][0]["message"]["content"])
91
+ return data["choices"][0]["message"]["content"]
92
+ except Exception as e:
93
+ raise AIEndpointRequestError(f"OpenRouter request failed: {e}")
potato/ai/prompt/image_annotation.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "detect": {
3
+ "prompt": "TASK: Detect objects in this image that match the specified labels.\n\nDescription: ${description}\nLabels to detect: ${labels}\nMinimum confidence: ${confidence_threshold}\n\nFor each detected object, provide:\n1. The label (must be from the provided labels list)\n2. Bounding box with normalized coordinates (0-1 range)\n3. Confidence score (0-1)\n\nCoordinate system:\n- x,y is the top-left corner of the box\n- x increases left to right (0 = left edge, 1 = right edge)\n- y increases top to bottom (0 = top edge, 1 = bottom edge)\n- width and height are also normalized (0-1)\n\nReturn JSON:\n{\n \"detections\": [\n {\n \"label\": \"object_name\",\n \"bbox\": {\"x\": 0.1, \"y\": 0.2, \"width\": 0.3, \"height\": 0.4},\n \"confidence\": 0.95\n }\n ]\n}\n\nOnly include objects you can clearly identify with confidence >= ${confidence_threshold}.",
4
+ "output_format": "visual_detection",
5
+ "img": "/static/ai_assistant_img/detect.svg",
6
+ "name": "Detect"
7
+ },
8
+ "detection": {
9
+ "prompt": "TASK: Detect objects in this image that match the specified labels.\n\nDescription: ${description}\nLabels to detect: ${labels}\nMinimum confidence: ${confidence_threshold}\n\nFor each detected object, provide:\n1. The label (must be from the provided labels list)\n2. Bounding box with normalized coordinates (0-1 range)\n3. Confidence score (0-1)\n\nCoordinate system:\n- x,y is the top-left corner of the box\n- x increases left to right (0 = left edge, 1 = right edge)\n- y increases top to bottom (0 = top edge, 1 = bottom edge)\n- width and height are also normalized (0-1)\n\nReturn JSON:\n{\n \"detections\": [\n {\n \"label\": \"object_name\",\n \"bbox\": {\"x\": 0.1, \"y\": 0.2, \"width\": 0.3, \"height\": 0.4},\n \"confidence\": 0.95\n }\n ]\n}\n\nOnly include objects you can clearly identify with confidence >= ${confidence_threshold}.",
10
+ "output_format": "visual_detection",
11
+ "img": "/static/ai_assistant_img/detect.svg",
12
+ "name": "Detect"
13
+ },
14
+ "pre_annotate": {
15
+ "prompt": "TASK: Pre-annotate this image by detecting all objects that match the available labels.\n\nDescription: ${description}\nAvailable labels: ${labels}\n\nDetect ALL instances of objects matching these labels. Be thorough - it's better to include uncertain detections (the human annotator will review them).\n\nFor each detection provide:\n1. Label from the available list\n2. Bounding box (normalized 0-1 coordinates)\n3. Confidence score\n\nReturn JSON:\n{\n \"detections\": [\n {\n \"label\": \"label_name\",\n \"bbox\": {\"x\": 0.0, \"y\": 0.0, \"width\": 0.0, \"height\": 0.0},\n \"confidence\": 0.0\n }\n ]\n}\n\nInclude all potential detections, even with lower confidence. The annotator will verify.",
16
+ "output_format": "visual_detection",
17
+ "img": "/static/ai_assistant_img/auto.svg",
18
+ "name": "Auto"
19
+ },
20
+ "classification": {
21
+ "prompt": "TASK: Classify the specified region in this image.\n\nDescription: ${description}\nRegion: ${region}\nAvailable labels: ${labels}\n\nLook at the indicated region and determine which label best describes what you see there.\n\nReturn JSON:\n{\n \"suggested_label\": \"label_name\",\n \"confidence\": 0.85,\n \"reasoning\": \"Brief explanation of why this label fits\"\n}",
22
+ "output_format": "visual_classification",
23
+ "img": "/static/ai_assistant_img/classify.svg",
24
+ "name": "Classify"
25
+ },
26
+ "hint": {
27
+ "prompt": "TASK: Provide a helpful hint for annotating this image WITHOUT revealing exact answers.\n\nAnnotation task: ${description}\nAvailable labels: ${labels}\n\nProvide guidance that helps the annotator without giving away:\n- Exact object locations\n- Specific label assignments\n\nGood hints:\n- Point out relevant visual features\n- Suggest areas that deserve attention\n- Note potential challenges or ambiguities\n- Remind about edge cases\n\nReturn JSON:\n{\n \"hint\": \"Your helpful guidance here\",\n \"suggestive_choice\": \"optional_focus_area\"\n}",
28
+ "output_format": "default_hint",
29
+ "img": "/static/ai_assistant_img/blub.svg",
30
+ "name": "Hint"
31
+ },
32
+ "keyword": {
33
+ "prompt": "TASK: Identify visual keywords/features associated with each label in this image.\n\nAnnotation task: ${description}\nAvailable labels: ${labels}\n\nFor each label, identify visual cues or features that would indicate its presence.\n\nReturn JSON:\n{\n \"label_keywords\": [\n {\n \"label\": \"label_name\",\n \"keywords\": [\"visual_feature_1\", \"visual_feature_2\"]\n }\n ]\n}",
34
+ "output_format": "default_keyword",
35
+ "img": "/static/ai_assistant_img/highlight.svg",
36
+ "name": "Keywords"
37
+ },
38
+ "rationale": {
39
+ "prompt": "TASK: Provide rationale for how each label might apply to this image.\n\nAnnotation task: ${description}\nAvailable labels: ${labels}\n\nFor each label, explain what evidence in the image supports or contradicts its application.\n\nReturn JSON:\n{\n \"rationales\": [\n {\n \"label\": \"label_name\",\n \"reasoning\": \"Explanation of visual evidence for/against this label\"\n }\n ]\n}",
40
+ "output_format": "default_rationale",
41
+ "img": "/static/ai_assistant_img/question.svg",
42
+ "name": "Rationale"
43
+ }
44
+ }
potato/ai/prompt/likert.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hint": {
3
+ "prompt": "TASK: Generate annotation guidance for Likert scale task.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Scale: ${min_label} (1) to ${max_label} (${size})\n- Scale points: 1, 2, 3, ..., ${size}\n\nINSTRUCTIONS:\n1. Analyze the text for features relevant to the annotation task\n2. Generate a helpful hint that guides thinking WITHOUT revealing the answer\n3. Suggest a scale position based on your analysis\n\nHINT REQUIREMENTS:\n- Focus on specific textual evidence (word choice, tone, structure)\n- Point out subtle indicators the annotator should notice\n- Be concrete and actionable, not generic\n- Guide analytical thinking without bias toward any scale position",
4
+ "output_format": "default_hint",
5
+ "img": "/static/ai_assistant_img/blub.svg"
6
+ },
7
+ "keyword": {
8
+ "prompt": "TASK: Extract key words/phrases that guide Likert scale annotation decisions.\n\nINPUT DETAILS:\n- Text: \"${text}\"\n- Annotation task: ${description}\n- Scale: ${min_label} (1) to ${max_label} (${size})\n- Scale range: 1, 2, 3, ..., ${size}\n\nOBJECTIVE: Identify 3-5 most significant words/phrases that directly indicate scale positioning. Output as JSON array format.\n\nSELECTION CRITERIA:\n- Words that signal intensity/degree (extremely, slightly, moderately)\n- Sentiment markers (positive/negative indicators)\n- Qualifying language (hedges, certainty markers)\n- Context-specific terminology relevant to the annotation task\n- Structural indicators (but, however, although)\n\nPRIORITIZE:\n1. Words with clear scale implications\n2. Phrases that distinguish between scale levels\n3. Contextual clues that affect interpretation\n4. Intensity modifiers and qualifiers",
9
+ "output_format": "default_keyword",
10
+ "img": "/static/ai_assistant_img/highlight.svg"
11
+ },
12
+ "rationale": {
13
+ "name": "Rationale",
14
+ "prompt": "TASK: Generate rationales for different positions on the Likert scale.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Scale: ${min_label} (1) to ${max_label} (${size})\n\nINSTRUCTIONS:\nProvide rationales for different scale positions (low, middle, high). Explain what evidence in the text could support each position. Be balanced and objective.\n\nOUTPUT FORMAT:\nReturn a JSON object with \"rationales\" array containing objects with \"label\" (scale position descriptor) and \"reasoning\" fields.\n\nEXAMPLE OUTPUT:\n{\"rationales\": [{\"label\": \"low (1-2)\", \"reasoning\": \"The negative tone and words like 'disappointing' suggest a low rating\"}, {\"label\": \"middle (3)\", \"reasoning\": \"Mixed signals with both positive and negative elements\"}, {\"label\": \"high (4-5)\", \"reasoning\": \"Strong positive language like 'excellent' supports a high rating\"}]}",
15
+ "output_format": "default_rationale",
16
+ "img": "/static/ai_assistant_img/question.svg"
17
+ }
18
+ }
19
+
20
+
potato/ai/prompt/models_module.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Type, Union, Dict, List
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class GeneralHintFormat(BaseModel):
6
+ hint: str
7
+ suggestive_choice: Union[str, int]
8
+
9
+
10
+ class LabelKeywords(BaseModel):
11
+ """Keywords/phrases associated with a specific label."""
12
+ label: str
13
+ keywords: List[str]
14
+
15
+
16
+ class GeneralKeywordFormat(BaseModel):
17
+ """Simplified keyword format: list of label -> keywords mappings.
18
+
19
+ Example output:
20
+ {
21
+ "label_keywords": [
22
+ {"label": "positive", "keywords": ["great", "love it", "excellent"]},
23
+ {"label": "negative", "keywords": ["terrible", "awful"]}
24
+ ]
25
+ }
26
+ """
27
+ label_keywords: List[LabelKeywords]
28
+
29
+
30
+ class GeneralRandomFormat(BaseModel):
31
+ """Deprecated: Use GeneralRationaleFormat instead."""
32
+ random: str
33
+
34
+
35
+ class LabelRationale(BaseModel):
36
+ """Rationale/reasoning for why a specific label might apply."""
37
+ label: str
38
+ reasoning: str
39
+
40
+
41
+ class GeneralRationaleFormat(BaseModel):
42
+ """Rationale format: explanations for how each label might apply to the text.
43
+
44
+ Example output:
45
+ {
46
+ "rationales": [
47
+ {"label": "positive", "reasoning": "The phrase 'excellent quality' suggests satisfaction"},
48
+ {"label": "negative", "reasoning": "The mention of 'delayed shipping' indicates frustration"}
49
+ ]
50
+ }
51
+ """
52
+ rationales: List[LabelRationale]
53
+
54
+
55
+ # ============================================================================
56
+ # Visual Annotation Output Formats
57
+ # ============================================================================
58
+
59
+ class BoundingBox(BaseModel):
60
+ """Normalized bounding box coordinates (0-1 range).
61
+
62
+ x, y: top-left corner position
63
+ width, height: box dimensions
64
+ All values are normalized to image dimensions (0-1).
65
+ """
66
+ x: float
67
+ y: float
68
+ width: float
69
+ height: float
70
+
71
+
72
+ class Detection(BaseModel):
73
+ """Single object detection result.
74
+
75
+ Example:
76
+ {
77
+ "label": "person",
78
+ "bbox": {"x": 0.1, "y": 0.2, "width": 0.3, "height": 0.5},
79
+ "confidence": 0.95
80
+ }
81
+ """
82
+ label: str
83
+ bbox: BoundingBox
84
+ confidence: float
85
+
86
+
87
+ class VisualDetectionFormat(BaseModel):
88
+ """Object detection results for an image.
89
+
90
+ Example output:
91
+ {
92
+ "detections": [
93
+ {"label": "car", "bbox": {"x": 0.1, "y": 0.2, "width": 0.3, "height": 0.2}, "confidence": 0.92},
94
+ {"label": "person", "bbox": {"x": 0.5, "y": 0.3, "width": 0.1, "height": 0.4}, "confidence": 0.87}
95
+ ]
96
+ }
97
+ """
98
+ detections: List[Detection]
99
+
100
+
101
+ class VisualClassificationFormat(BaseModel):
102
+ """Classification result for an image or region.
103
+
104
+ Example output:
105
+ {
106
+ "suggested_label": "cat",
107
+ "confidence": 0.89,
108
+ "reasoning": "The image shows a feline with pointed ears and whiskers"
109
+ }
110
+ """
111
+ suggested_label: str
112
+ confidence: float
113
+ reasoning: Optional[str] = None
114
+
115
+
116
+ class VideoSegment(BaseModel):
117
+ """Temporal segment in a video.
118
+
119
+ Times are in seconds.
120
+ """
121
+ start_time: float
122
+ end_time: float
123
+ suggested_label: str
124
+ confidence: float
125
+ description: Optional[str] = None
126
+
127
+
128
+ class VideoSceneDetectionFormat(BaseModel):
129
+ """Scene/segment detection results for a video.
130
+
131
+ Example output:
132
+ {
133
+ "segments": [
134
+ {"start_time": 0.0, "end_time": 5.5, "suggested_label": "intro", "confidence": 0.9},
135
+ {"start_time": 5.5, "end_time": 15.0, "suggested_label": "action", "confidence": 0.85}
136
+ ]
137
+ }
138
+ """
139
+ segments: List[VideoSegment]
140
+
141
+
142
+ class VideoKeyframe(BaseModel):
143
+ """Keyframe annotation for a video.
144
+
145
+ timestamp: Time in seconds
146
+ """
147
+ timestamp: float
148
+ suggested_label: str
149
+ confidence: float
150
+ reason: Optional[str] = None
151
+
152
+
153
+ class VideoKeyframeDetectionFormat(BaseModel):
154
+ """Keyframe detection results for a video.
155
+
156
+ Example output:
157
+ {
158
+ "keyframes": [
159
+ {"timestamp": 2.5, "suggested_label": "scene_change", "confidence": 0.95, "reason": "Major visual transition"},
160
+ {"timestamp": 8.0, "suggested_label": "action_peak", "confidence": 0.82, "reason": "Key moment in action"}
161
+ ]
162
+ }
163
+ """
164
+ keyframes: List[VideoKeyframe]
165
+
166
+
167
+ class TrackPosition(BaseModel):
168
+ """Object position in a single frame for tracking."""
169
+ frame_index: int
170
+ bbox: BoundingBox
171
+ confidence: float
172
+
173
+
174
+ class ObjectTrack(BaseModel):
175
+ """Tracked object across multiple frames."""
176
+ track_id: int
177
+ label: str
178
+ positions: List[TrackPosition]
179
+
180
+
181
+ class VideoTrackingSuggestionFormat(BaseModel):
182
+ """Object tracking suggestions for a video.
183
+
184
+ Example output:
185
+ {
186
+ "tracks": [
187
+ {
188
+ "track_id": 1,
189
+ "label": "person",
190
+ "positions": [
191
+ {"frame_index": 0, "bbox": {"x": 0.1, "y": 0.2, "width": 0.15, "height": 0.3}, "confidence": 0.9},
192
+ {"frame_index": 1, "bbox": {"x": 0.12, "y": 0.22, "width": 0.15, "height": 0.3}, "confidence": 0.88}
193
+ ]
194
+ }
195
+ ]
196
+ }
197
+ """
198
+ tracks: List[ObjectTrack]
199
+
200
+
201
+ class FrameDetections(BaseModel):
202
+ """Detections for a single video frame."""
203
+ frame_index: int
204
+ detections: List[Detection]
205
+
206
+
207
+ class MultiFrameDetectionFormat(BaseModel):
208
+ """Detection results across multiple video frames.
209
+
210
+ Used when running detection on sampled video frames.
211
+ """
212
+ frames: List[FrameDetections]
213
+
214
+
215
+ # ============================================================================
216
+ # Class Registry
217
+ # ============================================================================
218
+
219
+ # ============================================================================
220
+ # Option Highlighting Output Format
221
+ # ============================================================================
222
+
223
+ class OptionHighlightFormat(BaseModel):
224
+ """LLM response for option highlighting.
225
+
226
+ Used to identify the most likely correct options for a discrete annotation task.
227
+ The highlighted options are shown at full opacity while others are dimmed.
228
+
229
+ Example output:
230
+ {
231
+ "highlighted_options": ["positive", "neutral"],
232
+ "confidence": 0.85
233
+ }
234
+ """
235
+ highlighted_options: List[str] # Top-k most likely option names/values
236
+ confidence: Optional[float] = None # Optional overall confidence score (0-1)
237
+
238
+
239
+ CLASS_REGISTRY = {
240
+ # Text annotation formats
241
+ "default_hint": GeneralHintFormat,
242
+ "default_keyword": GeneralKeywordFormat,
243
+ "default_random": GeneralRandomFormat, # Keep for backwards compatibility
244
+ "default_rationale": GeneralRationaleFormat,
245
+
246
+ # Option highlighting format
247
+ "option_highlight": OptionHighlightFormat,
248
+
249
+ # Visual annotation formats - Image
250
+ "visual_detection": VisualDetectionFormat,
251
+ "visual_classification": VisualClassificationFormat,
252
+
253
+ # Visual annotation formats - Video
254
+ "video_scene_detection": VideoSceneDetectionFormat,
255
+ "video_keyframe_detection": VideoKeyframeDetectionFormat,
256
+ "video_tracking_suggestion": VideoTrackingSuggestionFormat,
257
+ "multi_frame_detection": MultiFrameDetectionFormat,
258
+ }
potato/ai/prompt/multiselect.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hint": {
3
+ "prompt": "TASK: Generate guidance for multiple label selection.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nINSTRUCTIONS:\n1. Analyze text for features that may correspond to multiple labels\n2. Guide toward identifying ALL applicable categories\n3. Focus on overlapping characteristics and comprehensive analysis\n\nHINT REQUIREMENTS:\n- Identify indicators for each potential label category\n- Point out that multiple selections may be appropriate\n- Guide systematic evaluation of all label options\n- Highlight overlapping or complementary features",
4
+ "output_format": "default_hint",
5
+ "img": "/static/ai_assistant_img/blub.svg"
6
+ },
7
+ "keyword": {
8
+ "prompt": "TASK: Extract words/phrases that indicate multiple applicable labels.\n\nINPUT DETAILS:\n- Text: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nOBJECTIVE: Identify terms that support selection of multiple labels.\n\nSELECTION CRITERIA:\n- Words that indicate multiple categories simultaneously\n- Terms supporting different label aspects\n- Overlapping category indicators\n- Comprehensive feature markers\n- Multi-faceted descriptors",
9
+ "output_format": "default_keyword",
10
+ "img": "/static/ai_assistant_img/highlight.svg"
11
+ },
12
+ "rationale": {
13
+ "name": "Rationale",
14
+ "prompt": "TASK: Generate rationales explaining why each label might apply to this text.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nCRITICAL REQUIREMENT:\nYou MUST provide a rationale for EVERY label listed above. Your output must have exactly one entry per label.\n\nINSTRUCTIONS:\nFor EACH available label (ALL of them), provide a brief rationale explaining what evidence in the text could support selecting that label. Since multiple labels can be selected, focus on independent evidence for each label. If a label doesn't apply, explain why.\n\nOUTPUT FORMAT:\nReturn a JSON object with \"rationales\" array containing one object per label, each with \"label\" and \"reasoning\" fields.\n\nEXAMPLE OUTPUT:\n{\"rationales\": [{\"label\": \"category1\", \"reasoning\": \"The text mentions X which relates to this category\"}, {\"label\": \"category2\", \"reasoning\": \"The phrase Y suggests this also applies\"}, {\"label\": \"category3\", \"reasoning\": \"No direct evidence for this label in the text\"}]}",
15
+ "output_format": "default_rationale",
16
+ "img": "/static/ai_assistant_img/question.svg"
17
+ }
18
+ }
19
+
20
+
potato/ai/prompt/number.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hint": {
3
+ "prompt": "TASK: Generate guidance for numerical value extraction/estimation.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n\nINSTRUCTIONS:\n1. Analyze the text for numerical clues or quantifiable elements\n2. Guide the annotator toward identifying the correct numerical value\n3. Focus on mathematical, statistical, or countable aspects\n\nHINT REQUIREMENTS:\n- Identify numerical indicators, quantities, or measurable elements\n- Point out calculation methods or counting strategies\n- Highlight context that affects numerical interpretation\n- Guide toward systematic analysis approach",
4
+ "output_format": "default_hint",
5
+ "img": "/static/ai_assistant_img/blub.svg"
6
+ },
7
+ "keyword": {
8
+ "prompt": "TASK: Extract words/phrases that contain or indicate numerical information.\n\nINPUT DETAILS:\n- Text: \"${text}\"\n- Annotation task: ${description}\n\nOBJECTIVE: Identify words/phrases that directly relate to the numerical answer.\n\nSELECTION CRITERIA:\n- Explicit numbers, quantities, or measurements\n- Words indicating amount, frequency, or degree\n- Mathematical or statistical terminology\n- Comparative language (more, less, double, half)\n- Time references, percentages, ratios",
9
+ "output_format": "default_keyword",
10
+ "img": "/static/ai_assistant_img/highlight.svg"
11
+ },
12
+ "rationale": {
13
+ "name": "Rationale",
14
+ "prompt": "TASK: Generate rationales for different approaches to determining the numerical answer.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n\nINSTRUCTIONS:\nProvide different rationales or approaches for determining the numerical value. Consider different interpretations or calculation methods if applicable.\n\nOUTPUT FORMAT:\nReturn a JSON object with \"rationales\" array containing objects with \"label\" (approach name) and \"reasoning\" fields.\n\nEXAMPLE OUTPUT:\n{\"rationales\": [{\"label\": \"literal count\", \"reasoning\": \"Counting explicit mentions gives X\"}, {\"label\": \"inclusive interpretation\", \"reasoning\": \"Including implicit references increases the count to Y\"}]}",
15
+ "output_format": "default_rationale",
16
+ "img": "/static/ai_assistant_img/question.svg"
17
+ }
18
+ }
potato/ai/prompt/option_highlight.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "option_highlight": {
3
+ "prompt": "TASK: Identify the most likely correct options for an annotation task.\n\nINPUT DETAILS:\n- Content to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available options: ${labels}\n- Number of options to highlight: ${top_k}\n\nINSTRUCTIONS:\n1. Analyze the content carefully in the context of the annotation task\n2. Consider which ${top_k} options are most likely to be correct based on:\n - Direct evidence in the content\n - Contextual clues and tone\n - Domain knowledge relevant to the task\n3. Select exactly ${top_k} options (or fewer if fewer options exist)\n\nIMPORTANT:\n- Return ONLY the option names/values exactly as they appear in the available options list\n- Do not modify, paraphrase, or abbreviate the option names\n- If you're uncertain, choose options that seem most plausible given the content\n\nOUTPUT FORMAT:\nReturn a JSON object with:\n- \"highlighted_options\": array of ${top_k} option names from the available options\n- \"confidence\": your confidence score from 0.0 to 1.0\n\nEXAMPLE (if options are: positive, negative, neutral):\n{\"highlighted_options\": [\"positive\", \"neutral\"], \"confidence\": 0.75}",
4
+ "output_format": "option_highlight",
5
+ "img": "/static/ai_assistant_img/highlight.svg",
6
+ "name": "Option Highlight"
7
+ }
8
+ }
potato/ai/prompt/radio.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hint": {
3
+ "prompt": "TASK: Generate annotation guidance for single-choice selection.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nINSTRUCTIONS:\n1. Analyze the text for features that distinguish between the available labels\n2. Generate a helpful hint that guides classification WITHOUT revealing the answer\n3. Focus on decision-making criteria between options\n\nHINT REQUIREMENTS:\n- Identify key textual indicators that differentiate between label options\n- Point out distinguishing features (style, content, context)\n- Guide analytical thinking toward the classification criteria\n- Be specific about what to examine, not generic",
4
+ "output_format": "default_hint",
5
+ "img": "/static/ai_assistant_img/blub.svg"
6
+ },
7
+ "keyword": {
8
+ "prompt": "TASK: Identify words or short phrases in the text that relate to each label.\n\nINPUT DETAILS:\n- Text: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nINSTRUCTIONS:\nFor each label, find words or short phrases from the text that indicate or relate to that label. Only include words/phrases that actually appear in the text. Return 1-5 keywords per label. If no words relate to a label, return an empty list for that label.\n\nEXAMPLE OUTPUT FORMAT:\n{\"label_keywords\": [{\"label\": \"positive\", \"keywords\": [\"great\", \"love it\"]}, {\"label\": \"negative\", \"keywords\": [\"terrible\"]}]}",
9
+ "output_format": "default_keyword",
10
+ "img": "/static/ai_assistant_img/highlight.svg"
11
+ },
12
+ "rationale": {
13
+ "name": "Rationale",
14
+ "prompt": "TASK: Generate rationales explaining why each label might apply to this text.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nCRITICAL REQUIREMENT:\nYou MUST provide a rationale for EVERY label listed above. Your output must have exactly one entry per label.\n\nINSTRUCTIONS:\nFor EACH available label (ALL of them), provide a brief rationale explaining what evidence in the text could support choosing that label. Be balanced and objective - present the case for each label fairly. If a label doesn't strongly apply, explain what would need to be present.\n\nOUTPUT FORMAT:\nReturn a JSON object with \"rationales\" array containing one object per label, each with \"label\" and \"reasoning\" fields.\n\nEXAMPLE (if labels are: positive, negative, neutral, mixed):\n{\"rationales\": [{\"label\": \"positive\", \"reasoning\": \"The phrase 'excellent service' suggests satisfaction\"}, {\"label\": \"negative\", \"reasoning\": \"No clear negative indicators present\"}, {\"label\": \"neutral\", \"reasoning\": \"The tone is more emotional than neutral\"}, {\"label\": \"mixed\", \"reasoning\": \"Would need both positive and negative elements\"}]}",
15
+ "output_format": "default_rationale",
16
+ "img": "/static/ai_assistant_img/question.svg"
17
+ }
18
+ }