Spaces:
Paused
Paused
Deploy: Potato — Codebook Annotation
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +4 -0
- Dockerfile +41 -0
- README.md +42 -5
- annotation_output/.gitkeep +0 -0
- config.yaml +27 -0
- data/codebook-example.json +7 -0
- entrypoint.sh +30 -0
- layouts/task_layout.html +90 -0
- potato/__init__.py +45 -0
- potato/__main__.py +3 -0
- potato/active_learning_manager.py +1623 -0
- potato/adjudication.py +1224 -0
- potato/adjudication_export.py +162 -0
- potato/admin.py +0 -0
- potato/agent_proxy/__init__.py +44 -0
- potato/agent_proxy/base.py +138 -0
- potato/agent_proxy/coding_proxy.py +466 -0
- potato/agent_proxy/echo_proxy.py +55 -0
- potato/agent_proxy/http_proxy.py +108 -0
- potato/agent_proxy/openai_proxy.py +105 -0
- potato/agent_proxy/sandbox.py +76 -0
- potato/agent_proxy/session.py +119 -0
- potato/agent_runner.py +1008 -0
- potato/agent_runner_manager.py +226 -0
- potato/agreement.py +278 -0
- potato/ai/__init__.py +1 -0
- potato/ai/ai_cache.py +1473 -0
- potato/ai/ai_endpoint.py +688 -0
- potato/ai/ai_help_wrapper.py +203 -0
- potato/ai/ai_prompt.py +94 -0
- potato/ai/anthropic_endpoint.py +118 -0
- potato/ai/anthropic_vision_endpoint.py +405 -0
- potato/ai/gemini_endpoint.py +58 -0
- potato/ai/huggingface_endpoint.py +62 -0
- potato/ai/icl_labeler.py +1110 -0
- potato/ai/icl_prompt_builder.py +315 -0
- potato/ai/judge.py +265 -0
- potato/ai/llm_active_learning.py +733 -0
- potato/ai/ollama_endpoint.py +160 -0
- potato/ai/ollama_vision_endpoint.py +313 -0
- potato/ai/openai_endpoint.py +94 -0
- potato/ai/openai_vision_endpoint.py +324 -0
- potato/ai/openrouter_endpoint.py +93 -0
- potato/ai/prompt/image_annotation.json +44 -0
- potato/ai/prompt/likert.json +20 -0
- potato/ai/prompt/models_module.py +258 -0
- potato/ai/prompt/multiselect.json +20 -0
- potato/ai/prompt/number.json +18 -0
- potato/ai/prompt/option_highlight.json +8 -0
- potato/ai/prompt/radio.json +18 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
potato/static/vendor/font-awesome-6.7.2/webfonts/fa-brands-400.ttf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
potato/static/vendor/font-awesome-6.7.2/webfonts/fa-brands-400.woff2 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
potato/static/vendor/font-awesome-6.7.2/webfonts/fa-solid-900.ttf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
potato/static/vendor/font-awesome-6.7.2/webfonts/fa-solid-900.woff2 filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Create non-root user (HF Spaces requires UID 1000)
|
| 4 |
+
RUN useradd -m -u 1000 potato
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && \
|
| 8 |
+
apt-get install -y --no-install-recommends git && \
|
| 9 |
+
rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Set working directory
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
# Copy requirements first for layer caching
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt gunicorn
|
| 17 |
+
|
| 18 |
+
# Copy the application source
|
| 19 |
+
COPY . .
|
| 20 |
+
|
| 21 |
+
# Install potato
|
| 22 |
+
RUN pip install --no-cache-dir -e .
|
| 23 |
+
|
| 24 |
+
# Copy entrypoint
|
| 25 |
+
COPY entrypoint.sh /entrypoint.sh
|
| 26 |
+
RUN chmod +x /entrypoint.sh
|
| 27 |
+
|
| 28 |
+
# Create directories for output
|
| 29 |
+
RUN mkdir -p /app/annotation_output && \
|
| 30 |
+
chown -R potato:potato /app
|
| 31 |
+
|
| 32 |
+
# Switch to non-root user
|
| 33 |
+
USER potato
|
| 34 |
+
|
| 35 |
+
# HuggingFace Spaces expects port 7860
|
| 36 |
+
EXPOSE 7860
|
| 37 |
+
|
| 38 |
+
ENV POTATO_CONFIG=config.yaml
|
| 39 |
+
ENV PORT=7860
|
| 40 |
+
|
| 41 |
+
ENTRYPOINT ["/entrypoint.sh"]
|
README.md
CHANGED
|
@@ -1,10 +1,47 @@
|
|
| 1 |
---
|
| 2 |
-
title: Codebook
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Potato — Codebook Annotation
|
| 3 |
+
emoji: 🥔
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
tags:
|
| 11 |
+
- annotation
|
| 12 |
+
- potato
|
| 13 |
+
- advanced
|
| 14 |
+
- qda
|
| 15 |
+
- codebook
|
| 16 |
---
|
| 17 |
|
| 18 |
+
# Potato — Codebook Annotation
|
| 19 |
+
|
| 20 |
+
Shared evolving codebook across annotators.
|
| 21 |
+
|
| 22 |
+
A live demo of **[Potato](https://www.potatoannotator.com)** — the free, self-hosted annotation platform for NLP,
|
| 23 |
+
agentic, and GenAI research, configured entirely through YAML.
|
| 24 |
+
Visit **[www.potatoannotator.com](https://www.potatoannotator.com)** for docs, the schema gallery, and more demos.
|
| 25 |
+
|
| 26 |
+
## Try it out
|
| 27 |
+
|
| 28 |
+
1. Enter any username to log in (no password required).
|
| 29 |
+
2. Read the item shown in the main panel.
|
| 30 |
+
3. Annotate using the schemes on the right.
|
| 31 |
+
4. Click **Next** to continue.
|
| 32 |
+
|
| 33 |
+
> **Run your own copy:** click the **⋮ → Duplicate this Space** button (top-right) to launch
|
| 34 |
+
> this exact demo in your own account on free hardware — change the data and config to make it yours.
|
| 35 |
+
|
| 36 |
+
> Annotations in this demo are ephemeral. To collect and keep data, deploy your own
|
| 37 |
+
> Space — see the [deployment guide](https://github.com/davidjurgens/potato/blob/master/deployment/huggingface-spaces/deploy.md).
|
| 38 |
+
|
| 39 |
+
## About Potato
|
| 40 |
+
|
| 41 |
+
Potato supports 20+ annotation types — text, spans, images, audio, video, documents,
|
| 42 |
+
and agent traces — with AI-assisted labeling, quality control, and adjudication.
|
| 43 |
+
|
| 44 |
+
🥔 **Website: [www.potatoannotator.com](https://www.potatoannotator.com)** ·
|
| 45 |
+
[Documentation](https://www.potatoannotator.com) ·
|
| 46 |
+
[GitHub](https://github.com/davidjurgens/potato) ·
|
| 47 |
+
[All demos](https://github.com/davidjurgens/potato/blob/master/docs/data-export/potato_on_huggingface.md)
|
annotation_output/.gitkeep
ADDED
|
File without changes
|
config.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
port: 8000
|
| 2 |
+
annotation_task_name: Codebook Example
|
| 3 |
+
task_dir: .
|
| 4 |
+
output_annotation_dir: annotation_output/codebook-example/
|
| 5 |
+
output_annotation_format: tsv
|
| 6 |
+
data_files:
|
| 7 |
+
- data/codebook-example.json
|
| 8 |
+
item_properties:
|
| 9 |
+
id_key: id
|
| 10 |
+
text_key: text
|
| 11 |
+
user_config:
|
| 12 |
+
allow_all_users: true
|
| 13 |
+
users: []
|
| 14 |
+
alert_time_each_instance: 10000000
|
| 15 |
+
codebook_mode: open
|
| 16 |
+
annotation_schemes:
|
| 17 |
+
- annotation_type: multiselect
|
| 18 |
+
name: themes
|
| 19 |
+
description: Which qualitative themes appear in this excerpt?
|
| 20 |
+
codebook: true
|
| 21 |
+
labels:
|
| 22 |
+
- access barriers
|
| 23 |
+
- cost concerns
|
| 24 |
+
- provider trust
|
| 25 |
+
- wait times
|
| 26 |
+
site_dir: default
|
| 27 |
+
require_no_password: true
|
data/codebook-example.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{"id": "1", "text": "I couldn't get an appointment for three weeks, and even then the clinic was far from a bus line."},
|
| 3 |
+
{"id": "2", "text": "The doctor explained everything clearly and I felt listened to for the first time."},
|
| 4 |
+
{"id": "3", "text": "My copay went up again this year, so I've been skipping the follow-up visits."},
|
| 5 |
+
{"id": "4", "text": "Front desk staff were dismissive, and I waited over an hour past my scheduled time."},
|
| 6 |
+
{"id": "5", "text": "Telehealth made it much easier, but I still worry whether they have my full history."}
|
| 7 |
+
]
|
entrypoint.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
# Configuration
|
| 5 |
+
CONFIG_FILE="${POTATO_CONFIG:-config.yaml}"
|
| 6 |
+
PORT="${PORT:-7860}"
|
| 7 |
+
WORKERS="${GUNICORN_WORKERS:-2}"
|
| 8 |
+
THREADS="${GUNICORN_THREADS:-4}"
|
| 9 |
+
TIMEOUT="${GUNICORN_TIMEOUT:-120}"
|
| 10 |
+
|
| 11 |
+
echo "Starting Potato Demo Space..."
|
| 12 |
+
echo " Config: ${CONFIG_FILE}"
|
| 13 |
+
echo " Port: ${PORT}"
|
| 14 |
+
echo " Workers: ${WORKERS}"
|
| 15 |
+
|
| 16 |
+
# Validate config exists
|
| 17 |
+
if [ ! -f "${CONFIG_FILE}" ]; then
|
| 18 |
+
echo "ERROR: Config file not found: ${CONFIG_FILE}"
|
| 19 |
+
exit 1
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
# Start with gunicorn using the factory pattern
|
| 23 |
+
exec gunicorn \
|
| 24 |
+
--bind "0.0.0.0:${PORT}" \
|
| 25 |
+
--workers "${WORKERS}" \
|
| 26 |
+
--threads "${THREADS}" \
|
| 27 |
+
--timeout "${TIMEOUT}" \
|
| 28 |
+
--access-logfile - \
|
| 29 |
+
--error-logfile - \
|
| 30 |
+
"potato.flask_server:create_app('${CONFIG_FILE}')"
|
layouts/task_layout.html
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- CONFIG_HASH: 80d495277c377949581be1771e4d8c3e_c4da35694b3e1122725c151f4a96e755 -->
|
| 2 |
+
<!-- Generated annotation layout file -->
|
| 3 |
+
<!-- This file was automatically generated based on the annotation schemes in your config -->
|
| 4 |
+
<!-- You can customize this file to modify the layout of your annotation interface -->
|
| 5 |
+
<!-- Changes to this file will be preserved across server restarts -->
|
| 6 |
+
|
| 7 |
+
<div class="annotation_schema">
|
| 8 |
+
|
| 9 |
+
<form id="themes" class="annotation-form multiselect shadcn-multiselect-container" action="javascript:void(0)" data-annotation-id="0" data-annotation-type="multiselect" data-schema-name="themes" data-grid-columns="1">
|
| 10 |
+
|
| 11 |
+
<fieldset schema="themes">
|
| 12 |
+
<legend class="shadcn-multiselect-title">Which qualitative themes appear in this excerpt?</legend>
|
| 13 |
+
<div class="shadcn-multiselect-grid" style="grid-template-columns: repeat(1, 1fr);">
|
| 14 |
+
<div class="shadcn-multiselect-item">
|
| 15 |
+
<input class="themes shadcn-multiselect-checkbox annotation-input"
|
| 16 |
+
type="checkbox"
|
| 17 |
+
id="themes_access barriers_checkbox"
|
| 18 |
+
name="themes:::access barriers"
|
| 19 |
+
value="access barriers"
|
| 20 |
+
label_name="access barriers"
|
| 21 |
+
schema="themes"
|
| 22 |
+
onclick="whetherNone(this);registerAnnotation(this)"
|
| 23 |
+
validation="">
|
| 24 |
+
<label for="themes_access barriers_checkbox" schema="themes" class="shadcn-multiselect-label">
|
| 25 |
+
Access Barriers
|
| 26 |
+
</label>
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
<div class="shadcn-multiselect-item">
|
| 30 |
+
<input class="themes shadcn-multiselect-checkbox annotation-input"
|
| 31 |
+
type="checkbox"
|
| 32 |
+
id="themes_cost concerns_checkbox"
|
| 33 |
+
name="themes:::cost concerns"
|
| 34 |
+
value="cost concerns"
|
| 35 |
+
label_name="cost concerns"
|
| 36 |
+
schema="themes"
|
| 37 |
+
onclick="whetherNone(this);registerAnnotation(this)"
|
| 38 |
+
validation="">
|
| 39 |
+
<label for="themes_cost concerns_checkbox" schema="themes" class="shadcn-multiselect-label">
|
| 40 |
+
Cost Concerns
|
| 41 |
+
</label>
|
| 42 |
+
</div>
|
| 43 |
+
|
| 44 |
+
<div class="shadcn-multiselect-item">
|
| 45 |
+
<input class="themes shadcn-multiselect-checkbox annotation-input"
|
| 46 |
+
type="checkbox"
|
| 47 |
+
id="themes_provider trust_checkbox"
|
| 48 |
+
name="themes:::provider trust"
|
| 49 |
+
value="provider trust"
|
| 50 |
+
label_name="provider trust"
|
| 51 |
+
schema="themes"
|
| 52 |
+
onclick="whetherNone(this);registerAnnotation(this)"
|
| 53 |
+
validation="">
|
| 54 |
+
<label for="themes_provider trust_checkbox" schema="themes" class="shadcn-multiselect-label">
|
| 55 |
+
Provider Trust
|
| 56 |
+
</label>
|
| 57 |
+
</div>
|
| 58 |
+
|
| 59 |
+
<div class="shadcn-multiselect-item">
|
| 60 |
+
<input class="themes shadcn-multiselect-checkbox annotation-input"
|
| 61 |
+
type="checkbox"
|
| 62 |
+
id="themes_wait times_checkbox"
|
| 63 |
+
name="themes:::wait times"
|
| 64 |
+
value="wait times"
|
| 65 |
+
label_name="wait times"
|
| 66 |
+
schema="themes"
|
| 67 |
+
onclick="whetherNone(this);registerAnnotation(this)"
|
| 68 |
+
validation="">
|
| 69 |
+
<label for="themes_wait times_checkbox" schema="themes" class="shadcn-multiselect-label">
|
| 70 |
+
Wait Times
|
| 71 |
+
</label>
|
| 72 |
+
</div>
|
| 73 |
+
|
| 74 |
+
<div class="shadcn-multiselect-item">
|
| 75 |
+
<input class="themes shadcn-multiselect-checkbox annotation-input"
|
| 76 |
+
type="checkbox"
|
| 77 |
+
id="themes_InspectCode_checkbox"
|
| 78 |
+
name="themes:::InspectCode"
|
| 79 |
+
value="InspectCode"
|
| 80 |
+
label_name="InspectCode"
|
| 81 |
+
schema="themes"
|
| 82 |
+
onclick="whetherNone(this);registerAnnotation(this)"
|
| 83 |
+
validation="">
|
| 84 |
+
<label for="themes_InspectCode_checkbox" schema="themes" class="shadcn-multiselect-label">
|
| 85 |
+
InspectCode
|
| 86 |
+
</label>
|
| 87 |
+
</div>
|
| 88 |
+
</div></fieldset></form>
|
| 89 |
+
|
| 90 |
+
</div>
|
potato/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Potato Annotation Platform
|
| 3 |
+
|
| 4 |
+
A flexible, web-based platform for text annotation tasks.
|
| 5 |
+
|
| 6 |
+
This package provides a comprehensive annotation system with the following features:
|
| 7 |
+
- Multi-phase annotation workflows (consent, instructions, training, annotation, post-study)
|
| 8 |
+
- Support for various annotation types (labels, spans, text, likert scales, best-worst scaling)
|
| 9 |
+
- User authentication and session management
|
| 10 |
+
- Active learning capabilities
|
| 11 |
+
- Admin dashboard for monitoring progress
|
| 12 |
+
- Configurable assignment strategies
|
| 13 |
+
- Multi-language and multi-task support
|
| 14 |
+
|
| 15 |
+
Main Components:
|
| 16 |
+
- flask_server: Core Flask application and server logic
|
| 17 |
+
- routes: HTTP route handlers and request processing
|
| 18 |
+
- user_state_management: User progress tracking and state persistence
|
| 19 |
+
- item_state_management: Data item management and assignment
|
| 20 |
+
- authentificaton: User authentication backends
|
| 21 |
+
- admin: Admin dashboard functionality
|
| 22 |
+
- activelearning: Active learning algorithms and model training
|
| 23 |
+
|
| 24 |
+
Usage:
|
| 25 |
+
from potato.flask_server import create_app
|
| 26 |
+
app = create_app()
|
| 27 |
+
app.run()
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from .flask_server import create_app
|
| 31 |
+
|
| 32 |
+
__version__ = "2.6.0"
|
| 33 |
+
__author__ = "Potato Annotation Platform Team"
|
| 34 |
+
__description__ = "A flexible, web-based platform for text annotation tasks"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def __getattr__(name):
|
| 38 |
+
"""Lazy imports for optional heavy dependencies."""
|
| 39 |
+
if name == "load_as_dataset":
|
| 40 |
+
from .datasets_integration import load_as_dataset
|
| 41 |
+
return load_as_dataset
|
| 42 |
+
if name == "load_annotations":
|
| 43 |
+
from .datasets_integration import load_annotations
|
| 44 |
+
return load_annotations
|
| 45 |
+
raise AttributeError(f"module 'potato' has no attribute {name!r}")
|
potato/__main__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from potato.flask_server import main
|
| 2 |
+
|
| 3 |
+
main()
|
potato/active_learning_manager.py
ADDED
|
@@ -0,0 +1,1623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced Active Learning Manager with Database Persistence
|
| 3 |
+
|
| 4 |
+
This module provides a comprehensive active learning system with optional
|
| 5 |
+
database persistence, model saving, LLM integration, and multiple query
|
| 6 |
+
strategies including uncertainty sampling, diversity sampling, BADGE, BALD,
|
| 7 |
+
and hybrid combinations.
|
| 8 |
+
|
| 9 |
+
References:
|
| 10 |
+
[1] Ash et al. (2020) "Deep Batch Active Learning by Diverse, Uncertain
|
| 11 |
+
Gradient Lower Bounds" (BADGE). ICLR 2020.
|
| 12 |
+
[2] Houlsby et al. (2011) "Bayesian Active Learning for Classification
|
| 13 |
+
and Preference Learning" (BALD).
|
| 14 |
+
[3] Bayer et al. (2024) "ActiveLLM: Large Language Model-Based Active
|
| 15 |
+
Learning for Textual Few-Shot Scenarios". TACL.
|
| 16 |
+
[4] Yuan et al. (2024) "Hide and Seek in Noise Labels: Noise-Robust
|
| 17 |
+
Collaborative Active Learning" (NoiseAL). ACL 2024.
|
| 18 |
+
[5] Mavromatis et al. (2024) "CoverICL: Selective Annotation for
|
| 19 |
+
In-Context Learning via Active Graph Coverage". EMNLP 2024.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import threading
|
| 23 |
+
import logging
|
| 24 |
+
import time
|
| 25 |
+
import os
|
| 26 |
+
import pickle
|
| 27 |
+
import json
|
| 28 |
+
from typing import Dict, List, Optional, Tuple, Any, Union
|
| 29 |
+
from collections import defaultdict, Counter
|
| 30 |
+
import dataclasses
|
| 31 |
+
from dataclasses import dataclass, field, asdict
|
| 32 |
+
from enum import Enum
|
| 33 |
+
import random
|
| 34 |
+
import queue
|
| 35 |
+
from datetime import datetime
|
| 36 |
+
from abc import ABC, abstractmethod
|
| 37 |
+
|
| 38 |
+
from sklearn.pipeline import Pipeline
|
| 39 |
+
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
|
| 40 |
+
from sklearn.linear_model import LogisticRegression
|
| 41 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 42 |
+
from sklearn.svm import SVC
|
| 43 |
+
from sklearn.metrics import accuracy_score, classification_report
|
| 44 |
+
import numpy as np
|
| 45 |
+
|
| 46 |
+
from potato.item_state_management import ItemStateManager, get_item_state_manager
|
| 47 |
+
from potato.user_state_management import get_user_state_manager
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ResolutionStrategy(Enum):
|
| 54 |
+
"""Strategies for resolving multiple annotations per instance."""
|
| 55 |
+
MAJORITY_VOTE = "majority_vote"
|
| 56 |
+
RANDOM = "random"
|
| 57 |
+
CONSENSUS = "consensus"
|
| 58 |
+
WEIGHTED_AVERAGE = "weighted_average"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# SentenceTransformerVectorizer
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
class SentenceTransformerVectorizer:
|
| 66 |
+
"""sklearn-compatible wrapper for sentence-transformers.
|
| 67 |
+
|
| 68 |
+
Uses dense embeddings from pre-trained transformer models instead of
|
| 69 |
+
bag-of-words features. Produces 384-dim vectors (for default model)
|
| 70 |
+
that capture semantic meaning, enabling better classification with
|
| 71 |
+
fewer training examples.
|
| 72 |
+
|
| 73 |
+
The ``sentence-transformers`` package is an **optional** dependency and
|
| 74 |
+
is only imported when this vectorizer is actually used.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
| 78 |
+
self.model_name = model_name
|
| 79 |
+
self._model = None
|
| 80 |
+
|
| 81 |
+
def fit(self, X, y=None):
|
| 82 |
+
from sentence_transformers import SentenceTransformer
|
| 83 |
+
self._model = SentenceTransformer(self.model_name)
|
| 84 |
+
return self
|
| 85 |
+
|
| 86 |
+
def transform(self, X):
|
| 87 |
+
if self._model is None:
|
| 88 |
+
raise RuntimeError("SentenceTransformerVectorizer has not been fitted yet")
|
| 89 |
+
return self._model.encode(list(X), show_progress_bar=False)
|
| 90 |
+
|
| 91 |
+
def fit_transform(self, X, y=None):
|
| 92 |
+
self.fit(X, y)
|
| 93 |
+
return self.transform(X)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
# Query Strategies
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
class QueryStrategy(ABC):
|
| 101 |
+
"""Base class for active learning query strategies."""
|
| 102 |
+
|
| 103 |
+
@abstractmethod
|
| 104 |
+
def rank(self, texts: List[str], model, vectorizer,
|
| 105 |
+
annotated_texts: Optional[List[str]] = None) -> List[Tuple[int, float]]:
|
| 106 |
+
"""Return list of (index, score) sorted by selection priority (highest first)."""
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class UncertaintySampling(QueryStrategy):
|
| 110 |
+
"""Select instances where classifier is least confident.
|
| 111 |
+
|
| 112 |
+
Selects x* = argmax_x (1 - max_y P(y|x)), i.e., instances where the
|
| 113 |
+
model's best guess has lowest confidence.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def rank(self, texts, model, vectorizer, annotated_texts=None):
|
| 117 |
+
try:
|
| 118 |
+
features = vectorizer.transform(texts)
|
| 119 |
+
probas = model.predict_proba(features)
|
| 120 |
+
# Score = 1 - max_prob (higher = more uncertain = higher priority)
|
| 121 |
+
scores = 1.0 - np.max(probas, axis=1)
|
| 122 |
+
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
|
| 123 |
+
return ranked
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.warning(f"UncertaintySampling failed: {e}")
|
| 126 |
+
return [(i, 0.5) for i in range(len(texts))]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class DiversitySampling(QueryStrategy):
|
| 130 |
+
"""Select instances that maximize feature-space coverage.
|
| 131 |
+
|
| 132 |
+
Uses cosine distance from already-annotated instances in the vectorized
|
| 133 |
+
feature space. Ensures the training set covers the full data distribution
|
| 134 |
+
rather than over-sampling one region.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def rank(self, texts, model, vectorizer, annotated_texts=None):
|
| 138 |
+
from sklearn.metrics.pairwise import cosine_distances
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
features = vectorizer.transform(texts)
|
| 142 |
+
if hasattr(features, 'toarray'):
|
| 143 |
+
features = features.toarray()
|
| 144 |
+
|
| 145 |
+
if annotated_texts:
|
| 146 |
+
annotated_features = vectorizer.transform(annotated_texts)
|
| 147 |
+
if hasattr(annotated_features, 'toarray'):
|
| 148 |
+
annotated_features = annotated_features.toarray()
|
| 149 |
+
# Score = min cosine distance to any annotated instance
|
| 150 |
+
distances = cosine_distances(features, annotated_features)
|
| 151 |
+
scores = np.min(distances, axis=1)
|
| 152 |
+
else:
|
| 153 |
+
# No annotated texts yet: use distance from centroid
|
| 154 |
+
centroid = np.mean(features, axis=0, keepdims=True)
|
| 155 |
+
scores = cosine_distances(features, centroid).ravel()
|
| 156 |
+
|
| 157 |
+
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
|
| 158 |
+
return ranked
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.warning(f"DiversitySampling failed: {e}")
|
| 161 |
+
return [(i, 0.5) for i in range(len(texts))]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class BadgeStrategy(QueryStrategy):
|
| 165 |
+
"""BADGE approximation: uncertainty-weighted diversity.
|
| 166 |
+
|
| 167 |
+
Inspired by Ash et al. (2020) [Ref 1]. Full BADGE uses gradient embeddings
|
| 168 |
+
from neural networks. Our approximation:
|
| 169 |
+
1. Weight feature vectors by (1 - max_prob) as uncertainty proxy
|
| 170 |
+
2. Run k-means++ initialization on weighted vectors to select
|
| 171 |
+
diverse-uncertain instances.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def rank(self, texts, model, vectorizer, annotated_texts=None):
|
| 175 |
+
try:
|
| 176 |
+
features = vectorizer.transform(texts)
|
| 177 |
+
if hasattr(features, 'toarray'):
|
| 178 |
+
features = features.toarray()
|
| 179 |
+
|
| 180 |
+
probas = model.predict_proba(features)
|
| 181 |
+
uncertainty = 1.0 - np.max(probas, axis=1)
|
| 182 |
+
|
| 183 |
+
# Weight features by uncertainty
|
| 184 |
+
weighted = features * uncertainty[:, np.newaxis]
|
| 185 |
+
|
| 186 |
+
# Use k-means++ initialization to select diverse-uncertain points
|
| 187 |
+
from sklearn.cluster import kmeans_plusplus
|
| 188 |
+
n_clusters = min(len(texts), max(1, len(texts) // 2))
|
| 189 |
+
_, indices = kmeans_plusplus(weighted, n_clusters=n_clusters,
|
| 190 |
+
random_state=42)
|
| 191 |
+
|
| 192 |
+
# Build score: selected centroids get highest scores
|
| 193 |
+
scores = np.zeros(len(texts))
|
| 194 |
+
for rank_pos, idx in enumerate(indices):
|
| 195 |
+
scores[idx] = len(indices) - rank_pos # highest for first-selected
|
| 196 |
+
|
| 197 |
+
# For non-selected, use uncertainty as tiebreaker
|
| 198 |
+
for i in range(len(texts)):
|
| 199 |
+
if scores[i] == 0:
|
| 200 |
+
scores[i] = uncertainty[i] * 0.01
|
| 201 |
+
|
| 202 |
+
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
|
| 203 |
+
return ranked
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.warning(f"BadgeStrategy failed, falling back to uncertainty: {e}")
|
| 206 |
+
return UncertaintySampling().rank(texts, model, vectorizer, annotated_texts)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class BaldStrategy(QueryStrategy):
|
| 210 |
+
"""BALD: Bayesian Active Learning by Disagreement.
|
| 211 |
+
|
| 212 |
+
Based on Houlsby et al. (2011) [Ref 2]. Trains an ensemble of classifiers
|
| 213 |
+
with different random seeds/bootstrap samples. Selects instances with
|
| 214 |
+
highest mutual information: H[y|x] - E_theta[H[y|x,theta]], i.e.,
|
| 215 |
+
where the ensemble disagrees most.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(self, n_estimators: int = 5, bootstrap_fraction: float = 0.8):
|
| 219 |
+
self.n_estimators = n_estimators
|
| 220 |
+
self.bootstrap_fraction = bootstrap_fraction
|
| 221 |
+
|
| 222 |
+
def rank(self, texts, model, vectorizer, annotated_texts=None):
|
| 223 |
+
try:
|
| 224 |
+
features = vectorizer.transform(texts)
|
| 225 |
+
if hasattr(features, 'toarray'):
|
| 226 |
+
features = features.toarray()
|
| 227 |
+
|
| 228 |
+
probas = model.predict_proba(features)
|
| 229 |
+
# Average entropy
|
| 230 |
+
avg_proba = probas
|
| 231 |
+
entropy_avg = -np.sum(avg_proba * np.log(avg_proba + 1e-10), axis=1)
|
| 232 |
+
|
| 233 |
+
# For a single model, we approximate BALD by using dropout-like noise
|
| 234 |
+
# or by comparing with uniform. Since we store the ensemble models
|
| 235 |
+
# on the manager, we just use the single model's entropy here and
|
| 236 |
+
# the ensemble version is handled in ActiveLearningManager._train_bald_ensemble
|
| 237 |
+
scores = entropy_avg
|
| 238 |
+
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
|
| 239 |
+
return ranked
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logger.warning(f"BaldStrategy failed: {e}")
|
| 242 |
+
return [(i, 0.5) for i in range(len(texts))]
|
| 243 |
+
|
| 244 |
+
def rank_with_ensemble(self, texts, ensemble_models, vectorizer):
|
| 245 |
+
"""Rank using actual ensemble disagreement (mutual information)."""
|
| 246 |
+
try:
|
| 247 |
+
features = vectorizer.transform(texts)
|
| 248 |
+
if hasattr(features, 'toarray'):
|
| 249 |
+
features = features.toarray()
|
| 250 |
+
|
| 251 |
+
all_probas = []
|
| 252 |
+
for m in ensemble_models:
|
| 253 |
+
all_probas.append(m.predict_proba(features))
|
| 254 |
+
|
| 255 |
+
all_probas = np.array(all_probas) # (n_estimators, n_samples, n_classes)
|
| 256 |
+
|
| 257 |
+
# Mean prediction across ensemble
|
| 258 |
+
mean_proba = np.mean(all_probas, axis=0) # (n_samples, n_classes)
|
| 259 |
+
|
| 260 |
+
# H[y|x] - entropy of mean prediction
|
| 261 |
+
entropy_mean = -np.sum(mean_proba * np.log(mean_proba + 1e-10), axis=1)
|
| 262 |
+
|
| 263 |
+
# E_theta[H[y|x,theta]] - mean of individual entropies
|
| 264 |
+
individual_entropies = -np.sum(all_probas * np.log(all_probas + 1e-10), axis=2)
|
| 265 |
+
mean_entropy = np.mean(individual_entropies, axis=0)
|
| 266 |
+
|
| 267 |
+
# Mutual information = H[y|x] - E[H[y|x,theta]]
|
| 268 |
+
mutual_info = entropy_mean - mean_entropy
|
| 269 |
+
|
| 270 |
+
ranked = sorted(enumerate(mutual_info), key=lambda x: x[1], reverse=True)
|
| 271 |
+
return ranked
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.warning(f"BaldStrategy ensemble ranking failed: {e}")
|
| 274 |
+
return [(i, 0.5) for i in range(len(texts))]
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class HybridStrategy(QueryStrategy):
|
| 278 |
+
"""Weighted combination of uncertainty and diversity scores.
|
| 279 |
+
|
| 280 |
+
Combines strategies with configurable weights. Default: 0.7 uncertainty +
|
| 281 |
+
0.3 diversity.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(self, weights: Optional[Dict[str, float]] = None):
|
| 285 |
+
self.weights = weights or {"uncertainty": 0.7, "diversity": 0.3}
|
| 286 |
+
|
| 287 |
+
def rank(self, texts, model, vectorizer, annotated_texts=None):
|
| 288 |
+
try:
|
| 289 |
+
strategies = {}
|
| 290 |
+
if self.weights.get("uncertainty", 0) > 0:
|
| 291 |
+
strategies["uncertainty"] = UncertaintySampling()
|
| 292 |
+
if self.weights.get("diversity", 0) > 0:
|
| 293 |
+
strategies["diversity"] = DiversitySampling()
|
| 294 |
+
|
| 295 |
+
# Collect raw scores from each strategy
|
| 296 |
+
all_scores = {}
|
| 297 |
+
for name, strategy in strategies.items():
|
| 298 |
+
rankings = strategy.rank(texts, model, vectorizer, annotated_texts)
|
| 299 |
+
score_map = {idx: score for idx, score in rankings}
|
| 300 |
+
all_scores[name] = score_map
|
| 301 |
+
|
| 302 |
+
# Normalize each strategy's scores to [0, 1]
|
| 303 |
+
for name in all_scores:
|
| 304 |
+
vals = list(all_scores[name].values())
|
| 305 |
+
min_val, max_val = min(vals), max(vals)
|
| 306 |
+
rng = max_val - min_val if max_val > min_val else 1.0
|
| 307 |
+
all_scores[name] = {
|
| 308 |
+
idx: (s - min_val) / rng for idx, s in all_scores[name].items()
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
# Weighted combination
|
| 312 |
+
combined = {}
|
| 313 |
+
for i in range(len(texts)):
|
| 314 |
+
combined[i] = sum(
|
| 315 |
+
self.weights.get(name, 0) * all_scores.get(name, {}).get(i, 0)
|
| 316 |
+
for name in self.weights
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
ranked = sorted(combined.items(), key=lambda x: x[1], reverse=True)
|
| 320 |
+
return ranked
|
| 321 |
+
except Exception as e:
|
| 322 |
+
logger.warning(f"HybridStrategy failed: {e}")
|
| 323 |
+
return UncertaintySampling().rank(texts, model, vectorizer, annotated_texts)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# Strategy registry
|
| 327 |
+
STRATEGY_REGISTRY = {
|
| 328 |
+
"uncertainty": UncertaintySampling,
|
| 329 |
+
"diversity": DiversitySampling,
|
| 330 |
+
"badge": BadgeStrategy,
|
| 331 |
+
"bald": BaldStrategy,
|
| 332 |
+
"hybrid": HybridStrategy,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def create_query_strategy(config: 'ActiveLearningConfig') -> QueryStrategy:
|
| 337 |
+
"""Create a query strategy from config."""
|
| 338 |
+
strategy_name = config.query_strategy
|
| 339 |
+
if strategy_name == "hybrid":
|
| 340 |
+
return HybridStrategy(weights=config.hybrid_weights)
|
| 341 |
+
elif strategy_name == "bald":
|
| 342 |
+
params = config.bald_params
|
| 343 |
+
return BaldStrategy(
|
| 344 |
+
n_estimators=params.get("n_estimators", 5),
|
| 345 |
+
bootstrap_fraction=params.get("bootstrap_fraction", 0.8),
|
| 346 |
+
)
|
| 347 |
+
elif strategy_name in STRATEGY_REGISTRY:
|
| 348 |
+
return STRATEGY_REGISTRY[strategy_name]()
|
| 349 |
+
else:
|
| 350 |
+
logger.warning(f"Unknown strategy '{strategy_name}', falling back to uncertainty")
|
| 351 |
+
return UncertaintySampling()
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# ---------------------------------------------------------------------------
|
| 355 |
+
# ICLClassifier wrapper (Phase 5A)
|
| 356 |
+
# ---------------------------------------------------------------------------
|
| 357 |
+
|
| 358 |
+
class ICLClassifier:
|
| 359 |
+
"""Wraps ICLLabeler as an sklearn-compatible classifier for ensemble use.
|
| 360 |
+
|
| 361 |
+
Enables combining LLM-based ICL predictions with traditional classifier
|
| 362 |
+
predictions in a hybrid ensemble for active learning scoring.
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
def __init__(self, icl_labeler, schema_name: str, label_names: List[str]):
|
| 366 |
+
self.icl_labeler = icl_labeler
|
| 367 |
+
self.schema_name = schema_name
|
| 368 |
+
self.label_names = label_names
|
| 369 |
+
self.classes_ = np.array(label_names)
|
| 370 |
+
|
| 371 |
+
def predict_proba(self, texts: List[str]) -> np.ndarray:
|
| 372 |
+
"""Get label probabilities from LLM via ICL."""
|
| 373 |
+
n_classes = len(self.label_names)
|
| 374 |
+
probas = np.full((len(texts), n_classes), 1.0 / n_classes)
|
| 375 |
+
|
| 376 |
+
for i, text in enumerate(texts):
|
| 377 |
+
try:
|
| 378 |
+
prediction = self.icl_labeler.label_instance(
|
| 379 |
+
instance_id=f"_al_query_{i}",
|
| 380 |
+
schema_name=self.schema_name,
|
| 381 |
+
instance_text=text,
|
| 382 |
+
)
|
| 383 |
+
if prediction and prediction.predicted_label in self.label_names:
|
| 384 |
+
idx = self.label_names.index(prediction.predicted_label)
|
| 385 |
+
conf = prediction.confidence_score
|
| 386 |
+
# Distribute: conf to predicted label, (1-conf)/(n-1) to others
|
| 387 |
+
remaining = (1.0 - conf) / max(1, n_classes - 1)
|
| 388 |
+
probas[i] = remaining
|
| 389 |
+
probas[i, idx] = conf
|
| 390 |
+
except Exception:
|
| 391 |
+
pass # Keep uniform distribution
|
| 392 |
+
|
| 393 |
+
return probas
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# ---------------------------------------------------------------------------
|
| 397 |
+
# Configuration
|
| 398 |
+
# ---------------------------------------------------------------------------
|
| 399 |
+
|
| 400 |
+
@dataclass
|
| 401 |
+
class ActiveLearningConfig:
|
| 402 |
+
"""Enhanced configuration for active learning."""
|
| 403 |
+
enabled: bool = False
|
| 404 |
+
classifier_name: str = "sklearn.linear_model.LogisticRegression"
|
| 405 |
+
classifier_kwargs: Dict[str, Any] = None
|
| 406 |
+
vectorizer_name: str = "sklearn.feature_extraction.text.TfidfVectorizer"
|
| 407 |
+
vectorizer_kwargs: Dict[str, Any] = None
|
| 408 |
+
min_annotations_per_instance: int = 1
|
| 409 |
+
min_instances_for_training: int = 10
|
| 410 |
+
max_instances_to_reorder: Optional[int] = None
|
| 411 |
+
resolution_strategy: ResolutionStrategy = ResolutionStrategy.MAJORITY_VOTE
|
| 412 |
+
random_sample_percent: float = 0.2
|
| 413 |
+
update_frequency: int = 5
|
| 414 |
+
schema_names: List[str] = None
|
| 415 |
+
|
| 416 |
+
# Classifier/vectorizer passthrough params (Phase 1C)
|
| 417 |
+
classifier_params: Dict[str, Any] = field(default_factory=dict)
|
| 418 |
+
vectorizer_params: Dict[str, Any] = field(default_factory=dict)
|
| 419 |
+
|
| 420 |
+
# Probability calibration (Phase 1D)
|
| 421 |
+
calibrate_probabilities: bool = True
|
| 422 |
+
|
| 423 |
+
# Query strategy (Phase 2)
|
| 424 |
+
query_strategy: str = "uncertainty"
|
| 425 |
+
hybrid_weights: Dict[str, float] = field(
|
| 426 |
+
default_factory=lambda: {"uncertainty": 0.7, "diversity": 0.3}
|
| 427 |
+
)
|
| 428 |
+
bald_params: Dict[str, Any] = field(
|
| 429 |
+
default_factory=lambda: {"n_estimators": 5, "bootstrap_fraction": 0.8}
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Cold-start (Phase 3)
|
| 433 |
+
cold_start_strategy: str = "random"
|
| 434 |
+
cold_start_batch_size: int = 20
|
| 435 |
+
|
| 436 |
+
# ICL ensemble (Phase 5)
|
| 437 |
+
use_icl_ensemble: bool = False
|
| 438 |
+
icl_ensemble_params: Dict[str, Any] = field(default_factory=lambda: {
|
| 439 |
+
"initial_icl_weight": 0.7,
|
| 440 |
+
"final_icl_weight": 0.2,
|
| 441 |
+
"transition_instances": 100,
|
| 442 |
+
})
|
| 443 |
+
|
| 444 |
+
# Annotation routing (Phase 5D)
|
| 445 |
+
annotation_routing: bool = False
|
| 446 |
+
routing_thresholds: Dict[str, float] = field(default_factory=lambda: {
|
| 447 |
+
"auto_label_min_confidence": 0.9,
|
| 448 |
+
"show_suggestion_below": 0.5,
|
| 449 |
+
})
|
| 450 |
+
verification_sample_rate: float = 0.2
|
| 451 |
+
|
| 452 |
+
# Database persistence
|
| 453 |
+
database_enabled: bool = False
|
| 454 |
+
database_config: Dict[str, Any] = None
|
| 455 |
+
|
| 456 |
+
# Model persistence
|
| 457 |
+
model_persistence_enabled: bool = False
|
| 458 |
+
model_save_directory: Optional[str] = None
|
| 459 |
+
model_retention_count: int = 2
|
| 460 |
+
|
| 461 |
+
# LLM integration
|
| 462 |
+
llm_enabled: bool = False
|
| 463 |
+
llm_config: Dict[str, Any] = None
|
| 464 |
+
|
| 465 |
+
def __post_init__(self):
|
| 466 |
+
if self.classifier_kwargs is None:
|
| 467 |
+
self.classifier_kwargs = {}
|
| 468 |
+
if self.vectorizer_kwargs is None:
|
| 469 |
+
self.vectorizer_kwargs = {}
|
| 470 |
+
if self.schema_names is None:
|
| 471 |
+
self.schema_names = []
|
| 472 |
+
if self.database_config is None:
|
| 473 |
+
self.database_config = {}
|
| 474 |
+
if self.llm_config is None:
|
| 475 |
+
self.llm_config = {}
|
| 476 |
+
# Merge classifier_params into classifier_kwargs
|
| 477 |
+
if self.classifier_params:
|
| 478 |
+
self.classifier_kwargs.update(self.classifier_params)
|
| 479 |
+
# Merge vectorizer_params into vectorizer_kwargs
|
| 480 |
+
if self.vectorizer_params:
|
| 481 |
+
self.vectorizer_kwargs.update(self.vectorizer_params)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@dataclass
|
| 485 |
+
class TrainingMetrics:
|
| 486 |
+
"""Metrics for a training run."""
|
| 487 |
+
schema_name: str
|
| 488 |
+
training_time: float
|
| 489 |
+
accuracy: float
|
| 490 |
+
instance_count: int
|
| 491 |
+
timestamp: datetime
|
| 492 |
+
model_file_path: Optional[str] = None
|
| 493 |
+
confidence_distribution: Dict[str, float] = None
|
| 494 |
+
error_message: Optional[str] = None
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class ModelPersistence:
|
| 498 |
+
"""Handles model saving and loading with metadata."""
|
| 499 |
+
|
| 500 |
+
def __init__(self, save_directory: str, retention_count: int = 2):
|
| 501 |
+
self.save_directory = save_directory
|
| 502 |
+
self.retention_count = retention_count
|
| 503 |
+
self.logger = logging.getLogger(__name__)
|
| 504 |
+
|
| 505 |
+
# Ensure directory exists
|
| 506 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 507 |
+
|
| 508 |
+
def save_model(self, model: Pipeline, schema_name: str, instance_count: int) -> str:
|
| 509 |
+
"""Save a trained model with metadata."""
|
| 510 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 511 |
+
filename = f"{schema_name}_{instance_count}_{timestamp}.pkl"
|
| 512 |
+
filepath = os.path.join(self.save_directory, filename)
|
| 513 |
+
|
| 514 |
+
try:
|
| 515 |
+
# Save the complete model (including vectorizer)
|
| 516 |
+
with open(filepath, 'wb') as f:
|
| 517 |
+
pickle.dump(model, f)
|
| 518 |
+
|
| 519 |
+
self.logger.info(f"Saved model to {filepath}")
|
| 520 |
+
|
| 521 |
+
# Clean up old models
|
| 522 |
+
self._cleanup_old_models(schema_name)
|
| 523 |
+
|
| 524 |
+
return filepath
|
| 525 |
+
except Exception as e:
|
| 526 |
+
self.logger.error(f"Failed to save model: {e}")
|
| 527 |
+
raise
|
| 528 |
+
|
| 529 |
+
def load_model(self, filepath: str) -> Optional[Pipeline]:
|
| 530 |
+
"""Load a saved model."""
|
| 531 |
+
try:
|
| 532 |
+
with open(filepath, 'rb') as f:
|
| 533 |
+
model = pickle.load(f)
|
| 534 |
+
|
| 535 |
+
# TODO: Add schema validation here in the future
|
| 536 |
+
# This is a placeholder for future schema validation enhancement
|
| 537 |
+
|
| 538 |
+
self.logger.info(f"Loaded model from {filepath}")
|
| 539 |
+
return model
|
| 540 |
+
except Exception as e:
|
| 541 |
+
self.logger.error(f"Failed to load model from {filepath}: {e}")
|
| 542 |
+
return None
|
| 543 |
+
|
| 544 |
+
def _cleanup_old_models(self, schema_name: str):
|
| 545 |
+
"""Clean up old models based on retention policy."""
|
| 546 |
+
try:
|
| 547 |
+
# Find all model files for this schema
|
| 548 |
+
model_files = []
|
| 549 |
+
|
| 550 |
+
for filename in os.listdir(self.save_directory):
|
| 551 |
+
if filename.startswith(f"{schema_name}_") and filename.endswith(".pkl"):
|
| 552 |
+
filepath = os.path.join(self.save_directory, filename)
|
| 553 |
+
model_files.append((filepath, os.path.getmtime(filepath)))
|
| 554 |
+
|
| 555 |
+
# Sort by modification time (newest first)
|
| 556 |
+
model_files.sort(key=lambda x: x[1], reverse=True)
|
| 557 |
+
|
| 558 |
+
# Remove old models beyond retention count
|
| 559 |
+
for filepath, _ in model_files[self.retention_count:]:
|
| 560 |
+
try:
|
| 561 |
+
os.remove(filepath)
|
| 562 |
+
self.logger.info(f"Removed old model: {filepath}")
|
| 563 |
+
except Exception as e:
|
| 564 |
+
self.logger.warning(f"Failed to remove old model {filepath}: {e}")
|
| 565 |
+
|
| 566 |
+
except Exception as e:
|
| 567 |
+
self.logger.error(f"Error during model cleanup: {e}")
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
class DatabaseStateManager:
|
| 571 |
+
"""Manages database persistence for active learning state."""
|
| 572 |
+
|
| 573 |
+
def __init__(self, config: Dict[str, Any]):
|
| 574 |
+
self.config = config
|
| 575 |
+
self.logger = logging.getLogger(__name__)
|
| 576 |
+
self.connection = None
|
| 577 |
+
self._initialize_database()
|
| 578 |
+
|
| 579 |
+
def _initialize_database(self):
|
| 580 |
+
"""Initialize database connection and create tables."""
|
| 581 |
+
try:
|
| 582 |
+
# Use the same database system as main Potato application
|
| 583 |
+
if self.config.get('type') == 'mysql':
|
| 584 |
+
self._init_mysql_connection()
|
| 585 |
+
else:
|
| 586 |
+
self._init_file_based_connection()
|
| 587 |
+
|
| 588 |
+
self._create_tables()
|
| 589 |
+
self.logger.info("Active learning database initialized successfully")
|
| 590 |
+
except Exception as e:
|
| 591 |
+
self.logger.error(f"Failed to initialize database: {e}")
|
| 592 |
+
raise
|
| 593 |
+
|
| 594 |
+
def _init_mysql_connection(self):
|
| 595 |
+
"""Initialize MySQL connection."""
|
| 596 |
+
# TODO: Implement MySQL connection
|
| 597 |
+
pass
|
| 598 |
+
|
| 599 |
+
def _init_file_based_connection(self):
|
| 600 |
+
"""Initialize file-based database connection."""
|
| 601 |
+
# TODO: Implement file-based database
|
| 602 |
+
pass
|
| 603 |
+
|
| 604 |
+
def _create_tables(self):
|
| 605 |
+
"""Create database tables for active learning."""
|
| 606 |
+
# TODO: Implement table creation
|
| 607 |
+
pass
|
| 608 |
+
|
| 609 |
+
def save_training_metrics(self, metrics: TrainingMetrics):
|
| 610 |
+
"""Save training metrics to database."""
|
| 611 |
+
# TODO: Implement metrics saving
|
| 612 |
+
pass
|
| 613 |
+
|
| 614 |
+
def get_training_history(self, schema_name: Optional[str] = None) -> List[TrainingMetrics]:
|
| 615 |
+
"""Get training history from database."""
|
| 616 |
+
# TODO: Implement history retrieval
|
| 617 |
+
return []
|
| 618 |
+
|
| 619 |
+
def save_schema_cycling_state(self, current_schema: str, schema_order: List[str]):
|
| 620 |
+
"""Save current schema cycling state."""
|
| 621 |
+
# TODO: Implement state saving
|
| 622 |
+
pass
|
| 623 |
+
|
| 624 |
+
def get_schema_cycling_state(self) -> Tuple[str, List[str]]:
|
| 625 |
+
"""Get current schema cycling state."""
|
| 626 |
+
# TODO: Implement state retrieval
|
| 627 |
+
return "", []
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
class SchemaCycler:
|
| 631 |
+
"""Manages cycling through multiple annotation schemes."""
|
| 632 |
+
|
| 633 |
+
def __init__(self, schema_names: List[str], database_manager: Optional[DatabaseStateManager] = None):
|
| 634 |
+
self.schema_names = self._validate_schemas(schema_names)
|
| 635 |
+
self.database_manager = database_manager
|
| 636 |
+
self.current_index = 0
|
| 637 |
+
self.logger = logging.getLogger(__name__)
|
| 638 |
+
self._lock = threading.Lock()
|
| 639 |
+
|
| 640 |
+
# Load state from database if available
|
| 641 |
+
if self.database_manager:
|
| 642 |
+
self._load_state()
|
| 643 |
+
|
| 644 |
+
def _validate_schemas(self, schema_names: List[str]) -> List[str]:
|
| 645 |
+
"""Validate and filter schema names."""
|
| 646 |
+
valid_schemas = []
|
| 647 |
+
|
| 648 |
+
for schema in schema_names:
|
| 649 |
+
# Exclude text and span annotation schemes
|
| 650 |
+
if schema in ['text', 'span']:
|
| 651 |
+
raise ValueError(f"Text and span annotation schemes are not supported for active learning: {schema}")
|
| 652 |
+
valid_schemas.append(schema)
|
| 653 |
+
|
| 654 |
+
return valid_schemas
|
| 655 |
+
|
| 656 |
+
def _load_state(self):
|
| 657 |
+
"""Load cycling state from database."""
|
| 658 |
+
try:
|
| 659 |
+
current_schema, schema_order = self.database_manager.get_schema_cycling_state()
|
| 660 |
+
with self._lock:
|
| 661 |
+
if current_schema in self.schema_names:
|
| 662 |
+
self.current_index = self.schema_names.index(current_schema)
|
| 663 |
+
except Exception as e:
|
| 664 |
+
self.logger.warning(f"Failed to load schema cycling state: {e}")
|
| 665 |
+
|
| 666 |
+
def get_current_schema(self) -> Optional[str]:
|
| 667 |
+
"""Get the current schema for training."""
|
| 668 |
+
if not self.schema_names:
|
| 669 |
+
return None
|
| 670 |
+
with self._lock:
|
| 671 |
+
return self.schema_names[self.current_index]
|
| 672 |
+
|
| 673 |
+
def advance_schema(self):
|
| 674 |
+
"""Advance to the next schema in the cycle."""
|
| 675 |
+
if not self.schema_names:
|
| 676 |
+
return
|
| 677 |
+
|
| 678 |
+
with self._lock:
|
| 679 |
+
self.current_index = (self.current_index + 1) % len(self.schema_names)
|
| 680 |
+
current_schema = self.schema_names[self.current_index]
|
| 681 |
+
|
| 682 |
+
# Save state to database if available
|
| 683 |
+
if self.database_manager:
|
| 684 |
+
try:
|
| 685 |
+
self.database_manager.save_schema_cycling_state(
|
| 686 |
+
current_schema,
|
| 687 |
+
self.schema_names
|
| 688 |
+
)
|
| 689 |
+
except Exception as e:
|
| 690 |
+
self.logger.warning(f"Failed to save schema cycling state: {e}")
|
| 691 |
+
|
| 692 |
+
def get_schema_order(self) -> List[str]:
|
| 693 |
+
"""Get the current schema cycling order."""
|
| 694 |
+
return self.schema_names.copy()
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class ActiveLearningManager:
|
| 698 |
+
"""
|
| 699 |
+
Manages active learning operations including classifier training and instance reordering.
|
| 700 |
+
|
| 701 |
+
This class provides thread-safe operations for:
|
| 702 |
+
- Training classifiers on annotated data
|
| 703 |
+
- Predicting confidence scores for unlabeled instances
|
| 704 |
+
- Reordering instances based on configurable query strategies
|
| 705 |
+
- Cold-start LLM-based instance selection
|
| 706 |
+
- ICL/classifier ensemble for improved ranking
|
| 707 |
+
- Noise-aware annotation routing
|
| 708 |
+
- Managing training state and progress
|
| 709 |
+
- Database persistence and model saving
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
def __init__(self, config: ActiveLearningConfig):
|
| 713 |
+
self.config = config
|
| 714 |
+
self.logger = logging.getLogger(__name__)
|
| 715 |
+
|
| 716 |
+
# Thread safety
|
| 717 |
+
self._lock = threading.RLock()
|
| 718 |
+
self._training_queue = queue.Queue()
|
| 719 |
+
self._training_thread = None
|
| 720 |
+
self._stop_training = threading.Event()
|
| 721 |
+
|
| 722 |
+
# State tracking
|
| 723 |
+
self._last_training_time = 0
|
| 724 |
+
self._training_count = 0
|
| 725 |
+
self._models = {} # schema_name -> trained_model
|
| 726 |
+
self._vectorizers = {} # schema_name -> fitted vectorizer
|
| 727 |
+
self._bald_ensembles = {} # schema_name -> list of classifiers
|
| 728 |
+
self._last_annotation_count = 0
|
| 729 |
+
self._training_metrics = [] # List of TrainingMetrics
|
| 730 |
+
self._annotated_texts = {} # schema_name -> list of annotated texts
|
| 731 |
+
|
| 732 |
+
# Query strategy
|
| 733 |
+
self._query_strategy = create_query_strategy(config)
|
| 734 |
+
|
| 735 |
+
# Database and persistence
|
| 736 |
+
self.database_manager = None
|
| 737 |
+
self.model_persistence = None
|
| 738 |
+
self.schema_cycler = None
|
| 739 |
+
|
| 740 |
+
# Initialize components
|
| 741 |
+
self._initialize_components()
|
| 742 |
+
|
| 743 |
+
# Start training thread if enabled
|
| 744 |
+
if self.config.enabled:
|
| 745 |
+
self._start_training_thread()
|
| 746 |
+
|
| 747 |
+
def _initialize_components(self):
|
| 748 |
+
"""Initialize database, model persistence, and schema cycler."""
|
| 749 |
+
# Initialize database manager if enabled
|
| 750 |
+
if self.config.database_enabled:
|
| 751 |
+
try:
|
| 752 |
+
self.database_manager = DatabaseStateManager(self.config.database_config)
|
| 753 |
+
except Exception as e:
|
| 754 |
+
self.logger.error(f"Failed to initialize database manager: {e}")
|
| 755 |
+
# Continue without database persistence
|
| 756 |
+
|
| 757 |
+
# Initialize model persistence if enabled
|
| 758 |
+
if self.config.model_persistence_enabled and self.config.model_save_directory:
|
| 759 |
+
try:
|
| 760 |
+
self.model_persistence = ModelPersistence(
|
| 761 |
+
self.config.model_save_directory,
|
| 762 |
+
self.config.model_retention_count
|
| 763 |
+
)
|
| 764 |
+
except Exception as e:
|
| 765 |
+
self.logger.error(f"Failed to initialize model persistence: {e}")
|
| 766 |
+
# Continue without model persistence
|
| 767 |
+
|
| 768 |
+
# Initialize schema cycler
|
| 769 |
+
try:
|
| 770 |
+
self.schema_cycler = SchemaCycler(self.config.schema_names, self.database_manager)
|
| 771 |
+
except Exception as e:
|
| 772 |
+
self.logger.error(f"Failed to initialize schema cycler: {e}")
|
| 773 |
+
raise # Schema cycler is critical
|
| 774 |
+
|
| 775 |
+
def _start_training_thread(self):
|
| 776 |
+
"""Start the background training thread."""
|
| 777 |
+
if self._training_thread is None or not self._training_thread.is_alive():
|
| 778 |
+
self._training_thread = threading.Thread(target=self._training_worker, daemon=True)
|
| 779 |
+
self._training_thread.start()
|
| 780 |
+
self.logger.info("Active learning training thread started")
|
| 781 |
+
|
| 782 |
+
def _training_worker(self):
|
| 783 |
+
"""Background worker for training classifiers."""
|
| 784 |
+
while not self._stop_training.is_set():
|
| 785 |
+
try:
|
| 786 |
+
# Wait for training request
|
| 787 |
+
training_request = self._training_queue.get(timeout=1.0)
|
| 788 |
+
if training_request is None: # Shutdown signal
|
| 789 |
+
break
|
| 790 |
+
|
| 791 |
+
self._perform_training()
|
| 792 |
+
self._training_queue.task_done()
|
| 793 |
+
|
| 794 |
+
except queue.Empty:
|
| 795 |
+
continue
|
| 796 |
+
except Exception as e:
|
| 797 |
+
self.logger.error(f"Error in training worker: {e}")
|
| 798 |
+
|
| 799 |
+
def _perform_training(self):
|
| 800 |
+
"""Perform the actual classifier training."""
|
| 801 |
+
with self._lock:
|
| 802 |
+
try:
|
| 803 |
+
self.logger.info("Starting active learning classifier training")
|
| 804 |
+
start_time = time.time()
|
| 805 |
+
|
| 806 |
+
# Get current schema for training
|
| 807 |
+
current_schema = self.schema_cycler.get_current_schema()
|
| 808 |
+
if not current_schema:
|
| 809 |
+
self.logger.warning("No schema available for training")
|
| 810 |
+
return
|
| 811 |
+
|
| 812 |
+
# Get current annotation state
|
| 813 |
+
item_manager = get_item_state_manager()
|
| 814 |
+
user_manager = get_user_state_manager()
|
| 815 |
+
|
| 816 |
+
# Collect training data
|
| 817 |
+
training_data = self._collect_training_data(item_manager, user_manager, current_schema)
|
| 818 |
+
|
| 819 |
+
if not training_data:
|
| 820 |
+
self.logger.warning(f"No training data available for schema {current_schema}")
|
| 821 |
+
# If in cold-start phase, try LLM-based reordering
|
| 822 |
+
if self.config.cold_start_strategy == "llm" and self.config.llm_enabled:
|
| 823 |
+
self._cold_start_reorder(item_manager)
|
| 824 |
+
return
|
| 825 |
+
|
| 826 |
+
# Train classifier
|
| 827 |
+
model, metrics = self._train_classifier(training_data, current_schema)
|
| 828 |
+
|
| 829 |
+
if model:
|
| 830 |
+
self._models[current_schema] = model
|
| 831 |
+
self._annotated_texts[current_schema] = training_data["texts"]
|
| 832 |
+
|
| 833 |
+
# Save model if persistence is enabled
|
| 834 |
+
if self.model_persistence:
|
| 835 |
+
try:
|
| 836 |
+
model_path = self.model_persistence.save_model(
|
| 837 |
+
model, current_schema, len(training_data["texts"])
|
| 838 |
+
)
|
| 839 |
+
metrics.model_file_path = model_path
|
| 840 |
+
except Exception as e:
|
| 841 |
+
self.logger.error(f"Failed to save model: {e}")
|
| 842 |
+
|
| 843 |
+
# Save metrics to database if available
|
| 844 |
+
if self.database_manager:
|
| 845 |
+
try:
|
| 846 |
+
self.database_manager.save_training_metrics(metrics)
|
| 847 |
+
except Exception as e:
|
| 848 |
+
self.logger.error(f"Failed to save metrics: {e}")
|
| 849 |
+
|
| 850 |
+
# Reorder instances
|
| 851 |
+
self._reorder_instances(item_manager, current_schema)
|
| 852 |
+
|
| 853 |
+
# Advance to next schema
|
| 854 |
+
self.schema_cycler.advance_schema()
|
| 855 |
+
|
| 856 |
+
self._training_count += 1
|
| 857 |
+
self._last_training_time = time.time()
|
| 858 |
+
|
| 859 |
+
training_duration = time.time() - start_time
|
| 860 |
+
self.logger.info(f"Active learning training completed for schema {current_schema} "
|
| 861 |
+
f"(run #{self._training_count}, duration: {training_duration:.2f}s)")
|
| 862 |
+
else:
|
| 863 |
+
self.logger.warning(f"Failed to train model for schema {current_schema}")
|
| 864 |
+
# Try cold-start if not enough data
|
| 865 |
+
if (self.config.cold_start_strategy == "llm"
|
| 866 |
+
and self.config.llm_enabled
|
| 867 |
+
and len(training_data.get("texts", [])) < self.config.min_instances_for_training):
|
| 868 |
+
self._cold_start_reorder(item_manager)
|
| 869 |
+
|
| 870 |
+
except Exception as e:
|
| 871 |
+
self.logger.error(f"Error during training: {e}")
|
| 872 |
+
# Continue without failing the entire system
|
| 873 |
+
|
| 874 |
+
def _collect_training_data(self, item_manager: ItemStateManager, user_manager, schema_name: str) -> Dict:
|
| 875 |
+
"""Collect training data for a specific schema."""
|
| 876 |
+
training_data = {"texts": [], "labels": [], "instance_ids": []}
|
| 877 |
+
|
| 878 |
+
# Get all user states
|
| 879 |
+
user_states = user_manager.get_all_users()
|
| 880 |
+
self.logger.debug(f"Found {len(user_states)} user states")
|
| 881 |
+
|
| 882 |
+
# Collect annotations per instance
|
| 883 |
+
instance_annotations = defaultdict(list)
|
| 884 |
+
|
| 885 |
+
for user_state in user_states:
|
| 886 |
+
user_annotations = user_state.get_all_annotations()
|
| 887 |
+
self.logger.debug(f"User {user_state.user_id} has {len(user_annotations)} annotations")
|
| 888 |
+
for instance_id, annotations in user_annotations.items():
|
| 889 |
+
# Check if the schema exists in the labels section
|
| 890 |
+
if 'labels' in annotations:
|
| 891 |
+
labels_dict = annotations['labels']
|
| 892 |
+
# Handle Label objects as keys
|
| 893 |
+
for label_obj, value in labels_dict.items():
|
| 894 |
+
if hasattr(label_obj, 'get_schema') and label_obj.get_schema() == schema_name:
|
| 895 |
+
instance_annotations[instance_id].append({
|
| 896 |
+
"label": label_obj.get_name(),
|
| 897 |
+
"value": value,
|
| 898 |
+
"user": user_state.user_id
|
| 899 |
+
})
|
| 900 |
+
|
| 901 |
+
self.logger.debug(f"Collected annotations for {len(instance_annotations)} instances")
|
| 902 |
+
|
| 903 |
+
# Filter instances with sufficient annotations
|
| 904 |
+
for instance_id, annotations in instance_annotations.items():
|
| 905 |
+
if len(annotations) >= self.config.min_annotations_per_instance:
|
| 906 |
+
# Resolve multiple annotations
|
| 907 |
+
resolved_label = self._resolve_annotations(annotations)
|
| 908 |
+
if resolved_label:
|
| 909 |
+
item = item_manager.get_item(instance_id)
|
| 910 |
+
if item:
|
| 911 |
+
text = item.get_text()
|
| 912 |
+
training_data["texts"].append(text)
|
| 913 |
+
training_data["labels"].append(resolved_label)
|
| 914 |
+
training_data["instance_ids"].append(instance_id)
|
| 915 |
+
|
| 916 |
+
self.logger.debug(f"Training data collected: {len(training_data['texts'])} texts, {len(training_data['labels'])} labels")
|
| 917 |
+
return training_data
|
| 918 |
+
|
| 919 |
+
def _resolve_annotations(self, annotations: List[Dict]) -> Optional[str]:
|
| 920 |
+
"""Resolve multiple annotations using the configured strategy."""
|
| 921 |
+
if not annotations:
|
| 922 |
+
return None
|
| 923 |
+
|
| 924 |
+
if self.config.resolution_strategy == ResolutionStrategy.MAJORITY_VOTE:
|
| 925 |
+
return self._majority_vote(annotations)
|
| 926 |
+
elif self.config.resolution_strategy == ResolutionStrategy.RANDOM:
|
| 927 |
+
return self._random_selection(annotations)
|
| 928 |
+
elif self.config.resolution_strategy == ResolutionStrategy.CONSENSUS:
|
| 929 |
+
return self._consensus_resolution(annotations)
|
| 930 |
+
else:
|
| 931 |
+
return self._majority_vote(annotations) # Default fallback
|
| 932 |
+
|
| 933 |
+
def _majority_vote(self, annotations: List[Dict]) -> str:
|
| 934 |
+
"""Resolve annotations using majority vote with random tie-breaking."""
|
| 935 |
+
label_counts = Counter(ann["label"] for ann in annotations)
|
| 936 |
+
max_count = max(label_counts.values())
|
| 937 |
+
# Find all labels with the maximum count (handles ties)
|
| 938 |
+
tied_labels = [label for label, count in label_counts.items() if count == max_count]
|
| 939 |
+
# Break ties randomly
|
| 940 |
+
return random.choice(tied_labels)
|
| 941 |
+
|
| 942 |
+
def _random_selection(self, annotations: List[Dict]) -> str:
|
| 943 |
+
"""Resolve annotations by random selection."""
|
| 944 |
+
return random.choice(annotations)["label"]
|
| 945 |
+
|
| 946 |
+
def _consensus_resolution(self, annotations: List[Dict]) -> Optional[str]:
|
| 947 |
+
"""Resolve annotations by consensus (all must agree)."""
|
| 948 |
+
labels = [ann["label"] for ann in annotations]
|
| 949 |
+
if len(set(labels)) == 1:
|
| 950 |
+
return labels[0]
|
| 951 |
+
return None
|
| 952 |
+
|
| 953 |
+
def _train_classifier(self, training_data: Dict, schema_name: str) -> Tuple[Optional[Pipeline], TrainingMetrics]:
|
| 954 |
+
"""Train a classifier for a specific schema."""
|
| 955 |
+
start_time = time.time()
|
| 956 |
+
|
| 957 |
+
if len(training_data["texts"]) < self.config.min_instances_for_training:
|
| 958 |
+
error_msg = f"Insufficient training data for schema {schema_name}: {len(training_data['texts'])} < {self.config.min_instances_for_training}"
|
| 959 |
+
self.logger.warning(error_msg)
|
| 960 |
+
return None, TrainingMetrics(
|
| 961 |
+
schema_name=schema_name,
|
| 962 |
+
training_time=time.time() - start_time,
|
| 963 |
+
accuracy=0.0,
|
| 964 |
+
instance_count=len(training_data["texts"]),
|
| 965 |
+
timestamp=datetime.now(),
|
| 966 |
+
error_message=error_msg
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
# Check for sufficient label diversity
|
| 970 |
+
unique_labels = set(training_data["labels"])
|
| 971 |
+
if len(unique_labels) < 2:
|
| 972 |
+
error_msg = f"Insufficient label diversity for schema {schema_name}: {len(unique_labels)} unique labels"
|
| 973 |
+
self.logger.warning(error_msg)
|
| 974 |
+
return None, TrainingMetrics(
|
| 975 |
+
schema_name=schema_name,
|
| 976 |
+
training_time=time.time() - start_time,
|
| 977 |
+
accuracy=0.0,
|
| 978 |
+
instance_count=len(training_data["texts"]),
|
| 979 |
+
timestamp=datetime.now(),
|
| 980 |
+
error_message=error_msg
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
try:
|
| 984 |
+
# Create and train classifier
|
| 985 |
+
classifier = self._create_classifier()
|
| 986 |
+
vectorizer = self._create_vectorizer()
|
| 987 |
+
|
| 988 |
+
pipeline = Pipeline([
|
| 989 |
+
("vectorizer", vectorizer),
|
| 990 |
+
("classifier", classifier)
|
| 991 |
+
])
|
| 992 |
+
|
| 993 |
+
pipeline.fit(training_data["texts"], training_data["labels"])
|
| 994 |
+
|
| 995 |
+
# Apply probability calibration if enabled
|
| 996 |
+
if self.config.calibrate_probabilities and hasattr(classifier, 'predict_proba'):
|
| 997 |
+
num_samples = len(training_data["texts"])
|
| 998 |
+
if num_samples >= 5:
|
| 999 |
+
try:
|
| 1000 |
+
from sklearn.calibration import CalibratedClassifierCV
|
| 1001 |
+
cv_folds = min(3, num_samples // 2)
|
| 1002 |
+
if cv_folds >= 2:
|
| 1003 |
+
calibrated = CalibratedClassifierCV(
|
| 1004 |
+
pipeline, cv=cv_folds, method='isotonic'
|
| 1005 |
+
)
|
| 1006 |
+
calibrated.fit(training_data["texts"], training_data["labels"])
|
| 1007 |
+
pipeline = calibrated
|
| 1008 |
+
self.logger.debug(f"Applied probability calibration with {cv_folds}-fold CV")
|
| 1009 |
+
except Exception as e:
|
| 1010 |
+
self.logger.warning(f"Calibration failed, using uncalibrated model: {e}")
|
| 1011 |
+
|
| 1012 |
+
# Store vectorizer separately for strategy use
|
| 1013 |
+
self._vectorizers[schema_name] = pipeline.named_steps.get("vectorizer", vectorizer) if hasattr(pipeline, 'named_steps') else vectorizer
|
| 1014 |
+
|
| 1015 |
+
# Train BALD ensemble if needed
|
| 1016 |
+
if self.config.query_strategy == "bald":
|
| 1017 |
+
self._train_bald_ensemble(training_data, schema_name)
|
| 1018 |
+
|
| 1019 |
+
# Calculate accuracy
|
| 1020 |
+
predictions = pipeline.predict(training_data["texts"])
|
| 1021 |
+
accuracy = accuracy_score(training_data["labels"], predictions)
|
| 1022 |
+
|
| 1023 |
+
# Calculate confidence distribution
|
| 1024 |
+
confidence_distribution = self._calculate_confidence_distribution(pipeline, training_data["texts"])
|
| 1025 |
+
|
| 1026 |
+
training_time = time.time() - start_time
|
| 1027 |
+
|
| 1028 |
+
metrics = TrainingMetrics(
|
| 1029 |
+
schema_name=schema_name,
|
| 1030 |
+
training_time=training_time,
|
| 1031 |
+
accuracy=accuracy,
|
| 1032 |
+
instance_count=len(training_data["texts"]),
|
| 1033 |
+
timestamp=datetime.now(),
|
| 1034 |
+
confidence_distribution=confidence_distribution
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
self.logger.info(f"Trained classifier for schema {schema_name} with {len(training_data['texts'])} instances, "
|
| 1038 |
+
f"accuracy: {accuracy:.3f}, time: {training_time:.2f}s")
|
| 1039 |
+
|
| 1040 |
+
return pipeline, metrics
|
| 1041 |
+
|
| 1042 |
+
except Exception as e:
|
| 1043 |
+
error_msg = f"Error training classifier for schema {schema_name}: {e}"
|
| 1044 |
+
self.logger.error(error_msg)
|
| 1045 |
+
return None, TrainingMetrics(
|
| 1046 |
+
schema_name=schema_name,
|
| 1047 |
+
training_time=time.time() - start_time,
|
| 1048 |
+
accuracy=0.0,
|
| 1049 |
+
instance_count=len(training_data["texts"]),
|
| 1050 |
+
timestamp=datetime.now(),
|
| 1051 |
+
error_message=error_msg
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
def _train_bald_ensemble(self, training_data: Dict, schema_name: str):
|
| 1055 |
+
"""Train an ensemble of classifiers for BALD strategy."""
|
| 1056 |
+
params = self.config.bald_params
|
| 1057 |
+
n_estimators = params.get("n_estimators", 5)
|
| 1058 |
+
bootstrap_fraction = params.get("bootstrap_fraction", 0.8)
|
| 1059 |
+
|
| 1060 |
+
texts = training_data["texts"]
|
| 1061 |
+
labels = training_data["labels"]
|
| 1062 |
+
n_samples = len(texts)
|
| 1063 |
+
bootstrap_size = max(2, int(n_samples * bootstrap_fraction))
|
| 1064 |
+
|
| 1065 |
+
ensemble = []
|
| 1066 |
+
for i in range(n_estimators):
|
| 1067 |
+
indices = np.random.choice(n_samples, size=bootstrap_size, replace=True)
|
| 1068 |
+
boot_texts = [texts[j] for j in indices]
|
| 1069 |
+
boot_labels = [labels[j] for j in indices]
|
| 1070 |
+
|
| 1071 |
+
# Need at least 2 classes
|
| 1072 |
+
if len(set(boot_labels)) < 2:
|
| 1073 |
+
continue
|
| 1074 |
+
|
| 1075 |
+
clf = self._create_classifier()
|
| 1076 |
+
vec = self._create_vectorizer()
|
| 1077 |
+
pipe = Pipeline([("vectorizer", vec), ("classifier", clf)])
|
| 1078 |
+
pipe.fit(boot_texts, boot_labels)
|
| 1079 |
+
ensemble.append(pipe)
|
| 1080 |
+
|
| 1081 |
+
if ensemble:
|
| 1082 |
+
self._bald_ensembles[schema_name] = ensemble
|
| 1083 |
+
self.logger.info(f"Trained BALD ensemble with {len(ensemble)} models for {schema_name}")
|
| 1084 |
+
|
| 1085 |
+
def _calculate_confidence_distribution(self, pipeline, texts: List[str]) -> Dict[str, float]:
|
| 1086 |
+
"""Calculate confidence score distribution."""
|
| 1087 |
+
try:
|
| 1088 |
+
probas = pipeline.predict_proba(texts)
|
| 1089 |
+
max_confidences = np.max(probas, axis=1)
|
| 1090 |
+
|
| 1091 |
+
# Create histogram bins
|
| 1092 |
+
bins = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
|
| 1093 |
+
hist, _ = np.histogram(max_confidences, bins=bins)
|
| 1094 |
+
|
| 1095 |
+
# Convert to percentages
|
| 1096 |
+
total = len(max_confidences)
|
| 1097 |
+
distribution = {}
|
| 1098 |
+
for i, count in enumerate(hist):
|
| 1099 |
+
bin_label = f"{bins[i]:.1f}-{bins[i+1]:.1f}"
|
| 1100 |
+
distribution[bin_label] = (count / total) * 100 if total > 0 else 0
|
| 1101 |
+
|
| 1102 |
+
return distribution
|
| 1103 |
+
except Exception as e:
|
| 1104 |
+
self.logger.warning(f"Failed to calculate confidence distribution: {e}")
|
| 1105 |
+
return {}
|
| 1106 |
+
|
| 1107 |
+
def _create_classifier(self):
|
| 1108 |
+
"""Create classifier instance based on configuration."""
|
| 1109 |
+
kwargs = dict(self.config.classifier_kwargs)
|
| 1110 |
+
|
| 1111 |
+
if self.config.classifier_name == "sklearn.linear_model.LogisticRegression":
|
| 1112 |
+
return LogisticRegression(**kwargs)
|
| 1113 |
+
elif self.config.classifier_name == "sklearn.ensemble.RandomForestClassifier":
|
| 1114 |
+
return RandomForestClassifier(**kwargs)
|
| 1115 |
+
elif self.config.classifier_name == "sklearn.svm.SVC":
|
| 1116 |
+
kwargs.setdefault("probability", True)
|
| 1117 |
+
return SVC(**kwargs)
|
| 1118 |
+
else:
|
| 1119 |
+
# Try to import dynamically
|
| 1120 |
+
try:
|
| 1121 |
+
module_name, class_name = self.config.classifier_name.rsplit('.', 1)
|
| 1122 |
+
module = __import__(module_name, fromlist=[class_name])
|
| 1123 |
+
classifier_class = getattr(module, class_name)
|
| 1124 |
+
return classifier_class(**kwargs)
|
| 1125 |
+
except Exception as e:
|
| 1126 |
+
self.logger.error(f"Failed to create classifier {self.config.classifier_name}: {e}")
|
| 1127 |
+
return LogisticRegression() # Fallback
|
| 1128 |
+
|
| 1129 |
+
def _create_vectorizer(self):
|
| 1130 |
+
"""Create vectorizer instance based on configuration."""
|
| 1131 |
+
kwargs = dict(self.config.vectorizer_kwargs)
|
| 1132 |
+
|
| 1133 |
+
if self.config.vectorizer_name == "sklearn.feature_extraction.text.CountVectorizer":
|
| 1134 |
+
return CountVectorizer(**kwargs)
|
| 1135 |
+
elif self.config.vectorizer_name == "sklearn.feature_extraction.text.TfidfVectorizer":
|
| 1136 |
+
return TfidfVectorizer(**kwargs)
|
| 1137 |
+
elif self.config.vectorizer_name == "sentence-transformers":
|
| 1138 |
+
model_name = kwargs.pop("model_name", "all-MiniLM-L6-v2")
|
| 1139 |
+
return SentenceTransformerVectorizer(model_name=model_name)
|
| 1140 |
+
else:
|
| 1141 |
+
# Try to import dynamically
|
| 1142 |
+
try:
|
| 1143 |
+
module_name, class_name = self.config.vectorizer_name.rsplit('.', 1)
|
| 1144 |
+
module = __import__(module_name, fromlist=[class_name])
|
| 1145 |
+
vectorizer_class = getattr(module, class_name)
|
| 1146 |
+
return vectorizer_class(**kwargs)
|
| 1147 |
+
except Exception as e:
|
| 1148 |
+
self.logger.error(f"Failed to create vectorizer {self.config.vectorizer_name}: {e}")
|
| 1149 |
+
return TfidfVectorizer() # Fallback
|
| 1150 |
+
|
| 1151 |
+
def _reorder_instances(self, item_manager: ItemStateManager, schema_name: str):
|
| 1152 |
+
"""Reorder instances based on the configured query strategy."""
|
| 1153 |
+
if schema_name not in self._models:
|
| 1154 |
+
self.logger.warning(f"No trained model available for schema {schema_name}")
|
| 1155 |
+
return
|
| 1156 |
+
|
| 1157 |
+
# Get unlabeled instances
|
| 1158 |
+
unlabeled_instances = []
|
| 1159 |
+
unlabeled_texts = []
|
| 1160 |
+
for instance_id in item_manager.get_instance_ids():
|
| 1161 |
+
if not item_manager.get_annotators_for_item(instance_id):
|
| 1162 |
+
item = item_manager.get_item(instance_id)
|
| 1163 |
+
if item:
|
| 1164 |
+
unlabeled_instances.append(instance_id)
|
| 1165 |
+
unlabeled_texts.append(item.get_text())
|
| 1166 |
+
|
| 1167 |
+
if not unlabeled_texts:
|
| 1168 |
+
self.logger.info("No unlabeled instances to reorder")
|
| 1169 |
+
return
|
| 1170 |
+
|
| 1171 |
+
# Limit number of instances to process
|
| 1172 |
+
if self.config.max_instances_to_reorder:
|
| 1173 |
+
limit = self.config.max_instances_to_reorder
|
| 1174 |
+
unlabeled_instances = unlabeled_instances[:limit]
|
| 1175 |
+
unlabeled_texts = unlabeled_texts[:limit]
|
| 1176 |
+
|
| 1177 |
+
model = self._models[schema_name]
|
| 1178 |
+
annotated = self._annotated_texts.get(schema_name, [])
|
| 1179 |
+
|
| 1180 |
+
# Get rankings from strategy
|
| 1181 |
+
if (self.config.query_strategy == "bald"
|
| 1182 |
+
and schema_name in self._bald_ensembles
|
| 1183 |
+
and isinstance(self._query_strategy, BaldStrategy)):
|
| 1184 |
+
vectorizer = self._vectorizers.get(schema_name)
|
| 1185 |
+
if vectorizer:
|
| 1186 |
+
rankings = self._query_strategy.rank_with_ensemble(
|
| 1187 |
+
unlabeled_texts, self._bald_ensembles[schema_name], vectorizer
|
| 1188 |
+
)
|
| 1189 |
+
else:
|
| 1190 |
+
rankings = self._query_strategy.rank(unlabeled_texts, model, model, annotated)
|
| 1191 |
+
else:
|
| 1192 |
+
# Extract vectorizer and classifier from pipeline for strategy use
|
| 1193 |
+
vectorizer = self._vectorizers.get(schema_name)
|
| 1194 |
+
classifier = model
|
| 1195 |
+
if vectorizer:
|
| 1196 |
+
rankings = self._query_strategy.rank(
|
| 1197 |
+
unlabeled_texts, classifier, vectorizer, annotated
|
| 1198 |
+
)
|
| 1199 |
+
else:
|
| 1200 |
+
# Fallback: use confidence scores directly
|
| 1201 |
+
instance_scores = self._calculate_confidence_scores(
|
| 1202 |
+
unlabeled_instances, item_manager, schema_name
|
| 1203 |
+
)
|
| 1204 |
+
sorted_instances = sorted(instance_scores, key=lambda x: x[1])
|
| 1205 |
+
self._apply_reordering(sorted_instances, item_manager)
|
| 1206 |
+
return
|
| 1207 |
+
|
| 1208 |
+
# ICL ensemble blending (Phase 5B)
|
| 1209 |
+
if self.config.use_icl_ensemble:
|
| 1210 |
+
rankings = self._blend_icl_scores(
|
| 1211 |
+
rankings, unlabeled_texts, schema_name
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
# Map rankings back to instance IDs
|
| 1215 |
+
sorted_instances = [
|
| 1216 |
+
(unlabeled_instances[idx], score) for idx, score in rankings
|
| 1217 |
+
if idx < len(unlabeled_instances)
|
| 1218 |
+
]
|
| 1219 |
+
|
| 1220 |
+
# Apply reordering with random sampling
|
| 1221 |
+
self._apply_reordering(sorted_instances, item_manager)
|
| 1222 |
+
|
| 1223 |
+
def _blend_icl_scores(self, rankings: List[Tuple[int, float]],
|
| 1224 |
+
texts: List[str], schema_name: str) -> List[Tuple[int, float]]:
|
| 1225 |
+
"""Blend query strategy scores with ICL predictions."""
|
| 1226 |
+
try:
|
| 1227 |
+
from potato.ai.icl_labeler import get_icl_labeler
|
| 1228 |
+
icl_labeler = get_icl_labeler()
|
| 1229 |
+
if icl_labeler is None or not icl_labeler.has_enough_examples(schema_name):
|
| 1230 |
+
return rankings
|
| 1231 |
+
|
| 1232 |
+
# Determine interpolation weight based on annotation count
|
| 1233 |
+
params = self.config.icl_ensemble_params
|
| 1234 |
+
initial_w = params.get("initial_icl_weight", 0.7)
|
| 1235 |
+
final_w = params.get("final_icl_weight", 0.2)
|
| 1236 |
+
transition = params.get("transition_instances", 100)
|
| 1237 |
+
|
| 1238 |
+
annotated_count = len(self._annotated_texts.get(schema_name, []))
|
| 1239 |
+
progress = min(1.0, annotated_count / max(1, transition))
|
| 1240 |
+
icl_weight = initial_w + (final_w - initial_w) * progress
|
| 1241 |
+
strategy_weight = 1.0 - icl_weight
|
| 1242 |
+
|
| 1243 |
+
# Get ICL confidence for each text
|
| 1244 |
+
icl_scores = {}
|
| 1245 |
+
for idx, text in enumerate(texts):
|
| 1246 |
+
try:
|
| 1247 |
+
pred = icl_labeler.label_instance(
|
| 1248 |
+
instance_id=f"_al_blend_{idx}",
|
| 1249 |
+
schema_name=schema_name,
|
| 1250 |
+
instance_text=text,
|
| 1251 |
+
)
|
| 1252 |
+
if pred:
|
| 1253 |
+
# Lower confidence = higher priority (more uncertain)
|
| 1254 |
+
icl_scores[idx] = 1.0 - pred.confidence_score
|
| 1255 |
+
else:
|
| 1256 |
+
icl_scores[idx] = 0.5
|
| 1257 |
+
except Exception:
|
| 1258 |
+
icl_scores[idx] = 0.5
|
| 1259 |
+
|
| 1260 |
+
# Normalize strategy scores
|
| 1261 |
+
strategy_map = {idx: score for idx, score in rankings}
|
| 1262 |
+
s_vals = list(strategy_map.values())
|
| 1263 |
+
s_min, s_max = min(s_vals), max(s_vals)
|
| 1264 |
+
s_rng = s_max - s_min if s_max > s_min else 1.0
|
| 1265 |
+
|
| 1266 |
+
# Normalize ICL scores
|
| 1267 |
+
i_vals = list(icl_scores.values())
|
| 1268 |
+
i_min, i_max = min(i_vals), max(i_vals)
|
| 1269 |
+
i_rng = i_max - i_min if i_max > i_min else 1.0
|
| 1270 |
+
|
| 1271 |
+
blended = []
|
| 1272 |
+
for idx, score in rankings:
|
| 1273 |
+
norm_s = (score - s_min) / s_rng
|
| 1274 |
+
norm_i = (icl_scores.get(idx, 0.5) - i_min) / i_rng
|
| 1275 |
+
combined = strategy_weight * norm_s + icl_weight * norm_i
|
| 1276 |
+
blended.append((idx, combined))
|
| 1277 |
+
|
| 1278 |
+
blended.sort(key=lambda x: x[1], reverse=True)
|
| 1279 |
+
return blended
|
| 1280 |
+
|
| 1281 |
+
except ImportError:
|
| 1282 |
+
return rankings
|
| 1283 |
+
except Exception as e:
|
| 1284 |
+
self.logger.warning(f"ICL blending failed: {e}")
|
| 1285 |
+
return rankings
|
| 1286 |
+
|
| 1287 |
+
def _cold_start_reorder(self, item_manager: ItemStateManager):
|
| 1288 |
+
"""LLM-based cold-start instance selection (Phase 3A).
|
| 1289 |
+
|
| 1290 |
+
Based on Bayer et al. (2024) ActiveLLM approach. Before enough
|
| 1291 |
+
annotations exist for classifier training, use LLM to estimate
|
| 1292 |
+
which instances are most informative by finding those where LLM
|
| 1293 |
+
confidence is moderate (on the decision boundary).
|
| 1294 |
+
"""
|
| 1295 |
+
try:
|
| 1296 |
+
from potato.ai.llm_active_learning import create_llm_active_learning
|
| 1297 |
+
|
| 1298 |
+
llm = create_llm_active_learning(self.config.llm_config)
|
| 1299 |
+
|
| 1300 |
+
# Sample candidate instances
|
| 1301 |
+
all_ids = list(item_manager.get_instance_ids())
|
| 1302 |
+
unannotated = [
|
| 1303 |
+
iid for iid in all_ids
|
| 1304 |
+
if not item_manager.get_annotators_for_item(iid)
|
| 1305 |
+
]
|
| 1306 |
+
|
| 1307 |
+
if not unannotated:
|
| 1308 |
+
return
|
| 1309 |
+
|
| 1310 |
+
batch_size = min(self.config.cold_start_batch_size, len(unannotated))
|
| 1311 |
+
candidates = random.sample(unannotated, batch_size)
|
| 1312 |
+
|
| 1313 |
+
instances = []
|
| 1314 |
+
for iid in candidates:
|
| 1315 |
+
item = item_manager.get_item(iid)
|
| 1316 |
+
if item:
|
| 1317 |
+
instances.append({"id": iid, "text": item.get_text()})
|
| 1318 |
+
|
| 1319 |
+
if not instances:
|
| 1320 |
+
return
|
| 1321 |
+
|
| 1322 |
+
# Get LLM predictions
|
| 1323 |
+
schema_name = self.schema_cycler.get_current_schema() if self.schema_cycler else None
|
| 1324 |
+
predictions = llm.predict_instances(
|
| 1325 |
+
instances=instances,
|
| 1326 |
+
annotation_instructions="Rate your confidence in labeling this text.",
|
| 1327 |
+
schema_name=schema_name or "default",
|
| 1328 |
+
label_options=["positive", "negative", "neutral"],
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
+
# Select instances with moderate confidence (decision boundary)
|
| 1332 |
+
moderate = []
|
| 1333 |
+
other = []
|
| 1334 |
+
for pred in predictions:
|
| 1335 |
+
if 0.4 <= pred.confidence_score <= 0.7:
|
| 1336 |
+
moderate.append((pred.instance_id, pred.confidence_score))
|
| 1337 |
+
else:
|
| 1338 |
+
other.append((pred.instance_id, pred.confidence_score))
|
| 1339 |
+
|
| 1340 |
+
# Moderate-confidence first, then others, interleaved with random
|
| 1341 |
+
reordered = [iid for iid, _ in moderate] + [iid for iid, _ in other]
|
| 1342 |
+
|
| 1343 |
+
# Add remaining unannotated instances not in the sample
|
| 1344 |
+
sampled_set = set(candidates)
|
| 1345 |
+
remaining = [iid for iid in unannotated if iid not in sampled_set]
|
| 1346 |
+
random.shuffle(remaining)
|
| 1347 |
+
reordered.extend(remaining)
|
| 1348 |
+
|
| 1349 |
+
item_manager.reorder_instances(reordered)
|
| 1350 |
+
self.logger.info(f"Cold-start LLM reordering: {len(moderate)} moderate-confidence, "
|
| 1351 |
+
f"{len(other)} other, {len(remaining)} remaining")
|
| 1352 |
+
|
| 1353 |
+
except Exception as e:
|
| 1354 |
+
self.logger.warning(f"Cold-start LLM reordering failed: {e}")
|
| 1355 |
+
|
| 1356 |
+
def _route_annotation(self, instance_id: str, instance_text: str,
|
| 1357 |
+
schema_name: str) -> Dict[str, Any]:
|
| 1358 |
+
"""Noise-aware annotation routing (Phase 5D).
|
| 1359 |
+
|
| 1360 |
+
Based on Yuan et al. (2024) NoiseAL approach. Routes instances
|
| 1361 |
+
between LLM auto-labeling and human annotation based on LLM
|
| 1362 |
+
confidence levels.
|
| 1363 |
+
|
| 1364 |
+
Returns:
|
| 1365 |
+
Dict with 'route' ('human'|'auto'), optional 'suggestion',
|
| 1366 |
+
and optional 'auto_label'.
|
| 1367 |
+
"""
|
| 1368 |
+
if not self.config.annotation_routing:
|
| 1369 |
+
return {"route": "human"}
|
| 1370 |
+
|
| 1371 |
+
thresholds = self.config.routing_thresholds
|
| 1372 |
+
auto_min = thresholds.get("auto_label_min_confidence", 0.9)
|
| 1373 |
+
suggest_below = thresholds.get("show_suggestion_below", 0.5)
|
| 1374 |
+
|
| 1375 |
+
try:
|
| 1376 |
+
from potato.ai.icl_labeler import get_icl_labeler
|
| 1377 |
+
icl_labeler = get_icl_labeler()
|
| 1378 |
+
if icl_labeler is None or not icl_labeler.has_enough_examples(schema_name):
|
| 1379 |
+
return {"route": "human"}
|
| 1380 |
+
|
| 1381 |
+
prediction = icl_labeler.label_instance(
|
| 1382 |
+
instance_id=instance_id,
|
| 1383 |
+
schema_name=schema_name,
|
| 1384 |
+
instance_text=instance_text,
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
if prediction is None:
|
| 1388 |
+
return {"route": "human"}
|
| 1389 |
+
|
| 1390 |
+
confidence = prediction.confidence_score
|
| 1391 |
+
|
| 1392 |
+
if confidence >= auto_min:
|
| 1393 |
+
# High confidence: auto-label with periodic verification
|
| 1394 |
+
should_verify = random.random() < self.config.verification_sample_rate
|
| 1395 |
+
return {
|
| 1396 |
+
"route": "auto",
|
| 1397 |
+
"auto_label": prediction.predicted_label,
|
| 1398 |
+
"confidence": confidence,
|
| 1399 |
+
"needs_verification": should_verify,
|
| 1400 |
+
}
|
| 1401 |
+
elif confidence < suggest_below:
|
| 1402 |
+
# Low confidence: route to human with LLM suggestion
|
| 1403 |
+
return {
|
| 1404 |
+
"route": "human",
|
| 1405 |
+
"suggestion": prediction.predicted_label,
|
| 1406 |
+
"confidence": confidence,
|
| 1407 |
+
}
|
| 1408 |
+
else:
|
| 1409 |
+
# Medium confidence: route to human (most informative)
|
| 1410 |
+
return {"route": "human"}
|
| 1411 |
+
|
| 1412 |
+
except ImportError:
|
| 1413 |
+
return {"route": "human"}
|
| 1414 |
+
except Exception as e:
|
| 1415 |
+
self.logger.warning(f"Annotation routing failed for {instance_id}: {e}")
|
| 1416 |
+
return {"route": "human"}
|
| 1417 |
+
|
| 1418 |
+
def _calculate_confidence_scores(self, instance_ids: List[str], item_manager: ItemStateManager, schema_name: str) -> List[Tuple[str, float]]:
|
| 1419 |
+
"""Calculate confidence scores for instances."""
|
| 1420 |
+
instance_scores = []
|
| 1421 |
+
model = self._models[schema_name]
|
| 1422 |
+
|
| 1423 |
+
for instance_id in instance_ids:
|
| 1424 |
+
item = item_manager.get_item(instance_id)
|
| 1425 |
+
if not item:
|
| 1426 |
+
continue
|
| 1427 |
+
|
| 1428 |
+
text = item.get_text()
|
| 1429 |
+
|
| 1430 |
+
try:
|
| 1431 |
+
# Get prediction probabilities
|
| 1432 |
+
probas = model.predict_proba([text])[0]
|
| 1433 |
+
confidence = np.max(probas)
|
| 1434 |
+
instance_scores.append((instance_id, confidence))
|
| 1435 |
+
except Exception as e:
|
| 1436 |
+
self.logger.warning(f"Error predicting for instance {instance_id}: {e}")
|
| 1437 |
+
# Default to low confidence for failed predictions
|
| 1438 |
+
instance_scores.append((instance_id, 0.1))
|
| 1439 |
+
|
| 1440 |
+
return instance_scores
|
| 1441 |
+
|
| 1442 |
+
def _apply_reordering(self, sorted_instances: List[Tuple[str, float]], item_manager: ItemStateManager):
|
| 1443 |
+
"""Apply the new ordering to the item manager."""
|
| 1444 |
+
# Extract instance IDs in new order
|
| 1445 |
+
new_order = [instance_id for instance_id, _ in sorted_instances]
|
| 1446 |
+
|
| 1447 |
+
if not new_order:
|
| 1448 |
+
return
|
| 1449 |
+
|
| 1450 |
+
# Apply random sampling
|
| 1451 |
+
random_count = int(len(new_order) * self.config.random_sample_percent)
|
| 1452 |
+
if random_count > 0 and random_count <= len(new_order):
|
| 1453 |
+
random_instances = random.sample(new_order, random_count)
|
| 1454 |
+
else:
|
| 1455 |
+
random_instances = []
|
| 1456 |
+
|
| 1457 |
+
# Interleave active learning and random instances
|
| 1458 |
+
final_order = []
|
| 1459 |
+
al_idx = 0
|
| 1460 |
+
rand_idx = 0
|
| 1461 |
+
|
| 1462 |
+
while al_idx < len(new_order) or rand_idx < len(random_instances):
|
| 1463 |
+
if al_idx < len(new_order):
|
| 1464 |
+
final_order.append(new_order[al_idx])
|
| 1465 |
+
al_idx += 1
|
| 1466 |
+
if rand_idx < len(random_instances):
|
| 1467 |
+
final_order.append(random_instances[rand_idx])
|
| 1468 |
+
rand_idx += 1
|
| 1469 |
+
|
| 1470 |
+
# Update item manager ordering
|
| 1471 |
+
item_manager.reorder_instances(final_order)
|
| 1472 |
+
self.logger.info(f"Reordered {len(final_order)} instances")
|
| 1473 |
+
|
| 1474 |
+
def check_and_trigger_training(self):
|
| 1475 |
+
"""Check if training should be triggered and queue it if needed."""
|
| 1476 |
+
if not self.config.enabled:
|
| 1477 |
+
self.logger.debug("Active learning is disabled")
|
| 1478 |
+
return
|
| 1479 |
+
|
| 1480 |
+
with self._lock:
|
| 1481 |
+
# Count current annotations
|
| 1482 |
+
user_manager = get_user_state_manager()
|
| 1483 |
+
current_annotation_count = sum(
|
| 1484 |
+
len(user_state.get_all_annotations())
|
| 1485 |
+
for user_state in user_manager.get_all_users()
|
| 1486 |
+
)
|
| 1487 |
+
|
| 1488 |
+
self.logger.debug(f"Current annotation count: {current_annotation_count}, last count: {self._last_annotation_count}, update_frequency: {self.config.update_frequency}")
|
| 1489 |
+
|
| 1490 |
+
# Check if we should trigger training
|
| 1491 |
+
if (current_annotation_count - self._last_annotation_count) >= self.config.update_frequency:
|
| 1492 |
+
self._training_queue.put("train")
|
| 1493 |
+
self._last_annotation_count = current_annotation_count
|
| 1494 |
+
self.logger.info(f"Queued active learning training (annotations: {current_annotation_count})")
|
| 1495 |
+
else:
|
| 1496 |
+
self.logger.debug("Not enough new annotations to trigger training")
|
| 1497 |
+
|
| 1498 |
+
def force_training(self):
|
| 1499 |
+
"""Force immediate training (for testing purposes)."""
|
| 1500 |
+
if not self.config.enabled:
|
| 1501 |
+
self.logger.debug("Active learning is disabled")
|
| 1502 |
+
return
|
| 1503 |
+
|
| 1504 |
+
self.logger.info("Forcing immediate active learning training")
|
| 1505 |
+
self._training_queue.put("train")
|
| 1506 |
+
|
| 1507 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 1508 |
+
"""Get active learning statistics."""
|
| 1509 |
+
with self._lock:
|
| 1510 |
+
stats = {
|
| 1511 |
+
"enabled": self.config.enabled,
|
| 1512 |
+
"training_count": self._training_count,
|
| 1513 |
+
"last_training_time": self._last_training_time,
|
| 1514 |
+
"models_trained": list(self._models.keys()),
|
| 1515 |
+
"current_schema": self.schema_cycler.get_current_schema() if self.schema_cycler else None,
|
| 1516 |
+
"schema_order": self.schema_cycler.get_schema_order() if self.schema_cycler else [],
|
| 1517 |
+
"database_enabled": self.config.database_enabled,
|
| 1518 |
+
"model_persistence_enabled": self.config.model_persistence_enabled,
|
| 1519 |
+
"llm_enabled": self.config.llm_enabled,
|
| 1520 |
+
"query_strategy": self.config.query_strategy,
|
| 1521 |
+
"calibrate_probabilities": self.config.calibrate_probabilities,
|
| 1522 |
+
"cold_start_strategy": self.config.cold_start_strategy,
|
| 1523 |
+
"use_icl_ensemble": self.config.use_icl_ensemble,
|
| 1524 |
+
"annotation_routing": self.config.annotation_routing,
|
| 1525 |
+
}
|
| 1526 |
+
|
| 1527 |
+
# Add training metrics if available
|
| 1528 |
+
if self.database_manager:
|
| 1529 |
+
try:
|
| 1530 |
+
stats["training_history"] = [
|
| 1531 |
+
asdict(metrics) for metrics in self.database_manager.get_training_history()
|
| 1532 |
+
]
|
| 1533 |
+
except Exception as e:
|
| 1534 |
+
self.logger.warning(f"Failed to get training history: {e}")
|
| 1535 |
+
stats["training_history"] = []
|
| 1536 |
+
|
| 1537 |
+
return stats
|
| 1538 |
+
|
| 1539 |
+
def shutdown(self):
|
| 1540 |
+
"""Shutdown the active learning manager."""
|
| 1541 |
+
self._stop_training.set()
|
| 1542 |
+
if self._training_thread and self._training_thread.is_alive():
|
| 1543 |
+
self._training_queue.put(None) # Shutdown signal
|
| 1544 |
+
self._training_thread.join(timeout=5.0)
|
| 1545 |
+
self.logger.info("Active learning manager shutdown complete")
|
| 1546 |
+
|
| 1547 |
+
|
| 1548 |
+
# Global singleton instance
|
| 1549 |
+
ACTIVE_LEARNING_MANAGER = None
|
| 1550 |
+
|
| 1551 |
+
|
| 1552 |
+
def parse_active_learning_config(config_data: Dict[str, Any]) -> Optional[ActiveLearningConfig]:
|
| 1553 |
+
"""Build an ``ActiveLearningConfig`` from a Potato project config dict.
|
| 1554 |
+
|
| 1555 |
+
Returns None when active learning is not enabled. Maps the keys under the
|
| 1556 |
+
``active_learning:`` section onto the dataclass fields (unknown keys are
|
| 1557 |
+
ignored), and defaults ``schema_names`` to the project's labelable
|
| 1558 |
+
annotation schemes when not given.
|
| 1559 |
+
"""
|
| 1560 |
+
al_dict = (config_data or {}).get("active_learning", {}) or {}
|
| 1561 |
+
if not al_dict.get("enabled"):
|
| 1562 |
+
return None
|
| 1563 |
+
|
| 1564 |
+
valid_fields = {f.name for f in dataclasses.fields(ActiveLearningConfig)}
|
| 1565 |
+
kwargs = {k: v for k, v in al_dict.items() if k in valid_fields}
|
| 1566 |
+
|
| 1567 |
+
# Honor the nested `active_learning.llm:` block (LLM cold-start / ICL).
|
| 1568 |
+
# The dataclass uses flat fields (llm_enabled / llm_config), so translate.
|
| 1569 |
+
llm_block = al_dict.get("llm")
|
| 1570 |
+
if isinstance(llm_block, dict):
|
| 1571 |
+
kwargs.setdefault("llm_enabled", bool(llm_block.get("enabled", False)))
|
| 1572 |
+
kwargs.setdefault("llm_config", llm_block)
|
| 1573 |
+
|
| 1574 |
+
# YAML parses sequences as lists, but sklearn's vectorizers require a tuple
|
| 1575 |
+
# for ngram_range (e.g. (1, 2)). Coerce it so training doesn't fail.
|
| 1576 |
+
vec_params = kwargs.get("vectorizer_params")
|
| 1577 |
+
if isinstance(vec_params, dict) and isinstance(vec_params.get("ngram_range"), list):
|
| 1578 |
+
vec_params = dict(vec_params)
|
| 1579 |
+
vec_params["ngram_range"] = tuple(vec_params["ngram_range"])
|
| 1580 |
+
kwargs["vectorizer_params"] = vec_params
|
| 1581 |
+
|
| 1582 |
+
# resolution_strategy may arrive as a string; coerce to the enum.
|
| 1583 |
+
rs = kwargs.get("resolution_strategy")
|
| 1584 |
+
if isinstance(rs, str):
|
| 1585 |
+
try:
|
| 1586 |
+
kwargs["resolution_strategy"] = ResolutionStrategy(rs)
|
| 1587 |
+
except ValueError:
|
| 1588 |
+
kwargs.pop("resolution_strategy", None)
|
| 1589 |
+
|
| 1590 |
+
# Default schema_names to the labelable schemes in the project.
|
| 1591 |
+
if not kwargs.get("schema_names"):
|
| 1592 |
+
schemes = config_data.get("annotation_schemes", []) or []
|
| 1593 |
+
kwargs["schema_names"] = [
|
| 1594 |
+
s.get("name") for s in schemes
|
| 1595 |
+
if s.get("name") and s.get("annotation_type") in (
|
| 1596 |
+
"radio", "multiselect", "likert", "select"
|
| 1597 |
+
)
|
| 1598 |
+
]
|
| 1599 |
+
|
| 1600 |
+
return ActiveLearningConfig(**kwargs)
|
| 1601 |
+
|
| 1602 |
+
|
| 1603 |
+
def init_active_learning_manager(config: ActiveLearningConfig) -> ActiveLearningManager:
|
| 1604 |
+
"""Initialize the global active learning manager."""
|
| 1605 |
+
global ACTIVE_LEARNING_MANAGER
|
| 1606 |
+
|
| 1607 |
+
if ACTIVE_LEARNING_MANAGER is None:
|
| 1608 |
+
ACTIVE_LEARNING_MANAGER = ActiveLearningManager(config)
|
| 1609 |
+
|
| 1610 |
+
return ACTIVE_LEARNING_MANAGER
|
| 1611 |
+
|
| 1612 |
+
|
| 1613 |
+
def get_active_learning_manager() -> Optional[ActiveLearningManager]:
|
| 1614 |
+
"""Get the global active learning manager."""
|
| 1615 |
+
return ACTIVE_LEARNING_MANAGER
|
| 1616 |
+
|
| 1617 |
+
|
| 1618 |
+
def clear_active_learning_manager():
|
| 1619 |
+
"""Clear the global active learning manager (for testing)."""
|
| 1620 |
+
global ACTIVE_LEARNING_MANAGER
|
| 1621 |
+
if ACTIVE_LEARNING_MANAGER:
|
| 1622 |
+
ACTIVE_LEARNING_MANAGER.shutdown()
|
| 1623 |
+
ACTIVE_LEARNING_MANAGER = None
|
potato/adjudication.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adjudication Module
|
| 3 |
+
|
| 4 |
+
This module provides a comprehensive adjudication system where designated users
|
| 5 |
+
review items with multiple annotations, resolve disagreements, and produce
|
| 6 |
+
gold-standard final decisions.
|
| 7 |
+
|
| 8 |
+
Adjudication is NOT a phase — it's a parallel workflow accessible via a dedicated
|
| 9 |
+
/adjudicate route, available to users with adjudicator privileges. This avoids
|
| 10 |
+
disrupting the existing phase progression system.
|
| 11 |
+
|
| 12 |
+
Key Components:
|
| 13 |
+
- AdjudicationConfig: Configuration dataclass for adjudication settings
|
| 14 |
+
- AdjudicationItem: Represents an item eligible for adjudication with all annotations
|
| 15 |
+
- AdjudicationDecision: Represents an adjudicator's final decision on an item
|
| 16 |
+
- AdjudicationManager: Singleton manager for the adjudication workflow
|
| 17 |
+
|
| 18 |
+
The workflow:
|
| 19 |
+
1. Annotators complete annotations via /annotate (existing workflow)
|
| 20 |
+
2. AdjudicationManager monitors annotation counts and agreement
|
| 21 |
+
3. Items are flagged when criteria are met (min annotations, low agreement)
|
| 22 |
+
4. Adjudicators review items via /adjudicate and submit decisions
|
| 23 |
+
5. Final dataset CLI merges unanimous + adjudicated decisions
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import json
|
| 27 |
+
import logging
|
| 28 |
+
import math
|
| 29 |
+
import os
|
| 30 |
+
import threading
|
| 31 |
+
from collections import Counter, defaultdict
|
| 32 |
+
from dataclasses import dataclass, field
|
| 33 |
+
from datetime import datetime
|
| 34 |
+
from typing import Dict, List, Optional, Any, Set
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
# Singleton instance
|
| 39 |
+
_ADJUDICATION_MANAGER = None
|
| 40 |
+
_ADJUDICATION_LOCK = threading.Lock()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class AdjudicationConfig:
|
| 45 |
+
"""Configuration for adjudication features."""
|
| 46 |
+
enabled: bool = False
|
| 47 |
+
adjudicator_users: List[str] = field(default_factory=list)
|
| 48 |
+
|
| 49 |
+
# Trigger criteria
|
| 50 |
+
min_annotations: int = 2
|
| 51 |
+
require_fully_annotated: bool = False
|
| 52 |
+
agreement_threshold: float = 0.75
|
| 53 |
+
show_all_items: bool = False
|
| 54 |
+
|
| 55 |
+
# Display options
|
| 56 |
+
show_annotator_names: bool = True
|
| 57 |
+
show_timing_data: bool = True
|
| 58 |
+
show_agreement_scores: bool = True
|
| 59 |
+
fast_decision_warning_ms: int = 2000
|
| 60 |
+
|
| 61 |
+
# Adjudicator metadata fields
|
| 62 |
+
require_confidence: bool = True
|
| 63 |
+
require_notes_on_override: bool = False
|
| 64 |
+
error_taxonomy: List[str] = field(default_factory=lambda: [
|
| 65 |
+
"ambiguous_text", "guideline_gap", "annotator_error",
|
| 66 |
+
"edge_case", "subjective_disagreement", "other"
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
# Similarity (Phase 3, optional)
|
| 70 |
+
similarity_enabled: bool = False
|
| 71 |
+
similarity_model: str = "all-MiniLM-L6-v2"
|
| 72 |
+
similarity_top_k: int = 5
|
| 73 |
+
similarity_precompute: bool = True
|
| 74 |
+
|
| 75 |
+
# Output
|
| 76 |
+
output_subdir: str = "adjudication"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class AdjudicationItem:
|
| 81 |
+
"""Represents an item eligible for adjudication with all annotator data."""
|
| 82 |
+
instance_id: str
|
| 83 |
+
annotations: Dict[str, Dict[str, Any]] # user_id -> {schema: {label: value}}
|
| 84 |
+
span_annotations: Dict[str, List[Dict]] # user_id -> [span_dict, ...]
|
| 85 |
+
behavioral_data: Dict[str, Dict] # user_id -> {total_time_ms, ...}
|
| 86 |
+
agreement_scores: Dict[str, float] # schema_name -> agreement score
|
| 87 |
+
overall_agreement: float
|
| 88 |
+
num_annotators: int
|
| 89 |
+
status: str = "pending" # pending, in_progress, completed, skipped
|
| 90 |
+
assigned_adjudicator: Optional[str] = None
|
| 91 |
+
mace_predictions: Dict[str, Any] = field(default_factory=dict) # schema -> predicted label
|
| 92 |
+
|
| 93 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 94 |
+
"""Serialize to dictionary for JSON output."""
|
| 95 |
+
result = {
|
| 96 |
+
"instance_id": self.instance_id,
|
| 97 |
+
"annotations": self.annotations,
|
| 98 |
+
"span_annotations": self.span_annotations,
|
| 99 |
+
"behavioral_data": self.behavioral_data,
|
| 100 |
+
"agreement_scores": self.agreement_scores,
|
| 101 |
+
"overall_agreement": self.overall_agreement,
|
| 102 |
+
"num_annotators": self.num_annotators,
|
| 103 |
+
"status": self.status,
|
| 104 |
+
"assigned_adjudicator": self.assigned_adjudicator,
|
| 105 |
+
}
|
| 106 |
+
if self.mace_predictions:
|
| 107 |
+
result["mace_predictions"] = self.mace_predictions
|
| 108 |
+
return result
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@dataclass
|
| 112 |
+
class AdjudicationDecision:
|
| 113 |
+
"""Represents an adjudicator's final decision on an item."""
|
| 114 |
+
instance_id: str
|
| 115 |
+
adjudicator_id: str
|
| 116 |
+
timestamp: str # ISO format string
|
| 117 |
+
label_decisions: Dict[str, Any] # schema -> value
|
| 118 |
+
span_decisions: List[Dict] # list of span dicts
|
| 119 |
+
source: Dict[str, str] # schema -> "annotator_X" | "adjudicator" | "merged"
|
| 120 |
+
confidence: str # "high", "medium", "low"
|
| 121 |
+
notes: str
|
| 122 |
+
error_taxonomy: List[str]
|
| 123 |
+
guideline_update_flag: bool = False
|
| 124 |
+
guideline_update_notes: str = ""
|
| 125 |
+
time_spent_ms: int = 0
|
| 126 |
+
|
| 127 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 128 |
+
"""Serialize to dictionary for JSON output."""
|
| 129 |
+
return {
|
| 130 |
+
"instance_id": self.instance_id,
|
| 131 |
+
"adjudicator_id": self.adjudicator_id,
|
| 132 |
+
"timestamp": self.timestamp,
|
| 133 |
+
"label_decisions": self.label_decisions,
|
| 134 |
+
"span_decisions": self.span_decisions,
|
| 135 |
+
"source": self.source,
|
| 136 |
+
"confidence": self.confidence,
|
| 137 |
+
"notes": self.notes,
|
| 138 |
+
"error_taxonomy": self.error_taxonomy,
|
| 139 |
+
"guideline_update_flag": self.guideline_update_flag,
|
| 140 |
+
"guideline_update_notes": self.guideline_update_notes,
|
| 141 |
+
"time_spent_ms": self.time_spent_ms,
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
@classmethod
|
| 145 |
+
def from_dict(cls, d: Dict[str, Any]) -> "AdjudicationDecision":
|
| 146 |
+
"""Deserialize from dictionary."""
|
| 147 |
+
return cls(
|
| 148 |
+
instance_id=d["instance_id"],
|
| 149 |
+
adjudicator_id=d["adjudicator_id"],
|
| 150 |
+
timestamp=d["timestamp"],
|
| 151 |
+
label_decisions=d.get("label_decisions", {}),
|
| 152 |
+
span_decisions=d.get("span_decisions", []),
|
| 153 |
+
source=d.get("source", {}),
|
| 154 |
+
confidence=d.get("confidence", "medium"),
|
| 155 |
+
notes=d.get("notes", ""),
|
| 156 |
+
error_taxonomy=d.get("error_taxonomy", []),
|
| 157 |
+
guideline_update_flag=d.get("guideline_update_flag", False),
|
| 158 |
+
guideline_update_notes=d.get("guideline_update_notes", ""),
|
| 159 |
+
time_spent_ms=d.get("time_spent_ms", 0),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class AdjudicationManager:
|
| 164 |
+
"""
|
| 165 |
+
Manages the adjudication workflow including queue building, agreement
|
| 166 |
+
computation, decision storage, and final dataset generation.
|
| 167 |
+
|
| 168 |
+
Follows the singleton pattern used by QualityControlManager.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(self, config: Dict[str, Any]):
|
| 172 |
+
"""
|
| 173 |
+
Initialize the adjudication manager.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
config: The full application configuration dictionary
|
| 177 |
+
"""
|
| 178 |
+
self.config = config
|
| 179 |
+
self.logger = logging.getLogger(__name__)
|
| 180 |
+
self._lock = threading.RLock()
|
| 181 |
+
|
| 182 |
+
# Parse configuration
|
| 183 |
+
self.adj_config = self._parse_config(config)
|
| 184 |
+
|
| 185 |
+
# Queue and decisions
|
| 186 |
+
self.queue: Dict[str, AdjudicationItem] = {} # instance_id -> AdjudicationItem
|
| 187 |
+
self.decisions: Dict[str, AdjudicationDecision] = {} # instance_id -> decision
|
| 188 |
+
self._queue_built = False
|
| 189 |
+
|
| 190 |
+
# Load any previously saved decisions
|
| 191 |
+
self._load_decisions()
|
| 192 |
+
|
| 193 |
+
# Initialize similarity engine (Phase 3)
|
| 194 |
+
self.similarity_engine = None
|
| 195 |
+
if self.adj_config.similarity_enabled:
|
| 196 |
+
from potato.similarity import init_similarity_engine
|
| 197 |
+
self.similarity_engine = init_similarity_engine(config, self.adj_config)
|
| 198 |
+
if (self.similarity_engine and self.similarity_engine.enabled
|
| 199 |
+
and self.adj_config.similarity_precompute):
|
| 200 |
+
self._precompute_similarities()
|
| 201 |
+
|
| 202 |
+
self.logger.info(
|
| 203 |
+
f"AdjudicationManager initialized: enabled={self.adj_config.enabled}, "
|
| 204 |
+
f"adjudicators={self.adj_config.adjudicator_users}"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def _parse_config(self, config: Dict[str, Any]) -> AdjudicationConfig:
|
| 208 |
+
"""Parse adjudication configuration from the main config."""
|
| 209 |
+
adj = AdjudicationConfig()
|
| 210 |
+
|
| 211 |
+
adj_config = config.get("adjudication", {})
|
| 212 |
+
if not adj_config or not adj_config.get("enabled", False):
|
| 213 |
+
return adj
|
| 214 |
+
|
| 215 |
+
adj.enabled = True
|
| 216 |
+
adj.adjudicator_users = adj_config.get("adjudicator_users", [])
|
| 217 |
+
adj.min_annotations = adj_config.get("min_annotations", 2)
|
| 218 |
+
adj.require_fully_annotated = adj_config.get("require_fully_annotated", False)
|
| 219 |
+
adj.agreement_threshold = adj_config.get("agreement_threshold", 0.75)
|
| 220 |
+
adj.show_all_items = adj_config.get("show_all_items", False)
|
| 221 |
+
adj.show_annotator_names = adj_config.get("show_annotator_names", True)
|
| 222 |
+
adj.show_timing_data = adj_config.get("show_timing_data", True)
|
| 223 |
+
adj.show_agreement_scores = adj_config.get("show_agreement_scores", True)
|
| 224 |
+
adj.fast_decision_warning_ms = adj_config.get("fast_decision_warning_ms", 2000)
|
| 225 |
+
adj.require_confidence = adj_config.get("require_confidence", True)
|
| 226 |
+
adj.require_notes_on_override = adj_config.get("require_notes_on_override", False)
|
| 227 |
+
|
| 228 |
+
if "error_taxonomy" in adj_config:
|
| 229 |
+
adj.error_taxonomy = adj_config["error_taxonomy"]
|
| 230 |
+
|
| 231 |
+
# Similarity settings
|
| 232 |
+
sim_config = adj_config.get("similarity", {})
|
| 233 |
+
if sim_config.get("enabled", False):
|
| 234 |
+
adj.similarity_enabled = True
|
| 235 |
+
adj.similarity_model = sim_config.get("model", "all-MiniLM-L6-v2")
|
| 236 |
+
adj.similarity_top_k = sim_config.get("top_k", 5)
|
| 237 |
+
adj.similarity_precompute = sim_config.get("precompute_on_start", True)
|
| 238 |
+
|
| 239 |
+
adj.output_subdir = adj_config.get("output_subdir", "adjudication")
|
| 240 |
+
|
| 241 |
+
return adj
|
| 242 |
+
|
| 243 |
+
def is_adjudicator(self, username: str) -> bool:
|
| 244 |
+
"""Check if a user is an authorized adjudicator."""
|
| 245 |
+
if not self.adj_config.enabled:
|
| 246 |
+
return False
|
| 247 |
+
return username in self.adj_config.adjudicator_users
|
| 248 |
+
|
| 249 |
+
def build_queue(self) -> List[AdjudicationItem]:
|
| 250 |
+
"""
|
| 251 |
+
Scan all user annotations and build the adjudication queue.
|
| 252 |
+
|
| 253 |
+
Items become eligible when they have enough annotations and
|
| 254 |
+
agreement is below the threshold.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
List of AdjudicationItem objects
|
| 258 |
+
"""
|
| 259 |
+
from potato.user_state_management import get_user_state_manager
|
| 260 |
+
from potato.item_state_management import get_item_state_manager
|
| 261 |
+
|
| 262 |
+
with self._lock:
|
| 263 |
+
usm = get_user_state_manager()
|
| 264 |
+
ism = get_item_state_manager()
|
| 265 |
+
|
| 266 |
+
# Get all annotation schemes from config
|
| 267 |
+
annotation_schemes = self.config.get("annotation_schemes", [])
|
| 268 |
+
scheme_names = [s.get("name", "") for s in annotation_schemes]
|
| 269 |
+
|
| 270 |
+
# Iterate over all items
|
| 271 |
+
for instance_id, item in ism.instance_id_to_instance.items():
|
| 272 |
+
instance_id_str = str(instance_id)
|
| 273 |
+
|
| 274 |
+
# Skip if already decided
|
| 275 |
+
if instance_id_str in self.decisions:
|
| 276 |
+
if instance_id_str not in self.queue:
|
| 277 |
+
continue
|
| 278 |
+
# Mark as completed if decision exists
|
| 279 |
+
self.queue[instance_id_str].status = "completed"
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
# Get all annotators for this item
|
| 283 |
+
annotators = ism.instance_annotators.get(instance_id, set())
|
| 284 |
+
# Filter out adjudicators from annotator list
|
| 285 |
+
annotators = {
|
| 286 |
+
u for u in annotators
|
| 287 |
+
if u not in self.adj_config.adjudicator_users
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
if len(annotators) < self.adj_config.min_annotations:
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
# Check if we require fully annotated items
|
| 294 |
+
if self.adj_config.require_fully_annotated:
|
| 295 |
+
max_per_item = ism.max_annotations_per_item
|
| 296 |
+
if max_per_item > 0 and len(annotators) < max_per_item:
|
| 297 |
+
continue
|
| 298 |
+
|
| 299 |
+
# Collect annotations from all annotators
|
| 300 |
+
item_annotations = {}
|
| 301 |
+
item_spans = {}
|
| 302 |
+
item_behavioral = {}
|
| 303 |
+
|
| 304 |
+
for user_id in annotators:
|
| 305 |
+
user_state = usm.get_user_state(user_id)
|
| 306 |
+
if not user_state:
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
# Get label annotations
|
| 310 |
+
label_annots = user_state.instance_id_to_label_to_value.get(
|
| 311 |
+
instance_id_str, {}
|
| 312 |
+
)
|
| 313 |
+
if label_annots:
|
| 314 |
+
item_annotations[user_id] = self._serialize_labels(label_annots)
|
| 315 |
+
|
| 316 |
+
# Get span annotations
|
| 317 |
+
span_annots = user_state.instance_id_to_span_to_value.get(
|
| 318 |
+
instance_id_str, {}
|
| 319 |
+
)
|
| 320 |
+
if span_annots:
|
| 321 |
+
item_spans[user_id] = self._serialize_spans(span_annots)
|
| 322 |
+
|
| 323 |
+
# Get behavioral data
|
| 324 |
+
bd = user_state.instance_id_to_behavioral_data.get(
|
| 325 |
+
instance_id_str, {}
|
| 326 |
+
)
|
| 327 |
+
if bd:
|
| 328 |
+
item_behavioral[user_id] = self._serialize_behavioral(bd)
|
| 329 |
+
|
| 330 |
+
if not item_annotations and not item_spans:
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
# Compute agreement scores
|
| 334 |
+
agreement_scores = self._compute_agreement(
|
| 335 |
+
item_annotations, scheme_names
|
| 336 |
+
)
|
| 337 |
+
overall = self._compute_overall_agreement(agreement_scores)
|
| 338 |
+
|
| 339 |
+
# Filter by agreement threshold
|
| 340 |
+
if not self.adj_config.show_all_items:
|
| 341 |
+
if overall >= self.adj_config.agreement_threshold:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
# Preserve existing status if already in queue
|
| 345 |
+
existing = self.queue.get(instance_id_str)
|
| 346 |
+
status = existing.status if existing else "pending"
|
| 347 |
+
assigned = existing.assigned_adjudicator if existing else None
|
| 348 |
+
|
| 349 |
+
# Enrich with MACE predictions if available
|
| 350 |
+
mace_preds = {}
|
| 351 |
+
try:
|
| 352 |
+
from potato.mace_manager import get_mace_manager
|
| 353 |
+
mace_mgr = get_mace_manager()
|
| 354 |
+
if mace_mgr and mace_mgr.results:
|
| 355 |
+
for sname in scheme_names:
|
| 356 |
+
pred = mace_mgr.get_prediction(instance_id_str, sname)
|
| 357 |
+
if pred is not None:
|
| 358 |
+
mace_preds[sname] = pred
|
| 359 |
+
except Exception:
|
| 360 |
+
pass # MACE is optional
|
| 361 |
+
|
| 362 |
+
self.queue[instance_id_str] = AdjudicationItem(
|
| 363 |
+
instance_id=instance_id_str,
|
| 364 |
+
annotations=item_annotations,
|
| 365 |
+
span_annotations=item_spans,
|
| 366 |
+
behavioral_data=item_behavioral,
|
| 367 |
+
agreement_scores=agreement_scores,
|
| 368 |
+
overall_agreement=overall,
|
| 369 |
+
num_annotators=len(annotators),
|
| 370 |
+
status=status,
|
| 371 |
+
assigned_adjudicator=assigned,
|
| 372 |
+
mace_predictions=mace_preds,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
self._queue_built = True
|
| 376 |
+
return list(self.queue.values())
|
| 377 |
+
|
| 378 |
+
def try_enqueue_item(self, instance_id: str) -> bool:
|
| 379 |
+
"""
|
| 380 |
+
Evaluate a single item and, if it qualifies, add it to the queue.
|
| 381 |
+
|
| 382 |
+
Called when an overlap-sample item saturates so that low-agreement
|
| 383 |
+
items show up in the adjudication queue without needing a full
|
| 384 |
+
``build_queue()`` rescan. Returns True if the item ended up in the
|
| 385 |
+
queue, False otherwise.
|
| 386 |
+
"""
|
| 387 |
+
if not self.adj_config.enabled:
|
| 388 |
+
return False
|
| 389 |
+
|
| 390 |
+
from potato.user_state_management import get_user_state_manager
|
| 391 |
+
from potato.item_state_management import get_item_state_manager
|
| 392 |
+
|
| 393 |
+
usm = get_user_state_manager()
|
| 394 |
+
ism = get_item_state_manager()
|
| 395 |
+
if usm is None or ism is None:
|
| 396 |
+
return False
|
| 397 |
+
|
| 398 |
+
with self._lock:
|
| 399 |
+
instance_id_str = str(instance_id)
|
| 400 |
+
if instance_id_str in self.decisions:
|
| 401 |
+
return False
|
| 402 |
+
item = ism.instance_id_to_instance.get(instance_id)
|
| 403 |
+
if item is None:
|
| 404 |
+
return False
|
| 405 |
+
|
| 406 |
+
annotators = {
|
| 407 |
+
u for u in ism.instance_annotators.get(instance_id, set())
|
| 408 |
+
if u not in self.adj_config.adjudicator_users
|
| 409 |
+
}
|
| 410 |
+
if len(annotators) < self.adj_config.min_annotations:
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
+
scheme_names = [s.get("name", "") for s in self.config.get("annotation_schemes", [])]
|
| 414 |
+
item_annotations: Dict[str, Any] = {}
|
| 415 |
+
item_spans: Dict[str, Any] = {}
|
| 416 |
+
item_behavioral: Dict[str, Any] = {}
|
| 417 |
+
for user_id in annotators:
|
| 418 |
+
ustate = usm.get_user_state(user_id)
|
| 419 |
+
if not ustate:
|
| 420 |
+
continue
|
| 421 |
+
la = ustate.instance_id_to_label_to_value.get(instance_id_str, {})
|
| 422 |
+
if la:
|
| 423 |
+
item_annotations[user_id] = self._serialize_labels(la)
|
| 424 |
+
sa = ustate.instance_id_to_span_to_value.get(instance_id_str, {})
|
| 425 |
+
if sa:
|
| 426 |
+
item_spans[user_id] = self._serialize_spans(sa)
|
| 427 |
+
bd = ustate.instance_id_to_behavioral_data.get(instance_id_str, {})
|
| 428 |
+
if bd:
|
| 429 |
+
item_behavioral[user_id] = self._serialize_behavioral(bd)
|
| 430 |
+
|
| 431 |
+
if not item_annotations and not item_spans:
|
| 432 |
+
return False
|
| 433 |
+
|
| 434 |
+
agreement_scores = self._compute_agreement(item_annotations, scheme_names)
|
| 435 |
+
overall = self._compute_overall_agreement(agreement_scores)
|
| 436 |
+
if not self.adj_config.show_all_items:
|
| 437 |
+
if overall >= self.adj_config.agreement_threshold:
|
| 438 |
+
return False
|
| 439 |
+
|
| 440 |
+
existing = self.queue.get(instance_id_str)
|
| 441 |
+
self.queue[instance_id_str] = AdjudicationItem(
|
| 442 |
+
instance_id=instance_id_str,
|
| 443 |
+
annotations=item_annotations,
|
| 444 |
+
span_annotations=item_spans,
|
| 445 |
+
behavioral_data=item_behavioral,
|
| 446 |
+
agreement_scores=agreement_scores,
|
| 447 |
+
overall_agreement=overall,
|
| 448 |
+
num_annotators=len(annotators),
|
| 449 |
+
status=existing.status if existing else "pending",
|
| 450 |
+
assigned_adjudicator=existing.assigned_adjudicator if existing else None,
|
| 451 |
+
)
|
| 452 |
+
self.logger.info(
|
| 453 |
+
"Auto-routed item %s into adjudication queue (overall agreement=%.3f, "
|
| 454 |
+
"threshold=%.3f, annotators=%d)",
|
| 455 |
+
instance_id_str, overall, self.adj_config.agreement_threshold, len(annotators),
|
| 456 |
+
)
|
| 457 |
+
return True
|
| 458 |
+
|
| 459 |
+
def _serialize_labels(self, label_data: Dict) -> Dict[str, Any]:
|
| 460 |
+
"""Convert label annotation data to serializable dict."""
|
| 461 |
+
result = {}
|
| 462 |
+
for key, value in label_data.items():
|
| 463 |
+
# Key might be a Label object or a string
|
| 464 |
+
if hasattr(key, 'get_schema'):
|
| 465 |
+
schema = key.get_schema()
|
| 466 |
+
name = key.get_name()
|
| 467 |
+
if schema not in result:
|
| 468 |
+
result[schema] = {}
|
| 469 |
+
result[schema][name] = value
|
| 470 |
+
elif isinstance(key, str):
|
| 471 |
+
result[key] = value
|
| 472 |
+
else:
|
| 473 |
+
result[str(key)] = value
|
| 474 |
+
return result
|
| 475 |
+
|
| 476 |
+
def _serialize_spans(self, span_data: Dict) -> List[Dict]:
|
| 477 |
+
"""Convert span annotation data to serializable list."""
|
| 478 |
+
spans = []
|
| 479 |
+
for key, value in span_data.items():
|
| 480 |
+
if hasattr(key, 'get_schema'):
|
| 481 |
+
spans.append({
|
| 482 |
+
"schema": key.get_schema(),
|
| 483 |
+
"name": key.get_name(),
|
| 484 |
+
"title": key.get_title() if hasattr(key, 'get_title') else "",
|
| 485 |
+
"start": key.get_start(),
|
| 486 |
+
"end": key.get_end(),
|
| 487 |
+
"id": key.get_id(),
|
| 488 |
+
"target_field": key.get_target_field() if hasattr(key, 'get_target_field') else None,
|
| 489 |
+
})
|
| 490 |
+
elif isinstance(value, dict):
|
| 491 |
+
spans.append(value)
|
| 492 |
+
return spans
|
| 493 |
+
|
| 494 |
+
def _serialize_behavioral(self, bd) -> Dict:
|
| 495 |
+
"""Convert behavioral data to serializable dict."""
|
| 496 |
+
if hasattr(bd, 'to_dict'):
|
| 497 |
+
return bd.to_dict()
|
| 498 |
+
elif isinstance(bd, dict):
|
| 499 |
+
return bd
|
| 500 |
+
return {}
|
| 501 |
+
|
| 502 |
+
def _compute_agreement(
|
| 503 |
+
self, item_annotations: Dict[str, Dict], scheme_names: List[str]
|
| 504 |
+
) -> Dict[str, float]:
|
| 505 |
+
"""
|
| 506 |
+
Compute per-schema agreement for an item.
|
| 507 |
+
|
| 508 |
+
Uses simple percentage agreement (proportion of annotators who chose
|
| 509 |
+
the most common label). For more sophisticated metrics, simpledorff
|
| 510 |
+
can be used but requires multiple items.
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
Dict mapping schema_name to agreement score (0.0 - 1.0)
|
| 514 |
+
"""
|
| 515 |
+
agreement_scores = {}
|
| 516 |
+
|
| 517 |
+
for schema in scheme_names:
|
| 518 |
+
values = []
|
| 519 |
+
for user_id, user_annots in item_annotations.items():
|
| 520 |
+
if schema in user_annots:
|
| 521 |
+
val = user_annots[schema]
|
| 522 |
+
# Normalize to comparable form
|
| 523 |
+
if isinstance(val, dict):
|
| 524 |
+
# Radio stores {label: label} (value is the label string)
|
| 525 |
+
# and multiselect stores {label: value/true}. A label is
|
| 526 |
+
# "selected" when its value is present/truthy. The old
|
| 527 |
+
# filter (v is True / == "true" / == 1) dropped radio's
|
| 528 |
+
# string value, collapsing every annotator to an empty
|
| 529 |
+
# frozenset -> a spurious 1.0 agreement even on total
|
| 530 |
+
# disagreement.
|
| 531 |
+
falsey = (False, None, "", "false", "False", 0, "0")
|
| 532 |
+
selected = frozenset(
|
| 533 |
+
k for k, v in val.items() if v not in falsey
|
| 534 |
+
)
|
| 535 |
+
values.append(selected)
|
| 536 |
+
else:
|
| 537 |
+
values.append(val)
|
| 538 |
+
|
| 539 |
+
if len(values) < 2:
|
| 540 |
+
continue
|
| 541 |
+
|
| 542 |
+
# Compute pairwise agreement (percentage)
|
| 543 |
+
agree_count = 0
|
| 544 |
+
total_pairs = 0
|
| 545 |
+
for i in range(len(values)):
|
| 546 |
+
for j in range(i + 1, len(values)):
|
| 547 |
+
total_pairs += 1
|
| 548 |
+
if values[i] == values[j]:
|
| 549 |
+
agree_count += 1
|
| 550 |
+
|
| 551 |
+
agreement_scores[schema] = (
|
| 552 |
+
agree_count / total_pairs if total_pairs > 0 else 1.0
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
return agreement_scores
|
| 556 |
+
|
| 557 |
+
def _compute_overall_agreement(self, agreement_scores: Dict[str, float]) -> float:
|
| 558 |
+
"""Compute overall agreement as the mean of per-schema scores."""
|
| 559 |
+
if not agreement_scores:
|
| 560 |
+
return 1.0
|
| 561 |
+
return sum(agreement_scores.values()) / len(agreement_scores)
|
| 562 |
+
|
| 563 |
+
def get_queue(
|
| 564 |
+
self,
|
| 565 |
+
adjudicator_id: Optional[str] = None,
|
| 566 |
+
filter_status: Optional[str] = None,
|
| 567 |
+
) -> List[AdjudicationItem]:
|
| 568 |
+
"""
|
| 569 |
+
Get the adjudication queue, optionally filtered by status.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
adjudicator_id: Optional adjudicator to filter by assignment
|
| 573 |
+
filter_status: Optional status filter ("pending", "completed", etc.)
|
| 574 |
+
|
| 575 |
+
Returns:
|
| 576 |
+
List of AdjudicationItem objects
|
| 577 |
+
"""
|
| 578 |
+
with self._lock:
|
| 579 |
+
if not self._queue_built:
|
| 580 |
+
self.build_queue()
|
| 581 |
+
|
| 582 |
+
items = list(self.queue.values())
|
| 583 |
+
|
| 584 |
+
if filter_status:
|
| 585 |
+
items = [i for i in items if i.status == filter_status]
|
| 586 |
+
|
| 587 |
+
# Sort: pending first, then by agreement (lowest first)
|
| 588 |
+
items.sort(key=lambda x: (
|
| 589 |
+
0 if x.status == "pending" else 1 if x.status == "in_progress" else 2,
|
| 590 |
+
x.overall_agreement,
|
| 591 |
+
))
|
| 592 |
+
|
| 593 |
+
return items
|
| 594 |
+
|
| 595 |
+
def get_item(self, instance_id: str) -> Optional[AdjudicationItem]:
|
| 596 |
+
"""
|
| 597 |
+
Get full item data for adjudication.
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
instance_id: The instance ID to retrieve
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
AdjudicationItem or None if not in queue
|
| 604 |
+
"""
|
| 605 |
+
with self._lock:
|
| 606 |
+
if not self._queue_built:
|
| 607 |
+
self.build_queue()
|
| 608 |
+
return self.queue.get(str(instance_id))
|
| 609 |
+
|
| 610 |
+
def get_item_text(self, instance_id: str) -> str:
|
| 611 |
+
"""Get the text content for an item."""
|
| 612 |
+
from potato.item_state_management import get_item_state_manager
|
| 613 |
+
|
| 614 |
+
ism = get_item_state_manager()
|
| 615 |
+
item = ism.instance_id_to_instance.get(instance_id)
|
| 616 |
+
if item:
|
| 617 |
+
# Use text_key from config if available
|
| 618 |
+
text_key = self.config.get("item_properties", {}).get("text_key", "text")
|
| 619 |
+
data = item.get_data()
|
| 620 |
+
if isinstance(data, dict) and text_key in data:
|
| 621 |
+
return data[text_key]
|
| 622 |
+
return item.get_text()
|
| 623 |
+
return ""
|
| 624 |
+
|
| 625 |
+
def get_item_data(self, instance_id: str) -> Dict[str, Any]:
|
| 626 |
+
"""Get the full raw data for an item."""
|
| 627 |
+
from potato.item_state_management import get_item_state_manager
|
| 628 |
+
|
| 629 |
+
ism = get_item_state_manager()
|
| 630 |
+
item = ism.instance_id_to_instance.get(instance_id)
|
| 631 |
+
if item:
|
| 632 |
+
data = item.get_data()
|
| 633 |
+
if isinstance(data, dict):
|
| 634 |
+
return data
|
| 635 |
+
return {"text": str(data)}
|
| 636 |
+
return {}
|
| 637 |
+
|
| 638 |
+
def get_next_item(self, adjudicator_id: str) -> Optional[AdjudicationItem]:
|
| 639 |
+
"""Get the next pending item for an adjudicator."""
|
| 640 |
+
items = self.get_queue(filter_status="pending")
|
| 641 |
+
if items:
|
| 642 |
+
return items[0]
|
| 643 |
+
return None
|
| 644 |
+
|
| 645 |
+
def skip_item(self, instance_id: str, adjudicator_id: str) -> bool:
|
| 646 |
+
"""Mark an item as skipped."""
|
| 647 |
+
with self._lock:
|
| 648 |
+
item = self.queue.get(str(instance_id))
|
| 649 |
+
if item:
|
| 650 |
+
item.status = "skipped"
|
| 651 |
+
return True
|
| 652 |
+
return False
|
| 653 |
+
|
| 654 |
+
def submit_decision(self, decision: AdjudicationDecision) -> bool:
|
| 655 |
+
"""
|
| 656 |
+
Submit an adjudication decision.
|
| 657 |
+
|
| 658 |
+
Args:
|
| 659 |
+
decision: The AdjudicationDecision to save
|
| 660 |
+
|
| 661 |
+
Returns:
|
| 662 |
+
True if successful
|
| 663 |
+
"""
|
| 664 |
+
with self._lock:
|
| 665 |
+
instance_id = str(decision.instance_id)
|
| 666 |
+
self.decisions[instance_id] = decision
|
| 667 |
+
|
| 668 |
+
# Update queue status
|
| 669 |
+
if instance_id in self.queue:
|
| 670 |
+
self.queue[instance_id].status = "completed"
|
| 671 |
+
self.queue[instance_id].assigned_adjudicator = decision.adjudicator_id
|
| 672 |
+
|
| 673 |
+
# Persist to disk
|
| 674 |
+
self._save_decisions()
|
| 675 |
+
|
| 676 |
+
self.logger.info(
|
| 677 |
+
f"Adjudication decision saved for {instance_id} "
|
| 678 |
+
f"by {decision.adjudicator_id}"
|
| 679 |
+
)
|
| 680 |
+
return True
|
| 681 |
+
|
| 682 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 683 |
+
"""Get adjudication progress statistics."""
|
| 684 |
+
with self._lock:
|
| 685 |
+
if not self._queue_built:
|
| 686 |
+
self.build_queue()
|
| 687 |
+
|
| 688 |
+
total = len(self.queue)
|
| 689 |
+
completed = sum(
|
| 690 |
+
1 for i in self.queue.values() if i.status == "completed"
|
| 691 |
+
)
|
| 692 |
+
pending = sum(
|
| 693 |
+
1 for i in self.queue.values() if i.status == "pending"
|
| 694 |
+
)
|
| 695 |
+
skipped = sum(
|
| 696 |
+
1 for i in self.queue.values() if i.status == "skipped"
|
| 697 |
+
)
|
| 698 |
+
in_progress = sum(
|
| 699 |
+
1 for i in self.queue.values() if i.status == "in_progress"
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
avg_agreement = 0.0
|
| 703 |
+
if self.queue:
|
| 704 |
+
avg_agreement = sum(
|
| 705 |
+
i.overall_agreement for i in self.queue.values()
|
| 706 |
+
) / len(self.queue)
|
| 707 |
+
|
| 708 |
+
# Per-adjudicator stats
|
| 709 |
+
adjudicator_stats = defaultdict(lambda: {"completed": 0, "total_time_ms": 0})
|
| 710 |
+
for decision in self.decisions.values():
|
| 711 |
+
adj_id = decision.adjudicator_id
|
| 712 |
+
adjudicator_stats[adj_id]["completed"] += 1
|
| 713 |
+
adjudicator_stats[adj_id]["total_time_ms"] += decision.time_spent_ms
|
| 714 |
+
|
| 715 |
+
return {
|
| 716 |
+
"total": total,
|
| 717 |
+
"completed": completed,
|
| 718 |
+
"pending": pending,
|
| 719 |
+
"skipped": skipped,
|
| 720 |
+
"in_progress": in_progress,
|
| 721 |
+
"completion_rate": completed / total if total > 0 else 0.0,
|
| 722 |
+
"avg_agreement": avg_agreement,
|
| 723 |
+
"adjudicator_stats": dict(adjudicator_stats),
|
| 724 |
+
}
|
| 725 |
+
|
| 726 |
+
def get_decision(self, instance_id: str) -> Optional[AdjudicationDecision]:
|
| 727 |
+
"""Get the decision for an item, if one exists."""
|
| 728 |
+
return self.decisions.get(str(instance_id))
|
| 729 |
+
|
| 730 |
+
# ------------------------------------------------------------------
|
| 731 |
+
# Phase 3: Similarity integration
|
| 732 |
+
# ------------------------------------------------------------------
|
| 733 |
+
|
| 734 |
+
def _precompute_similarities(self) -> None:
|
| 735 |
+
"""Precompute embeddings for all items in the item state manager."""
|
| 736 |
+
if not self.similarity_engine or not self.similarity_engine.enabled:
|
| 737 |
+
return
|
| 738 |
+
|
| 739 |
+
from potato.item_state_management import get_item_state_manager
|
| 740 |
+
|
| 741 |
+
try:
|
| 742 |
+
ism = get_item_state_manager()
|
| 743 |
+
item_texts = {}
|
| 744 |
+
for instance_id, item in ism.instance_id_to_instance.items():
|
| 745 |
+
text = self.get_item_text(str(instance_id))
|
| 746 |
+
if text:
|
| 747 |
+
item_texts[str(instance_id)] = text
|
| 748 |
+
|
| 749 |
+
if item_texts:
|
| 750 |
+
count = self.similarity_engine.precompute_embeddings(item_texts)
|
| 751 |
+
self.logger.info(f"Precomputed {count} similarity embeddings")
|
| 752 |
+
except Exception as e:
|
| 753 |
+
self.logger.error(f"Error precomputing similarities: {e}")
|
| 754 |
+
|
| 755 |
+
def get_similar_items(
|
| 756 |
+
self, instance_id: str, include_metadata: bool = True
|
| 757 |
+
) -> List[Dict[str, Any]]:
|
| 758 |
+
"""
|
| 759 |
+
Get similar items for a given instance, enriched with queue metadata.
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
instance_id: The reference instance ID
|
| 763 |
+
include_metadata: Whether to include decision/consensus data
|
| 764 |
+
|
| 765 |
+
Returns:
|
| 766 |
+
List of dicts with instance_id, similarity, and optional metadata
|
| 767 |
+
"""
|
| 768 |
+
if not self.similarity_engine or not self.similarity_engine.enabled:
|
| 769 |
+
return []
|
| 770 |
+
|
| 771 |
+
similar = self.similarity_engine.find_similar(instance_id)
|
| 772 |
+
results = []
|
| 773 |
+
|
| 774 |
+
for other_id, score in similar:
|
| 775 |
+
entry = {
|
| 776 |
+
"instance_id": other_id,
|
| 777 |
+
"similarity": round(score, 4),
|
| 778 |
+
"text_preview": self.similarity_engine.text_cache.get(
|
| 779 |
+
other_id, ""
|
| 780 |
+
),
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
if include_metadata:
|
| 784 |
+
queue_item = self.queue.get(other_id)
|
| 785 |
+
decision = self.decisions.get(other_id)
|
| 786 |
+
|
| 787 |
+
entry["in_queue"] = queue_item is not None
|
| 788 |
+
entry["status"] = queue_item.status if queue_item else None
|
| 789 |
+
entry["overall_agreement"] = (
|
| 790 |
+
queue_item.overall_agreement if queue_item else None
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if decision:
|
| 794 |
+
entry["decision"] = "completed"
|
| 795 |
+
entry["consensus_label"] = None
|
| 796 |
+
else:
|
| 797 |
+
entry["decision"] = None
|
| 798 |
+
if queue_item:
|
| 799 |
+
entry["consensus_label"] = self._get_consensus_label(
|
| 800 |
+
queue_item
|
| 801 |
+
)
|
| 802 |
+
else:
|
| 803 |
+
entry["consensus_label"] = None
|
| 804 |
+
|
| 805 |
+
results.append(entry)
|
| 806 |
+
|
| 807 |
+
return results
|
| 808 |
+
|
| 809 |
+
def _get_consensus_label(self, item: AdjudicationItem) -> Optional[str]:
|
| 810 |
+
"""
|
| 811 |
+
Get the majority/consensus label for an item across the first schema.
|
| 812 |
+
|
| 813 |
+
Args:
|
| 814 |
+
item: The AdjudicationItem
|
| 815 |
+
|
| 816 |
+
Returns:
|
| 817 |
+
The most common label value as a string, or None
|
| 818 |
+
"""
|
| 819 |
+
if not item.annotations:
|
| 820 |
+
return None
|
| 821 |
+
|
| 822 |
+
# Use the first schema that has values
|
| 823 |
+
for user_annots in item.annotations.values():
|
| 824 |
+
for schema_name in user_annots:
|
| 825 |
+
# Collect all values for this schema
|
| 826 |
+
values = []
|
| 827 |
+
for ua in item.annotations.values():
|
| 828 |
+
val = ua.get(schema_name)
|
| 829 |
+
if val is not None:
|
| 830 |
+
if isinstance(val, dict):
|
| 831 |
+
# Multiselect: use frozenset representation
|
| 832 |
+
selected = sorted(
|
| 833 |
+
k for k, v in val.items()
|
| 834 |
+
if v is True or v == "true" or v == 1
|
| 835 |
+
)
|
| 836 |
+
values.append(", ".join(selected) if selected else str(val))
|
| 837 |
+
else:
|
| 838 |
+
values.append(str(val))
|
| 839 |
+
|
| 840 |
+
if values:
|
| 841 |
+
counter = Counter(values)
|
| 842 |
+
return counter.most_common(1)[0][0]
|
| 843 |
+
|
| 844 |
+
return None
|
| 845 |
+
|
| 846 |
+
# ------------------------------------------------------------------
|
| 847 |
+
# Phase 3: Behavioral signal analysis
|
| 848 |
+
# ------------------------------------------------------------------
|
| 849 |
+
|
| 850 |
+
def get_annotator_signals(
|
| 851 |
+
self, user_id: str, instance_id: str
|
| 852 |
+
) -> Dict[str, Any]:
|
| 853 |
+
"""
|
| 854 |
+
Compute per-annotator quality signals for a specific item.
|
| 855 |
+
|
| 856 |
+
Returns:
|
| 857 |
+
Dict with user_id, instance_id, flags list, and metrics dict
|
| 858 |
+
"""
|
| 859 |
+
flags = []
|
| 860 |
+
metrics = {}
|
| 861 |
+
|
| 862 |
+
instance_id = str(instance_id)
|
| 863 |
+
item = self.queue.get(instance_id)
|
| 864 |
+
if not item:
|
| 865 |
+
return {"user_id": user_id, "instance_id": instance_id,
|
| 866 |
+
"flags": [], "metrics": {}}
|
| 867 |
+
|
| 868 |
+
# Get behavioral data for this user on this item
|
| 869 |
+
bd = item.behavioral_data.get(user_id, {})
|
| 870 |
+
if hasattr(bd, 'to_dict'):
|
| 871 |
+
bd = bd.to_dict()
|
| 872 |
+
|
| 873 |
+
total_time = bd.get("total_time_ms", 0)
|
| 874 |
+
metrics["total_time_ms"] = total_time
|
| 875 |
+
|
| 876 |
+
# 1. Speed z-score vs user's typical time
|
| 877 |
+
user_times = self._get_user_times(user_id)
|
| 878 |
+
if len(user_times) >= 3 and total_time > 0:
|
| 879 |
+
mean_time = sum(user_times) / len(user_times)
|
| 880 |
+
std_time = math.sqrt(
|
| 881 |
+
sum((t - mean_time) ** 2 for t in user_times) / len(user_times)
|
| 882 |
+
)
|
| 883 |
+
if std_time > 0:
|
| 884 |
+
z_score = (total_time - mean_time) / std_time
|
| 885 |
+
metrics["speed_z_score"] = round(z_score, 2)
|
| 886 |
+
if z_score < -2.0:
|
| 887 |
+
flags.append({
|
| 888 |
+
"type": "unusually_fast",
|
| 889 |
+
"severity": "high",
|
| 890 |
+
"message": f"Annotation time ({total_time}ms) is {abs(z_score):.1f} std devs below average"
|
| 891 |
+
})
|
| 892 |
+
|
| 893 |
+
# 2. Fast decision warning
|
| 894 |
+
fast_threshold = self.adj_config.fast_decision_warning_ms
|
| 895 |
+
if fast_threshold > 0 and 0 < total_time < fast_threshold:
|
| 896 |
+
flags.append({
|
| 897 |
+
"type": "fast_decision",
|
| 898 |
+
"severity": "medium",
|
| 899 |
+
"message": f"Decision made in {total_time}ms (below {fast_threshold}ms threshold)"
|
| 900 |
+
})
|
| 901 |
+
|
| 902 |
+
# 3. Annotation change count
|
| 903 |
+
raw_changes = bd.get("annotation_changes", [])
|
| 904 |
+
change_count = len(raw_changes) if isinstance(raw_changes, list) else int(raw_changes or 0)
|
| 905 |
+
metrics["annotation_changes"] = change_count
|
| 906 |
+
if change_count > 5:
|
| 907 |
+
flags.append({
|
| 908 |
+
"type": "excessive_changes",
|
| 909 |
+
"severity": "medium",
|
| 910 |
+
"message": f"Made {change_count} annotation changes on this item"
|
| 911 |
+
})
|
| 912 |
+
|
| 913 |
+
# 4. Historical agreement rate with consensus
|
| 914 |
+
agreement_rate = self._compute_user_agreement_rate(user_id)
|
| 915 |
+
if agreement_rate is not None:
|
| 916 |
+
metrics["agreement_rate"] = round(agreement_rate, 3)
|
| 917 |
+
if agreement_rate < 0.4:
|
| 918 |
+
flags.append({
|
| 919 |
+
"type": "low_agreement",
|
| 920 |
+
"severity": "high",
|
| 921 |
+
"message": f"Agreement rate with consensus: {agreement_rate:.0%}"
|
| 922 |
+
})
|
| 923 |
+
|
| 924 |
+
# 5. Similar item consistency
|
| 925 |
+
if self.similarity_engine and self.similarity_engine.enabled:
|
| 926 |
+
inconsistencies = self._check_similar_item_consistency(
|
| 927 |
+
user_id, instance_id
|
| 928 |
+
)
|
| 929 |
+
metrics["similar_item_inconsistencies"] = inconsistencies
|
| 930 |
+
if inconsistencies > 0:
|
| 931 |
+
flags.append({
|
| 932 |
+
"type": "similar_item_inconsistency",
|
| 933 |
+
"severity": "medium",
|
| 934 |
+
"message": f"Different label on {inconsistencies} similar item(s)"
|
| 935 |
+
})
|
| 936 |
+
|
| 937 |
+
return {
|
| 938 |
+
"user_id": user_id,
|
| 939 |
+
"instance_id": instance_id,
|
| 940 |
+
"flags": flags,
|
| 941 |
+
"metrics": metrics,
|
| 942 |
+
}
|
| 943 |
+
|
| 944 |
+
def _get_user_times(self, user_id: str) -> List[float]:
|
| 945 |
+
"""Collect all annotation times for a user across queue items."""
|
| 946 |
+
times = []
|
| 947 |
+
for item in self.queue.values():
|
| 948 |
+
bd = item.behavioral_data.get(user_id, {})
|
| 949 |
+
if hasattr(bd, 'to_dict'):
|
| 950 |
+
bd = bd.to_dict()
|
| 951 |
+
t = bd.get("total_time_ms", 0)
|
| 952 |
+
if t > 0:
|
| 953 |
+
times.append(t)
|
| 954 |
+
return times
|
| 955 |
+
|
| 956 |
+
def _compute_user_agreement_rate(self, user_id: str) -> Optional[float]:
|
| 957 |
+
"""
|
| 958 |
+
Compute how often a user agrees with the consensus across all items.
|
| 959 |
+
|
| 960 |
+
Returns:
|
| 961 |
+
Float 0-1 or None if insufficient data (needs >= 3 items)
|
| 962 |
+
"""
|
| 963 |
+
agree_count = 0
|
| 964 |
+
total_count = 0
|
| 965 |
+
|
| 966 |
+
for item in self.queue.values():
|
| 967 |
+
if user_id not in item.annotations:
|
| 968 |
+
continue
|
| 969 |
+
|
| 970 |
+
consensus = self._get_consensus_label(item)
|
| 971 |
+
if consensus is None:
|
| 972 |
+
continue
|
| 973 |
+
|
| 974 |
+
user_annots = item.annotations[user_id]
|
| 975 |
+
# Check the first schema
|
| 976 |
+
for schema_name, val in user_annots.items():
|
| 977 |
+
if isinstance(val, dict):
|
| 978 |
+
selected = sorted(
|
| 979 |
+
k for k, v in val.items()
|
| 980 |
+
if v is True or v == "true" or v == 1
|
| 981 |
+
)
|
| 982 |
+
user_label = ", ".join(selected) if selected else str(val)
|
| 983 |
+
else:
|
| 984 |
+
user_label = str(val)
|
| 985 |
+
|
| 986 |
+
if user_label == consensus:
|
| 987 |
+
agree_count += 1
|
| 988 |
+
total_count += 1
|
| 989 |
+
break # Only check first schema
|
| 990 |
+
|
| 991 |
+
if total_count < 3:
|
| 992 |
+
return None
|
| 993 |
+
|
| 994 |
+
return agree_count / total_count
|
| 995 |
+
|
| 996 |
+
def _check_similar_item_consistency(
|
| 997 |
+
self, user_id: str, instance_id: str
|
| 998 |
+
) -> int:
|
| 999 |
+
"""
|
| 1000 |
+
Check if user's label on similar items (>0.8 similarity) is consistent.
|
| 1001 |
+
|
| 1002 |
+
Returns:
|
| 1003 |
+
Count of similar items where user's label differs
|
| 1004 |
+
"""
|
| 1005 |
+
if not self.similarity_engine:
|
| 1006 |
+
return 0
|
| 1007 |
+
|
| 1008 |
+
similar = self.similarity_engine.find_similar(instance_id)
|
| 1009 |
+
if not similar:
|
| 1010 |
+
return 0
|
| 1011 |
+
|
| 1012 |
+
# Get user's label on the current item
|
| 1013 |
+
item = self.queue.get(instance_id)
|
| 1014 |
+
if not item or user_id not in item.annotations:
|
| 1015 |
+
return 0
|
| 1016 |
+
|
| 1017 |
+
user_annots = item.annotations[user_id]
|
| 1018 |
+
current_label = None
|
| 1019 |
+
current_schema = None
|
| 1020 |
+
for schema_name, val in user_annots.items():
|
| 1021 |
+
current_schema = schema_name
|
| 1022 |
+
if isinstance(val, dict):
|
| 1023 |
+
selected = sorted(
|
| 1024 |
+
k for k, v in val.items()
|
| 1025 |
+
if v is True or v == "true" or v == 1
|
| 1026 |
+
)
|
| 1027 |
+
current_label = ", ".join(selected) if selected else str(val)
|
| 1028 |
+
else:
|
| 1029 |
+
current_label = str(val)
|
| 1030 |
+
break
|
| 1031 |
+
|
| 1032 |
+
if current_label is None:
|
| 1033 |
+
return 0
|
| 1034 |
+
|
| 1035 |
+
inconsistencies = 0
|
| 1036 |
+
for other_id, score in similar:
|
| 1037 |
+
if score < 0.8:
|
| 1038 |
+
break # Results are sorted by score desc
|
| 1039 |
+
|
| 1040 |
+
other_item = self.queue.get(other_id)
|
| 1041 |
+
if not other_item or user_id not in other_item.annotations:
|
| 1042 |
+
continue
|
| 1043 |
+
|
| 1044 |
+
other_annots = other_item.annotations[user_id]
|
| 1045 |
+
other_val = other_annots.get(current_schema)
|
| 1046 |
+
if other_val is None:
|
| 1047 |
+
continue
|
| 1048 |
+
|
| 1049 |
+
if isinstance(other_val, dict):
|
| 1050 |
+
selected = sorted(
|
| 1051 |
+
k for k, v in other_val.items()
|
| 1052 |
+
if v is True or v == "true" or v == 1
|
| 1053 |
+
)
|
| 1054 |
+
other_label = ", ".join(selected) if selected else str(other_val)
|
| 1055 |
+
else:
|
| 1056 |
+
other_label = str(other_val)
|
| 1057 |
+
|
| 1058 |
+
if other_label != current_label:
|
| 1059 |
+
inconsistencies += 1
|
| 1060 |
+
|
| 1061 |
+
return inconsistencies
|
| 1062 |
+
|
| 1063 |
+
def _get_output_dir(self) -> str:
|
| 1064 |
+
"""Get the adjudication output directory."""
|
| 1065 |
+
output_dir = self.config.get("output_annotation_dir", "annotation_output")
|
| 1066 |
+
adj_dir = os.path.join(output_dir, self.adj_config.output_subdir)
|
| 1067 |
+
os.makedirs(adj_dir, exist_ok=True)
|
| 1068 |
+
return adj_dir
|
| 1069 |
+
|
| 1070 |
+
def _save_decisions(self) -> None:
|
| 1071 |
+
"""Persist all decisions to disk."""
|
| 1072 |
+
try:
|
| 1073 |
+
adj_dir = self._get_output_dir()
|
| 1074 |
+
decisions_file = os.path.join(adj_dir, "decisions.json")
|
| 1075 |
+
|
| 1076 |
+
data = {
|
| 1077 |
+
"decisions": [d.to_dict() for d in self.decisions.values()],
|
| 1078 |
+
"last_updated": datetime.now().isoformat(),
|
| 1079 |
+
}
|
| 1080 |
+
|
| 1081 |
+
with open(decisions_file, "w", encoding="utf-8") as f:
|
| 1082 |
+
json.dump(data, f, indent=2)
|
| 1083 |
+
|
| 1084 |
+
except Exception as e:
|
| 1085 |
+
self.logger.error(f"Failed to save adjudication decisions: {e}")
|
| 1086 |
+
|
| 1087 |
+
def _load_decisions(self) -> None:
|
| 1088 |
+
"""Load previously saved decisions from disk."""
|
| 1089 |
+
try:
|
| 1090 |
+
output_dir = self.config.get("output_annotation_dir", "annotation_output")
|
| 1091 |
+
adj_dir = os.path.join(output_dir, self.adj_config.output_subdir)
|
| 1092 |
+
decisions_file = os.path.join(adj_dir, "decisions.json")
|
| 1093 |
+
|
| 1094 |
+
if not os.path.exists(decisions_file):
|
| 1095 |
+
return
|
| 1096 |
+
|
| 1097 |
+
with open(decisions_file, "r", encoding="utf-8") as f:
|
| 1098 |
+
data = json.load(f)
|
| 1099 |
+
|
| 1100 |
+
for d in data.get("decisions", []):
|
| 1101 |
+
decision = AdjudicationDecision.from_dict(d)
|
| 1102 |
+
self.decisions[decision.instance_id] = decision
|
| 1103 |
+
|
| 1104 |
+
self.logger.info(
|
| 1105 |
+
f"Loaded {len(self.decisions)} previous adjudication decisions"
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
except Exception as e:
|
| 1109 |
+
self.logger.warning(f"Failed to load adjudication decisions: {e}")
|
| 1110 |
+
|
| 1111 |
+
def generate_final_dataset(self) -> List[Dict[str, Any]]:
|
| 1112 |
+
"""
|
| 1113 |
+
Generate the final dataset by merging unanimous agreements
|
| 1114 |
+
and adjudication decisions.
|
| 1115 |
+
|
| 1116 |
+
Returns:
|
| 1117 |
+
List of item dicts with final labels and provenance
|
| 1118 |
+
"""
|
| 1119 |
+
from potato.user_state_management import get_user_state_manager
|
| 1120 |
+
from potato.item_state_management import get_item_state_manager
|
| 1121 |
+
|
| 1122 |
+
usm = get_user_state_manager()
|
| 1123 |
+
ism = get_item_state_manager()
|
| 1124 |
+
|
| 1125 |
+
annotation_schemes = self.config.get("annotation_schemes", [])
|
| 1126 |
+
scheme_names = [s.get("name", "") for s in annotation_schemes]
|
| 1127 |
+
results = []
|
| 1128 |
+
|
| 1129 |
+
for instance_id, item in ism.instance_id_to_instance.items():
|
| 1130 |
+
instance_id_str = str(instance_id)
|
| 1131 |
+
result = {
|
| 1132 |
+
"instance_id": instance_id_str,
|
| 1133 |
+
"item_data": item.get_data() if hasattr(item, 'get_data') else {},
|
| 1134 |
+
}
|
| 1135 |
+
|
| 1136 |
+
# Check if we have an adjudication decision
|
| 1137 |
+
decision = self.decisions.get(instance_id_str)
|
| 1138 |
+
if decision:
|
| 1139 |
+
result["labels"] = decision.label_decisions
|
| 1140 |
+
result["spans"] = decision.span_decisions
|
| 1141 |
+
result["source"] = "adjudicated"
|
| 1142 |
+
result["adjudicator"] = decision.adjudicator_id
|
| 1143 |
+
result["confidence"] = decision.confidence
|
| 1144 |
+
result["provenance"] = decision.source
|
| 1145 |
+
results.append(result)
|
| 1146 |
+
continue
|
| 1147 |
+
|
| 1148 |
+
# Check for unanimous agreement
|
| 1149 |
+
annotators = ism.instance_annotators.get(instance_id, set())
|
| 1150 |
+
annotators = {
|
| 1151 |
+
u for u in annotators
|
| 1152 |
+
if u not in self.adj_config.adjudicator_users
|
| 1153 |
+
}
|
| 1154 |
+
|
| 1155 |
+
if len(annotators) < 2:
|
| 1156 |
+
continue
|
| 1157 |
+
|
| 1158 |
+
# Collect annotations
|
| 1159 |
+
annotations = {}
|
| 1160 |
+
for user_id in annotators:
|
| 1161 |
+
user_state = usm.get_user_state(user_id)
|
| 1162 |
+
if not user_state:
|
| 1163 |
+
continue
|
| 1164 |
+
labels = user_state.instance_id_to_label_to_value.get(
|
| 1165 |
+
instance_id_str, {}
|
| 1166 |
+
)
|
| 1167 |
+
if labels:
|
| 1168 |
+
annotations[user_id] = self._serialize_labels(labels)
|
| 1169 |
+
|
| 1170 |
+
if not annotations:
|
| 1171 |
+
continue
|
| 1172 |
+
|
| 1173 |
+
# Check for unanimity per schema
|
| 1174 |
+
unanimous_labels = {}
|
| 1175 |
+
is_unanimous = True
|
| 1176 |
+
for schema in scheme_names:
|
| 1177 |
+
values = []
|
| 1178 |
+
for user_annots in annotations.values():
|
| 1179 |
+
if schema in user_annots:
|
| 1180 |
+
values.append(json.dumps(user_annots[schema], sort_keys=True))
|
| 1181 |
+
|
| 1182 |
+
if len(values) < 2:
|
| 1183 |
+
continue
|
| 1184 |
+
|
| 1185 |
+
if len(set(values)) == 1:
|
| 1186 |
+
unanimous_labels[schema] = json.loads(values[0])
|
| 1187 |
+
else:
|
| 1188 |
+
is_unanimous = False
|
| 1189 |
+
|
| 1190 |
+
if is_unanimous and unanimous_labels:
|
| 1191 |
+
result["labels"] = unanimous_labels
|
| 1192 |
+
result["source"] = "unanimous"
|
| 1193 |
+
result["num_annotators"] = len(annotators)
|
| 1194 |
+
results.append(result)
|
| 1195 |
+
else:
|
| 1196 |
+
result["labels"] = {}
|
| 1197 |
+
result["source"] = "unresolved"
|
| 1198 |
+
result["num_annotators"] = len(annotators)
|
| 1199 |
+
results.append(result)
|
| 1200 |
+
|
| 1201 |
+
return results
|
| 1202 |
+
|
| 1203 |
+
|
| 1204 |
+
def init_adjudication_manager(config: Dict[str, Any]) -> Optional[AdjudicationManager]:
|
| 1205 |
+
"""Initialize the singleton AdjudicationManager."""
|
| 1206 |
+
global _ADJUDICATION_MANAGER
|
| 1207 |
+
|
| 1208 |
+
with _ADJUDICATION_LOCK:
|
| 1209 |
+
if _ADJUDICATION_MANAGER is None:
|
| 1210 |
+
_ADJUDICATION_MANAGER = AdjudicationManager(config)
|
| 1211 |
+
|
| 1212 |
+
return _ADJUDICATION_MANAGER
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
def get_adjudication_manager() -> Optional[AdjudicationManager]:
|
| 1216 |
+
"""Get the singleton AdjudicationManager instance."""
|
| 1217 |
+
return _ADJUDICATION_MANAGER
|
| 1218 |
+
|
| 1219 |
+
|
| 1220 |
+
def clear_adjudication_manager():
|
| 1221 |
+
"""Clear the singleton (for testing)."""
|
| 1222 |
+
global _ADJUDICATION_MANAGER
|
| 1223 |
+
with _ADJUDICATION_LOCK:
|
| 1224 |
+
_ADJUDICATION_MANAGER = None
|
potato/adjudication_export.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adjudication Export CLI
|
| 3 |
+
|
| 4 |
+
Generate final datasets by merging unanimous agreements and adjudication decisions.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python -m potato.adjudication_export --config config.yaml --output final_dataset.jsonl
|
| 8 |
+
python -m potato.adjudication_export --config config.yaml --output final.csv --format csv
|
| 9 |
+
python -m potato.adjudication_export --config config.yaml --output final.json --format json
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import csv
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser(
|
| 24 |
+
description="Export adjudicated dataset from Potato annotation project"
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--config", required=True,
|
| 28 |
+
help="Path to the Potato config YAML file"
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--output", required=True,
|
| 32 |
+
help="Output file path"
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--format", choices=["jsonl", "json", "csv"], default="jsonl",
|
| 36 |
+
help="Output format (default: jsonl)"
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--include-unresolved", action="store_true",
|
| 40 |
+
help="Include items without adjudication or consensus"
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--verbose", "-v", action="store_true",
|
| 44 |
+
help="Verbose output"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
|
| 49 |
+
if args.verbose:
|
| 50 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 51 |
+
else:
|
| 52 |
+
logging.basicConfig(level=logging.INFO)
|
| 53 |
+
|
| 54 |
+
# Load config
|
| 55 |
+
from potato.server_utils.config_module import init_config, config
|
| 56 |
+
try:
|
| 57 |
+
init_config(args.config)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Error loading config: {e}", file=sys.stderr)
|
| 60 |
+
sys.exit(1)
|
| 61 |
+
|
| 62 |
+
# Initialize state managers
|
| 63 |
+
from potato.item_state_management import init_item_state_manager
|
| 64 |
+
from potato.user_state_management import init_user_state_manager
|
| 65 |
+
|
| 66 |
+
init_user_state_manager(config)
|
| 67 |
+
init_item_state_manager(config)
|
| 68 |
+
|
| 69 |
+
# Load data (this loads items and user annotations from disk)
|
| 70 |
+
# We need a minimal load - just items and user states
|
| 71 |
+
from potato.flask_server import load_instance_data, load_user_data
|
| 72 |
+
load_instance_data(config)
|
| 73 |
+
load_user_data(config)
|
| 74 |
+
|
| 75 |
+
# Initialize adjudication manager
|
| 76 |
+
from potato.adjudication import init_adjudication_manager
|
| 77 |
+
adj_mgr = init_adjudication_manager(config)
|
| 78 |
+
|
| 79 |
+
if not adj_mgr or not adj_mgr.adj_config.enabled:
|
| 80 |
+
print("Adjudication is not enabled in this config.", file=sys.stderr)
|
| 81 |
+
sys.exit(1)
|
| 82 |
+
|
| 83 |
+
# Build queue to compute agreements
|
| 84 |
+
adj_mgr.build_queue()
|
| 85 |
+
|
| 86 |
+
# Generate final dataset
|
| 87 |
+
results = adj_mgr.generate_final_dataset()
|
| 88 |
+
|
| 89 |
+
# Filter unresolved if not requested
|
| 90 |
+
if not args.include_unresolved:
|
| 91 |
+
results = [r for r in results if r.get("source") != "unresolved"]
|
| 92 |
+
|
| 93 |
+
# Write output
|
| 94 |
+
output_path = args.output
|
| 95 |
+
fmt = args.format
|
| 96 |
+
|
| 97 |
+
if fmt == "jsonl":
|
| 98 |
+
with open(output_path, "w") as f:
|
| 99 |
+
for item in results:
|
| 100 |
+
f.write(json.dumps(item) + "\n")
|
| 101 |
+
|
| 102 |
+
elif fmt == "json":
|
| 103 |
+
with open(output_path, "w") as f:
|
| 104 |
+
json.dump(results, f, indent=2)
|
| 105 |
+
|
| 106 |
+
elif fmt == "csv":
|
| 107 |
+
if not results:
|
| 108 |
+
print("No results to export.", file=sys.stderr)
|
| 109 |
+
sys.exit(0)
|
| 110 |
+
|
| 111 |
+
# Flatten for CSV
|
| 112 |
+
fieldnames = set()
|
| 113 |
+
flat_results = []
|
| 114 |
+
for item in results:
|
| 115 |
+
flat = {
|
| 116 |
+
"instance_id": item["instance_id"],
|
| 117 |
+
"source": item.get("source", ""),
|
| 118 |
+
}
|
| 119 |
+
# Flatten labels
|
| 120 |
+
labels = item.get("labels", {})
|
| 121 |
+
for schema, value in labels.items():
|
| 122 |
+
if isinstance(value, dict):
|
| 123 |
+
flat[schema] = json.dumps(value)
|
| 124 |
+
else:
|
| 125 |
+
flat[schema] = value
|
| 126 |
+
|
| 127 |
+
# Add provenance fields
|
| 128 |
+
if "adjudicator" in item:
|
| 129 |
+
flat["adjudicator"] = item["adjudicator"]
|
| 130 |
+
if "confidence" in item:
|
| 131 |
+
flat["confidence"] = item["confidence"]
|
| 132 |
+
if "num_annotators" in item:
|
| 133 |
+
flat["num_annotators"] = item["num_annotators"]
|
| 134 |
+
|
| 135 |
+
fieldnames.update(flat.keys())
|
| 136 |
+
flat_results.append(flat)
|
| 137 |
+
|
| 138 |
+
# Sort fieldnames for consistent output
|
| 139 |
+
fieldnames = sorted(fieldnames)
|
| 140 |
+
|
| 141 |
+
with open(output_path, "w", newline="") as f:
|
| 142 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
|
| 143 |
+
writer.writeheader()
|
| 144 |
+
writer.writerows(flat_results)
|
| 145 |
+
|
| 146 |
+
# Summary
|
| 147 |
+
total = len(results)
|
| 148 |
+
unanimous = sum(1 for r in results if r.get("source") == "unanimous")
|
| 149 |
+
adjudicated = sum(1 for r in results if r.get("source") == "adjudicated")
|
| 150 |
+
unresolved = sum(1 for r in results if r.get("source") == "unresolved")
|
| 151 |
+
|
| 152 |
+
print(f"\nExport complete: {output_path}")
|
| 153 |
+
print(f" Total items: {total}")
|
| 154 |
+
print(f" Unanimous: {unanimous}")
|
| 155 |
+
print(f" Adjudicated: {adjudicated}")
|
| 156 |
+
if args.include_unresolved:
|
| 157 |
+
print(f" Unresolved: {unresolved}")
|
| 158 |
+
print(f" Format: {fmt}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
potato/admin.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
potato/agent_proxy/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Proxy Package
|
| 3 |
+
|
| 4 |
+
Provides agent proxy implementations for live agent interaction during annotation.
|
| 5 |
+
Proxies communicate with AI agent backends (echo, HTTP, OpenAI) and return
|
| 6 |
+
responses to the annotation interface.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
from potato.agent_proxy import AgentProxyFactory
|
| 10 |
+
|
| 11 |
+
proxy = AgentProxyFactory.create(config)
|
| 12 |
+
context = proxy.start_session("Book a flight to Paris")
|
| 13 |
+
response = proxy.send_message("Hello", context)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from .base import AgentMessage, AgentResponse, BaseAgentProxy, AgentProxyFactory
|
| 17 |
+
from .session import (
|
| 18 |
+
AgentSession,
|
| 19 |
+
AgentSessionManager,
|
| 20 |
+
init_agent_session_manager,
|
| 21 |
+
get_agent_session_manager,
|
| 22 |
+
clear_agent_session_manager,
|
| 23 |
+
)
|
| 24 |
+
from .sandbox import SafetySandbox, SandboxViolation
|
| 25 |
+
|
| 26 |
+
# Import proxy implementations to trigger registration
|
| 27 |
+
from . import echo_proxy
|
| 28 |
+
from . import http_proxy
|
| 29 |
+
from . import openai_proxy
|
| 30 |
+
from . import coding_proxy # subprocess_coding + docker_coding
|
| 31 |
+
|
| 32 |
+
__all__ = [
|
| 33 |
+
"AgentMessage",
|
| 34 |
+
"AgentResponse",
|
| 35 |
+
"BaseAgentProxy",
|
| 36 |
+
"AgentProxyFactory",
|
| 37 |
+
"AgentSession",
|
| 38 |
+
"AgentSessionManager",
|
| 39 |
+
"init_agent_session_manager",
|
| 40 |
+
"get_agent_session_manager",
|
| 41 |
+
"clear_agent_session_manager",
|
| 42 |
+
"SafetySandbox",
|
| 43 |
+
"SandboxViolation",
|
| 44 |
+
]
|
potato/agent_proxy/base.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Proxy Base Module
|
| 3 |
+
|
| 4 |
+
Provides the abstract base class and data structures for agent proxies,
|
| 5 |
+
plus a factory registry for creating proxy instances from configuration.
|
| 6 |
+
|
| 7 |
+
Agent proxies allow annotators to interact with AI agents live during
|
| 8 |
+
annotation tasks. Each proxy type (echo, http, openai) handles
|
| 9 |
+
communication with a specific kind of agent backend.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from abc import ABC, abstractmethod
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from typing import Dict, Any, Optional, List
|
| 15 |
+
import logging
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class AgentMessage:
|
| 23 |
+
"""A single message in an agent conversation."""
|
| 24 |
+
role: str # "user", "agent", "system", "error"
|
| 25 |
+
content: str
|
| 26 |
+
timestamp: float = field(default_factory=time.time)
|
| 27 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class AgentResponse:
|
| 32 |
+
"""Response from an agent proxy after sending a message."""
|
| 33 |
+
message: AgentMessage
|
| 34 |
+
done: bool = False
|
| 35 |
+
error: Optional[str] = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BaseAgentProxy(ABC):
|
| 39 |
+
"""
|
| 40 |
+
Abstract base class for agent proxies.
|
| 41 |
+
|
| 42 |
+
Subclasses implement communication with specific agent backends
|
| 43 |
+
(echo for testing, HTTP for generic REST APIs, OpenAI for chat completions).
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
proxy_type: str = ""
|
| 47 |
+
|
| 48 |
+
def __init__(self, config: dict):
|
| 49 |
+
self.config = config
|
| 50 |
+
self._initialize()
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def _initialize(self):
|
| 54 |
+
"""Set up connections, validate config. Called by __init__."""
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
@abstractmethod
|
| 58 |
+
def start_session(self, task_description: str) -> dict:
|
| 59 |
+
"""
|
| 60 |
+
Start a new interaction session.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
task_description: The task the annotator should accomplish with the agent.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Proxy-specific session context dict (stored in AgentSession.proxy_context).
|
| 67 |
+
"""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
def send_message(self, message: str, session_context: dict) -> AgentResponse:
|
| 72 |
+
"""
|
| 73 |
+
Send a message to the agent and get a blocking response.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
message: The user's message text.
|
| 77 |
+
session_context: The proxy-specific context from start_session.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
AgentResponse with the agent's reply.
|
| 81 |
+
"""
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
def end_session(self, session_context: dict):
|
| 85 |
+
"""
|
| 86 |
+
Clean up session resources. Override if needed.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
session_context: The proxy-specific context from start_session.
|
| 90 |
+
"""
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class AgentProxyFactory:
|
| 95 |
+
"""Factory registry for creating agent proxy instances."""
|
| 96 |
+
|
| 97 |
+
_proxies: Dict[str, type] = {}
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def register(cls, proxy_type: str, proxy_class: type):
|
| 101 |
+
"""Register a proxy type."""
|
| 102 |
+
cls._proxies[proxy_type] = proxy_class
|
| 103 |
+
logger.debug(f"Registered agent proxy type: {proxy_type}")
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def create(cls, config: dict) -> BaseAgentProxy:
|
| 107 |
+
"""
|
| 108 |
+
Create an agent proxy from configuration.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
config: The full config dict. Reads from config["agent_proxy"].
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Configured BaseAgentProxy instance.
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
ValueError: If proxy type is unknown or missing.
|
| 118 |
+
"""
|
| 119 |
+
agent_config = config.get("agent_proxy", {})
|
| 120 |
+
proxy_type = agent_config.get("type")
|
| 121 |
+
|
| 122 |
+
if not proxy_type:
|
| 123 |
+
raise ValueError("agent_proxy.type is required")
|
| 124 |
+
|
| 125 |
+
if proxy_type not in cls._proxies:
|
| 126 |
+
supported = ", ".join(sorted(cls._proxies.keys()))
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Unknown agent proxy type: '{proxy_type}'. "
|
| 129 |
+
f"Supported types: {supported}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
proxy_class = cls._proxies[proxy_type]
|
| 133 |
+
return proxy_class(agent_config)
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def get_supported_types(cls) -> List[str]:
|
| 137 |
+
"""Get list of registered proxy type names."""
|
| 138 |
+
return sorted(cls._proxies.keys())
|
potato/agent_proxy/coding_proxy.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Coding-agent proxies — LLM plans + sandboxed code execution.
|
| 3 |
+
|
| 4 |
+
Two implementations behind a shared base class:
|
| 5 |
+
|
| 6 |
+
- :class:`SubprocessCodingAgentProxy` (default, ``type: subprocess_coding``)
|
| 7 |
+
runs each Python or shell action in a per-session temp workspace via
|
| 8 |
+
``subprocess.run`` with a per-step timeout and an output cap. Light
|
| 9 |
+
isolation — suitable for trusted-input research workflows. The
|
| 10 |
+
workspace is sandboxed (separate cwd) but **not** a security boundary;
|
| 11 |
+
malicious code can still touch the host filesystem outside ``cwd``.
|
| 12 |
+
|
| 13 |
+
- :class:`DockerCodingAgentProxy` (``type: docker_coding``) runs each
|
| 14 |
+
action inside an ephemeral Docker container with ``--network=none``,
|
| 15 |
+
``--memory``, ``--cpus``, ``--read-only`` and a writable workspace
|
| 16 |
+
bind-mounted at ``/work``. Real isolation — survives untrusted code.
|
| 17 |
+
Requires the ``docker`` Python package and a running Docker daemon.
|
| 18 |
+
|
| 19 |
+
Both inherit per-step / per-session / rate-limit enforcement from the
|
| 20 |
+
existing :mod:`potato.agent_proxy.sandbox` framework via the standard
|
| 21 |
+
``send_message`` flow in ``routes.py:agent_chat_send``.
|
| 22 |
+
|
| 23 |
+
Configuration shape (both proxies):
|
| 24 |
+
|
| 25 |
+
agent_proxy:
|
| 26 |
+
type: subprocess_coding | docker_coding
|
| 27 |
+
llm:
|
| 28 |
+
endpoint_type: ollama
|
| 29 |
+
model: llama3.2:3b
|
| 30 |
+
base_url: http://localhost:11434
|
| 31 |
+
temperature: 0.2
|
| 32 |
+
max_tokens: 800
|
| 33 |
+
execution:
|
| 34 |
+
per_step_timeout: 8 # seconds
|
| 35 |
+
max_output_chars: 4000
|
| 36 |
+
starter_files: {} # {filename: contents} written into workspace
|
| 37 |
+
docker: # only for docker_coding
|
| 38 |
+
image: python:3.11-slim
|
| 39 |
+
memory: 512m
|
| 40 |
+
cpus: 1.0
|
| 41 |
+
network: none # "none" or "bridge"
|
| 42 |
+
sandbox: { max_steps: 20, ... } # standard agent-proxy sandbox knobs
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
from __future__ import annotations
|
| 46 |
+
|
| 47 |
+
import json
|
| 48 |
+
import logging
|
| 49 |
+
import os
|
| 50 |
+
import re
|
| 51 |
+
import shutil
|
| 52 |
+
import subprocess
|
| 53 |
+
import tempfile
|
| 54 |
+
from abc import abstractmethod
|
| 55 |
+
from dataclasses import dataclass
|
| 56 |
+
from typing import Any, Dict, List, Optional
|
| 57 |
+
|
| 58 |
+
from .base import AgentMessage, AgentProxyFactory, AgentResponse, BaseAgentProxy
|
| 59 |
+
|
| 60 |
+
logger = logging.getLogger(__name__)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
_PLANNER_SYSTEM_PROMPT = (
|
| 64 |
+
"You are an autonomous coding agent. The user will give you a coding task. "
|
| 65 |
+
"Each turn, decide a SINGLE next action and respond with ONLY a JSON object "
|
| 66 |
+
"of the form: {\"thought\": str, \"action\": {\"type\": str, \"code\": str}}. "
|
| 67 |
+
"The 'type' must be one of:\n"
|
| 68 |
+
" - \"python\": run the contents of 'code' as a python script in the workspace\n"
|
| 69 |
+
" - \"shell\": run 'code' as a bash command in the workspace\n"
|
| 70 |
+
" - \"finish\": stop and return your final answer in 'code' (which is then shown "
|
| 71 |
+
"to the user as your conclusion -- no execution happens)\n"
|
| 72 |
+
"Keep each action small and focused. Use 'finish' as soon as the task is done."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class _ExecResult:
|
| 78 |
+
stdout: str
|
| 79 |
+
stderr: str
|
| 80 |
+
exit_code: Optional[int]
|
| 81 |
+
timed_out: bool = False
|
| 82 |
+
error: Optional[str] = None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CodingAgentProxy(BaseAgentProxy):
|
| 86 |
+
"""Shared planner/executor scaffold for coding agents.
|
| 87 |
+
|
| 88 |
+
Subclasses implement :meth:`_execute` to run an action in their
|
| 89 |
+
chosen sandbox. Each call to :meth:`send_message`:
|
| 90 |
+
|
| 91 |
+
1. Appends the user message to the running history.
|
| 92 |
+
2. Asks an LLM (configurable endpoint) for a JSON ``{thought, action}``.
|
| 93 |
+
3. Hands ``action`` to the subclass for execution.
|
| 94 |
+
4. Returns a single reply combining thought + tool output.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def _initialize(self):
|
| 98 |
+
llm_cfg = self.config.get("llm") or {}
|
| 99 |
+
self.llm_endpoint_type = llm_cfg.get("endpoint_type", "ollama")
|
| 100 |
+
self.llm_model = llm_cfg.get("model")
|
| 101 |
+
self.llm_base_url = llm_cfg.get("base_url")
|
| 102 |
+
self.llm_temperature = llm_cfg.get("temperature", 0.2)
|
| 103 |
+
self.llm_max_tokens = llm_cfg.get("max_tokens", 800)
|
| 104 |
+
# OpenAI-compatible servers (vLLM etc.) ignore the key but the SDK
|
| 105 |
+
# requires a non-empty string. Ollama needs none. Without forwarding
|
| 106 |
+
# this the planner silently failed with "planner_unavailable".
|
| 107 |
+
self.llm_api_key = llm_cfg.get("api_key")
|
| 108 |
+
# Last endpoint init / call error, surfaced to the user instead of
|
| 109 |
+
# an opaque "planner unavailable" message.
|
| 110 |
+
self._llm_error: Optional[str] = None
|
| 111 |
+
|
| 112 |
+
execution_cfg = self.config.get("execution") or {}
|
| 113 |
+
self.per_step_timeout = execution_cfg.get("per_step_timeout", 8)
|
| 114 |
+
self.max_output_chars = execution_cfg.get("max_output_chars", 4000)
|
| 115 |
+
self.starter_files: Dict[str, str] = execution_cfg.get("starter_files", {}) or {}
|
| 116 |
+
|
| 117 |
+
self._llm = None # lazy
|
| 118 |
+
|
| 119 |
+
# ------------------------------------------------------------------
|
| 120 |
+
# LLM lazy-init
|
| 121 |
+
# ------------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
def _get_llm(self):
|
| 124 |
+
if self._llm is not None:
|
| 125 |
+
return self._llm
|
| 126 |
+
try:
|
| 127 |
+
from potato.ai.ai_endpoint import AIEndpointFactory
|
| 128 |
+
|
| 129 |
+
ai_cfg: Dict[str, Any] = {
|
| 130 |
+
"model": self.llm_model,
|
| 131 |
+
"max_tokens": self.llm_max_tokens,
|
| 132 |
+
"temperature": self.llm_temperature,
|
| 133 |
+
}
|
| 134 |
+
if self.llm_base_url:
|
| 135 |
+
ai_cfg["base_url"] = self.llm_base_url
|
| 136 |
+
# Forward the key for OpenAI-compatible endpoints; vLLM ignores
|
| 137 |
+
# its value but the OpenAI SDK rejects an empty one. Fall back to
|
| 138 |
+
# env then a non-empty placeholder so local servers just work.
|
| 139 |
+
ai_cfg["api_key"] = (
|
| 140 |
+
self.llm_api_key
|
| 141 |
+
or os.environ.get("OPENAI_API_KEY")
|
| 142 |
+
or os.environ.get("ANTHROPIC_API_KEY")
|
| 143 |
+
or "EMPTY"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self._llm = AIEndpointFactory.create_endpoint({
|
| 147 |
+
"ai_support": {
|
| 148 |
+
"enabled": True,
|
| 149 |
+
"endpoint_type": self.llm_endpoint_type,
|
| 150 |
+
"ai_config": ai_cfg,
|
| 151 |
+
}
|
| 152 |
+
})
|
| 153 |
+
self._llm_error = None
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.warning("CodingAgentProxy: planner LLM init failed: %s", e)
|
| 156 |
+
self._llm_error = f"{type(e).__name__}: {e}"
|
| 157 |
+
self._llm = None
|
| 158 |
+
return self._llm
|
| 159 |
+
|
| 160 |
+
# ------------------------------------------------------------------
|
| 161 |
+
# Lifecycle
|
| 162 |
+
# ------------------------------------------------------------------
|
| 163 |
+
|
| 164 |
+
def start_session(self, task_description: str) -> dict:
|
| 165 |
+
workspace = tempfile.mkdtemp(prefix="potato_coding_agent_")
|
| 166 |
+
for filename, contents in self.starter_files.items():
|
| 167 |
+
target = os.path.join(workspace, filename)
|
| 168 |
+
os.makedirs(os.path.dirname(target) or workspace, exist_ok=True)
|
| 169 |
+
with open(target, "w") as f:
|
| 170 |
+
f.write(contents)
|
| 171 |
+
history = [
|
| 172 |
+
{"role": "system", "content": _PLANNER_SYSTEM_PROMPT},
|
| 173 |
+
{
|
| 174 |
+
"role": "system",
|
| 175 |
+
"content": f"Workspace: {workspace}\nTask: {task_description}",
|
| 176 |
+
},
|
| 177 |
+
]
|
| 178 |
+
return {
|
| 179 |
+
"workspace": workspace,
|
| 180 |
+
"history": history,
|
| 181 |
+
"step": 0,
|
| 182 |
+
"finished": False,
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
def end_session(self, session_context: dict):
|
| 186 |
+
workspace = session_context.get("workspace") if session_context else None
|
| 187 |
+
if workspace and os.path.isdir(workspace):
|
| 188 |
+
shutil.rmtree(workspace, ignore_errors=True)
|
| 189 |
+
|
| 190 |
+
# ------------------------------------------------------------------
|
| 191 |
+
# Per-turn flow
|
| 192 |
+
# ------------------------------------------------------------------
|
| 193 |
+
|
| 194 |
+
def send_message(self, message: str, session_context: dict) -> AgentResponse:
|
| 195 |
+
if session_context.get("finished"):
|
| 196 |
+
return AgentResponse(
|
| 197 |
+
message=AgentMessage(role="agent", content="(session already finished)"),
|
| 198 |
+
done=True,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
history: List[Dict[str, str]] = session_context.setdefault("history", [])
|
| 202 |
+
history.append({"role": "user", "content": message})
|
| 203 |
+
session_context["step"] = session_context.get("step", 0) + 1
|
| 204 |
+
|
| 205 |
+
plan = self._plan_next_action(history)
|
| 206 |
+
if plan is None:
|
| 207 |
+
detail = self._llm_error or "no response from planner LLM"
|
| 208 |
+
reply = (
|
| 209 |
+
f"Planner LLM unavailable ({self.llm_endpoint_type}): "
|
| 210 |
+
f"{detail}"
|
| 211 |
+
)
|
| 212 |
+
history.append({"role": "assistant", "content": reply})
|
| 213 |
+
session_context["finished"] = True
|
| 214 |
+
return AgentResponse(
|
| 215 |
+
message=AgentMessage(role="error", content=reply), error="planner_unavailable",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
thought = plan.get("thought", "")
|
| 219 |
+
action = plan.get("action") or {}
|
| 220 |
+
atype = (action.get("type") or "").strip().lower()
|
| 221 |
+
code = action.get("code") or ""
|
| 222 |
+
|
| 223 |
+
if atype == "finish":
|
| 224 |
+
reply = self._format_finish_reply(thought, code)
|
| 225 |
+
history.append({"role": "assistant", "content": reply})
|
| 226 |
+
session_context["finished"] = True
|
| 227 |
+
return AgentResponse(
|
| 228 |
+
message=AgentMessage(role="agent", content=reply), done=True,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if atype not in ("python", "shell"):
|
| 232 |
+
reply = (
|
| 233 |
+
f"{thought}\n\n[invalid action type {atype!r}; try 'python', "
|
| 234 |
+
"'shell', or 'finish']"
|
| 235 |
+
)
|
| 236 |
+
history.append({"role": "assistant", "content": reply})
|
| 237 |
+
return AgentResponse(message=AgentMessage(role="agent", content=reply))
|
| 238 |
+
|
| 239 |
+
result = self._execute(atype, code, session_context)
|
| 240 |
+
reply = self._format_exec_reply(thought, atype, code, result)
|
| 241 |
+
history.append({"role": "assistant", "content": reply})
|
| 242 |
+
return AgentResponse(message=AgentMessage(role="agent", content=reply))
|
| 243 |
+
|
| 244 |
+
# ------------------------------------------------------------------
|
| 245 |
+
# Planning
|
| 246 |
+
# ------------------------------------------------------------------
|
| 247 |
+
|
| 248 |
+
def _plan_next_action(self, history: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
|
| 249 |
+
endpoint = self._get_llm()
|
| 250 |
+
if endpoint is None:
|
| 251 |
+
return None
|
| 252 |
+
try:
|
| 253 |
+
if hasattr(endpoint, "chat_query"):
|
| 254 |
+
raw = endpoint.chat_query(history)
|
| 255 |
+
else:
|
| 256 |
+
flat = "\n".join(f'{m["role"]}: {m["content"]}' for m in history)
|
| 257 |
+
raw = endpoint.query(flat + "\nassistant:", None)
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.warning("Planner LLM call failed: %s", e)
|
| 260 |
+
self._llm_error = f"{type(e).__name__}: {e}"
|
| 261 |
+
return None
|
| 262 |
+
|
| 263 |
+
if isinstance(raw, dict):
|
| 264 |
+
return raw # already parsed JSON
|
| 265 |
+
text = str(raw or "").strip()
|
| 266 |
+
if not text:
|
| 267 |
+
return None
|
| 268 |
+
try:
|
| 269 |
+
return json.loads(text)
|
| 270 |
+
except json.JSONDecodeError:
|
| 271 |
+
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
| 272 |
+
if match:
|
| 273 |
+
try:
|
| 274 |
+
return json.loads(match.group(0))
|
| 275 |
+
except json.JSONDecodeError:
|
| 276 |
+
pass
|
| 277 |
+
# Last-ditch: treat the whole text as a finish reply.
|
| 278 |
+
return {"thought": "", "action": {"type": "finish", "code": text[:500]}}
|
| 279 |
+
|
| 280 |
+
# ------------------------------------------------------------------
|
| 281 |
+
# Reply formatting
|
| 282 |
+
# ------------------------------------------------------------------
|
| 283 |
+
|
| 284 |
+
def _format_exec_reply(
|
| 285 |
+
self, thought: str, atype: str, code: str, result: _ExecResult
|
| 286 |
+
) -> str:
|
| 287 |
+
parts: List[str] = []
|
| 288 |
+
if thought:
|
| 289 |
+
parts.append(thought.strip())
|
| 290 |
+
parts.append(f"```{atype}\n{code.strip()}\n```")
|
| 291 |
+
out_block: List[str] = []
|
| 292 |
+
if result.timed_out:
|
| 293 |
+
out_block.append(f"[timeout after {self.per_step_timeout}s]")
|
| 294 |
+
if result.error:
|
| 295 |
+
out_block.append(f"[error: {result.error}]")
|
| 296 |
+
if result.exit_code is not None:
|
| 297 |
+
out_block.append(f"[exit={result.exit_code}]")
|
| 298 |
+
if result.stdout:
|
| 299 |
+
out_block.append("stdout:\n" + self._truncate(result.stdout))
|
| 300 |
+
if result.stderr:
|
| 301 |
+
out_block.append("stderr:\n" + self._truncate(result.stderr))
|
| 302 |
+
if not out_block:
|
| 303 |
+
out_block.append("(no output)")
|
| 304 |
+
parts.append("\n".join(out_block))
|
| 305 |
+
return "\n\n".join(parts)
|
| 306 |
+
|
| 307 |
+
def _format_finish_reply(self, thought: str, final_text: str) -> str:
|
| 308 |
+
parts = []
|
| 309 |
+
if thought:
|
| 310 |
+
parts.append(thought.strip())
|
| 311 |
+
if final_text:
|
| 312 |
+
parts.append(final_text.strip())
|
| 313 |
+
if not parts:
|
| 314 |
+
parts.append("(done)")
|
| 315 |
+
return "\n\n".join(parts)
|
| 316 |
+
|
| 317 |
+
def _truncate(self, text: str) -> str:
|
| 318 |
+
if len(text) <= self.max_output_chars:
|
| 319 |
+
return text
|
| 320 |
+
cut = self.max_output_chars
|
| 321 |
+
return text[:cut] + f"\n[...truncated {len(text) - cut} chars]"
|
| 322 |
+
|
| 323 |
+
# ------------------------------------------------------------------
|
| 324 |
+
# Sandbox-specific execution
|
| 325 |
+
# ------------------------------------------------------------------
|
| 326 |
+
|
| 327 |
+
@abstractmethod
|
| 328 |
+
def _execute(
|
| 329 |
+
self, action_type: str, code: str, session_context: dict
|
| 330 |
+
) -> _ExecResult:
|
| 331 |
+
"""Execute ``code`` (one of 'python' / 'shell') in the sandbox."""
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class SubprocessCodingAgentProxy(CodingAgentProxy):
|
| 335 |
+
"""Local subprocess-based execution.
|
| 336 |
+
|
| 337 |
+
NOT a security boundary -- the per-step timeout + tempdir cwd is the
|
| 338 |
+
only protection. Use ``DockerCodingAgentProxy`` for untrusted input.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
proxy_type = "subprocess_coding"
|
| 342 |
+
|
| 343 |
+
def _execute(
|
| 344 |
+
self, action_type: str, code: str, session_context: dict
|
| 345 |
+
) -> _ExecResult:
|
| 346 |
+
workspace = session_context["workspace"]
|
| 347 |
+
env = self._build_env()
|
| 348 |
+
try:
|
| 349 |
+
if action_type == "python":
|
| 350 |
+
script_path = os.path.join(workspace, "_action.py")
|
| 351 |
+
with open(script_path, "w") as f:
|
| 352 |
+
f.write(code)
|
| 353 |
+
proc = subprocess.run(
|
| 354 |
+
["python", script_path],
|
| 355 |
+
cwd=workspace,
|
| 356 |
+
env=env,
|
| 357 |
+
capture_output=True,
|
| 358 |
+
text=True,
|
| 359 |
+
timeout=self.per_step_timeout,
|
| 360 |
+
)
|
| 361 |
+
else: # shell
|
| 362 |
+
proc = subprocess.run(
|
| 363 |
+
["bash", "-c", code],
|
| 364 |
+
cwd=workspace,
|
| 365 |
+
env=env,
|
| 366 |
+
capture_output=True,
|
| 367 |
+
text=True,
|
| 368 |
+
timeout=self.per_step_timeout,
|
| 369 |
+
)
|
| 370 |
+
except subprocess.TimeoutExpired as e:
|
| 371 |
+
return _ExecResult(
|
| 372 |
+
stdout=(e.stdout or "") if isinstance(e.stdout, str) else "",
|
| 373 |
+
stderr=(e.stderr or "") if isinstance(e.stderr, str) else "",
|
| 374 |
+
exit_code=None,
|
| 375 |
+
timed_out=True,
|
| 376 |
+
)
|
| 377 |
+
except FileNotFoundError as e:
|
| 378 |
+
return _ExecResult(stdout="", stderr="", exit_code=None, error=str(e))
|
| 379 |
+
return _ExecResult(
|
| 380 |
+
stdout=proc.stdout or "",
|
| 381 |
+
stderr=proc.stderr or "",
|
| 382 |
+
exit_code=proc.returncode,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def _build_env(self) -> Dict[str, str]:
|
| 386 |
+
# Strip env down to a minimal set so subprocess code can't
|
| 387 |
+
# accidentally exfiltrate the host's secrets via env vars.
|
| 388 |
+
keep = {"PATH", "HOME", "LANG", "LC_ALL"}
|
| 389 |
+
return {k: v for k, v in os.environ.items() if k in keep}
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class DockerCodingAgentProxy(CodingAgentProxy):
|
| 393 |
+
"""Ephemeral-container execution. Real isolation; requires Docker."""
|
| 394 |
+
|
| 395 |
+
proxy_type = "docker_coding"
|
| 396 |
+
|
| 397 |
+
def _initialize(self):
|
| 398 |
+
super()._initialize()
|
| 399 |
+
docker_cfg = self.config.get("docker") or {}
|
| 400 |
+
self.docker_image = docker_cfg.get("image", "python:3.11-slim")
|
| 401 |
+
self.docker_memory = docker_cfg.get("memory", "512m")
|
| 402 |
+
self.docker_cpus = str(docker_cfg.get("cpus", 1.0))
|
| 403 |
+
self.docker_network = docker_cfg.get("network", "none")
|
| 404 |
+
self._docker = None # lazy
|
| 405 |
+
# Sanity check: warn if docker CLI isn't on PATH
|
| 406 |
+
if shutil.which("docker") is None:
|
| 407 |
+
logger.warning(
|
| 408 |
+
"DockerCodingAgentProxy: 'docker' CLI not found on PATH. "
|
| 409 |
+
"Container execution will fail at runtime."
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def _execute(
|
| 413 |
+
self, action_type: str, code: str, session_context: dict
|
| 414 |
+
) -> _ExecResult:
|
| 415 |
+
workspace = session_context["workspace"]
|
| 416 |
+
# Materialise the code as a file in the workspace so the container
|
| 417 |
+
# can run it without inline injection through `-c`.
|
| 418 |
+
if action_type == "python":
|
| 419 |
+
target = os.path.join(workspace, "_action.py")
|
| 420 |
+
with open(target, "w") as f:
|
| 421 |
+
f.write(code)
|
| 422 |
+
container_cmd = ["python", "/work/_action.py"]
|
| 423 |
+
else: # shell
|
| 424 |
+
target = os.path.join(workspace, "_action.sh")
|
| 425 |
+
with open(target, "w") as f:
|
| 426 |
+
f.write(code)
|
| 427 |
+
os.chmod(target, 0o755)
|
| 428 |
+
container_cmd = ["bash", "/work/_action.sh"]
|
| 429 |
+
|
| 430 |
+
cmd = [
|
| 431 |
+
"docker", "run", "--rm",
|
| 432 |
+
f"--network={self.docker_network}",
|
| 433 |
+
f"--memory={self.docker_memory}",
|
| 434 |
+
f"--cpus={self.docker_cpus}",
|
| 435 |
+
"--read-only",
|
| 436 |
+
"--tmpfs", "/tmp:exec,size=64m",
|
| 437 |
+
"-v", f"{workspace}:/work",
|
| 438 |
+
"-w", "/work",
|
| 439 |
+
self.docker_image,
|
| 440 |
+
] + container_cmd
|
| 441 |
+
try:
|
| 442 |
+
proc = subprocess.run(
|
| 443 |
+
cmd,
|
| 444 |
+
capture_output=True,
|
| 445 |
+
text=True,
|
| 446 |
+
timeout=self.per_step_timeout + 5, # docker pull/start overhead
|
| 447 |
+
)
|
| 448 |
+
except subprocess.TimeoutExpired as e:
|
| 449 |
+
return _ExecResult(
|
| 450 |
+
stdout=(e.stdout or "") if isinstance(e.stdout, str) else "",
|
| 451 |
+
stderr=(e.stderr or "") if isinstance(e.stderr, str) else "",
|
| 452 |
+
exit_code=None,
|
| 453 |
+
timed_out=True,
|
| 454 |
+
)
|
| 455 |
+
except FileNotFoundError as e:
|
| 456 |
+
return _ExecResult(stdout="", stderr="", exit_code=None, error=str(e))
|
| 457 |
+
return _ExecResult(
|
| 458 |
+
stdout=proc.stdout or "",
|
| 459 |
+
stderr=proc.stderr or "",
|
| 460 |
+
exit_code=proc.returncode,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# Register both with the factory so configs can refer to them by name.
|
| 465 |
+
AgentProxyFactory.register("subprocess_coding", SubprocessCodingAgentProxy)
|
| 466 |
+
AgentProxyFactory.register("docker_coding", DockerCodingAgentProxy)
|
potato/agent_proxy/echo_proxy.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Echo Agent Proxy
|
| 3 |
+
|
| 4 |
+
A testing/demo proxy that returns responses from a configurable list.
|
| 5 |
+
Cycles through responses in order, wrapping around when exhausted.
|
| 6 |
+
|
| 7 |
+
Configuration:
|
| 8 |
+
agent_proxy:
|
| 9 |
+
type: echo
|
| 10 |
+
responses:
|
| 11 |
+
- "I understand your request."
|
| 12 |
+
- "Working on it now."
|
| 13 |
+
- "Here's what I found."
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
from .base import BaseAgentProxy, AgentMessage, AgentResponse, AgentProxyFactory
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class EchoProxy(BaseAgentProxy):
|
| 24 |
+
"""Test proxy that returns canned responses in order."""
|
| 25 |
+
|
| 26 |
+
proxy_type = "echo"
|
| 27 |
+
|
| 28 |
+
def _initialize(self):
|
| 29 |
+
self.responses = self.config.get("responses", [
|
| 30 |
+
"I understand.",
|
| 31 |
+
"Working on it.",
|
| 32 |
+
"Done!",
|
| 33 |
+
])
|
| 34 |
+
|
| 35 |
+
def start_session(self, task_description: str) -> dict:
|
| 36 |
+
return {"response_index": 0, "task_description": task_description}
|
| 37 |
+
|
| 38 |
+
def send_message(self, message: str, session_context: dict) -> AgentResponse:
|
| 39 |
+
idx = session_context.get("response_index", 0)
|
| 40 |
+
response_text = self.responses[idx % len(self.responses)]
|
| 41 |
+
session_context["response_index"] = idx + 1
|
| 42 |
+
|
| 43 |
+
return AgentResponse(
|
| 44 |
+
message=AgentMessage(
|
| 45 |
+
role="agent",
|
| 46 |
+
content=response_text,
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def end_session(self, session_context: dict):
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Register with factory
|
| 55 |
+
AgentProxyFactory.register("echo", EchoProxy)
|
potato/agent_proxy/http_proxy.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generic HTTP Agent Proxy
|
| 3 |
+
|
| 4 |
+
POSTs to any REST endpoint with configurable field mapping.
|
| 5 |
+
Supports sending full conversation history and custom headers.
|
| 6 |
+
|
| 7 |
+
Configuration:
|
| 8 |
+
agent_proxy:
|
| 9 |
+
type: http
|
| 10 |
+
url: "http://localhost:8080/chat"
|
| 11 |
+
headers:
|
| 12 |
+
Authorization: "Bearer YOUR_KEY"
|
| 13 |
+
message_key: "message" # key in request body for user message
|
| 14 |
+
response_key: "response" # key in response JSON for agent reply
|
| 15 |
+
session_id_key: "session_id" # key in request/response for session tracking
|
| 16 |
+
send_history: false # whether to send full conversation history
|
| 17 |
+
history_key: "messages" # key for history array in request body
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
import uuid
|
| 22 |
+
|
| 23 |
+
import requests
|
| 24 |
+
|
| 25 |
+
from .base import BaseAgentProxy, AgentMessage, AgentResponse, AgentProxyFactory
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class GenericHTTPProxy(BaseAgentProxy):
|
| 31 |
+
"""Generic REST API proxy with configurable field mapping."""
|
| 32 |
+
|
| 33 |
+
proxy_type = "http"
|
| 34 |
+
|
| 35 |
+
def _initialize(self):
|
| 36 |
+
self.url = self.config.get("url")
|
| 37 |
+
if not self.url:
|
| 38 |
+
raise ValueError("http proxy requires 'url' in agent_proxy config")
|
| 39 |
+
|
| 40 |
+
self.headers = self.config.get("headers", {})
|
| 41 |
+
self.message_key = self.config.get("message_key", "message")
|
| 42 |
+
self.response_key = self.config.get("response_key", "response")
|
| 43 |
+
self.session_id_key = self.config.get("session_id_key", "session_id")
|
| 44 |
+
self.send_history = self.config.get("send_history", False)
|
| 45 |
+
self.history_key = self.config.get("history_key", "messages")
|
| 46 |
+
self.timeout = self.config.get("sandbox", {}).get(
|
| 47 |
+
"request_timeout_seconds", 60
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def start_session(self, task_description: str) -> dict:
|
| 51 |
+
return {
|
| 52 |
+
"session_id": str(uuid.uuid4()),
|
| 53 |
+
"task_description": task_description,
|
| 54 |
+
"history": [],
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
def send_message(self, message: str, session_context: dict) -> AgentResponse:
|
| 58 |
+
payload = {
|
| 59 |
+
self.message_key: message,
|
| 60 |
+
self.session_id_key: session_context["session_id"],
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
if self.send_history:
|
| 64 |
+
payload[self.history_key] = session_context.get("history", [])
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
resp = requests.post(
|
| 68 |
+
self.url,
|
| 69 |
+
json=payload,
|
| 70 |
+
headers=self.headers,
|
| 71 |
+
timeout=self.timeout,
|
| 72 |
+
)
|
| 73 |
+
resp.raise_for_status()
|
| 74 |
+
data = resp.json()
|
| 75 |
+
|
| 76 |
+
response_text = data.get(self.response_key, "")
|
| 77 |
+
if not response_text and isinstance(data, str):
|
| 78 |
+
response_text = data
|
| 79 |
+
|
| 80 |
+
# Update history
|
| 81 |
+
session_context.setdefault("history", []).append(
|
| 82 |
+
{"role": "user", "content": message}
|
| 83 |
+
)
|
| 84 |
+
session_context["history"].append(
|
| 85 |
+
{"role": "agent", "content": response_text}
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return AgentResponse(
|
| 89 |
+
message=AgentMessage(role="agent", content=str(response_text))
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
except requests.Timeout:
|
| 93 |
+
return AgentResponse(
|
| 94 |
+
message=AgentMessage(role="error", content="Agent request timed out."),
|
| 95 |
+
error="timeout",
|
| 96 |
+
)
|
| 97 |
+
except requests.RequestException as e:
|
| 98 |
+
logger.error(f"HTTP proxy request failed: {e}")
|
| 99 |
+
return AgentResponse(
|
| 100 |
+
message=AgentMessage(
|
| 101 |
+
role="error", content=f"Agent communication error: {e}"
|
| 102 |
+
),
|
| 103 |
+
error=str(e),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# Register with factory
|
| 108 |
+
AgentProxyFactory.register("http", GenericHTTPProxy)
|
potato/agent_proxy/openai_proxy.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI Chat Completions Agent Proxy
|
| 3 |
+
|
| 4 |
+
Uses the OpenAI SDK to communicate with chat completion models.
|
| 5 |
+
Maintains conversation history in session context for multi-turn dialogue.
|
| 6 |
+
|
| 7 |
+
Configuration:
|
| 8 |
+
agent_proxy:
|
| 9 |
+
type: openai
|
| 10 |
+
api_key: "${OPENAI_API_KEY}" # or set OPENAI_API_KEY env var
|
| 11 |
+
model: "gpt-4o"
|
| 12 |
+
system_prompt: "You are a helpful travel agent."
|
| 13 |
+
temperature: 0.7
|
| 14 |
+
max_tokens: 1024
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
from .base import BaseAgentProxy, AgentMessage, AgentResponse, AgentProxyFactory
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class OpenAIChatProxy(BaseAgentProxy):
|
| 26 |
+
"""OpenAI Chat Completions proxy."""
|
| 27 |
+
|
| 28 |
+
proxy_type = "openai"
|
| 29 |
+
|
| 30 |
+
def _initialize(self):
|
| 31 |
+
api_key = self.config.get("api_key", "")
|
| 32 |
+
# Support environment variable references like ${OPENAI_API_KEY}
|
| 33 |
+
if api_key.startswith("${") and api_key.endswith("}"):
|
| 34 |
+
env_var = api_key[2:-1]
|
| 35 |
+
api_key = os.environ.get(env_var, "")
|
| 36 |
+
|
| 37 |
+
if not api_key:
|
| 38 |
+
api_key = os.environ.get("OPENAI_API_KEY", "")
|
| 39 |
+
|
| 40 |
+
if not api_key:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
"OpenAI proxy requires api_key in config or OPENAI_API_KEY env var"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
import openai
|
| 47 |
+
self.client = openai.OpenAI(api_key=api_key)
|
| 48 |
+
except ImportError:
|
| 49 |
+
raise ImportError(
|
| 50 |
+
"openai package is required for the OpenAI proxy. "
|
| 51 |
+
"Install with: pip install openai"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.model = self.config.get("model", "gpt-4o")
|
| 55 |
+
self.system_prompt = self.config.get("system_prompt", "")
|
| 56 |
+
self.temperature = self.config.get("temperature", 0.7)
|
| 57 |
+
self.max_tokens = self.config.get("max_tokens", 1024)
|
| 58 |
+
self.timeout = self.config.get("sandbox", {}).get(
|
| 59 |
+
"request_timeout_seconds", 60
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def start_session(self, task_description: str) -> dict:
|
| 63 |
+
messages = []
|
| 64 |
+
if self.system_prompt:
|
| 65 |
+
messages.append({"role": "system", "content": self.system_prompt})
|
| 66 |
+
# Include task description as system context
|
| 67 |
+
messages.append({
|
| 68 |
+
"role": "system",
|
| 69 |
+
"content": f"The user's task: {task_description}",
|
| 70 |
+
})
|
| 71 |
+
return {"messages": messages}
|
| 72 |
+
|
| 73 |
+
def send_message(self, message: str, session_context: dict) -> AgentResponse:
|
| 74 |
+
messages = session_context.get("messages", [])
|
| 75 |
+
messages.append({"role": "user", "content": message})
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
response = self.client.chat.completions.create(
|
| 79 |
+
model=self.model,
|
| 80 |
+
messages=messages,
|
| 81 |
+
temperature=self.temperature,
|
| 82 |
+
max_tokens=self.max_tokens,
|
| 83 |
+
timeout=self.timeout,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
content = response.choices[0].message.content or ""
|
| 87 |
+
messages.append({"role": "assistant", "content": content})
|
| 88 |
+
session_context["messages"] = messages
|
| 89 |
+
|
| 90 |
+
return AgentResponse(
|
| 91 |
+
message=AgentMessage(role="agent", content=content)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"OpenAI proxy error: {e}")
|
| 96 |
+
return AgentResponse(
|
| 97 |
+
message=AgentMessage(
|
| 98 |
+
role="error", content=f"Agent error: {e}"
|
| 99 |
+
),
|
| 100 |
+
error=str(e),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Register with factory
|
| 105 |
+
AgentProxyFactory.register("openai", OpenAIChatProxy)
|
potato/agent_proxy/sandbox.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Proxy Safety Sandbox
|
| 3 |
+
|
| 4 |
+
Enforces limits on agent interactions: step counts, session timeouts,
|
| 5 |
+
rate limits, and request timeouts. Prevents runaway or abusive sessions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import threading
|
| 10 |
+
import logging
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from typing import Dict, List
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SandboxViolation(Exception):
|
| 18 |
+
"""Raised when a safety limit is exceeded."""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SafetySandbox:
|
| 23 |
+
"""Enforces safety limits on agent interactions."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, config: dict):
|
| 26 |
+
sandbox_config = config.get("sandbox", {})
|
| 27 |
+
self.max_steps = sandbox_config.get("max_steps", 20)
|
| 28 |
+
self.max_session_seconds = sandbox_config.get("max_session_seconds", 600)
|
| 29 |
+
self.rate_limit_per_minute = sandbox_config.get("rate_limit_per_minute", 10)
|
| 30 |
+
self.request_timeout = sandbox_config.get("request_timeout_seconds", 60)
|
| 31 |
+
|
| 32 |
+
# Sliding window rate limit tracking: user_id -> list of timestamps
|
| 33 |
+
self._rate_windows: Dict[str, List[float]] = defaultdict(list)
|
| 34 |
+
self._lock = threading.Lock()
|
| 35 |
+
|
| 36 |
+
def check_step_limit(self, current_steps: int):
|
| 37 |
+
"""Raise SandboxViolation if step limit reached."""
|
| 38 |
+
if current_steps >= self.max_steps:
|
| 39 |
+
raise SandboxViolation(
|
| 40 |
+
f"Step limit reached ({self.max_steps}). "
|
| 41 |
+
f"Please finish the conversation."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def check_session_timeout(self, session_start: float):
|
| 45 |
+
"""Raise SandboxViolation if session has timed out."""
|
| 46 |
+
elapsed = time.time() - session_start
|
| 47 |
+
if elapsed > self.max_session_seconds:
|
| 48 |
+
raise SandboxViolation(
|
| 49 |
+
f"Session timeout ({self.max_session_seconds}s). "
|
| 50 |
+
f"Please finish the conversation."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def check_rate_limit(self, user_id: str):
|
| 54 |
+
"""Raise SandboxViolation if user is sending too fast."""
|
| 55 |
+
now = time.time()
|
| 56 |
+
window_start = now - 60.0
|
| 57 |
+
|
| 58 |
+
with self._lock:
|
| 59 |
+
# Remove old entries outside the 1-minute window
|
| 60 |
+
timestamps = self._rate_windows[user_id]
|
| 61 |
+
self._rate_windows[user_id] = [
|
| 62 |
+
t for t in timestamps if t > window_start
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
if len(self._rate_windows[user_id]) >= self.rate_limit_per_minute:
|
| 66 |
+
raise SandboxViolation(
|
| 67 |
+
f"Rate limit exceeded ({self.rate_limit_per_minute}/min). "
|
| 68 |
+
f"Please wait before sending another message."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Record this request
|
| 72 |
+
self._rate_windows[user_id].append(now)
|
| 73 |
+
|
| 74 |
+
def get_request_timeout(self) -> float:
|
| 75 |
+
"""Get the timeout in seconds for proxy HTTP requests."""
|
| 76 |
+
return self.request_timeout
|
potato/agent_proxy/session.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Session Manager
|
| 3 |
+
|
| 4 |
+
Thread-safe singleton that tracks active agent interaction sessions.
|
| 5 |
+
Each session maps a (user_id, instance_id) pair to an AgentSession
|
| 6 |
+
containing the proxy, conversation history, and step count.
|
| 7 |
+
|
| 8 |
+
Follows the same singleton pattern as ItemStateManager and UserStateManager.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import threading
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from typing import Dict, List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
from .base import AgentMessage, BaseAgentProxy
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class AgentSession:
|
| 24 |
+
"""An active agent interaction session."""
|
| 25 |
+
user_id: str
|
| 26 |
+
instance_id: str
|
| 27 |
+
proxy: BaseAgentProxy
|
| 28 |
+
task_description: str
|
| 29 |
+
proxy_context: dict = field(default_factory=dict)
|
| 30 |
+
messages: List[AgentMessage] = field(default_factory=list)
|
| 31 |
+
step_count: int = 0
|
| 32 |
+
started_at: float = field(default_factory=time.time)
|
| 33 |
+
finished: bool = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AgentSessionManager:
|
| 37 |
+
"""Thread-safe manager for active agent sessions."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, config: dict):
|
| 40 |
+
self.config = config
|
| 41 |
+
self._sessions: Dict[Tuple[str, str], AgentSession] = {}
|
| 42 |
+
self._lock = threading.RLock()
|
| 43 |
+
|
| 44 |
+
def create_session(
|
| 45 |
+
self,
|
| 46 |
+
user_id: str,
|
| 47 |
+
instance_id: str,
|
| 48 |
+
proxy: BaseAgentProxy,
|
| 49 |
+
task_description: str,
|
| 50 |
+
) -> AgentSession:
|
| 51 |
+
"""Create a new session for a user/instance pair."""
|
| 52 |
+
with self._lock:
|
| 53 |
+
key = (user_id, instance_id)
|
| 54 |
+
if key in self._sessions and not self._sessions[key].finished:
|
| 55 |
+
logger.warning(
|
| 56 |
+
f"Session already exists for {key}, returning existing"
|
| 57 |
+
)
|
| 58 |
+
return self._sessions[key]
|
| 59 |
+
|
| 60 |
+
proxy_context = proxy.start_session(task_description)
|
| 61 |
+
session = AgentSession(
|
| 62 |
+
user_id=user_id,
|
| 63 |
+
instance_id=instance_id,
|
| 64 |
+
proxy=proxy,
|
| 65 |
+
task_description=task_description,
|
| 66 |
+
proxy_context=proxy_context,
|
| 67 |
+
)
|
| 68 |
+
self._sessions[key] = session
|
| 69 |
+
logger.debug(f"Created agent session for {key}")
|
| 70 |
+
return session
|
| 71 |
+
|
| 72 |
+
def get_session(
|
| 73 |
+
self, user_id: str, instance_id: str
|
| 74 |
+
) -> Optional[AgentSession]:
|
| 75 |
+
"""Get an active session, or None if not found."""
|
| 76 |
+
with self._lock:
|
| 77 |
+
return self._sessions.get((user_id, instance_id))
|
| 78 |
+
|
| 79 |
+
def remove_session(self, user_id: str, instance_id: str):
|
| 80 |
+
"""Remove a session and clean up proxy resources."""
|
| 81 |
+
with self._lock:
|
| 82 |
+
key = (user_id, instance_id)
|
| 83 |
+
session = self._sessions.pop(key, None)
|
| 84 |
+
if session:
|
| 85 |
+
try:
|
| 86 |
+
session.proxy.end_session(session.proxy_context)
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.warning(f"Error ending proxy session for {key}: {e}")
|
| 89 |
+
logger.debug(f"Removed agent session for {key}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Singleton management
|
| 93 |
+
_AGENT_SESSION_MANAGER: Optional[AgentSessionManager] = None
|
| 94 |
+
_AGENT_SESSION_MANAGER_LOCK = threading.Lock()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def init_agent_session_manager(config: dict) -> AgentSessionManager:
|
| 98 |
+
"""Initialize the singleton AgentSessionManager."""
|
| 99 |
+
global _AGENT_SESSION_MANAGER
|
| 100 |
+
if _AGENT_SESSION_MANAGER is None:
|
| 101 |
+
with _AGENT_SESSION_MANAGER_LOCK:
|
| 102 |
+
if _AGENT_SESSION_MANAGER is None:
|
| 103 |
+
_AGENT_SESSION_MANAGER = AgentSessionManager(config)
|
| 104 |
+
return _AGENT_SESSION_MANAGER
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_agent_session_manager() -> AgentSessionManager:
|
| 108 |
+
"""Get the singleton AgentSessionManager."""
|
| 109 |
+
global _AGENT_SESSION_MANAGER
|
| 110 |
+
if _AGENT_SESSION_MANAGER is None:
|
| 111 |
+
raise ValueError("AgentSessionManager has not been initialized yet!")
|
| 112 |
+
return _AGENT_SESSION_MANAGER
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def clear_agent_session_manager():
|
| 116 |
+
"""Clear the singleton instance (for testing)."""
|
| 117 |
+
global _AGENT_SESSION_MANAGER
|
| 118 |
+
with _AGENT_SESSION_MANAGER_LOCK:
|
| 119 |
+
_AGENT_SESSION_MANAGER = None
|
potato/agent_runner.py
ADDED
|
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Live Agent Runner
|
| 3 |
+
|
| 4 |
+
Manages an AI agent that browses the web via Playwright, controlled by an LLM.
|
| 5 |
+
Annotators can observe, pause, instruct, or take over the agent in real time.
|
| 6 |
+
|
| 7 |
+
The agent loop runs in a background thread with its own asyncio event loop.
|
| 8 |
+
Communication with Flask routes happens through thread-safe state and queues.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import asyncio
|
| 12 |
+
import base64
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import threading
|
| 17 |
+
import time
|
| 18 |
+
import uuid
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from enum import Enum
|
| 21 |
+
from queue import Queue, Empty
|
| 22 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AgentState(Enum):
|
| 28 |
+
"""States of the agent lifecycle."""
|
| 29 |
+
IDLE = "idle"
|
| 30 |
+
RUNNING = "running"
|
| 31 |
+
PAUSED = "paused"
|
| 32 |
+
TAKEOVER = "takeover"
|
| 33 |
+
COMPLETED = "completed"
|
| 34 |
+
ERROR = "error"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class AgentStep:
|
| 39 |
+
"""A single step in the agent's execution."""
|
| 40 |
+
step_index: int
|
| 41 |
+
screenshot_path: str
|
| 42 |
+
action: Dict[str, Any]
|
| 43 |
+
thought: str
|
| 44 |
+
observation: str
|
| 45 |
+
timestamp: float
|
| 46 |
+
url: str = ""
|
| 47 |
+
viewport: Optional[Dict[str, int]] = None
|
| 48 |
+
coordinates: Optional[Dict[str, int]] = None
|
| 49 |
+
element: Optional[Dict[str, Any]] = None
|
| 50 |
+
annotator_instruction: Optional[str] = None
|
| 51 |
+
|
| 52 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 53 |
+
d = {
|
| 54 |
+
"step_index": self.step_index,
|
| 55 |
+
"screenshot_url": self.screenshot_path,
|
| 56 |
+
"action_type": self.action.get("type", "unknown"),
|
| 57 |
+
"action": self.action,
|
| 58 |
+
"thought": self.thought,
|
| 59 |
+
"observation": self.observation,
|
| 60 |
+
"timestamp": self.timestamp,
|
| 61 |
+
"url": self.url,
|
| 62 |
+
}
|
| 63 |
+
if self.viewport:
|
| 64 |
+
d["viewport"] = self.viewport
|
| 65 |
+
if self.coordinates:
|
| 66 |
+
d["coordinates"] = self.coordinates
|
| 67 |
+
if self.element:
|
| 68 |
+
d["element"] = self.element
|
| 69 |
+
if self.annotator_instruction:
|
| 70 |
+
d["annotator_instruction"] = self.annotator_instruction
|
| 71 |
+
return d
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class AgentConfig:
|
| 76 |
+
"""Configuration for the agent runner."""
|
| 77 |
+
max_steps: int = 30
|
| 78 |
+
step_delay: float = 1.0
|
| 79 |
+
viewport_width: int = 1280
|
| 80 |
+
viewport_height: int = 720
|
| 81 |
+
system_prompt: str = ""
|
| 82 |
+
model: str = "claude-sonnet-4-20250514"
|
| 83 |
+
api_key: str = ""
|
| 84 |
+
max_tokens: int = 4096
|
| 85 |
+
temperature: float = 0.3
|
| 86 |
+
endpoint_type: str = "anthropic_vision"
|
| 87 |
+
history_window: int = 5 # Number of recent steps to include in LLM context
|
| 88 |
+
timeout: int = 60 # Per-request timeout in seconds
|
| 89 |
+
|
| 90 |
+
base_url: str = "" # For Ollama: server URL
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def from_config(cls, config: Dict[str, Any]) -> "AgentConfig":
|
| 94 |
+
"""Create AgentConfig from a live_agent YAML config dict."""
|
| 95 |
+
ai_config = config.get("ai_config", {})
|
| 96 |
+
viewport = config.get("viewport", {})
|
| 97 |
+
endpoint_type = config.get("endpoint_type", "anthropic_vision")
|
| 98 |
+
|
| 99 |
+
# API key: Ollama doesn't need one; OpenAI-compatible servers
|
| 100 |
+
# (e.g. vLLM) ignore it but the SDK requires a non-empty string.
|
| 101 |
+
if endpoint_type == "ollama_vision":
|
| 102 |
+
api_key = ai_config.get("api_key", "")
|
| 103 |
+
default_model = "gemma3:4b"
|
| 104 |
+
elif endpoint_type == "openai_vision":
|
| 105 |
+
api_key = ai_config.get("api_key", os.environ.get("OPENAI_API_KEY", "EMPTY"))
|
| 106 |
+
default_model = "" # must be set explicitly (e.g. served model id)
|
| 107 |
+
else:
|
| 108 |
+
api_key = ai_config.get("api_key", os.environ.get("ANTHROPIC_API_KEY", ""))
|
| 109 |
+
default_model = "claude-sonnet-4-20250514"
|
| 110 |
+
|
| 111 |
+
return cls(
|
| 112 |
+
max_steps=config.get("max_steps", 30),
|
| 113 |
+
step_delay=config.get("step_delay", 1.0),
|
| 114 |
+
viewport_width=viewport.get("width", 1280),
|
| 115 |
+
viewport_height=viewport.get("height", 720),
|
| 116 |
+
system_prompt=config.get("system_prompt", DEFAULT_SYSTEM_PROMPT),
|
| 117 |
+
model=ai_config.get("model", default_model),
|
| 118 |
+
api_key=api_key,
|
| 119 |
+
max_tokens=ai_config.get("max_tokens", 4096),
|
| 120 |
+
temperature=ai_config.get("temperature", 0.3),
|
| 121 |
+
endpoint_type=endpoint_type,
|
| 122 |
+
history_window=config.get("history_window", 5),
|
| 123 |
+
timeout=ai_config.get("timeout", 60),
|
| 124 |
+
base_url=ai_config.get("base_url", "http://localhost:11434"),
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
DEFAULT_SYSTEM_PROMPT = """You are a web browsing agent. You can see screenshots of web pages and take actions to complete tasks.
|
| 129 |
+
|
| 130 |
+
For each step, analyze the current screenshot and respond with a JSON object:
|
| 131 |
+
{
|
| 132 |
+
"thought": "Your reasoning about what you see and what to do next",
|
| 133 |
+
"action": {
|
| 134 |
+
"type": "click|type|scroll|navigate|wait|done",
|
| 135 |
+
// For click: "x": 100, "y": 200
|
| 136 |
+
// For type: "text": "hello world"
|
| 137 |
+
// For scroll: "direction": "up|down", "amount": 300
|
| 138 |
+
// For navigate: "url": "https://..."
|
| 139 |
+
// For wait: (no extra fields)
|
| 140 |
+
// For done: "summary": "Task completed because..."
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
Always respond with valid JSON only. No markdown, no extra text."""
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class AgentRunner:
|
| 148 |
+
"""
|
| 149 |
+
Runs an AI agent that browses the web via Playwright.
|
| 150 |
+
|
| 151 |
+
The agent loop:
|
| 152 |
+
1. Takes a screenshot
|
| 153 |
+
2. Sends it to the LLM with context/history
|
| 154 |
+
3. Parses the LLM response for an action
|
| 155 |
+
4. Executes the action via Playwright
|
| 156 |
+
5. Emits events to all listeners (for SSE)
|
| 157 |
+
6. Repeats until done, error, or max_steps
|
| 158 |
+
|
| 159 |
+
Thread-safe control methods allow pause/resume/instruct/takeover.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self, session_id: str, config: AgentConfig, screenshot_dir: str):
|
| 163 |
+
self.session_id = session_id
|
| 164 |
+
self.config = config
|
| 165 |
+
self.screenshot_dir = screenshot_dir
|
| 166 |
+
|
| 167 |
+
# State
|
| 168 |
+
self._state = AgentState.IDLE
|
| 169 |
+
self._state_lock = threading.Lock()
|
| 170 |
+
self._steps: List[AgentStep] = []
|
| 171 |
+
self._error: Optional[str] = None
|
| 172 |
+
|
| 173 |
+
# Control
|
| 174 |
+
self._pause_event = threading.Event()
|
| 175 |
+
self._pause_event.set() # Not paused initially
|
| 176 |
+
self._stop_flag = threading.Event()
|
| 177 |
+
self._instruction_queue: Queue = Queue()
|
| 178 |
+
self._takeover_actions: Queue = Queue()
|
| 179 |
+
|
| 180 |
+
# Listeners for SSE
|
| 181 |
+
self._listeners: List[Callable] = []
|
| 182 |
+
self._listeners_lock = threading.Lock()
|
| 183 |
+
|
| 184 |
+
# Annotator interactions log
|
| 185 |
+
self._interactions: List[Dict[str, Any]] = []
|
| 186 |
+
|
| 187 |
+
# Playwright session (set during run)
|
| 188 |
+
self._playwright_session = None
|
| 189 |
+
self._llm_client = None
|
| 190 |
+
|
| 191 |
+
# Background thread
|
| 192 |
+
self._thread: Optional[threading.Thread] = None
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def state(self) -> AgentState:
|
| 196 |
+
with self._state_lock:
|
| 197 |
+
return self._state
|
| 198 |
+
|
| 199 |
+
@state.setter
|
| 200 |
+
def state(self, new_state: AgentState):
|
| 201 |
+
with self._state_lock:
|
| 202 |
+
old_state = self._state
|
| 203 |
+
self._state = new_state
|
| 204 |
+
self._emit_event("state_change", {
|
| 205 |
+
"old_state": old_state.value,
|
| 206 |
+
"new_state": new_state.value,
|
| 207 |
+
"timestamp": time.time(),
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def steps(self) -> List[AgentStep]:
|
| 212 |
+
return list(self._steps)
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def step_count(self) -> int:
|
| 216 |
+
return len(self._steps)
|
| 217 |
+
|
| 218 |
+
@property
|
| 219 |
+
def error(self) -> Optional[str]:
|
| 220 |
+
return self._error
|
| 221 |
+
|
| 222 |
+
# --- Control methods (thread-safe) ---
|
| 223 |
+
|
| 224 |
+
def pause(self):
|
| 225 |
+
"""Pause the agent loop after the current step completes."""
|
| 226 |
+
if self.state == AgentState.RUNNING:
|
| 227 |
+
self._pause_event.clear()
|
| 228 |
+
self.state = AgentState.PAUSED
|
| 229 |
+
logger.info(f"[{self.session_id}] Agent paused")
|
| 230 |
+
|
| 231 |
+
def resume(self):
|
| 232 |
+
"""Resume a paused agent."""
|
| 233 |
+
if self.state == AgentState.PAUSED:
|
| 234 |
+
self.state = AgentState.RUNNING
|
| 235 |
+
self._pause_event.set()
|
| 236 |
+
logger.info(f"[{self.session_id}] Agent resumed")
|
| 237 |
+
|
| 238 |
+
def inject_instruction(self, instruction: str):
|
| 239 |
+
"""Send an instruction to the agent (processed at next step)."""
|
| 240 |
+
self._instruction_queue.put(instruction)
|
| 241 |
+
self._interactions.append({
|
| 242 |
+
"type": "instruction",
|
| 243 |
+
"text": instruction,
|
| 244 |
+
"timestamp": time.time(),
|
| 245 |
+
"step_index": self.step_count,
|
| 246 |
+
})
|
| 247 |
+
self._emit_event("instruction_received", {"instruction": instruction})
|
| 248 |
+
logger.info(f"[{self.session_id}] Instruction injected: {instruction[:100]}")
|
| 249 |
+
|
| 250 |
+
def enter_takeover(self):
|
| 251 |
+
"""Switch to manual takeover mode."""
|
| 252 |
+
if self.state in (AgentState.RUNNING, AgentState.PAUSED):
|
| 253 |
+
self._pause_event.clear() # Pause the agent loop
|
| 254 |
+
self.state = AgentState.TAKEOVER
|
| 255 |
+
self._interactions.append({
|
| 256 |
+
"type": "takeover_start",
|
| 257 |
+
"timestamp": time.time(),
|
| 258 |
+
"step_index": self.step_count,
|
| 259 |
+
})
|
| 260 |
+
logger.info(f"[{self.session_id}] Takeover mode entered")
|
| 261 |
+
|
| 262 |
+
def exit_takeover(self):
|
| 263 |
+
"""Exit manual takeover and resume the agent."""
|
| 264 |
+
if self.state == AgentState.TAKEOVER:
|
| 265 |
+
self._interactions.append({
|
| 266 |
+
"type": "takeover_end",
|
| 267 |
+
"timestamp": time.time(),
|
| 268 |
+
"step_index": self.step_count,
|
| 269 |
+
})
|
| 270 |
+
self.state = AgentState.RUNNING
|
| 271 |
+
self._pause_event.set()
|
| 272 |
+
logger.info(f"[{self.session_id}] Takeover mode exited")
|
| 273 |
+
|
| 274 |
+
def submit_manual_action(self, action: Dict[str, Any]):
|
| 275 |
+
"""Submit a manual action during takeover mode."""
|
| 276 |
+
if self.state == AgentState.TAKEOVER:
|
| 277 |
+
self._takeover_actions.put(action)
|
| 278 |
+
|
| 279 |
+
def stop(self):
|
| 280 |
+
"""Stop the agent loop."""
|
| 281 |
+
self._stop_flag.set()
|
| 282 |
+
self._pause_event.set() # Unblock if paused
|
| 283 |
+
logger.info(f"[{self.session_id}] Stop requested")
|
| 284 |
+
|
| 285 |
+
# --- Listener management ---
|
| 286 |
+
|
| 287 |
+
def add_listener(self, callback: Callable):
|
| 288 |
+
"""Add an SSE listener callback."""
|
| 289 |
+
with self._listeners_lock:
|
| 290 |
+
self._listeners.append(callback)
|
| 291 |
+
|
| 292 |
+
def remove_listener(self, callback: Callable):
|
| 293 |
+
"""Remove an SSE listener callback."""
|
| 294 |
+
with self._listeners_lock:
|
| 295 |
+
self._listeners = [l for l in self._listeners if l is not callback]
|
| 296 |
+
|
| 297 |
+
def _emit_event(self, event_type: str, data: Dict[str, Any]):
|
| 298 |
+
"""Emit an event to all listeners."""
|
| 299 |
+
event = {"type": event_type, "data": data, "session_id": self.session_id}
|
| 300 |
+
with self._listeners_lock:
|
| 301 |
+
for listener in self._listeners:
|
| 302 |
+
try:
|
| 303 |
+
listener(event)
|
| 304 |
+
except Exception as e:
|
| 305 |
+
logger.warning(f"Listener error: {e}")
|
| 306 |
+
|
| 307 |
+
# --- Main agent loop ---
|
| 308 |
+
|
| 309 |
+
def start(self, task_description: str, start_url: str):
|
| 310 |
+
"""Start the agent in a background thread."""
|
| 311 |
+
if self.state != AgentState.IDLE:
|
| 312 |
+
raise RuntimeError(f"Cannot start agent in state {self.state}")
|
| 313 |
+
|
| 314 |
+
self._thread = threading.Thread(
|
| 315 |
+
target=self._run_thread,
|
| 316 |
+
args=(task_description, start_url),
|
| 317 |
+
daemon=True,
|
| 318 |
+
name=f"agent-{self.session_id}",
|
| 319 |
+
)
|
| 320 |
+
self._thread.start()
|
| 321 |
+
|
| 322 |
+
def _run_thread(self, task_description: str, start_url: str):
|
| 323 |
+
"""Thread target: runs the async agent loop."""
|
| 324 |
+
loop = asyncio.new_event_loop()
|
| 325 |
+
asyncio.set_event_loop(loop)
|
| 326 |
+
try:
|
| 327 |
+
loop.run_until_complete(self._run_async(task_description, start_url))
|
| 328 |
+
except Exception as e:
|
| 329 |
+
logger.error(f"[{self.session_id}] Agent thread error: {e}")
|
| 330 |
+
self._error = str(e)
|
| 331 |
+
self.state = AgentState.ERROR
|
| 332 |
+
self._emit_event("error", {"message": str(e)})
|
| 333 |
+
finally:
|
| 334 |
+
loop.close()
|
| 335 |
+
|
| 336 |
+
async def _run_async(self, task_description: str, start_url: str):
|
| 337 |
+
"""Async agent loop."""
|
| 338 |
+
from potato.web_playwright import PlaywrightSession
|
| 339 |
+
|
| 340 |
+
self.state = AgentState.RUNNING
|
| 341 |
+
|
| 342 |
+
# Initialize Playwright
|
| 343 |
+
self._playwright_session = PlaywrightSession(
|
| 344 |
+
width=self.config.viewport_width,
|
| 345 |
+
height=self.config.viewport_height,
|
| 346 |
+
)
|
| 347 |
+
started = await self._playwright_session.start(start_url)
|
| 348 |
+
if not started:
|
| 349 |
+
raise RuntimeError("Failed to start Playwright browser session")
|
| 350 |
+
|
| 351 |
+
# Initialize LLM client
|
| 352 |
+
self._init_llm_client()
|
| 353 |
+
|
| 354 |
+
self._emit_event("started", {
|
| 355 |
+
"task": task_description,
|
| 356 |
+
"start_url": start_url,
|
| 357 |
+
"max_steps": self.config.max_steps,
|
| 358 |
+
})
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
for step_index in range(self.config.max_steps):
|
| 362 |
+
# Check stop flag
|
| 363 |
+
if self._stop_flag.is_set():
|
| 364 |
+
logger.info(f"[{self.session_id}] Stopped by user")
|
| 365 |
+
break
|
| 366 |
+
|
| 367 |
+
# Wait if paused (blocks until resume/stop)
|
| 368 |
+
while not self._pause_event.is_set():
|
| 369 |
+
if self._stop_flag.is_set():
|
| 370 |
+
break
|
| 371 |
+
# Handle takeover actions while paused in takeover mode
|
| 372 |
+
if self.state == AgentState.TAKEOVER:
|
| 373 |
+
await self._process_takeover_actions()
|
| 374 |
+
await asyncio.sleep(0.1)
|
| 375 |
+
|
| 376 |
+
if self._stop_flag.is_set():
|
| 377 |
+
break
|
| 378 |
+
|
| 379 |
+
# Check for injected instructions
|
| 380 |
+
instruction = None
|
| 381 |
+
try:
|
| 382 |
+
instruction = self._instruction_queue.get_nowait()
|
| 383 |
+
except Empty:
|
| 384 |
+
pass
|
| 385 |
+
|
| 386 |
+
# Execute one agent step
|
| 387 |
+
step = await self._agent_step(
|
| 388 |
+
step_index, task_description, instruction
|
| 389 |
+
)
|
| 390 |
+
self._steps.append(step)
|
| 391 |
+
|
| 392 |
+
# Check if agent decided it's done
|
| 393 |
+
if step.action.get("type") == "done":
|
| 394 |
+
logger.info(f"[{self.session_id}] Agent completed task")
|
| 395 |
+
break
|
| 396 |
+
|
| 397 |
+
# Step delay
|
| 398 |
+
if self.config.step_delay > 0:
|
| 399 |
+
await asyncio.sleep(self.config.step_delay)
|
| 400 |
+
|
| 401 |
+
self.state = AgentState.COMPLETED
|
| 402 |
+
self._emit_event("complete", {
|
| 403 |
+
"total_steps": len(self._steps),
|
| 404 |
+
"final_url": (await self._playwright_session.get_state()).get("url", ""),
|
| 405 |
+
})
|
| 406 |
+
|
| 407 |
+
finally:
|
| 408 |
+
await self._playwright_session.stop()
|
| 409 |
+
self._playwright_session = None
|
| 410 |
+
|
| 411 |
+
async def _agent_step(
|
| 412 |
+
self,
|
| 413 |
+
step_index: int,
|
| 414 |
+
task_description: str,
|
| 415 |
+
instruction: Optional[str] = None,
|
| 416 |
+
) -> AgentStep:
|
| 417 |
+
"""Execute a single agent step: screenshot → LLM → action → emit."""
|
| 418 |
+
|
| 419 |
+
# 1. Take screenshot
|
| 420 |
+
screenshot_bytes = await self._playwright_session.screenshot()
|
| 421 |
+
if not screenshot_bytes:
|
| 422 |
+
raise RuntimeError("Failed to capture screenshot")
|
| 423 |
+
|
| 424 |
+
screenshot_path = os.path.join(
|
| 425 |
+
self.screenshot_dir, f"step_{step_index:03d}.png"
|
| 426 |
+
)
|
| 427 |
+
os.makedirs(os.path.dirname(screenshot_path), exist_ok=True)
|
| 428 |
+
with open(screenshot_path, "wb") as f:
|
| 429 |
+
f.write(screenshot_bytes)
|
| 430 |
+
|
| 431 |
+
# 2. Get page state
|
| 432 |
+
page_state = await self._playwright_session.get_state()
|
| 433 |
+
|
| 434 |
+
# 3. Emit thinking event
|
| 435 |
+
self._emit_event("thinking", {
|
| 436 |
+
"step_index": step_index,
|
| 437 |
+
"screenshot_url": screenshot_path,
|
| 438 |
+
"url": page_state.get("url", ""),
|
| 439 |
+
})
|
| 440 |
+
|
| 441 |
+
# 4. Build messages and query LLM
|
| 442 |
+
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
| 443 |
+
messages = self._build_llm_messages(
|
| 444 |
+
screenshot_b64, task_description, instruction
|
| 445 |
+
)
|
| 446 |
+
llm_response = self._query_llm(messages)
|
| 447 |
+
|
| 448 |
+
# 5. Parse action from response
|
| 449 |
+
thought, action = self._parse_action(llm_response)
|
| 450 |
+
|
| 451 |
+
# 6. Execute action
|
| 452 |
+
observation = await self._execute_action(action)
|
| 453 |
+
|
| 454 |
+
# 7. Build step
|
| 455 |
+
step = AgentStep(
|
| 456 |
+
step_index=step_index,
|
| 457 |
+
screenshot_path=screenshot_path,
|
| 458 |
+
action=action,
|
| 459 |
+
thought=thought,
|
| 460 |
+
observation=observation,
|
| 461 |
+
timestamp=time.time(),
|
| 462 |
+
url=page_state.get("url", ""),
|
| 463 |
+
viewport=page_state.get("viewport"),
|
| 464 |
+
coordinates=_extract_coordinates(action),
|
| 465 |
+
annotator_instruction=instruction,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# 8. Emit step event
|
| 469 |
+
self._emit_event("step", step.to_dict())
|
| 470 |
+
|
| 471 |
+
return step
|
| 472 |
+
|
| 473 |
+
def _build_llm_messages(
|
| 474 |
+
self,
|
| 475 |
+
screenshot_b64: str,
|
| 476 |
+
task_description: str,
|
| 477 |
+
instruction: Optional[str] = None,
|
| 478 |
+
) -> List[Dict[str, Any]]:
|
| 479 |
+
"""Build message list for the LLM vision API."""
|
| 480 |
+
messages = []
|
| 481 |
+
|
| 482 |
+
# System message
|
| 483 |
+
system_prompt = self.config.system_prompt or DEFAULT_SYSTEM_PROMPT
|
| 484 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 485 |
+
|
| 486 |
+
# Task description
|
| 487 |
+
task_msg = f"Task: {task_description}"
|
| 488 |
+
if instruction:
|
| 489 |
+
task_msg += f"\n\nAnnotator instruction: {instruction}"
|
| 490 |
+
|
| 491 |
+
# Include recent step history
|
| 492 |
+
history_steps = self._steps[-self.config.history_window:]
|
| 493 |
+
if history_steps:
|
| 494 |
+
history_parts = []
|
| 495 |
+
for s in history_steps:
|
| 496 |
+
entry = f"Step {s.step_index}: thought='{s.thought}', action={json.dumps(s.action)}, observation='{s.observation}'"
|
| 497 |
+
history_parts.append(entry)
|
| 498 |
+
task_msg += "\n\nRecent history:\n" + "\n".join(history_parts)
|
| 499 |
+
|
| 500 |
+
messages.append({"role": "user", "content": task_msg})
|
| 501 |
+
|
| 502 |
+
# Current screenshot (as a separate user message with image)
|
| 503 |
+
messages.append({
|
| 504 |
+
"role": "user",
|
| 505 |
+
"content": [
|
| 506 |
+
{
|
| 507 |
+
"type": "image",
|
| 508 |
+
"source": {
|
| 509 |
+
"type": "base64",
|
| 510 |
+
"media_type": "image/png",
|
| 511 |
+
"data": screenshot_b64,
|
| 512 |
+
},
|
| 513 |
+
},
|
| 514 |
+
{
|
| 515 |
+
"type": "text",
|
| 516 |
+
"text": f"Current page screenshot (step {len(self._steps)}). What action should I take next?",
|
| 517 |
+
},
|
| 518 |
+
],
|
| 519 |
+
})
|
| 520 |
+
|
| 521 |
+
return messages
|
| 522 |
+
|
| 523 |
+
def _init_llm_client(self):
|
| 524 |
+
"""Initialize the LLM client based on endpoint_type."""
|
| 525 |
+
if self.config.endpoint_type == "anthropic_vision":
|
| 526 |
+
try:
|
| 527 |
+
import anthropic
|
| 528 |
+
except ImportError:
|
| 529 |
+
raise RuntimeError(
|
| 530 |
+
"anthropic package required. Install with: pip install anthropic"
|
| 531 |
+
)
|
| 532 |
+
api_key = self.config.api_key or os.environ.get("ANTHROPIC_API_KEY")
|
| 533 |
+
if not api_key:
|
| 534 |
+
raise RuntimeError(
|
| 535 |
+
"Anthropic API key required. Set in config or ANTHROPIC_API_KEY env var."
|
| 536 |
+
)
|
| 537 |
+
self._llm_client = anthropic.Anthropic(
|
| 538 |
+
api_key=api_key, timeout=self.config.timeout
|
| 539 |
+
)
|
| 540 |
+
elif self.config.endpoint_type == "ollama_vision":
|
| 541 |
+
try:
|
| 542 |
+
import ollama
|
| 543 |
+
except ImportError:
|
| 544 |
+
raise RuntimeError(
|
| 545 |
+
"ollama package required. Install with: pip install ollama"
|
| 546 |
+
)
|
| 547 |
+
host = self.config.base_url or "http://localhost:11434"
|
| 548 |
+
self._llm_client = ollama.Client(
|
| 549 |
+
host=host, timeout=self.config.timeout
|
| 550 |
+
)
|
| 551 |
+
# Verify connectivity
|
| 552 |
+
try:
|
| 553 |
+
self._llm_client.list()
|
| 554 |
+
logger.info(f"Connected to Ollama at {host}, model: {self.config.model}")
|
| 555 |
+
except Exception as e:
|
| 556 |
+
raise RuntimeError(f"Failed to connect to Ollama at {host}: {e}")
|
| 557 |
+
elif self.config.endpoint_type == "openai_vision":
|
| 558 |
+
try:
|
| 559 |
+
from openai import OpenAI
|
| 560 |
+
except ImportError:
|
| 561 |
+
raise RuntimeError(
|
| 562 |
+
"openai package required. Install with: pip install openai"
|
| 563 |
+
)
|
| 564 |
+
base_url = self.config.base_url or "https://api.openai.com/v1"
|
| 565 |
+
self._llm_client = OpenAI(
|
| 566 |
+
base_url=base_url,
|
| 567 |
+
api_key=self.config.api_key or "EMPTY",
|
| 568 |
+
timeout=self.config.timeout,
|
| 569 |
+
)
|
| 570 |
+
try:
|
| 571 |
+
self._llm_client.models.list()
|
| 572 |
+
logger.info(
|
| 573 |
+
f"Connected to OpenAI-compatible endpoint at {base_url}, "
|
| 574 |
+
f"model: {self.config.model}"
|
| 575 |
+
)
|
| 576 |
+
except Exception as e:
|
| 577 |
+
# Non-fatal: some servers gate /models; the chat call will
|
| 578 |
+
# surface a real error if the endpoint is truly unreachable.
|
| 579 |
+
logger.warning(
|
| 580 |
+
f"Could not list models at {base_url} ({e}); continuing."
|
| 581 |
+
)
|
| 582 |
+
else:
|
| 583 |
+
raise RuntimeError(
|
| 584 |
+
f"Unsupported endpoint_type: {self.config.endpoint_type}. "
|
| 585 |
+
f"Supported: 'anthropic_vision', 'ollama_vision', 'openai_vision'."
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
def _query_llm(self, messages: List[Dict[str, Any]]) -> str:
|
| 589 |
+
"""Send messages to the LLM and return the text response."""
|
| 590 |
+
if self.config.endpoint_type == "anthropic_vision":
|
| 591 |
+
return self._query_anthropic(messages)
|
| 592 |
+
elif self.config.endpoint_type == "ollama_vision":
|
| 593 |
+
return self._query_ollama(messages)
|
| 594 |
+
elif self.config.endpoint_type == "openai_vision":
|
| 595 |
+
return self._query_openai(messages)
|
| 596 |
+
raise RuntimeError(f"Unsupported endpoint type: {self.config.endpoint_type}")
|
| 597 |
+
|
| 598 |
+
def _query_openai(self, messages: List[Dict[str, Any]]) -> str:
|
| 599 |
+
"""Query an OpenAI-compatible vision endpoint (OpenAI, vLLM, etc.).
|
| 600 |
+
|
| 601 |
+
Converts the internal Anthropic-style message blocks into OpenAI
|
| 602 |
+
chat-completions format (image blocks become ``image_url`` data
|
| 603 |
+
URIs). Requests a JSON object response when the server supports it,
|
| 604 |
+
falling back gracefully if it does not.
|
| 605 |
+
"""
|
| 606 |
+
oai_messages = []
|
| 607 |
+
for msg in messages:
|
| 608 |
+
role = msg["role"]
|
| 609 |
+
content = msg.get("content", "")
|
| 610 |
+
if isinstance(content, str):
|
| 611 |
+
oai_messages.append({"role": role, "content": content})
|
| 612 |
+
continue
|
| 613 |
+
parts = []
|
| 614 |
+
for block in content:
|
| 615 |
+
if not isinstance(block, dict):
|
| 616 |
+
continue
|
| 617 |
+
if block.get("type") == "text":
|
| 618 |
+
parts.append({"type": "text", "text": block.get("text", "")})
|
| 619 |
+
elif block.get("type") == "image":
|
| 620 |
+
src = block.get("source", {})
|
| 621 |
+
if src.get("type") == "base64":
|
| 622 |
+
media = src.get("media_type", "image/png")
|
| 623 |
+
parts.append({
|
| 624 |
+
"type": "image_url",
|
| 625 |
+
"image_url": {
|
| 626 |
+
"url": f"data:{media};base64,{src['data']}"
|
| 627 |
+
},
|
| 628 |
+
})
|
| 629 |
+
oai_messages.append({"role": role, "content": parts})
|
| 630 |
+
|
| 631 |
+
kwargs = {
|
| 632 |
+
"model": self.config.model,
|
| 633 |
+
"messages": oai_messages,
|
| 634 |
+
"max_tokens": self.config.max_tokens,
|
| 635 |
+
"temperature": self.config.temperature,
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
def _is_rate_limit(exc) -> bool:
|
| 639 |
+
if getattr(exc, "status_code", None) == 429:
|
| 640 |
+
return True
|
| 641 |
+
s = str(exc).lower()
|
| 642 |
+
return ("429" in s or "rate limit" in s or "quota" in s
|
| 643 |
+
or "resource_exhausted" in s)
|
| 644 |
+
|
| 645 |
+
def _create(use_rf: bool):
|
| 646 |
+
if use_rf:
|
| 647 |
+
return self._llm_client.chat.completions.create(
|
| 648 |
+
response_format={"type": "json_object"}, **kwargs)
|
| 649 |
+
return self._llm_client.chat.completions.create(**kwargs)
|
| 650 |
+
|
| 651 |
+
# Transient 429s (per-minute rate/token bursts) are common mid-run
|
| 652 |
+
# even on paid tiers; back off and retry instead of failing the
|
| 653 |
+
# whole agent session.
|
| 654 |
+
backoffs = [5, 15, 30, 30, 30]
|
| 655 |
+
use_rf = True
|
| 656 |
+
attempt = 0
|
| 657 |
+
while True:
|
| 658 |
+
try:
|
| 659 |
+
resp = _create(use_rf)
|
| 660 |
+
break
|
| 661 |
+
except Exception as e:
|
| 662 |
+
if _is_rate_limit(e):
|
| 663 |
+
if attempt >= len(backoffs):
|
| 664 |
+
raise
|
| 665 |
+
wait = backoffs[attempt]
|
| 666 |
+
attempt += 1
|
| 667 |
+
logger.warning(
|
| 668 |
+
f"[{self.session_id}] LLM 429/rate-limited; "
|
| 669 |
+
f"retry {attempt}/{len(backoffs)} in {wait}s"
|
| 670 |
+
)
|
| 671 |
+
self._emit_event("thinking", {
|
| 672 |
+
"text": f"Rate-limited by the model API; "
|
| 673 |
+
f"waiting {wait}s before retrying…"
|
| 674 |
+
})
|
| 675 |
+
time.sleep(wait)
|
| 676 |
+
continue
|
| 677 |
+
if use_rf:
|
| 678 |
+
# Server may not support response_format; drop it once.
|
| 679 |
+
use_rf = False
|
| 680 |
+
continue
|
| 681 |
+
raise
|
| 682 |
+
return resp.choices[0].message.content or ""
|
| 683 |
+
|
| 684 |
+
def _query_anthropic(self, messages: List[Dict[str, Any]]) -> str:
|
| 685 |
+
"""Query Anthropic Claude with vision support."""
|
| 686 |
+
# Separate system message
|
| 687 |
+
system = ""
|
| 688 |
+
api_messages = []
|
| 689 |
+
for msg in messages:
|
| 690 |
+
if msg["role"] == "system":
|
| 691 |
+
system = msg["content"]
|
| 692 |
+
else:
|
| 693 |
+
api_messages.append(msg)
|
| 694 |
+
|
| 695 |
+
kwargs = {
|
| 696 |
+
"model": self.config.model,
|
| 697 |
+
"max_tokens": self.config.max_tokens,
|
| 698 |
+
"temperature": self.config.temperature,
|
| 699 |
+
"messages": api_messages,
|
| 700 |
+
}
|
| 701 |
+
if system:
|
| 702 |
+
kwargs["system"] = system
|
| 703 |
+
|
| 704 |
+
response = self._llm_client.messages.create(**kwargs)
|
| 705 |
+
return response.content[0].text
|
| 706 |
+
|
| 707 |
+
def _query_ollama(self, messages: List[Dict[str, Any]]) -> str:
|
| 708 |
+
"""Query Ollama vision model.
|
| 709 |
+
|
| 710 |
+
Converts Anthropic-format messages to Ollama format:
|
| 711 |
+
- System messages are prepended to the prompt text
|
| 712 |
+
- Multiple user messages are merged into a single message
|
| 713 |
+
- Content blocks with images use Ollama's 'images' key
|
| 714 |
+
"""
|
| 715 |
+
# Extract text and images from Anthropic-format messages
|
| 716 |
+
all_text_parts = []
|
| 717 |
+
all_images = []
|
| 718 |
+
for msg in messages:
|
| 719 |
+
content = msg.get("content", "")
|
| 720 |
+
if msg["role"] == "system":
|
| 721 |
+
if isinstance(content, str) and content:
|
| 722 |
+
all_text_parts.insert(0, content)
|
| 723 |
+
continue
|
| 724 |
+
if isinstance(content, list):
|
| 725 |
+
for block in content:
|
| 726 |
+
if isinstance(block, dict):
|
| 727 |
+
if block.get("type") == "text":
|
| 728 |
+
all_text_parts.append(block["text"])
|
| 729 |
+
elif block.get("type") == "image":
|
| 730 |
+
source = block.get("source", {})
|
| 731 |
+
if source.get("type") == "base64":
|
| 732 |
+
all_images.append(source["data"])
|
| 733 |
+
elif isinstance(content, str) and content:
|
| 734 |
+
all_text_parts.append(content)
|
| 735 |
+
|
| 736 |
+
ollama_msg = {
|
| 737 |
+
"role": "user",
|
| 738 |
+
"content": "\n\n".join(all_text_parts),
|
| 739 |
+
}
|
| 740 |
+
if all_images:
|
| 741 |
+
ollama_msg["images"] = all_images
|
| 742 |
+
|
| 743 |
+
options = {
|
| 744 |
+
"temperature": self.config.temperature,
|
| 745 |
+
"num_predict": self.config.max_tokens,
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
# Use Ollama's format schema to force structured JSON output
|
| 749 |
+
agent_schema = {
|
| 750 |
+
"type": "object",
|
| 751 |
+
"properties": {
|
| 752 |
+
"thought": {"type": "string"},
|
| 753 |
+
"action": {
|
| 754 |
+
"type": "object",
|
| 755 |
+
"properties": {
|
| 756 |
+
"type": {"type": "string"},
|
| 757 |
+
"x": {"type": "integer"},
|
| 758 |
+
"y": {"type": "integer"},
|
| 759 |
+
"text": {"type": "string"},
|
| 760 |
+
"url": {"type": "string"},
|
| 761 |
+
"direction": {"type": "string"},
|
| 762 |
+
"amount": {"type": "integer"},
|
| 763 |
+
"summary": {"type": "string"},
|
| 764 |
+
},
|
| 765 |
+
"required": ["type"],
|
| 766 |
+
},
|
| 767 |
+
},
|
| 768 |
+
"required": ["thought", "action"],
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
response = self._llm_client.chat(
|
| 772 |
+
model=self.config.model,
|
| 773 |
+
messages=[ollama_msg],
|
| 774 |
+
options=options,
|
| 775 |
+
format=agent_schema,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
# Extract content from response (handle both dict and Pydantic model)
|
| 779 |
+
message = (
|
| 780 |
+
response.get("message")
|
| 781 |
+
if hasattr(response, "get")
|
| 782 |
+
else getattr(response, "message", None)
|
| 783 |
+
)
|
| 784 |
+
if message is None:
|
| 785 |
+
raise RuntimeError("No message in Ollama response")
|
| 786 |
+
|
| 787 |
+
content = (
|
| 788 |
+
message.get("content")
|
| 789 |
+
if hasattr(message, "get")
|
| 790 |
+
else getattr(message, "content", None)
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
# Some models (e.g. qwen3-vl) put responses in 'thinking' field
|
| 794 |
+
# and leave content empty. Extract the agent JSON from thinking.
|
| 795 |
+
if not content:
|
| 796 |
+
thinking = (
|
| 797 |
+
message.get("thinking")
|
| 798 |
+
if hasattr(message, "get")
|
| 799 |
+
else getattr(message, "thinking", None)
|
| 800 |
+
)
|
| 801 |
+
if thinking:
|
| 802 |
+
content = _extract_agent_json(thinking)
|
| 803 |
+
|
| 804 |
+
return content or ""
|
| 805 |
+
|
| 806 |
+
def _parse_action(self, llm_response: str) -> tuple:
|
| 807 |
+
"""Parse thought and action from LLM JSON response.
|
| 808 |
+
|
| 809 |
+
Returns:
|
| 810 |
+
(thought, action_dict)
|
| 811 |
+
"""
|
| 812 |
+
# Try to extract JSON from response
|
| 813 |
+
text = llm_response.strip()
|
| 814 |
+
|
| 815 |
+
# Handle markdown code blocks
|
| 816 |
+
if "```json" in text:
|
| 817 |
+
import re
|
| 818 |
+
match = re.search(r"```json\s*([\s\S]*?)\s*```", text)
|
| 819 |
+
if match:
|
| 820 |
+
text = match.group(1).strip()
|
| 821 |
+
elif "```" in text:
|
| 822 |
+
import re
|
| 823 |
+
match = re.search(r"```\s*([\s\S]*?)\s*```", text)
|
| 824 |
+
if match:
|
| 825 |
+
text = match.group(1).strip()
|
| 826 |
+
|
| 827 |
+
try:
|
| 828 |
+
parsed = json.loads(text)
|
| 829 |
+
except json.JSONDecodeError:
|
| 830 |
+
logger.warning(f"Failed to parse LLM response as JSON: {text[:200]}")
|
| 831 |
+
return text, {"type": "wait"}
|
| 832 |
+
|
| 833 |
+
thought = parsed.get("thought", "")
|
| 834 |
+
action = parsed.get("action", {"type": "wait"})
|
| 835 |
+
|
| 836 |
+
# Validate action has a type
|
| 837 |
+
if "type" not in action:
|
| 838 |
+
action["type"] = "wait"
|
| 839 |
+
|
| 840 |
+
return thought, action
|
| 841 |
+
|
| 842 |
+
async def _execute_action(self, action: Dict[str, Any]) -> str:
|
| 843 |
+
"""Execute an action via Playwright and return observation."""
|
| 844 |
+
action_type = action.get("type", "wait")
|
| 845 |
+
pw = self._playwright_session
|
| 846 |
+
|
| 847 |
+
try:
|
| 848 |
+
if action_type == "click":
|
| 849 |
+
x = int(action.get("x", 0))
|
| 850 |
+
y = int(action.get("y", 0))
|
| 851 |
+
success = await pw.click(x, y)
|
| 852 |
+
return f"Clicked at ({x}, {y})" if success else f"Click failed at ({x}, {y})"
|
| 853 |
+
|
| 854 |
+
elif action_type == "type":
|
| 855 |
+
text = action.get("text", "")
|
| 856 |
+
# Handle control characters via keyboard.press
|
| 857 |
+
if text == "\b":
|
| 858 |
+
success = await pw.page.keyboard.press("Backspace") or True
|
| 859 |
+
return "Pressed Backspace"
|
| 860 |
+
elif text == "\n":
|
| 861 |
+
success = await pw.page.keyboard.press("Enter") or True
|
| 862 |
+
return "Pressed Enter"
|
| 863 |
+
elif text == "\t":
|
| 864 |
+
success = await pw.page.keyboard.press("Tab") or True
|
| 865 |
+
return "Pressed Tab"
|
| 866 |
+
else:
|
| 867 |
+
success = await pw.type_text(text)
|
| 868 |
+
return f"Typed '{text}'" if success else f"Type failed: '{text}'"
|
| 869 |
+
|
| 870 |
+
elif action_type == "scroll":
|
| 871 |
+
direction = action.get("direction", "down")
|
| 872 |
+
amount = int(action.get("amount", 300))
|
| 873 |
+
dy = amount if direction == "down" else -amount
|
| 874 |
+
success = await pw.scroll(0, dy)
|
| 875 |
+
return f"Scrolled {direction} by {amount}px" if success else "Scroll failed"
|
| 876 |
+
|
| 877 |
+
elif action_type == "navigate":
|
| 878 |
+
url = action.get("url", "")
|
| 879 |
+
success = await pw.navigate(url)
|
| 880 |
+
return f"Navigated to {url}" if success else f"Navigation failed: {url}"
|
| 881 |
+
|
| 882 |
+
elif action_type == "wait":
|
| 883 |
+
await asyncio.sleep(1)
|
| 884 |
+
return "Waited 1 second"
|
| 885 |
+
|
| 886 |
+
elif action_type == "done":
|
| 887 |
+
summary = action.get("summary", "Task completed")
|
| 888 |
+
return summary
|
| 889 |
+
|
| 890 |
+
else:
|
| 891 |
+
logger.warning(f"Unknown action type: {action_type}")
|
| 892 |
+
return f"Unknown action: {action_type}"
|
| 893 |
+
|
| 894 |
+
except Exception as e:
|
| 895 |
+
logger.error(f"Action execution error: {e}")
|
| 896 |
+
return f"Error executing {action_type}: {e}"
|
| 897 |
+
|
| 898 |
+
async def _process_takeover_actions(self):
|
| 899 |
+
"""Process manual actions submitted during takeover mode."""
|
| 900 |
+
try:
|
| 901 |
+
action = self._takeover_actions.get_nowait()
|
| 902 |
+
except Empty:
|
| 903 |
+
return
|
| 904 |
+
|
| 905 |
+
pw = self._playwright_session
|
| 906 |
+
if not pw:
|
| 907 |
+
return
|
| 908 |
+
|
| 909 |
+
observation = await self._execute_action(action)
|
| 910 |
+
|
| 911 |
+
# Take screenshot after manual action
|
| 912 |
+
screenshot_bytes = await pw.screenshot()
|
| 913 |
+
step_index = len(self._steps)
|
| 914 |
+
screenshot_path = os.path.join(
|
| 915 |
+
self.screenshot_dir, f"step_{step_index:03d}_manual.png"
|
| 916 |
+
)
|
| 917 |
+
if screenshot_bytes:
|
| 918 |
+
with open(screenshot_path, "wb") as f:
|
| 919 |
+
f.write(screenshot_bytes)
|
| 920 |
+
|
| 921 |
+
page_state = await pw.get_state()
|
| 922 |
+
|
| 923 |
+
step = AgentStep(
|
| 924 |
+
step_index=step_index,
|
| 925 |
+
screenshot_path=screenshot_path,
|
| 926 |
+
action={**action, "_manual": True},
|
| 927 |
+
thought="[Manual takeover action]",
|
| 928 |
+
observation=observation,
|
| 929 |
+
timestamp=time.time(),
|
| 930 |
+
url=page_state.get("url", ""),
|
| 931 |
+
viewport=page_state.get("viewport"),
|
| 932 |
+
coordinates=_extract_coordinates(action),
|
| 933 |
+
)
|
| 934 |
+
self._steps.append(step)
|
| 935 |
+
self._emit_event("step", step.to_dict())
|
| 936 |
+
|
| 937 |
+
# --- Trace export ---
|
| 938 |
+
|
| 939 |
+
def get_trace(self) -> Dict[str, Any]:
|
| 940 |
+
"""Export the session as a web_agent_trace-compatible dict."""
|
| 941 |
+
return {
|
| 942 |
+
"steps": [s.to_dict() for s in self._steps],
|
| 943 |
+
"task_description": "", # Set by caller
|
| 944 |
+
"session_id": self.session_id,
|
| 945 |
+
"agent_config": {
|
| 946 |
+
"model": self.config.model,
|
| 947 |
+
"endpoint_type": self.config.endpoint_type,
|
| 948 |
+
"max_steps": self.config.max_steps,
|
| 949 |
+
},
|
| 950 |
+
"annotator_interactions": self._interactions,
|
| 951 |
+
"state": self.state.value,
|
| 952 |
+
"total_steps": len(self._steps),
|
| 953 |
+
}
|
| 954 |
+
|
| 955 |
+
def get_state_summary(self) -> Dict[str, Any]:
|
| 956 |
+
"""Get a summary of current state for API responses."""
|
| 957 |
+
return {
|
| 958 |
+
"session_id": self.session_id,
|
| 959 |
+
"state": self.state.value,
|
| 960 |
+
"step_count": len(self._steps),
|
| 961 |
+
"error": self._error,
|
| 962 |
+
"has_instructions_pending": not self._instruction_queue.empty(),
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
def _extract_agent_json(text: str) -> str:
|
| 967 |
+
"""Extract the last valid JSON object containing 'thought' or 'action' from text.
|
| 968 |
+
|
| 969 |
+
Some models (qwen3-vl) put their chain-of-thought in the thinking field
|
| 970 |
+
with the actual JSON answer embedded in the text. This function finds
|
| 971 |
+
that JSON, skipping any example/template JSON from the prompt.
|
| 972 |
+
"""
|
| 973 |
+
import re
|
| 974 |
+
|
| 975 |
+
# Find all JSON-like blocks (balanced braces)
|
| 976 |
+
candidates = []
|
| 977 |
+
depth = 0
|
| 978 |
+
start = None
|
| 979 |
+
for i, ch in enumerate(text):
|
| 980 |
+
if ch == "{":
|
| 981 |
+
if depth == 0:
|
| 982 |
+
start = i
|
| 983 |
+
depth += 1
|
| 984 |
+
elif ch == "}":
|
| 985 |
+
depth -= 1
|
| 986 |
+
if depth == 0 and start is not None:
|
| 987 |
+
candidates.append(text[start : i + 1])
|
| 988 |
+
start = None
|
| 989 |
+
|
| 990 |
+
# Try each candidate (last first — most likely to be the final answer)
|
| 991 |
+
for candidate in reversed(candidates):
|
| 992 |
+
try:
|
| 993 |
+
parsed = json.loads(candidate)
|
| 994 |
+
if isinstance(parsed, dict) and ("thought" in parsed or "action" in parsed):
|
| 995 |
+
return candidate
|
| 996 |
+
except (json.JSONDecodeError, ValueError):
|
| 997 |
+
continue
|
| 998 |
+
|
| 999 |
+
# Fallback: try greedy regex for any JSON
|
| 1000 |
+
match = re.search(r"\{[^{}]*\}", text)
|
| 1001 |
+
return match.group(0) if match else ""
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
def _extract_coordinates(action: Dict[str, Any]) -> Optional[Dict[str, int]]:
|
| 1005 |
+
"""Extract x, y coordinates from an action if present."""
|
| 1006 |
+
if "x" in action and "y" in action:
|
| 1007 |
+
return {"x": int(action["x"]), "y": int(action["y"])}
|
| 1008 |
+
return None
|
potato/agent_runner_manager.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Runner Session Manager
|
| 3 |
+
|
| 4 |
+
Singleton that manages active AgentRunner sessions.
|
| 5 |
+
Keyed by "{user_id}:{instance_id}" for per-user, per-instance isolation.
|
| 6 |
+
Includes TTL-based cleanup and max concurrent session limits.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import atexit
|
| 10 |
+
import logging
|
| 11 |
+
import threading
|
| 12 |
+
import time
|
| 13 |
+
from typing import Dict, Optional
|
| 14 |
+
|
| 15 |
+
from potato.agent_runner import AgentConfig, AgentRunner, AgentState
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# Default limits
|
| 20 |
+
DEFAULT_MAX_SESSIONS = 10
|
| 21 |
+
DEFAULT_SESSION_TTL = 3600 # 1 hour
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AgentRunnerManager:
|
| 25 |
+
"""
|
| 26 |
+
Manages active AgentRunner sessions with lifecycle control.
|
| 27 |
+
|
| 28 |
+
Thread-safe singleton. Sessions are keyed by "{user_id}:{instance_id}".
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
_instance = None
|
| 32 |
+
_lock = threading.Lock()
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
max_sessions: int = DEFAULT_MAX_SESSIONS,
|
| 37 |
+
session_ttl: int = DEFAULT_SESSION_TTL,
|
| 38 |
+
):
|
| 39 |
+
self._sessions: Dict[str, AgentRunner] = {}
|
| 40 |
+
self._session_created: Dict[str, float] = {}
|
| 41 |
+
self._session_meta: Dict[str, Dict] = {}
|
| 42 |
+
self._lock = threading.Lock()
|
| 43 |
+
self.max_sessions = max_sessions
|
| 44 |
+
self.session_ttl = session_ttl
|
| 45 |
+
|
| 46 |
+
# Start cleanup thread
|
| 47 |
+
self._cleanup_stop = threading.Event()
|
| 48 |
+
self._cleanup_thread = threading.Thread(
|
| 49 |
+
target=self._cleanup_loop, daemon=True, name="agent-cleanup"
|
| 50 |
+
)
|
| 51 |
+
self._cleanup_thread.start()
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def get_instance(cls, **kwargs) -> "AgentRunnerManager":
|
| 55 |
+
"""Get or create the singleton instance."""
|
| 56 |
+
if cls._instance is None:
|
| 57 |
+
with cls._lock:
|
| 58 |
+
if cls._instance is None:
|
| 59 |
+
cls._instance = cls(**kwargs)
|
| 60 |
+
return cls._instance
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def clear_instance(cls):
|
| 64 |
+
"""Clear the singleton (for testing)."""
|
| 65 |
+
with cls._lock:
|
| 66 |
+
if cls._instance is not None:
|
| 67 |
+
cls._instance.shutdown()
|
| 68 |
+
cls._instance = None
|
| 69 |
+
|
| 70 |
+
def create_session(
|
| 71 |
+
self,
|
| 72 |
+
user_id: str,
|
| 73 |
+
instance_id: str,
|
| 74 |
+
config: AgentConfig,
|
| 75 |
+
screenshot_dir: str,
|
| 76 |
+
) -> AgentRunner:
|
| 77 |
+
"""
|
| 78 |
+
Create a new agent session.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
user_id: Annotator user ID
|
| 82 |
+
instance_id: Annotation instance ID
|
| 83 |
+
config: Agent configuration
|
| 84 |
+
screenshot_dir: Directory to store screenshots
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
AgentRunner instance
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
RuntimeError: If max sessions reached or session already exists
|
| 91 |
+
"""
|
| 92 |
+
session_key = f"{user_id}:{instance_id}"
|
| 93 |
+
|
| 94 |
+
with self._lock:
|
| 95 |
+
# Clean up expired sessions first
|
| 96 |
+
self._cleanup_expired_locked()
|
| 97 |
+
|
| 98 |
+
# Check for existing active session
|
| 99 |
+
if session_key in self._sessions:
|
| 100 |
+
existing = self._sessions[session_key]
|
| 101 |
+
if existing.state in (AgentState.RUNNING, AgentState.PAUSED, AgentState.TAKEOVER):
|
| 102 |
+
raise RuntimeError(
|
| 103 |
+
f"Active session already exists for {session_key}. "
|
| 104 |
+
f"Stop it first."
|
| 105 |
+
)
|
| 106 |
+
# Old completed/error session — remove it
|
| 107 |
+
del self._sessions[session_key]
|
| 108 |
+
del self._session_created[session_key]
|
| 109 |
+
if session_key in self._session_meta:
|
| 110 |
+
del self._session_meta[session_key]
|
| 111 |
+
|
| 112 |
+
# Check capacity
|
| 113 |
+
active_count = sum(
|
| 114 |
+
1
|
| 115 |
+
for s in self._sessions.values()
|
| 116 |
+
if s.state in (AgentState.RUNNING, AgentState.PAUSED, AgentState.TAKEOVER)
|
| 117 |
+
)
|
| 118 |
+
if active_count >= self.max_sessions:
|
| 119 |
+
raise RuntimeError(
|
| 120 |
+
f"Maximum concurrent sessions ({self.max_sessions}) reached"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
import uuid
|
| 124 |
+
session_id = str(uuid.uuid4())[:12]
|
| 125 |
+
runner = AgentRunner(session_id, config, screenshot_dir)
|
| 126 |
+
|
| 127 |
+
self._sessions[session_key] = runner
|
| 128 |
+
self._session_created[session_key] = time.time()
|
| 129 |
+
self._session_meta[session_key] = {
|
| 130 |
+
"user_id": user_id,
|
| 131 |
+
"instance_id": instance_id,
|
| 132 |
+
"session_id": session_id,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
logger.info(
|
| 136 |
+
f"Created agent session {session_id} for {session_key}"
|
| 137 |
+
)
|
| 138 |
+
return runner
|
| 139 |
+
|
| 140 |
+
def get_session(self, session_id: str) -> Optional[AgentRunner]:
|
| 141 |
+
"""Get a session by its session_id."""
|
| 142 |
+
with self._lock:
|
| 143 |
+
for runner in self._sessions.values():
|
| 144 |
+
if runner.session_id == session_id:
|
| 145 |
+
return runner
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
def get_session_by_key(self, user_id: str, instance_id: str) -> Optional[AgentRunner]:
|
| 149 |
+
"""Get a session by user_id and instance_id."""
|
| 150 |
+
session_key = f"{user_id}:{instance_id}"
|
| 151 |
+
with self._lock:
|
| 152 |
+
return self._sessions.get(session_key)
|
| 153 |
+
|
| 154 |
+
def remove_session(self, session_id: str):
|
| 155 |
+
"""Remove a session by session_id."""
|
| 156 |
+
with self._lock:
|
| 157 |
+
key_to_remove = None
|
| 158 |
+
for key, runner in self._sessions.items():
|
| 159 |
+
if runner.session_id == session_id:
|
| 160 |
+
key_to_remove = key
|
| 161 |
+
break
|
| 162 |
+
if key_to_remove:
|
| 163 |
+
runner = self._sessions.pop(key_to_remove)
|
| 164 |
+
self._session_created.pop(key_to_remove, None)
|
| 165 |
+
self._session_meta.pop(key_to_remove, None)
|
| 166 |
+
runner.stop()
|
| 167 |
+
logger.info(f"Removed agent session {session_id}")
|
| 168 |
+
|
| 169 |
+
def list_sessions(self) -> list:
|
| 170 |
+
"""List all active sessions."""
|
| 171 |
+
with self._lock:
|
| 172 |
+
result = []
|
| 173 |
+
for key, runner in self._sessions.items():
|
| 174 |
+
meta = self._session_meta.get(key, {})
|
| 175 |
+
result.append({
|
| 176 |
+
"session_id": runner.session_id,
|
| 177 |
+
"user_id": meta.get("user_id"),
|
| 178 |
+
"instance_id": meta.get("instance_id"),
|
| 179 |
+
"state": runner.state.value,
|
| 180 |
+
"step_count": runner.step_count,
|
| 181 |
+
"created": self._session_created.get(key),
|
| 182 |
+
})
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
def _cleanup_expired_locked(self):
|
| 186 |
+
"""Remove expired sessions. Must be called with self._lock held."""
|
| 187 |
+
now = time.time()
|
| 188 |
+
expired_keys = []
|
| 189 |
+
for key, created_at in self._session_created.items():
|
| 190 |
+
if now - created_at > self.session_ttl:
|
| 191 |
+
runner = self._sessions.get(key)
|
| 192 |
+
if runner and runner.state in (AgentState.COMPLETED, AgentState.ERROR, AgentState.IDLE):
|
| 193 |
+
expired_keys.append(key)
|
| 194 |
+
elif runner and now - created_at > self.session_ttl * 2:
|
| 195 |
+
# Force-stop sessions that have been running too long
|
| 196 |
+
runner.stop()
|
| 197 |
+
expired_keys.append(key)
|
| 198 |
+
|
| 199 |
+
for key in expired_keys:
|
| 200 |
+
self._sessions.pop(key, None)
|
| 201 |
+
self._session_created.pop(key, None)
|
| 202 |
+
self._session_meta.pop(key, None)
|
| 203 |
+
logger.info(f"Cleaned up expired session: {key}")
|
| 204 |
+
|
| 205 |
+
def _cleanup_loop(self):
|
| 206 |
+
"""Background cleanup thread."""
|
| 207 |
+
while not self._cleanup_stop.is_set():
|
| 208 |
+
self._cleanup_stop.wait(60) # Check every 60 seconds
|
| 209 |
+
if self._cleanup_stop.is_set():
|
| 210 |
+
break
|
| 211 |
+
with self._lock:
|
| 212 |
+
self._cleanup_expired_locked()
|
| 213 |
+
|
| 214 |
+
def shutdown(self):
|
| 215 |
+
"""Stop all sessions and cleanup thread."""
|
| 216 |
+
self._cleanup_stop.set()
|
| 217 |
+
with self._lock:
|
| 218 |
+
for key, runner in self._sessions.items():
|
| 219 |
+
try:
|
| 220 |
+
runner.stop()
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.warning(f"Error stopping session {key}: {e}")
|
| 223 |
+
self._sessions.clear()
|
| 224 |
+
self._session_created.clear()
|
| 225 |
+
self._session_meta.clear()
|
| 226 |
+
logger.info("AgentRunnerManager shut down")
|
potato/agreement.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inter-Annotator Agreement Calculation Module
|
| 3 |
+
|
| 4 |
+
This module provides functionality for calculating inter-annotator agreement metrics,
|
| 5 |
+
including Krippendorff's alpha, Cohen's kappa (pairwise), and Fleiss' kappa
|
| 6 |
+
(N raters), from annotation data. It supports both rating agreement (interval
|
| 7 |
+
metric) and skip agreement (nominal metric) calculations.
|
| 8 |
+
|
| 9 |
+
The module processes annotation files in JSON format and outputs agreement statistics
|
| 10 |
+
along with a CSV file containing the processed annotation data.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
from itertools import combinations
|
| 15 |
+
import simpledorff
|
| 16 |
+
from simpledorff.metrics import *
|
| 17 |
+
import ujson
|
| 18 |
+
import pandas as pd
|
| 19 |
+
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_nans(shape):
|
| 25 |
+
"""
|
| 26 |
+
Create a numpy array filled with NaN values.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
shape: The shape of the array to create
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
numpy.ndarray: Array filled with NaN values
|
| 33 |
+
"""
|
| 34 |
+
ar = np.empty(shape)
|
| 35 |
+
ar[:] = np.NaN
|
| 36 |
+
return ar
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def cohen_kappa_pairwise(reliability_df):
|
| 40 |
+
"""
|
| 41 |
+
Compute Cohen's kappa for every pair of annotators and return aggregate stats.
|
| 42 |
+
|
| 43 |
+
Cohen's kappa is defined for exactly two raters. With N>2 raters we compute
|
| 44 |
+
kappa for each pair on the items they both rated, then return the mean and the
|
| 45 |
+
per-pair breakdown. Pairs that share fewer than 2 items are skipped.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
reliability_df: long-format DataFrame with columns
|
| 49 |
+
unit (item id), annotator (user), annotation (label value).
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
dict with keys: mean_kappa (float | None), pairs (list of
|
| 53 |
+
{annotator_a, annotator_b, kappa, n_items}), n_pairs_evaluated,
|
| 54 |
+
n_pairs_skipped.
|
| 55 |
+
"""
|
| 56 |
+
from sklearn.metrics import cohen_kappa_score
|
| 57 |
+
|
| 58 |
+
annotators = sorted(reliability_df["annotator"].unique())
|
| 59 |
+
pairs = []
|
| 60 |
+
skipped = 0
|
| 61 |
+
|
| 62 |
+
for a, b in combinations(annotators, 2):
|
| 63 |
+
a_rows = reliability_df[reliability_df["annotator"] == a].set_index("unit")["annotation"]
|
| 64 |
+
b_rows = reliability_df[reliability_df["annotator"] == b].set_index("unit")["annotation"]
|
| 65 |
+
shared = a_rows.index.intersection(b_rows.index)
|
| 66 |
+
if len(shared) < 2:
|
| 67 |
+
skipped += 1
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
y_a = a_rows.loc[shared].astype(str).tolist()
|
| 71 |
+
y_b = b_rows.loc[shared].astype(str).tolist()
|
| 72 |
+
try:
|
| 73 |
+
kappa = float(cohen_kappa_score(y_a, y_b))
|
| 74 |
+
except Exception:
|
| 75 |
+
skipped += 1
|
| 76 |
+
continue
|
| 77 |
+
pairs.append({
|
| 78 |
+
"annotator_a": a,
|
| 79 |
+
"annotator_b": b,
|
| 80 |
+
"kappa": round(kappa, 4),
|
| 81 |
+
"n_items": int(len(shared)),
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
mean_kappa = (sum(p["kappa"] for p in pairs) / len(pairs)) if pairs else None
|
| 85 |
+
return {
|
| 86 |
+
"mean_kappa": round(mean_kappa, 4) if mean_kappa is not None else None,
|
| 87 |
+
"pairs": pairs,
|
| 88 |
+
"n_pairs_evaluated": len(pairs),
|
| 89 |
+
"n_pairs_skipped": skipped,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def fleiss_kappa(reliability_df):
|
| 94 |
+
"""
|
| 95 |
+
Compute Fleiss' kappa for N raters over a categorical label set.
|
| 96 |
+
|
| 97 |
+
Fleiss' kappa assumes the same number of ratings per item but tolerates
|
| 98 |
+
different rater identities per item. Items with fewer than 2 ratings are
|
| 99 |
+
dropped; the remaining items are padded by repeating their available
|
| 100 |
+
ratings up to the per-item rater count (`n_raters = max ratings per item`).
|
| 101 |
+
When per-item rater counts vary widely the metric is approximate; we report
|
| 102 |
+
`n_raters` and `n_items_evaluated` so the caller can judge.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
reliability_df: long-format DataFrame with columns
|
| 106 |
+
unit (item id), annotator (user), annotation (label value).
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
dict with keys: kappa (float | None), n_items_evaluated (int),
|
| 110 |
+
n_raters (int), n_categories (int), interpretation (str).
|
| 111 |
+
"""
|
| 112 |
+
if reliability_df.empty:
|
| 113 |
+
return {"kappa": None, "n_items_evaluated": 0, "n_raters": 0,
|
| 114 |
+
"n_categories": 0, "interpretation": "No data"}
|
| 115 |
+
|
| 116 |
+
df = reliability_df.copy()
|
| 117 |
+
df["annotation"] = df["annotation"].astype(str)
|
| 118 |
+
|
| 119 |
+
counts_by_item = df.groupby(["unit", "annotation"]).size().unstack(fill_value=0)
|
| 120 |
+
items_with_ratings = counts_by_item.sum(axis=1)
|
| 121 |
+
counts_by_item = counts_by_item.loc[items_with_ratings >= 2]
|
| 122 |
+
|
| 123 |
+
if counts_by_item.empty:
|
| 124 |
+
return {"kappa": None, "n_items_evaluated": 0, "n_raters": 0,
|
| 125 |
+
"n_categories": int(df["annotation"].nunique()),
|
| 126 |
+
"interpretation": "No items with >=2 raters"}
|
| 127 |
+
|
| 128 |
+
n_raters = int(counts_by_item.sum(axis=1).max())
|
| 129 |
+
n_items = int(counts_by_item.shape[0])
|
| 130 |
+
n_categories = int(counts_by_item.shape[1])
|
| 131 |
+
|
| 132 |
+
matrix = counts_by_item.to_numpy(dtype=float)
|
| 133 |
+
row_sums = matrix.sum(axis=1, keepdims=True)
|
| 134 |
+
row_sums[row_sums == 0] = 1.0
|
| 135 |
+
matrix = matrix * (n_raters / row_sums)
|
| 136 |
+
|
| 137 |
+
p_j = matrix.sum(axis=0) / (n_items * n_raters)
|
| 138 |
+
if n_raters < 2:
|
| 139 |
+
return {"kappa": None, "n_items_evaluated": n_items, "n_raters": n_raters,
|
| 140 |
+
"n_categories": n_categories,
|
| 141 |
+
"interpretation": "Need >=2 raters per item"}
|
| 142 |
+
p_i = (np.sum(matrix ** 2, axis=1) - n_raters) / (n_raters * (n_raters - 1))
|
| 143 |
+
p_bar = float(p_i.mean())
|
| 144 |
+
p_e = float(np.sum(p_j ** 2))
|
| 145 |
+
|
| 146 |
+
if p_e >= 1.0:
|
| 147 |
+
kappa = 1.0 if p_bar >= 1.0 else 0.0
|
| 148 |
+
else:
|
| 149 |
+
kappa = (p_bar - p_e) / (1 - p_e)
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"kappa": round(float(kappa), 4),
|
| 153 |
+
"n_items_evaluated": n_items,
|
| 154 |
+
"n_raters": n_raters,
|
| 155 |
+
"n_categories": n_categories,
|
| 156 |
+
"interpretation": interpret_kappa(kappa),
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def interpret_kappa(kappa):
|
| 161 |
+
"""Landis & Koch (1977) interpretation bands for kappa-family metrics."""
|
| 162 |
+
if kappa is None:
|
| 163 |
+
return "No agreement computable"
|
| 164 |
+
if kappa < 0:
|
| 165 |
+
return "Worse than chance"
|
| 166 |
+
if kappa < 0.21:
|
| 167 |
+
return "Slight"
|
| 168 |
+
if kappa < 0.41:
|
| 169 |
+
return "Fair"
|
| 170 |
+
if kappa < 0.61:
|
| 171 |
+
return "Moderate"
|
| 172 |
+
if kappa < 0.81:
|
| 173 |
+
return "Substantial"
|
| 174 |
+
return "Almost perfect"
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def flatten(annotations):
|
| 178 |
+
"""
|
| 179 |
+
Flatten annotation data structure for processing.
|
| 180 |
+
|
| 181 |
+
Converts a list of annotation dictionaries into a format where each
|
| 182 |
+
annotation is a dictionary mapping user IDs to their labels.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
annotations: List of annotation dictionaries
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
list: Flattened annotation data structure
|
| 189 |
+
|
| 190 |
+
Example:
|
| 191 |
+
Input: [{"user": "user1", "label": "positive"}, {"user": "user2", "label": "negative"}]
|
| 192 |
+
Output: [{"user1": "positive", "user2": "negative"}]
|
| 193 |
+
"""
|
| 194 |
+
return [{a["user"]: a["label"] for a in ann} for ann in annotations]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def main(args):
|
| 198 |
+
"""
|
| 199 |
+
Main function for calculating inter-annotator agreement.
|
| 200 |
+
|
| 201 |
+
This function processes annotation data from a JSON file, calculates
|
| 202 |
+
Krippendorff's alpha for both rating agreement and skip agreement,
|
| 203 |
+
and outputs the results along with a CSV file of the processed data.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
args: Command line arguments containing file paths
|
| 207 |
+
|
| 208 |
+
Side Effects:
|
| 209 |
+
- Reads annotation data from input file
|
| 210 |
+
- Prints agreement statistics to console
|
| 211 |
+
- Writes processed data to output CSV file
|
| 212 |
+
|
| 213 |
+
The function processes the first 385 annotations by default and handles
|
| 214 |
+
missing annotations and skipped items appropriately.
|
| 215 |
+
"""
|
| 216 |
+
# Load annotation data from JSON file
|
| 217 |
+
with open(args.file, "r") as f:
|
| 218 |
+
annotations = [ujson.loads(line)["annotations"] for line in f]
|
| 219 |
+
|
| 220 |
+
# Extract unique user IDs from all annotations
|
| 221 |
+
users = set([a["user"] for ann in annotations for a in ann])
|
| 222 |
+
annotations = flatten(annotations)
|
| 223 |
+
|
| 224 |
+
# Limit to first 385 annotations (configurable limit)
|
| 225 |
+
annotations = annotations[:385]
|
| 226 |
+
|
| 227 |
+
# Create data matrix for agreement calculation
|
| 228 |
+
# Each row represents a user, each column represents an annotation
|
| 229 |
+
# -1 values indicate skipped annotations, NaN indicates missing annotations
|
| 230 |
+
data = [
|
| 231 |
+
[np.nan if user not in a or int(a[user]) == -1 else int(a[user]) for a in annotations]
|
| 232 |
+
for user in users
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
# Create skip data matrix (boolean indicating if annotation was skipped)
|
| 236 |
+
skip_data = [
|
| 237 |
+
[np.nan if user not in a else int(a[user]) < 0 for a in annotations] for user in users
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
# Calculate statistics for each user
|
| 241 |
+
labeled = ~np.isnan(data)
|
| 242 |
+
skipped = [
|
| 243 |
+
[False if user not in a else int(a[user]) < 0 for a in annotations] for user in users
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
# Print summary statistics
|
| 247 |
+
print("calculating over:")
|
| 248 |
+
for user, skip in zip(labeled, skipped):
|
| 249 |
+
print("labeled:", sum(user))
|
| 250 |
+
print("skipped:", sum(skip))
|
| 251 |
+
|
| 252 |
+
# Count instances where all users provided annotations
|
| 253 |
+
print(np.all(labeled, axis=0).sum())
|
| 254 |
+
|
| 255 |
+
# Calculate and print Krippendorff's alpha for rating agreement
|
| 256 |
+
# Uses interval metric for continuous rating scales
|
| 257 |
+
print("rating agreement:")
|
| 258 |
+
print(simpledorff.calculate_krippendorffs_alpha(pd.DataFrame(data),metric_fn=interval_metric))
|
| 259 |
+
|
| 260 |
+
# Calculate and print Krippendorff's alpha for skip agreement
|
| 261 |
+
# Uses nominal metric for binary skip/no-skip decisions
|
| 262 |
+
print("skip agreement:")
|
| 263 |
+
print(simpledorff.calculate_krippendorffs_alpha(pd.DataFrame(data),metric_fn=nominal_metric))
|
| 264 |
+
|
| 265 |
+
# Write processed data to CSV file
|
| 266 |
+
with open(args.outfile, "w") as f:
|
| 267 |
+
for row in zip(*data):
|
| 268 |
+
f.write(",".join([str(a) for a in row]) + "\n")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
# Set up command line argument parsing
|
| 273 |
+
parser = argparse.ArgumentParser(
|
| 274 |
+
description="Calculate Krippendorf's alpha from given JSON file of annotations"
|
| 275 |
+
)
|
| 276 |
+
parser.add_argument("file", help="path to JSON file")
|
| 277 |
+
parser.add_argument("outfile", help="write path to CSV")
|
| 278 |
+
main(parser.parse_args())
|
potato/ai/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .ai_help_wrapper import generate_ai_help_html
|
potato/ai/ai_cache.py
ADDED
|
@@ -0,0 +1,1473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, Dict, Union
|
| 6 |
+
import requests
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import time
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
import threading
|
| 11 |
+
from builtins import open
|
| 12 |
+
from potato.server_utils.config_module import config
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
from potato.item_state_management import get_item_state_manager
|
| 17 |
+
from potato.ai.ai_endpoint import (
|
| 18 |
+
AIEndpointFactory,
|
| 19 |
+
Annotation_Type,
|
| 20 |
+
AnnotationInput,
|
| 21 |
+
ImageData,
|
| 22 |
+
VisualAnnotationInput,
|
| 23 |
+
ModelCapabilities,
|
| 24 |
+
)
|
| 25 |
+
from potato.ai.ollama_endpoint import OllamaEndpoint
|
| 26 |
+
from potato.ai.openrouter_endpoint import OpenRouterEndpoint
|
| 27 |
+
from potato.ai.ai_prompt import ModelManager, get_ai_prompt
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
AICACHEMANAGER = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _get_scheme_field(annotation_id: int, field: str, default=None):
|
| 34 |
+
"""Safely get a field from an annotation scheme with a clear error message."""
|
| 35 |
+
schemes = config.get("annotation_schemes", [])
|
| 36 |
+
if annotation_id >= len(schemes):
|
| 37 |
+
raise ValueError(
|
| 38 |
+
f"AI cache: annotation_id {annotation_id} out of range "
|
| 39 |
+
f"(only {len(schemes)} scheme(s) configured)"
|
| 40 |
+
)
|
| 41 |
+
scheme = schemes[annotation_id]
|
| 42 |
+
if default is not None:
|
| 43 |
+
return scheme.get(field, default)
|
| 44 |
+
if field not in scheme:
|
| 45 |
+
scheme_name = scheme.get("name", f"index {annotation_id}")
|
| 46 |
+
scheme_type = scheme.get("annotation_type", "unknown")
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"AI cache: annotation scheme '{scheme_name}' (type '{scheme_type}') "
|
| 49 |
+
f"missing required field '{field}'"
|
| 50 |
+
)
|
| 51 |
+
return scheme[field]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _get_instance_text(instance_id: int) -> str:
|
| 55 |
+
"""Get the text content from an instance using the configured text_key."""
|
| 56 |
+
item = get_item_state_manager().items()[instance_id]
|
| 57 |
+
item_data = item.get_data()
|
| 58 |
+
|
| 59 |
+
# Get the configured text_key
|
| 60 |
+
text_key = config.get("item_properties", {}).get("text_key", "text")
|
| 61 |
+
|
| 62 |
+
# Try the configured text_key first
|
| 63 |
+
if text_key in item_data:
|
| 64 |
+
return item_data[text_key]
|
| 65 |
+
|
| 66 |
+
# Fall back to common keys
|
| 67 |
+
for key in ['text', 'content', 'message']:
|
| 68 |
+
if key in item_data:
|
| 69 |
+
return item_data[key]
|
| 70 |
+
|
| 71 |
+
# Last resort: return any string value
|
| 72 |
+
for value in item_data.values():
|
| 73 |
+
if isinstance(value, str):
|
| 74 |
+
return value
|
| 75 |
+
|
| 76 |
+
return str(item_data)
|
| 77 |
+
|
| 78 |
+
def _is_image_url(text: str) -> bool:
|
| 79 |
+
"""Check if text appears to be an image URL."""
|
| 80 |
+
if not isinstance(text, str):
|
| 81 |
+
return False
|
| 82 |
+
text_lower = text.lower()
|
| 83 |
+
# Check for image extensions
|
| 84 |
+
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp']
|
| 85 |
+
if any(ext in text_lower for ext in image_extensions):
|
| 86 |
+
return True
|
| 87 |
+
# Check for common image hosting services
|
| 88 |
+
image_hosts = ['unsplash.com', 'imgur.com', 'flickr.com', 'picsum.photos']
|
| 89 |
+
if any(host in text_lower for host in image_hosts):
|
| 90 |
+
return True
|
| 91 |
+
# Check if URL starts with http and might be an image
|
| 92 |
+
if text_lower.startswith(('http://', 'https://')) and 'image' in text_lower:
|
| 93 |
+
return True
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
def _get_image_data_from_url(url: str) -> ImageData:
|
| 97 |
+
"""Download image from URL and return as ImageData.
|
| 98 |
+
|
| 99 |
+
Includes SSRF protection to prevent fetching from private/internal IPs.
|
| 100 |
+
"""
|
| 101 |
+
import base64
|
| 102 |
+
import ipaddress
|
| 103 |
+
import socket
|
| 104 |
+
from urllib.parse import urlparse
|
| 105 |
+
|
| 106 |
+
# SSRF protection: validate URL scheme and resolve hostname
|
| 107 |
+
try:
|
| 108 |
+
parsed = urlparse(url)
|
| 109 |
+
if parsed.scheme not in ('http', 'https'):
|
| 110 |
+
logger.warning(f"Blocked non-HTTP image URL: {url[:100]}")
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
hostname = parsed.hostname
|
| 114 |
+
if hostname:
|
| 115 |
+
addr_info = socket.getaddrinfo(hostname, None)
|
| 116 |
+
for info in addr_info:
|
| 117 |
+
ip_str = info[4][0]
|
| 118 |
+
try:
|
| 119 |
+
ip = ipaddress.ip_address(ip_str)
|
| 120 |
+
if ip.is_private or ip.is_loopback or ip.is_link_local:
|
| 121 |
+
logger.warning(
|
| 122 |
+
f"Blocked image URL resolving to private IP: "
|
| 123 |
+
f"{hostname} -> {ip_str}"
|
| 124 |
+
)
|
| 125 |
+
return None
|
| 126 |
+
except ValueError:
|
| 127 |
+
pass
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.warning(f"Failed to validate image URL {url[:100]}: {e}")
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
response = requests.get(url, timeout=30)
|
| 134 |
+
response.raise_for_status()
|
| 135 |
+
b64_data = base64.b64encode(response.content).decode('utf-8')
|
| 136 |
+
# Determine mime type from content-type header or URL
|
| 137 |
+
content_type = response.headers.get('content-type', 'image/jpeg')
|
| 138 |
+
return ImageData(source='base64', data=b64_data, mime_type=content_type)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Failed to download image from {url}: {e}")
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
def init_ai_cache_manager():
|
| 144 |
+
global AICACHEMANAGER
|
| 145 |
+
if AICACHEMANAGER is None:
|
| 146 |
+
AICACHEMANAGER = AiCacheManager()
|
| 147 |
+
|
| 148 |
+
return AICACHEMANAGER
|
| 149 |
+
|
| 150 |
+
def get_ai_cache_manager():
|
| 151 |
+
"""Get the AI cache manager instance. Returns None if not initialized (AI support disabled)."""
|
| 152 |
+
global AICACHEMANAGER
|
| 153 |
+
return AICACHEMANAGER
|
| 154 |
+
|
| 155 |
+
def clear_ai_cache_manager():
|
| 156 |
+
"""Clear the AI cache manager singleton. Used for testing."""
|
| 157 |
+
global AICACHEMANAGER
|
| 158 |
+
AICACHEMANAGER = None
|
| 159 |
+
|
| 160 |
+
class AiCacheManager:
|
| 161 |
+
def __init__(self):
|
| 162 |
+
ai_support = config["ai_support"]
|
| 163 |
+
if not ai_support["enabled"]:
|
| 164 |
+
return
|
| 165 |
+
cache_config = ai_support.get("cache_config", {})
|
| 166 |
+
ai_config = ai_support.get("ai_config", {})
|
| 167 |
+
include = ai_config.get("include") or {}
|
| 168 |
+
special_include = include.get("special_include", None)
|
| 169 |
+
self.include_all = include.get("all", False)
|
| 170 |
+
self.special_includes = {}
|
| 171 |
+
|
| 172 |
+
self.model_manager = ModelManager()
|
| 173 |
+
self.model_manager.load_models_module()
|
| 174 |
+
|
| 175 |
+
if special_include:
|
| 176 |
+
for page_key, page_value in special_include.items():
|
| 177 |
+
# Convert string keys to integers for easier lookup
|
| 178 |
+
page_index = int(page_key)
|
| 179 |
+
self.special_includes[page_index] = {}
|
| 180 |
+
for annotation_id, annotation_types in page_value.items():
|
| 181 |
+
annotation_id_int = int(annotation_id)
|
| 182 |
+
self.special_includes[page_index][annotation_id_int] = annotation_types
|
| 183 |
+
|
| 184 |
+
# Disk cache configuration.
|
| 185 |
+
# F-028: tolerate a partial/absent ai_cache config (e.g. AI support
|
| 186 |
+
# enabled for ICL with no disk_cache block) instead of crashing boot
|
| 187 |
+
# with KeyError: 'disk_cache'.
|
| 188 |
+
disk_cache_cfg = cache_config.get("disk_cache", {}) if isinstance(cache_config, dict) else {}
|
| 189 |
+
self.disk_cache_enabled = disk_cache_cfg.get("enabled", False)
|
| 190 |
+
|
| 191 |
+
disk_cache_path = disk_cache_cfg.get("path")
|
| 192 |
+
if self.disk_cache_enabled and not disk_cache_path:
|
| 193 |
+
raise Exception("You have enable disk cache, but you did not specific the path!")
|
| 194 |
+
self.disk_persistence_path = disk_cache_path
|
| 195 |
+
|
| 196 |
+
# Validate cache path stays within task directory
|
| 197 |
+
if self.disk_persistence_path:
|
| 198 |
+
task_dir = os.path.abspath(config.get("task_dir", "."))
|
| 199 |
+
cache_abs = os.path.abspath(
|
| 200 |
+
os.path.join(task_dir, self.disk_persistence_path)
|
| 201 |
+
if not os.path.isabs(self.disk_persistence_path)
|
| 202 |
+
else self.disk_persistence_path
|
| 203 |
+
)
|
| 204 |
+
if not cache_abs.startswith(task_dir + os.sep) and cache_abs != task_dir:
|
| 205 |
+
raise ValueError(
|
| 206 |
+
f"Cache path '{self.disk_persistence_path}' resolves to "
|
| 207 |
+
f"'{cache_abs}' which is outside the task directory "
|
| 208 |
+
f"'{task_dir}'. Path traversal is not allowed."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Prefetch configuration — clamp to sane ranges.
|
| 212 |
+
# F-028: default to no prefetch when the prefetch block is absent
|
| 213 |
+
# (e.g. cache_config: {enabled: false}) instead of KeyError on boot.
|
| 214 |
+
prefetch_cfg = cache_config.get("prefetch", {}) if isinstance(cache_config, dict) else {}
|
| 215 |
+
self.warm_up_page_count = max(0, min(int(prefetch_cfg.get("warm_up_page_count", 0)), 10000))
|
| 216 |
+
self.prefetch_page_count_on_next = max(0, min(int(prefetch_cfg.get("on_next", 0)), 10000))
|
| 217 |
+
self.prefetch_page_count_on_prev = max(0, min(int(prefetch_cfg.get("on_prev", 0)), 10000))
|
| 218 |
+
|
| 219 |
+
# Option highlighting configuration
|
| 220 |
+
option_highlighting = ai_support.get("option_highlighting", {})
|
| 221 |
+
self.option_highlighting_enabled = option_highlighting.get("enabled", False)
|
| 222 |
+
self.option_highlighting_top_k = option_highlighting.get("top_k", 3)
|
| 223 |
+
self.option_highlighting_dim_opacity = option_highlighting.get("dim_opacity", 0.4)
|
| 224 |
+
self.option_highlighting_auto_apply = option_highlighting.get("auto_apply", True)
|
| 225 |
+
self.option_highlighting_schemas = option_highlighting.get("schemas", None) # None means all
|
| 226 |
+
# Prefetch count for option highlighting — clamp to sane range
|
| 227 |
+
self.option_highlighting_prefetch_count = max(0, min(
|
| 228 |
+
int(option_highlighting.get("prefetch_count", 20)), 10000
|
| 229 |
+
))
|
| 230 |
+
|
| 231 |
+
# Threading
|
| 232 |
+
self.in_progress = {}
|
| 233 |
+
self.lock = threading.RLock()
|
| 234 |
+
self.executor = ThreadPoolExecutor(max_workers=20)
|
| 235 |
+
|
| 236 |
+
AIEndpointFactory.register_endpoint("ollama", OllamaEndpoint)
|
| 237 |
+
AIEndpointFactory.register_endpoint("open_router", OpenRouterEndpoint)
|
| 238 |
+
|
| 239 |
+
# Register visual AI endpoints
|
| 240 |
+
try:
|
| 241 |
+
from potato.ai.yolo_endpoint import YOLOEndpoint
|
| 242 |
+
AIEndpointFactory.register_endpoint("yolo", YOLOEndpoint)
|
| 243 |
+
except ImportError:
|
| 244 |
+
logger.debug("YOLO endpoint not available (ultralytics not installed)")
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
from potato.ai.ollama_vision_endpoint import OllamaVisionEndpoint
|
| 248 |
+
AIEndpointFactory.register_endpoint("ollama_vision", OllamaVisionEndpoint)
|
| 249 |
+
except ImportError:
|
| 250 |
+
logger.debug("Ollama Vision endpoint not available")
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
from potato.ai.openai_vision_endpoint import OpenAIVisionEndpoint
|
| 254 |
+
AIEndpointFactory.register_endpoint("openai_vision", OpenAIVisionEndpoint)
|
| 255 |
+
except ImportError:
|
| 256 |
+
logger.debug("OpenAI Vision endpoint not available")
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
from potato.ai.anthropic_vision_endpoint import AnthropicVisionEndpoint
|
| 260 |
+
AIEndpointFactory.register_endpoint("anthropic_vision", AnthropicVisionEndpoint)
|
| 261 |
+
except ImportError:
|
| 262 |
+
logger.debug("Anthropic Vision endpoint not available")
|
| 263 |
+
|
| 264 |
+
# Degrade gracefully if the AI backend (e.g. a local Ollama/vLLM server)
|
| 265 |
+
# is unreachable at boot: log a warning and serve the task with AI
|
| 266 |
+
# support disabled rather than aborting server startup.
|
| 267 |
+
try:
|
| 268 |
+
self.ai_endpoint = AIEndpointFactory.create_endpoint(config)
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.warning(
|
| 271 |
+
"AI endpoint unavailable at startup (%s). Continuing with AI "
|
| 272 |
+
"support disabled. Check that your AI backend is running.", e
|
| 273 |
+
)
|
| 274 |
+
self.ai_endpoint = None
|
| 275 |
+
|
| 276 |
+
# Create visual endpoint if different from main endpoint
|
| 277 |
+
self.visual_endpoint = None
|
| 278 |
+
visual_endpoint_type = config.get("ai_support", {}).get("visual_endpoint_type")
|
| 279 |
+
if visual_endpoint_type and visual_endpoint_type != config.get("ai_support", {}).get("endpoint_type"):
|
| 280 |
+
visual_config = {
|
| 281 |
+
"ai_support": {
|
| 282 |
+
"enabled": True,
|
| 283 |
+
"endpoint_type": visual_endpoint_type,
|
| 284 |
+
"ai_config": config.get("ai_support", {}).get("visual_ai_config", config.get("ai_support", {}).get("ai_config", {}))
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
try:
|
| 288 |
+
self.visual_endpoint = AIEndpointFactory.create_endpoint(visual_config)
|
| 289 |
+
except Exception as e:
|
| 290 |
+
logger.warning(
|
| 291 |
+
"Visual AI endpoint unavailable at startup (%s). Continuing "
|
| 292 |
+
"without visual AI support.", e
|
| 293 |
+
)
|
| 294 |
+
self.visual_endpoint = None
|
| 295 |
+
|
| 296 |
+
annotation_scheme = config.get("annotation_schemes")
|
| 297 |
+
self.annotations = []
|
| 298 |
+
for scheme in annotation_scheme:
|
| 299 |
+
self.annotations.append(scheme)
|
| 300 |
+
|
| 301 |
+
# Check if main endpoint supports vision
|
| 302 |
+
self.endpoint_supports_vision = hasattr(self.ai_endpoint, 'query_with_image')
|
| 303 |
+
logger.info(f"AI endpoint supports vision: {self.endpoint_supports_vision}")
|
| 304 |
+
|
| 305 |
+
# Initialize cache
|
| 306 |
+
if self.disk_cache_enabled:
|
| 307 |
+
self.load_cache_from_disk()
|
| 308 |
+
self.start_warmup()
|
| 309 |
+
|
| 310 |
+
def _validate_assistant_compatibility(
|
| 311 |
+
self, instance_id: int, annotation_id: int, ai_assistant: str
|
| 312 |
+
) -> tuple:
|
| 313 |
+
"""
|
| 314 |
+
Validate that the AI assistant is compatible with the input type and model capabilities.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
instance_id: The instance/item index
|
| 318 |
+
annotation_id: The annotation scheme index
|
| 319 |
+
ai_assistant: Type of assistance ('hint', 'keyword', 'rationale', 'detection', etc.)
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Tuple of (is_valid: bool, error_message: str)
|
| 323 |
+
If valid, error_message is empty string.
|
| 324 |
+
"""
|
| 325 |
+
try:
|
| 326 |
+
text = _get_instance_text(instance_id)
|
| 327 |
+
is_image = _is_image_url(text)
|
| 328 |
+
|
| 329 |
+
# Determine which endpoint to use
|
| 330 |
+
if is_image and self.visual_endpoint:
|
| 331 |
+
endpoint = self.visual_endpoint
|
| 332 |
+
elif is_image and self.endpoint_supports_vision:
|
| 333 |
+
endpoint = self.ai_endpoint
|
| 334 |
+
else:
|
| 335 |
+
endpoint = self.ai_endpoint
|
| 336 |
+
|
| 337 |
+
# Get capabilities from endpoint
|
| 338 |
+
capabilities = getattr(endpoint, 'CAPABILITIES', None)
|
| 339 |
+
|
| 340 |
+
if capabilities is None:
|
| 341 |
+
# No capabilities declared - allow all (backward compatibility)
|
| 342 |
+
logger.debug(f"Endpoint {type(endpoint).__name__} has no CAPABILITIES, allowing {ai_assistant}")
|
| 343 |
+
return True, ""
|
| 344 |
+
|
| 345 |
+
# Check if the assistant type is supported
|
| 346 |
+
if not capabilities.supports_assistant(ai_assistant, is_image):
|
| 347 |
+
input_type = "image" if is_image else "text"
|
| 348 |
+
return False, (
|
| 349 |
+
f"Model {type(endpoint).__name__} does not support '{ai_assistant}' "
|
| 350 |
+
f"for {input_type} content"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
return True, ""
|
| 354 |
+
|
| 355 |
+
except Exception as e:
|
| 356 |
+
logger.warning(f"Error validating assistant compatibility: {e}")
|
| 357 |
+
# On validation error, allow the request (fail open for now)
|
| 358 |
+
return True, ""
|
| 359 |
+
|
| 360 |
+
def get_endpoint_capabilities(self, for_image: bool = False) -> ModelCapabilities:
|
| 361 |
+
"""
|
| 362 |
+
Get the capabilities of the appropriate endpoint for the given input type.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
for_image: Whether the input is an image
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
ModelCapabilities instance, or a default permissive one if not declared
|
| 369 |
+
"""
|
| 370 |
+
if for_image and self.visual_endpoint:
|
| 371 |
+
endpoint = self.visual_endpoint
|
| 372 |
+
elif for_image and self.endpoint_supports_vision:
|
| 373 |
+
endpoint = self.ai_endpoint
|
| 374 |
+
else:
|
| 375 |
+
endpoint = self.ai_endpoint
|
| 376 |
+
|
| 377 |
+
capabilities = getattr(endpoint, 'CAPABILITIES', None)
|
| 378 |
+
if capabilities is None:
|
| 379 |
+
# Return permissive defaults for backward compatibility
|
| 380 |
+
return ModelCapabilities(
|
| 381 |
+
text_generation=True,
|
| 382 |
+
vision_input=for_image,
|
| 383 |
+
bounding_box_output=False,
|
| 384 |
+
text_classification=True,
|
| 385 |
+
image_classification=for_image,
|
| 386 |
+
rationale_generation=True,
|
| 387 |
+
keyword_extraction=not for_image,
|
| 388 |
+
)
|
| 389 |
+
return capabilities
|
| 390 |
+
|
| 391 |
+
def _get_ai_with_vision_support(self, text: str, prompt: str, output_format) -> str:
|
| 392 |
+
"""
|
| 393 |
+
Get AI response, using vision if text is an image URL and endpoint supports it.
|
| 394 |
+
"""
|
| 395 |
+
# Check if we should use vision
|
| 396 |
+
if self.endpoint_supports_vision and _is_image_url(text):
|
| 397 |
+
logger.debug(f"Using vision query for image URL: {text[:50]}...")
|
| 398 |
+
image_data = _get_image_data_from_url(text)
|
| 399 |
+
if image_data:
|
| 400 |
+
try:
|
| 401 |
+
return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
|
| 402 |
+
except Exception as e:
|
| 403 |
+
logger.error(f"Vision query failed: {e}")
|
| 404 |
+
# Fall back to text query
|
| 405 |
+
|
| 406 |
+
# Fall back to regular text query
|
| 407 |
+
return self.ai_endpoint.query(prompt, output_format)
|
| 408 |
+
|
| 409 |
+
def start_warmup(self):
|
| 410 |
+
self.start_prefetch(0, self.warm_up_page_count)
|
| 411 |
+
|
| 412 |
+
# Also prefetch option highlights if enabled
|
| 413 |
+
if self.option_highlighting_enabled:
|
| 414 |
+
self.start_option_highlight_prefetch(0, self.warm_up_page_count)
|
| 415 |
+
|
| 416 |
+
total = len(self.in_progress)
|
| 417 |
+
desc = "Preloading the AI"
|
| 418 |
+
|
| 419 |
+
progress_bar = tqdm(total=total, desc=desc, unit="item")
|
| 420 |
+
|
| 421 |
+
def count_completed():
|
| 422 |
+
return total - len(self.in_progress)
|
| 423 |
+
|
| 424 |
+
prev_done = 0
|
| 425 |
+
while self.in_progress:
|
| 426 |
+
current_done = count_completed()
|
| 427 |
+
progress_bar.update(current_done - prev_done)
|
| 428 |
+
prev_done = current_done
|
| 429 |
+
time.sleep(0.2)
|
| 430 |
+
|
| 431 |
+
final_done = count_completed()
|
| 432 |
+
if final_done > prev_done:
|
| 433 |
+
progress_bar.update(final_done - prev_done)
|
| 434 |
+
|
| 435 |
+
progress_bar.close()
|
| 436 |
+
|
| 437 |
+
def load_disk_cache_data(self, file_path: str) -> Dict[str, Any]:
|
| 438 |
+
"""loads the cache JSON from disk and returns a dictionary of stringified keys to values."""
|
| 439 |
+
try:
|
| 440 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 441 |
+
return json.load(f)
|
| 442 |
+
except Exception as e:
|
| 443 |
+
logger.error(f"Error loading disk cache: {e}")
|
| 444 |
+
return {}
|
| 445 |
+
|
| 446 |
+
def load_cache_from_disk(self):
|
| 447 |
+
"""Initializes disk cache file if it doesn't exist."""
|
| 448 |
+
if not self.disk_cache_enabled or not self.disk_persistence_path:
|
| 449 |
+
return
|
| 450 |
+
|
| 451 |
+
if os.path.exists(self.disk_persistence_path):
|
| 452 |
+
data = self.load_disk_cache_data(self.disk_persistence_path)
|
| 453 |
+
logger.info(f"Disk cache initialized with {len(data)} items")
|
| 454 |
+
else:
|
| 455 |
+
try:
|
| 456 |
+
# Create parent directory if it doesn't exist
|
| 457 |
+
os.makedirs(os.path.dirname(self.disk_persistence_path), exist_ok=True)
|
| 458 |
+
with open(self.disk_persistence_path, 'w', encoding='utf-8') as file:
|
| 459 |
+
json.dump({}, file)
|
| 460 |
+
logger.info(f"Initialized empty disk cache at {self.disk_persistence_path}")
|
| 461 |
+
except Exception as e:
|
| 462 |
+
logger.error(f"Failed to create disk cache: {e}")
|
| 463 |
+
|
| 464 |
+
def save_cache_to_disk(self, key, value):
|
| 465 |
+
"""saves a single key-value pair to disk cache using atomic write."""
|
| 466 |
+
if not self.disk_cache_enabled or not self.disk_persistence_path:
|
| 467 |
+
return
|
| 468 |
+
|
| 469 |
+
try:
|
| 470 |
+
os.makedirs(os.path.dirname(self.disk_persistence_path), exist_ok=True)
|
| 471 |
+
|
| 472 |
+
# Load existing disk data first
|
| 473 |
+
existing_disk_data = {}
|
| 474 |
+
if os.path.exists(self.disk_persistence_path):
|
| 475 |
+
existing_disk_data = self.load_disk_cache_data(self.disk_persistence_path)
|
| 476 |
+
|
| 477 |
+
# Add the new key-value pair
|
| 478 |
+
existing_disk_data[str(key)] = value
|
| 479 |
+
|
| 480 |
+
temp_path = self.disk_persistence_path + ".tmp"
|
| 481 |
+
with open(temp_path, 'w', encoding='utf-8') as f:
|
| 482 |
+
json.dump(existing_disk_data, f, indent=2, ensure_ascii=False)
|
| 483 |
+
os.rename(temp_path, self.disk_persistence_path)
|
| 484 |
+
except Exception as e:
|
| 485 |
+
logger.error(f"Error saving cache to disk: {e}")
|
| 486 |
+
|
| 487 |
+
def add_to_cache(self, key, value):
|
| 488 |
+
"""inserts a key-value into the disk cache."""
|
| 489 |
+
with self.lock:
|
| 490 |
+
if self.disk_cache_enabled:
|
| 491 |
+
self.save_cache_to_disk(key, value)
|
| 492 |
+
|
| 493 |
+
def get_from_cache(self, key):
|
| 494 |
+
"""Tries to retrieve the item from disk cache."""
|
| 495 |
+
with self.lock:
|
| 496 |
+
# Try disk cache
|
| 497 |
+
if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
|
| 498 |
+
try:
|
| 499 |
+
disk_data = self.load_disk_cache_data(self.disk_persistence_path)
|
| 500 |
+
key_str = str(key)
|
| 501 |
+
if key_str in disk_data:
|
| 502 |
+
return disk_data[key_str]
|
| 503 |
+
except Exception as e:
|
| 504 |
+
logger.error(f"Error reading from disk: {e}")
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
def generate_likert(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 508 |
+
from string import Template
|
| 509 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 510 |
+
description = _get_scheme_field(annotation_id, "description")
|
| 511 |
+
text = _get_instance_text(instance_id)
|
| 512 |
+
min_label = _get_scheme_field(annotation_id, "min_label")
|
| 513 |
+
max_label = _get_scheme_field(annotation_id, "max_label")
|
| 514 |
+
size = _get_scheme_field(annotation_id, "size")
|
| 515 |
+
|
| 516 |
+
ai_prompt = get_ai_prompt()
|
| 517 |
+
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
|
| 518 |
+
|
| 519 |
+
# Check if we should use vision endpoint for image-based content
|
| 520 |
+
if self.endpoint_supports_vision and _is_image_url(text):
|
| 521 |
+
logger.debug(f"Using vision for likert {ai_assistant} on image: {text[:50]}...")
|
| 522 |
+
image_data = _get_image_data_from_url(text)
|
| 523 |
+
if image_data:
|
| 524 |
+
# Build vision-specific prompts based on ai_assistant type
|
| 525 |
+
if ai_assistant == "hint":
|
| 526 |
+
prompt = f"""Look at this image and help with the following annotation task:
|
| 527 |
+
|
| 528 |
+
Task: {description}
|
| 529 |
+
Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
|
| 530 |
+
|
| 531 |
+
Please analyze the image and suggest an appropriate rating with a brief explanation.
|
| 532 |
+
Respond in JSON format: {{"hint": "<explanation>", "suggestive_choice": "<rating label>"}}"""
|
| 533 |
+
elif ai_assistant == "rationale":
|
| 534 |
+
prompt = f"""Look at this image and explain the reasoning for different rating choices:
|
| 535 |
+
|
| 536 |
+
Task: {description}
|
| 537 |
+
Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
|
| 538 |
+
|
| 539 |
+
For each possible rating, explain what visual evidence in the image would support that rating.
|
| 540 |
+
Respond in JSON format: {{"rationales": [{{"label": "<rating>", "reasoning": "<explanation>"}}]}}"""
|
| 541 |
+
elif ai_assistant == "keyword":
|
| 542 |
+
prompt = f"""Look at this image and identify visual features relevant to the rating task:
|
| 543 |
+
|
| 544 |
+
Task: {description}
|
| 545 |
+
Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
|
| 546 |
+
|
| 547 |
+
Identify key visual elements that would influence the rating.
|
| 548 |
+
Respond in JSON format: {{"keywords": ["<visual_feature_1>", "<visual_feature_2>"]}}"""
|
| 549 |
+
else:
|
| 550 |
+
prompt = f"Analyze this image for: {description}"
|
| 551 |
+
|
| 552 |
+
try:
|
| 553 |
+
return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
|
| 554 |
+
except Exception as e:
|
| 555 |
+
logger.error(f"Vision query failed for likert {ai_assistant}: {e}")
|
| 556 |
+
|
| 557 |
+
# Fall back to standard text-based generation
|
| 558 |
+
data = AnnotationInput(
|
| 559 |
+
ai_assistant=ai_assistant,
|
| 560 |
+
annotation_type=annotation_type,
|
| 561 |
+
text=text,
|
| 562 |
+
description=description,
|
| 563 |
+
min_label=min_label,
|
| 564 |
+
max_label=max_label,
|
| 565 |
+
size=size
|
| 566 |
+
)
|
| 567 |
+
res = self.ai_endpoint.get_ai(data, output_format)
|
| 568 |
+
return res
|
| 569 |
+
|
| 570 |
+
def generate_multiselect(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 571 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 572 |
+
description = _get_scheme_field(annotation_id, "description")
|
| 573 |
+
labels = _get_scheme_field(annotation_id, "labels")
|
| 574 |
+
text = _get_instance_text(instance_id)
|
| 575 |
+
|
| 576 |
+
ai_prompt = get_ai_prompt()
|
| 577 |
+
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
|
| 578 |
+
|
| 579 |
+
# Check if we should use vision endpoint for image-based content
|
| 580 |
+
if self.endpoint_supports_vision and _is_image_url(text):
|
| 581 |
+
logger.debug(f"Using vision for multiselect {ai_assistant} on image: {text[:50]}...")
|
| 582 |
+
image_data = _get_image_data_from_url(text)
|
| 583 |
+
if image_data:
|
| 584 |
+
# Format labels for the prompt
|
| 585 |
+
label_names = [l.get('name', l) if isinstance(l, dict) else l for l in labels]
|
| 586 |
+
labels_str = ', '.join(f'"{name}"' for name in label_names)
|
| 587 |
+
|
| 588 |
+
# Build vision-specific prompts based on ai_assistant type
|
| 589 |
+
if ai_assistant == "hint":
|
| 590 |
+
prompt = f"""Look at this image and help with the following annotation task:
|
| 591 |
+
|
| 592 |
+
Task: {description}
|
| 593 |
+
Available options (select all that apply): {labels_str}
|
| 594 |
+
|
| 595 |
+
Please analyze the image and suggest which options apply.
|
| 596 |
+
Respond in JSON format: {{"hint": "<explanation>", "suggestive_choices": ["<option1>", "<option2>"]}}"""
|
| 597 |
+
elif ai_assistant == "rationale":
|
| 598 |
+
prompt = f"""Look at this image and explain the reasoning for each option:
|
| 599 |
+
|
| 600 |
+
Task: {description}
|
| 601 |
+
Available options: {labels_str}
|
| 602 |
+
|
| 603 |
+
For each option, explain what visual evidence supports or contradicts it.
|
| 604 |
+
Respond in JSON format: {{"rationales": [{{"label": "<option>", "reasoning": "<explanation>"}}]}}"""
|
| 605 |
+
elif ai_assistant == "keyword":
|
| 606 |
+
prompt = f"""Look at this image and identify visual features for each option:
|
| 607 |
+
|
| 608 |
+
Task: {description}
|
| 609 |
+
Available options: {labels_str}
|
| 610 |
+
|
| 611 |
+
For each option, identify visual cues that indicate its presence.
|
| 612 |
+
Respond in JSON format: {{"label_keywords": [{{"label": "<option>", "keywords": ["<feature1>", "<feature2>"]}}]}}"""
|
| 613 |
+
else:
|
| 614 |
+
prompt = f"Analyze this image for: {description}. Options: {labels_str}"
|
| 615 |
+
|
| 616 |
+
try:
|
| 617 |
+
return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
|
| 618 |
+
except Exception as e:
|
| 619 |
+
logger.error(f"Vision query failed for multiselect {ai_assistant}: {e}")
|
| 620 |
+
|
| 621 |
+
# Fall back to standard text-based generation
|
| 622 |
+
data = AnnotationInput(
|
| 623 |
+
ai_assistant=ai_assistant,
|
| 624 |
+
annotation_type=annotation_type,
|
| 625 |
+
text=text,
|
| 626 |
+
description=description,
|
| 627 |
+
labels=labels
|
| 628 |
+
)
|
| 629 |
+
res = self.ai_endpoint.get_ai(data, output_format)
|
| 630 |
+
return res
|
| 631 |
+
|
| 632 |
+
def generate_radio(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 633 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 634 |
+
description = _get_scheme_field(annotation_id, "description")
|
| 635 |
+
text = _get_instance_text(instance_id)
|
| 636 |
+
labels = _get_scheme_field(annotation_id, "labels")
|
| 637 |
+
|
| 638 |
+
ai_prompt = get_ai_prompt()
|
| 639 |
+
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
|
| 640 |
+
|
| 641 |
+
# Check if we should use vision endpoint for image-based content
|
| 642 |
+
if self.endpoint_supports_vision and _is_image_url(text):
|
| 643 |
+
logger.debug(f"Using vision for radio {ai_assistant} on image: {text[:50]}...")
|
| 644 |
+
image_data = _get_image_data_from_url(text)
|
| 645 |
+
if image_data:
|
| 646 |
+
# Format labels for the prompt
|
| 647 |
+
label_names = [l.get('name', l) if isinstance(l, dict) else l for l in labels]
|
| 648 |
+
labels_str = ', '.join(f'"{name}"' for name in label_names)
|
| 649 |
+
|
| 650 |
+
# Build vision-specific prompts based on ai_assistant type
|
| 651 |
+
if ai_assistant == "hint":
|
| 652 |
+
prompt = f"""Look at this image and help with the following annotation task:
|
| 653 |
+
|
| 654 |
+
Task: {description}
|
| 655 |
+
Available options: {labels_str}
|
| 656 |
+
|
| 657 |
+
Please analyze the image and suggest the most appropriate option.
|
| 658 |
+
Respond in JSON format: {{"hint": "<explanation>", "suggestive_choice": "<selected option>"}}"""
|
| 659 |
+
elif ai_assistant == "rationale":
|
| 660 |
+
prompt = f"""Look at this image and explain the reasoning for each option:
|
| 661 |
+
|
| 662 |
+
Task: {description}
|
| 663 |
+
Available options: {labels_str}
|
| 664 |
+
|
| 665 |
+
For each option, explain what visual evidence in the image supports or contradicts it.
|
| 666 |
+
Respond in JSON format: {{"rationales": [{{"label": "<option>", "reasoning": "<explanation>"}}]}}"""
|
| 667 |
+
elif ai_assistant == "keyword":
|
| 668 |
+
prompt = f"""Look at this image and identify visual features for each option:
|
| 669 |
+
|
| 670 |
+
Task: {description}
|
| 671 |
+
Available options: {labels_str}
|
| 672 |
+
|
| 673 |
+
For each option, identify visual cues that would indicate its presence.
|
| 674 |
+
Respond in JSON format: {{"label_keywords": [{{"label": "<option>", "keywords": ["<feature1>", "<feature2>"]}}]}}"""
|
| 675 |
+
else:
|
| 676 |
+
prompt = f"Analyze this image for: {description}. Options: {labels_str}"
|
| 677 |
+
|
| 678 |
+
try:
|
| 679 |
+
return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
|
| 680 |
+
except Exception as e:
|
| 681 |
+
logger.error(f"Vision query failed for radio {ai_assistant}: {e}")
|
| 682 |
+
|
| 683 |
+
# Fall back to standard text-based generation
|
| 684 |
+
data = AnnotationInput(
|
| 685 |
+
ai_assistant=ai_assistant,
|
| 686 |
+
annotation_type=annotation_type,
|
| 687 |
+
text=text,
|
| 688 |
+
description=description,
|
| 689 |
+
labels=labels
|
| 690 |
+
)
|
| 691 |
+
res = self.ai_endpoint.get_ai(data, output_format)
|
| 692 |
+
return res
|
| 693 |
+
|
| 694 |
+
def generate_number(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 695 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 696 |
+
description = _get_scheme_field(annotation_id, "description")
|
| 697 |
+
text = _get_instance_text(instance_id)
|
| 698 |
+
|
| 699 |
+
data = AnnotationInput(
|
| 700 |
+
ai_assistant=ai_assistant,
|
| 701 |
+
annotation_type=annotation_type,
|
| 702 |
+
text=text,
|
| 703 |
+
description=description,
|
| 704 |
+
)
|
| 705 |
+
ai_prompt = get_ai_prompt();
|
| 706 |
+
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
|
| 707 |
+
res = self.ai_endpoint.get_ai(data, output_format)
|
| 708 |
+
return res
|
| 709 |
+
|
| 710 |
+
def generate_select(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 711 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 712 |
+
description = _get_scheme_field(annotation_id, "description")
|
| 713 |
+
labels = _get_scheme_field(annotation_id, "labels")
|
| 714 |
+
text = _get_instance_text(instance_id)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
data = AnnotationInput(
|
| 718 |
+
ai_assistant=ai_assistant,
|
| 719 |
+
annotation_type=annotation_type,
|
| 720 |
+
text=text,
|
| 721 |
+
description=description,
|
| 722 |
+
labels=labels
|
| 723 |
+
)
|
| 724 |
+
ai_prompt = get_ai_prompt();
|
| 725 |
+
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
|
| 726 |
+
res = self.ai_endpoint.get_ai(data, output_format)
|
| 727 |
+
return res
|
| 728 |
+
|
| 729 |
+
def generate_slider(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 730 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 731 |
+
description = _get_scheme_field(annotation_id, "description")
|
| 732 |
+
min_value = _get_scheme_field(annotation_id, "min_value")
|
| 733 |
+
max_value = _get_scheme_field(annotation_id, "max_value")
|
| 734 |
+
step = _get_scheme_field(annotation_id, "step", default=1)
|
| 735 |
+
text = _get_instance_text(instance_id)
|
| 736 |
+
|
| 737 |
+
data = AnnotationInput(
|
| 738 |
+
ai_assistant=ai_assistant,
|
| 739 |
+
annotation_type=annotation_type,
|
| 740 |
+
text=text,
|
| 741 |
+
description=description,
|
| 742 |
+
min_value=min_value,
|
| 743 |
+
max_value=max_value,
|
| 744 |
+
step=step
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
ai_prompt = get_ai_prompt();
|
| 748 |
+
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
|
| 749 |
+
res = self.ai_endpoint.get_ai(data, output_format)
|
| 750 |
+
return res
|
| 751 |
+
|
| 752 |
+
def generate_span(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 753 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 754 |
+
description = _get_scheme_field(annotation_id, "description")
|
| 755 |
+
labels = _get_scheme_field(annotation_id, "labels")
|
| 756 |
+
text = _get_instance_text(instance_id)
|
| 757 |
+
|
| 758 |
+
data = AnnotationInput(
|
| 759 |
+
ai_assistant=ai_assistant,
|
| 760 |
+
annotation_type=annotation_type,
|
| 761 |
+
text=text,
|
| 762 |
+
description=description,
|
| 763 |
+
labels=labels
|
| 764 |
+
)
|
| 765 |
+
ai_prompt = get_ai_prompt();
|
| 766 |
+
logger.debug(f"Generating span annotation with labels: {labels}")
|
| 767 |
+
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
|
| 768 |
+
res = self.ai_endpoint.get_ai(data, output_format)
|
| 769 |
+
return res
|
| 770 |
+
|
| 771 |
+
def generate_textbox(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 772 |
+
logger.debug(f"Generating textbox for annotation_id: {annotation_id}")
|
| 773 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 774 |
+
description = _get_scheme_field(annotation_id, "description")
|
| 775 |
+
text = _get_instance_text(instance_id)
|
| 776 |
+
|
| 777 |
+
data = AnnotationInput(
|
| 778 |
+
ai_assistant=ai_assistant,
|
| 779 |
+
annotation_type=annotation_type,
|
| 780 |
+
text=text,
|
| 781 |
+
description=description,
|
| 782 |
+
)
|
| 783 |
+
ai_prompt = get_ai_prompt();
|
| 784 |
+
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
|
| 785 |
+
res = self.ai_endpoint.get_ai(data, output_format)
|
| 786 |
+
return res
|
| 787 |
+
|
| 788 |
+
def generate_image_annotation(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
|
| 789 |
+
"""Generate AI assistance for image annotation tasks.
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
instance_id: The instance/item index
|
| 793 |
+
annotation_id: The annotation scheme index
|
| 794 |
+
ai_assistant: Type of assistance ('detection', 'classification', 'hint', 'pre_annotate', etc.)
|
| 795 |
+
|
| 796 |
+
Returns:
|
| 797 |
+
Dict with AI suggestions (detections, classifications, hints, etc.)
|
| 798 |
+
"""
|
| 799 |
+
logger.debug(f"Generating image annotation for instance={instance_id}, annotation={annotation_id}, assistant={ai_assistant}")
|
| 800 |
+
|
| 801 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 802 |
+
description = _get_scheme_field(annotation_id, "description", default="")
|
| 803 |
+
labels = _get_scheme_field(annotation_id, "labels", default=[])
|
| 804 |
+
|
| 805 |
+
# Extract label names if labels are dicts
|
| 806 |
+
if labels and isinstance(labels[0], dict):
|
| 807 |
+
labels = [l.get("name", str(l)) for l in labels]
|
| 808 |
+
|
| 809 |
+
# Get image URL from item data
|
| 810 |
+
item_data = get_item_state_manager().items()[instance_id].get_data()
|
| 811 |
+
image_url = self._extract_image_url(item_data)
|
| 812 |
+
|
| 813 |
+
if not image_url:
|
| 814 |
+
return {"error": "No image URL found in instance data"}
|
| 815 |
+
|
| 816 |
+
# Determine which endpoint to use
|
| 817 |
+
endpoint = self._get_visual_endpoint()
|
| 818 |
+
if not endpoint:
|
| 819 |
+
return {"error": "No visual AI endpoint configured"}
|
| 820 |
+
|
| 821 |
+
# Check if endpoint supports visual queries
|
| 822 |
+
if not hasattr(endpoint, 'query_with_image'):
|
| 823 |
+
# Fall back to text-based hint
|
| 824 |
+
return self._generate_text_hint_for_visual(instance_id, annotation_id, ai_assistant)
|
| 825 |
+
|
| 826 |
+
# Prepare image data
|
| 827 |
+
image_data = self._prepare_image_data(image_url)
|
| 828 |
+
|
| 829 |
+
# Get confidence threshold from config
|
| 830 |
+
confidence_threshold = _get_scheme_field(annotation_id, "ai_support", default={}).get(
|
| 831 |
+
"confidence_threshold", 0.5
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Build VisualAnnotationInput
|
| 835 |
+
data = VisualAnnotationInput(
|
| 836 |
+
ai_assistant=ai_assistant,
|
| 837 |
+
annotation_type=annotation_type,
|
| 838 |
+
task_type=ai_assistant, # detection, classification, hint, etc.
|
| 839 |
+
image_data=image_data,
|
| 840 |
+
description=description,
|
| 841 |
+
labels=labels,
|
| 842 |
+
confidence_threshold=confidence_threshold
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# Get output format from prompt config
|
| 846 |
+
ai_prompt = get_ai_prompt()
|
| 847 |
+
prompt_config = ai_prompt.get(annotation_type, {}).get(ai_assistant, {})
|
| 848 |
+
output_format_name = prompt_config.get("output_format", "visual_detection")
|
| 849 |
+
output_format = self.model_manager.get_model_class_by_name(output_format_name)
|
| 850 |
+
|
| 851 |
+
# Query the visual endpoint
|
| 852 |
+
result = endpoint.get_visual_ai(data, output_format)
|
| 853 |
+
return result
|
| 854 |
+
|
| 855 |
+
def generate_video_annotation(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
|
| 856 |
+
"""Generate AI assistance for video annotation tasks.
|
| 857 |
+
|
| 858 |
+
Args:
|
| 859 |
+
instance_id: The instance/item index
|
| 860 |
+
annotation_id: The annotation scheme index
|
| 861 |
+
ai_assistant: Type of assistance ('scene_detection', 'frame_classification', etc.)
|
| 862 |
+
|
| 863 |
+
Returns:
|
| 864 |
+
Dict with AI suggestions (segments, keyframes, etc.)
|
| 865 |
+
"""
|
| 866 |
+
logger.debug(f"Generating video annotation for instance={instance_id}, annotation={annotation_id}, assistant={ai_assistant}")
|
| 867 |
+
|
| 868 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
|
| 869 |
+
description = _get_scheme_field(annotation_id, "description", default="")
|
| 870 |
+
labels = _get_scheme_field(annotation_id, "labels", default=[])
|
| 871 |
+
|
| 872 |
+
# Extract label names if labels are dicts
|
| 873 |
+
if labels and isinstance(labels[0], dict):
|
| 874 |
+
labels = [l.get("name", str(l)) for l in labels]
|
| 875 |
+
|
| 876 |
+
# Get video URL from item data
|
| 877 |
+
item_data = get_item_state_manager().items()[instance_id].get_data()
|
| 878 |
+
video_url = self._extract_video_url(item_data)
|
| 879 |
+
|
| 880 |
+
if not video_url:
|
| 881 |
+
return {"error": "No video URL found in instance data"}
|
| 882 |
+
|
| 883 |
+
# Determine which endpoint to use
|
| 884 |
+
endpoint = self._get_visual_endpoint()
|
| 885 |
+
if not endpoint:
|
| 886 |
+
return {"error": "No visual AI endpoint configured"}
|
| 887 |
+
|
| 888 |
+
# Check if endpoint supports visual queries
|
| 889 |
+
if not hasattr(endpoint, 'query_with_image'):
|
| 890 |
+
return self._generate_text_hint_for_visual(instance_id, annotation_id, ai_assistant)
|
| 891 |
+
|
| 892 |
+
# Extract video frames
|
| 893 |
+
try:
|
| 894 |
+
frames = endpoint.extract_video_frames(video_url)
|
| 895 |
+
video_metadata = endpoint.get_video_metadata(video_url)
|
| 896 |
+
except Exception as e:
|
| 897 |
+
logger.error(f"Failed to extract video frames: {e}")
|
| 898 |
+
return {"error": f"Failed to process video: {str(e)}"}
|
| 899 |
+
|
| 900 |
+
# Build VisualAnnotationInput
|
| 901 |
+
data = VisualAnnotationInput(
|
| 902 |
+
ai_assistant=ai_assistant,
|
| 903 |
+
annotation_type=annotation_type,
|
| 904 |
+
task_type=ai_assistant,
|
| 905 |
+
image_data=frames, # List of frame images
|
| 906 |
+
description=description,
|
| 907 |
+
labels=labels,
|
| 908 |
+
video_metadata=video_metadata
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
# Get output format
|
| 912 |
+
ai_prompt = get_ai_prompt()
|
| 913 |
+
prompt_config = ai_prompt.get(annotation_type, {}).get(ai_assistant, {})
|
| 914 |
+
output_format_name = prompt_config.get("output_format", "video_scene_detection")
|
| 915 |
+
output_format = self.model_manager.get_model_class_by_name(output_format_name)
|
| 916 |
+
|
| 917 |
+
# Query the visual endpoint
|
| 918 |
+
result = endpoint.get_visual_ai(data, output_format)
|
| 919 |
+
return result
|
| 920 |
+
|
| 921 |
+
def _get_visual_endpoint(self):
|
| 922 |
+
"""Get the appropriate endpoint for visual tasks."""
|
| 923 |
+
# Use dedicated visual endpoint if configured
|
| 924 |
+
if self.visual_endpoint:
|
| 925 |
+
return self.visual_endpoint
|
| 926 |
+
|
| 927 |
+
# Check if main endpoint supports vision
|
| 928 |
+
if hasattr(self.ai_endpoint, 'query_with_image'):
|
| 929 |
+
return self.ai_endpoint
|
| 930 |
+
|
| 931 |
+
# Try to find a visual endpoint from registered types
|
| 932 |
+
visual_types = ['yolo', 'ollama_vision', 'openai_vision', 'anthropic_vision']
|
| 933 |
+
for vtype in visual_types:
|
| 934 |
+
if vtype in AIEndpointFactory._endpoints:
|
| 935 |
+
try:
|
| 936 |
+
visual_config = {
|
| 937 |
+
"ai_support": {
|
| 938 |
+
"enabled": True,
|
| 939 |
+
"endpoint_type": vtype,
|
| 940 |
+
"ai_config": config.get("ai_support", {}).get("ai_config", {})
|
| 941 |
+
}
|
| 942 |
+
}
|
| 943 |
+
return AIEndpointFactory.create_endpoint(visual_config)
|
| 944 |
+
except Exception as e:
|
| 945 |
+
logger.debug(f"Could not create {vtype} endpoint: {e}")
|
| 946 |
+
continue
|
| 947 |
+
|
| 948 |
+
return None
|
| 949 |
+
|
| 950 |
+
def _extract_image_url(self, item_data: Dict) -> str:
|
| 951 |
+
"""Extract image URL from item data.
|
| 952 |
+
|
| 953 |
+
Looks for common field names that might contain image URLs.
|
| 954 |
+
"""
|
| 955 |
+
# Common field names for images
|
| 956 |
+
image_fields = ['image', 'image_url', 'img', 'img_url', 'url', 'path', 'file', 'src']
|
| 957 |
+
|
| 958 |
+
for field in image_fields:
|
| 959 |
+
if field in item_data:
|
| 960 |
+
value = item_data[field]
|
| 961 |
+
if isinstance(value, str) and (
|
| 962 |
+
value.startswith(('http://', 'https://', '/')) or
|
| 963 |
+
value.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))
|
| 964 |
+
):
|
| 965 |
+
return value
|
| 966 |
+
|
| 967 |
+
# Check 'text' field for URL (common in simple configs)
|
| 968 |
+
if 'text' in item_data:
|
| 969 |
+
text = item_data['text']
|
| 970 |
+
if isinstance(text, str) and (
|
| 971 |
+
text.startswith(('http://', 'https://')) and
|
| 972 |
+
any(ext in text.lower() for ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp'])
|
| 973 |
+
):
|
| 974 |
+
return text
|
| 975 |
+
|
| 976 |
+
return None
|
| 977 |
+
|
| 978 |
+
def _extract_video_url(self, item_data: Dict) -> str:
|
| 979 |
+
"""Extract video URL from item data."""
|
| 980 |
+
# Common field names for videos
|
| 981 |
+
video_fields = ['video', 'video_url', 'url', 'path', 'file', 'src', 'media']
|
| 982 |
+
|
| 983 |
+
for field in video_fields:
|
| 984 |
+
if field in item_data:
|
| 985 |
+
value = item_data[field]
|
| 986 |
+
if isinstance(value, str) and (
|
| 987 |
+
value.startswith(('http://', 'https://', '/')) or
|
| 988 |
+
value.endswith(('.mp4', '.webm', '.ogg', '.avi', '.mov'))
|
| 989 |
+
):
|
| 990 |
+
return value
|
| 991 |
+
|
| 992 |
+
# Check 'text' field for URL
|
| 993 |
+
if 'text' in item_data:
|
| 994 |
+
text = item_data['text']
|
| 995 |
+
if isinstance(text, str) and (
|
| 996 |
+
text.startswith(('http://', 'https://')) and
|
| 997 |
+
any(ext in text.lower() for ext in ['.mp4', '.webm', '.ogg', '.avi', '.mov'])
|
| 998 |
+
):
|
| 999 |
+
return text
|
| 1000 |
+
|
| 1001 |
+
return None
|
| 1002 |
+
|
| 1003 |
+
def _prepare_image_data(self, image_url: str) -> ImageData:
|
| 1004 |
+
"""Prepare ImageData from URL or path."""
|
| 1005 |
+
if image_url.startswith(('http://', 'https://')):
|
| 1006 |
+
return ImageData(source="url", data=image_url)
|
| 1007 |
+
else:
|
| 1008 |
+
# Local file path - encode as base64
|
| 1009 |
+
from potato.ai.visual_ai_endpoint import BaseVisualAIEndpoint
|
| 1010 |
+
return BaseVisualAIEndpoint.encode_image_to_base64(image_url)
|
| 1011 |
+
|
| 1012 |
+
def _generate_text_hint_for_visual(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
|
| 1013 |
+
"""Generate text-based hint when visual endpoint is not available."""
|
| 1014 |
+
description = config["annotation_schemes"][annotation_id].get("description", "")
|
| 1015 |
+
labels = config["annotation_schemes"][annotation_id].get("labels", [])
|
| 1016 |
+
|
| 1017 |
+
if labels and isinstance(labels[0], dict):
|
| 1018 |
+
labels = [l.get("name", str(l)) for l in labels]
|
| 1019 |
+
|
| 1020 |
+
return {
|
| 1021 |
+
"hint": f"Review the {'image' if 'image' in config['annotation_schemes'][annotation_id]['annotation_type'] else 'video'} carefully. "
|
| 1022 |
+
f"Look for: {', '.join(labels) if labels else 'relevant content'}. "
|
| 1023 |
+
f"Task: {description}",
|
| 1024 |
+
"suggestive_choice": ""
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
def is_option_highlighting_enabled_for_scheme(self, annotation_id: int) -> bool:
|
| 1028 |
+
"""Check if option highlighting is enabled for a specific annotation scheme."""
|
| 1029 |
+
if not self.option_highlighting_enabled:
|
| 1030 |
+
return False
|
| 1031 |
+
|
| 1032 |
+
scheme = config["annotation_schemes"][annotation_id]
|
| 1033 |
+
annotation_type = scheme.get("annotation_type", "")
|
| 1034 |
+
scheme_name = scheme.get("name", "")
|
| 1035 |
+
|
| 1036 |
+
# Only applicable to discrete option types
|
| 1037 |
+
discrete_types = ["radio", "multiselect", "likert", "select"]
|
| 1038 |
+
if annotation_type not in discrete_types:
|
| 1039 |
+
return False
|
| 1040 |
+
|
| 1041 |
+
# Check if schemas filter is set
|
| 1042 |
+
if self.option_highlighting_schemas is not None:
|
| 1043 |
+
if scheme_name not in self.option_highlighting_schemas:
|
| 1044 |
+
return False
|
| 1045 |
+
|
| 1046 |
+
return True
|
| 1047 |
+
|
| 1048 |
+
def get_option_highlighting_config(self) -> Dict:
|
| 1049 |
+
"""Get the option highlighting configuration for the frontend."""
|
| 1050 |
+
return {
|
| 1051 |
+
"enabled": self.option_highlighting_enabled,
|
| 1052 |
+
"top_k": self.option_highlighting_top_k,
|
| 1053 |
+
"dim_opacity": self.option_highlighting_dim_opacity,
|
| 1054 |
+
"auto_apply": self.option_highlighting_auto_apply,
|
| 1055 |
+
"schemas": self.option_highlighting_schemas,
|
| 1056 |
+
"prefetch_count": self.option_highlighting_prefetch_count,
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
def generate_option_highlights(self, instance_id: int, annotation_id: int) -> Dict:
|
| 1060 |
+
"""Generate option highlighting suggestions for an annotation.
|
| 1061 |
+
|
| 1062 |
+
Args:
|
| 1063 |
+
instance_id: The instance/item index
|
| 1064 |
+
annotation_id: The annotation scheme index
|
| 1065 |
+
|
| 1066 |
+
Returns:
|
| 1067 |
+
Dict with highlighted options and configuration:
|
| 1068 |
+
{
|
| 1069 |
+
"highlighted": ["option1", "option2"],
|
| 1070 |
+
"top_k": 3,
|
| 1071 |
+
"confidence": 0.85
|
| 1072 |
+
}
|
| 1073 |
+
"""
|
| 1074 |
+
from string import Template
|
| 1075 |
+
|
| 1076 |
+
if not self.is_option_highlighting_enabled_for_scheme(annotation_id):
|
| 1077 |
+
return {"error": "Option highlighting not enabled for this scheme"}
|
| 1078 |
+
|
| 1079 |
+
annotation_type = _get_scheme_field(annotation_id, "annotation_type", default="")
|
| 1080 |
+
description = _get_scheme_field(annotation_id, "description", default="")
|
| 1081 |
+
labels = _get_scheme_field(annotation_id, "labels", default=[])
|
| 1082 |
+
|
| 1083 |
+
# Extract label names
|
| 1084 |
+
if labels and isinstance(labels[0], dict):
|
| 1085 |
+
label_names = [l.get("name", str(l)) for l in labels]
|
| 1086 |
+
else:
|
| 1087 |
+
label_names = [str(l) for l in labels]
|
| 1088 |
+
|
| 1089 |
+
# For likert scales, generate label names from min/max labels
|
| 1090 |
+
if annotation_type == "likert":
|
| 1091 |
+
size = scheme.get("size", 5)
|
| 1092 |
+
min_label = scheme.get("min_label", "1")
|
| 1093 |
+
max_label = scheme.get("max_label", str(size))
|
| 1094 |
+
label_names = [f"{i+1} ({min_label if i == 0 else max_label if i == size-1 else ''})" for i in range(size)]
|
| 1095 |
+
# Clean up empty parentheses
|
| 1096 |
+
label_names = [l.replace(" ()", "") for l in label_names]
|
| 1097 |
+
|
| 1098 |
+
text = _get_instance_text(instance_id)
|
| 1099 |
+
top_k = min(self.option_highlighting_top_k, len(label_names))
|
| 1100 |
+
|
| 1101 |
+
# Get prompt template
|
| 1102 |
+
ai_prompt = get_ai_prompt()
|
| 1103 |
+
prompt_config = ai_prompt.get("option_highlight", {}).get("option_highlight", {})
|
| 1104 |
+
|
| 1105 |
+
if not prompt_config:
|
| 1106 |
+
return {"error": "Option highlight prompt not configured"}
|
| 1107 |
+
|
| 1108 |
+
prompt_template = prompt_config.get("prompt", "")
|
| 1109 |
+
output_format_name = prompt_config.get("output_format", "option_highlight")
|
| 1110 |
+
output_format = self.model_manager.get_model_class_by_name(output_format_name)
|
| 1111 |
+
|
| 1112 |
+
# Build the prompt with clear delimiters to mitigate prompt injection.
|
| 1113 |
+
# The user content is wrapped in XML-style tags so the LLM can
|
| 1114 |
+
# distinguish between instructions and untrusted data.
|
| 1115 |
+
delimited_text = (
|
| 1116 |
+
f"<user_content>\n{text}\n</user_content>"
|
| 1117 |
+
)
|
| 1118 |
+
template = Template(prompt_template)
|
| 1119 |
+
prompt = template.safe_substitute(
|
| 1120 |
+
text=delimited_text,
|
| 1121 |
+
description=description,
|
| 1122 |
+
labels=", ".join(label_names),
|
| 1123 |
+
top_k=top_k
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
# Query the AI endpoint
|
| 1127 |
+
try:
|
| 1128 |
+
result = self.ai_endpoint.query(prompt, output_format)
|
| 1129 |
+
logger.debug(f"Option highlight raw result: {result}")
|
| 1130 |
+
|
| 1131 |
+
# Parse the result
|
| 1132 |
+
if isinstance(result, str):
|
| 1133 |
+
import json as json_module
|
| 1134 |
+
try:
|
| 1135 |
+
# Try to parse JSON from the response
|
| 1136 |
+
result = json_module.loads(result)
|
| 1137 |
+
except json_module.JSONDecodeError:
|
| 1138 |
+
# Try to extract JSON from markdown code block
|
| 1139 |
+
if "```json" in result:
|
| 1140 |
+
json_start = result.find("```json") + 7
|
| 1141 |
+
json_end = result.find("```", json_start)
|
| 1142 |
+
result = json_module.loads(result[json_start:json_end].strip())
|
| 1143 |
+
elif "```" in result:
|
| 1144 |
+
json_start = result.find("```") + 3
|
| 1145 |
+
json_end = result.find("```", json_start)
|
| 1146 |
+
result = json_module.loads(result[json_start:json_end].strip())
|
| 1147 |
+
else:
|
| 1148 |
+
return {"error": f"Could not parse response: {result[:100]}"}
|
| 1149 |
+
|
| 1150 |
+
highlighted = result.get("highlighted_options", [])
|
| 1151 |
+
confidence = result.get("confidence", None)
|
| 1152 |
+
|
| 1153 |
+
# Validate highlighted options against available labels
|
| 1154 |
+
valid_highlighted = [opt for opt in highlighted if opt in label_names]
|
| 1155 |
+
|
| 1156 |
+
return {
|
| 1157 |
+
"highlighted": valid_highlighted[:top_k],
|
| 1158 |
+
"top_k": top_k,
|
| 1159 |
+
"confidence": confidence
|
| 1160 |
+
}
|
| 1161 |
+
|
| 1162 |
+
except Exception as e:
|
| 1163 |
+
logger.error(f"Error generating option highlights: {e}")
|
| 1164 |
+
return {"error": str(e)}
|
| 1165 |
+
|
| 1166 |
+
def get_option_highlights(self, instance_id: int, annotation_id: int) -> Dict:
|
| 1167 |
+
"""Get option highlights from cache or generate them.
|
| 1168 |
+
|
| 1169 |
+
Args:
|
| 1170 |
+
instance_id: The instance/item index
|
| 1171 |
+
annotation_id: The annotation scheme index
|
| 1172 |
+
|
| 1173 |
+
Returns:
|
| 1174 |
+
Dict with highlighted options
|
| 1175 |
+
"""
|
| 1176 |
+
key = (instance_id, annotation_id, "option_highlight")
|
| 1177 |
+
|
| 1178 |
+
# Try cache first
|
| 1179 |
+
if self.disk_cache_enabled:
|
| 1180 |
+
cached = self.get_from_cache(key)
|
| 1181 |
+
if cached is not None:
|
| 1182 |
+
logger.debug(f"Option highlight cache hit for {key}")
|
| 1183 |
+
return cached
|
| 1184 |
+
|
| 1185 |
+
# Generate
|
| 1186 |
+
result = self.generate_option_highlights(instance_id, annotation_id)
|
| 1187 |
+
|
| 1188 |
+
# Cache if successful
|
| 1189 |
+
if "error" not in result and self.disk_cache_enabled:
|
| 1190 |
+
self.add_to_cache(key, result)
|
| 1191 |
+
|
| 1192 |
+
return result
|
| 1193 |
+
|
| 1194 |
+
def start_option_highlight_prefetch(self, page_id: int, prefetch_amount: int = None):
|
| 1195 |
+
"""Prefetch option highlights for upcoming items.
|
| 1196 |
+
|
| 1197 |
+
Args:
|
| 1198 |
+
page_id: Current page/instance index
|
| 1199 |
+
prefetch_amount: Number of items to prefetch (uses config default if None)
|
| 1200 |
+
"""
|
| 1201 |
+
if not self.option_highlighting_enabled or not self.disk_cache_enabled:
|
| 1202 |
+
return
|
| 1203 |
+
|
| 1204 |
+
if prefetch_amount is None:
|
| 1205 |
+
prefetch_amount = self.option_highlighting_prefetch_count
|
| 1206 |
+
|
| 1207 |
+
ism = get_item_state_manager()
|
| 1208 |
+
with self.lock:
|
| 1209 |
+
# Calculate range
|
| 1210 |
+
if prefetch_amount >= 0:
|
| 1211 |
+
start_idx = page_id
|
| 1212 |
+
end_idx = min(start_idx + prefetch_amount, len(ism.items()))
|
| 1213 |
+
else:
|
| 1214 |
+
start_idx = max(page_id + prefetch_amount, 0)
|
| 1215 |
+
end_idx = page_id
|
| 1216 |
+
|
| 1217 |
+
keys = []
|
| 1218 |
+
for i in range(start_idx, end_idx):
|
| 1219 |
+
for annotation_id, scheme in enumerate(config["annotation_schemes"]):
|
| 1220 |
+
if self.is_option_highlighting_enabled_for_scheme(annotation_id):
|
| 1221 |
+
key = (i, annotation_id, "option_highlight")
|
| 1222 |
+
# Check if not already cached or in progress
|
| 1223 |
+
if self.get_from_cache(key) is None and key not in self.in_progress:
|
| 1224 |
+
keys.append(key)
|
| 1225 |
+
|
| 1226 |
+
# Submit prefetch jobs
|
| 1227 |
+
for key in keys:
|
| 1228 |
+
instance_id, annotation_id, _ = key
|
| 1229 |
+
future = self.executor.submit(self.generate_option_highlights, instance_id, annotation_id)
|
| 1230 |
+
self.in_progress[key] = future
|
| 1231 |
+
|
| 1232 |
+
def callback(fut, cache_key=key):
|
| 1233 |
+
with self.lock:
|
| 1234 |
+
try:
|
| 1235 |
+
result = fut.result()
|
| 1236 |
+
if "error" not in result:
|
| 1237 |
+
self.add_to_cache(cache_key, result)
|
| 1238 |
+
except Exception as e:
|
| 1239 |
+
logger.error(f"Option highlight prefetch failed for {cache_key}: {e}")
|
| 1240 |
+
self.in_progress.pop(cache_key, None)
|
| 1241 |
+
|
| 1242 |
+
future.add_done_callback(callback)
|
| 1243 |
+
|
| 1244 |
+
if keys:
|
| 1245 |
+
logger.debug(f"Started option highlight prefetch for {len(keys)} items")
|
| 1246 |
+
|
| 1247 |
+
def get_include_all(self):
|
| 1248 |
+
return self.include_all
|
| 1249 |
+
|
| 1250 |
+
def get_special_include(self, page_number_int, annotation_id_int):
|
| 1251 |
+
logger.debug(f"get_special_include: page={page_number_int}, annotation_id={annotation_id_int}")
|
| 1252 |
+
if not self.special_includes.get(page_number_int):
|
| 1253 |
+
return None
|
| 1254 |
+
elif not self.special_includes.get(page_number_int).get(annotation_id_int):
|
| 1255 |
+
return None
|
| 1256 |
+
return self.special_includes.get(page_number_int).get(annotation_id_int)
|
| 1257 |
+
|
| 1258 |
+
def start_prefetch(self, page_id, prefetch_amount):
|
| 1259 |
+
"""Prefetches a fixed number of upcoming items to warm the cache."""
|
| 1260 |
+
if not config.get("ai_support", {}).get("enabled") or not self.disk_cache_enabled:
|
| 1261 |
+
return
|
| 1262 |
+
|
| 1263 |
+
ism = get_item_state_manager()
|
| 1264 |
+
with self.lock:
|
| 1265 |
+
# Calculate range bounds
|
| 1266 |
+
if prefetch_amount >= 0:
|
| 1267 |
+
start_idx = page_id
|
| 1268 |
+
end_idx = min(start_idx + prefetch_amount, len(ism.items()))
|
| 1269 |
+
else:
|
| 1270 |
+
start_idx = max(page_id - prefetch_amount, 0)
|
| 1271 |
+
end_idx = page_id
|
| 1272 |
+
|
| 1273 |
+
logger.debug(f"Prefetch range: start_idx={start_idx}, end_idx={end_idx}")
|
| 1274 |
+
keys = []
|
| 1275 |
+
|
| 1276 |
+
for i in range(start_idx, end_idx):
|
| 1277 |
+
# Check if this page should be included
|
| 1278 |
+
if not self.should_include_page(i):
|
| 1279 |
+
continue
|
| 1280 |
+
|
| 1281 |
+
# Process each annotation scheme for this page
|
| 1282 |
+
for annotation_id, scheme in enumerate(config["annotation_schemes"]):
|
| 1283 |
+
if not self.should_include_scheme(i, annotation_id):
|
| 1284 |
+
continue
|
| 1285 |
+
|
| 1286 |
+
annotation_type = scheme["annotation_type"]
|
| 1287 |
+
ai_prompt = get_ai_prompt()
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
if not ai_prompt[annotation_type]:
|
| 1291 |
+
raise Exception(f"{annotation_type} is not defined in ai_prompt")
|
| 1292 |
+
|
| 1293 |
+
# Generate keys for this page/scheme combination
|
| 1294 |
+
scheme_keys = self.get_keys_for_scheme(i, annotation_type, annotation_id, ai_prompt)
|
| 1295 |
+
keys.extend(scheme_keys)
|
| 1296 |
+
|
| 1297 |
+
if keys:
|
| 1298 |
+
self.prefetch(keys)
|
| 1299 |
+
|
| 1300 |
+
def should_include_page(self, page_index):
|
| 1301 |
+
"""Determine if a page should be included based on include_all and special_includes."""
|
| 1302 |
+
if self.include_all:
|
| 1303 |
+
return True
|
| 1304 |
+
return page_index in self.special_includes
|
| 1305 |
+
|
| 1306 |
+
def should_include_scheme(self, page_index, annotation_id):
|
| 1307 |
+
"""Determine if a scheme should be included for a given page."""
|
| 1308 |
+
if self.include_all:
|
| 1309 |
+
return True
|
| 1310 |
+
|
| 1311 |
+
# Check if page is in special_includes and scheme is specified
|
| 1312 |
+
if page_index in self.special_includes:
|
| 1313 |
+
page_includes = self.special_includes[page_index]
|
| 1314 |
+
# Handle both list and dict formats for page_includes
|
| 1315 |
+
if isinstance(page_includes, dict):
|
| 1316 |
+
return annotation_id in page_includes
|
| 1317 |
+
elif isinstance(page_includes, list):
|
| 1318 |
+
return annotation_id in page_includes
|
| 1319 |
+
|
| 1320 |
+
return False
|
| 1321 |
+
|
| 1322 |
+
def get_keys_for_scheme(self, page_index, annotation_type, annotation_id, ai_prompt):
|
| 1323 |
+
"""Get all keys for a specific page combination."""
|
| 1324 |
+
keys = []
|
| 1325 |
+
|
| 1326 |
+
# Check if this page/annotation has specific overrides in special_includes
|
| 1327 |
+
if (page_index in self.special_includes and
|
| 1328 |
+
isinstance(self.special_includes[page_index], dict) and
|
| 1329 |
+
annotation_id in self.special_includes[page_index]):
|
| 1330 |
+
|
| 1331 |
+
# Use special_includes (overrides include_all setting)
|
| 1332 |
+
specified_keys = self.special_includes[page_index][annotation_id]
|
| 1333 |
+
for key in specified_keys:
|
| 1334 |
+
keys.append((page_index, annotation_id, key))
|
| 1335 |
+
elif self.include_all:
|
| 1336 |
+
# No specific override, so include all available keys for this annotation type
|
| 1337 |
+
for key in ai_prompt[annotation_type]:
|
| 1338 |
+
keys.append((page_index, annotation_id, key))
|
| 1339 |
+
# If include_all is False and no special_include entry, return empty keys
|
| 1340 |
+
|
| 1341 |
+
return keys
|
| 1342 |
+
|
| 1343 |
+
def prefetch(self, keys: list):
|
| 1344 |
+
"""checks if keys are already cached and asynchronously generates missing ones"""
|
| 1345 |
+
with self.lock:
|
| 1346 |
+
for key in keys:
|
| 1347 |
+
if self.get_from_cache(key) is None and key not in self.in_progress:
|
| 1348 |
+
# i, annotation_id, annotation_type, ai_prompt
|
| 1349 |
+
instance_id, annotation_id, ai_assistant = key
|
| 1350 |
+
|
| 1351 |
+
future = self.executor.submit(self.compute_help, instance_id, annotation_id, ai_assistant)
|
| 1352 |
+
self.in_progress[key] = future
|
| 1353 |
+
def callback(fut, cache_key=key):
|
| 1354 |
+
with self.lock:
|
| 1355 |
+
try:
|
| 1356 |
+
result = fut.result()
|
| 1357 |
+
self.add_to_cache(cache_key, result)
|
| 1358 |
+
except Exception as e:
|
| 1359 |
+
logger.error(f"Prefetch failed for key {cache_key}: {e}")
|
| 1360 |
+
self.in_progress.pop(cache_key, None)
|
| 1361 |
+
|
| 1362 |
+
future.add_done_callback(callback)
|
| 1363 |
+
|
| 1364 |
+
def get_ai_help(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
|
| 1365 |
+
"""retrieves AI help either from cache, waits for in-progress, or computes on-demand."""
|
| 1366 |
+
key = (instance_id, annotation_id, ai_assistant)
|
| 1367 |
+
|
| 1368 |
+
# Check if caching is enabled for this help type
|
| 1369 |
+
if not self.disk_cache_enabled:
|
| 1370 |
+
return self.compute_help(instance_id, annotation_id, ai_assistant)
|
| 1371 |
+
|
| 1372 |
+
# Try to get from cache if caching is enabled
|
| 1373 |
+
cached_value = self.get_from_cache(key)
|
| 1374 |
+
if cached_value is not None:
|
| 1375 |
+
logger.debug(f"Cache hit for key: {key}")
|
| 1376 |
+
return cached_value
|
| 1377 |
+
|
| 1378 |
+
with self.lock:
|
| 1379 |
+
if key in self.in_progress:
|
| 1380 |
+
future = self.in_progress[key]
|
| 1381 |
+
else:
|
| 1382 |
+
future = self.executor.submit(self.compute_help, instance_id, annotation_id, ai_assistant)
|
| 1383 |
+
self.in_progress[key] = future
|
| 1384 |
+
try:
|
| 1385 |
+
result = future.result(timeout=60)
|
| 1386 |
+
# Don't cache error responses
|
| 1387 |
+
is_error_response = (
|
| 1388 |
+
isinstance(result, str) and
|
| 1389 |
+
(result.startswith("Unable to generate") or
|
| 1390 |
+
result.startswith("Error:") or
|
| 1391 |
+
"error" in result.lower()[:50])
|
| 1392 |
+
)
|
| 1393 |
+
if self.disk_cache_enabled and not is_error_response:
|
| 1394 |
+
self.add_to_cache(key, result)
|
| 1395 |
+
elif is_error_response:
|
| 1396 |
+
logger.warning(f"Not caching error response for key {key}: {result[:100]}")
|
| 1397 |
+
with self.lock:
|
| 1398 |
+
self.in_progress.pop(key, None)
|
| 1399 |
+
return result
|
| 1400 |
+
except Exception as e:
|
| 1401 |
+
logger.error(f"Error computing help for key {key}: {e}")
|
| 1402 |
+
with self.lock:
|
| 1403 |
+
self.in_progress.pop(key, None)
|
| 1404 |
+
return f"Error: {str(e)}"
|
| 1405 |
+
|
| 1406 |
+
def compute_help(self, instance_id: int, annotation_id: int, ai_assistant: str):
|
| 1407 |
+
# Validate that the assistant type is compatible with the model and input
|
| 1408 |
+
is_valid, error_message = self._validate_assistant_compatibility(
|
| 1409 |
+
instance_id, annotation_id, ai_assistant
|
| 1410 |
+
)
|
| 1411 |
+
if not is_valid:
|
| 1412 |
+
logger.warning(f"Assistant compatibility check failed: {error_message}")
|
| 1413 |
+
return {"error": error_message}
|
| 1414 |
+
|
| 1415 |
+
annotation_type_str = config["annotation_schemes"][annotation_id]["annotation_type"]
|
| 1416 |
+
annotation_type = Annotation_Type(annotation_type_str)
|
| 1417 |
+
if annotation_type == Annotation_Type.LIKERT:
|
| 1418 |
+
return self.generate_likert(instance_id, annotation_id, ai_assistant)
|
| 1419 |
+
elif annotation_type == Annotation_Type.RADIO:
|
| 1420 |
+
return self.generate_radio(instance_id, annotation_id, ai_assistant)
|
| 1421 |
+
elif annotation_type == Annotation_Type.MULTISELECT:
|
| 1422 |
+
return self.generate_multiselect(instance_id, annotation_id, ai_assistant)
|
| 1423 |
+
elif annotation_type == Annotation_Type.NUMBER:
|
| 1424 |
+
return self.generate_number(instance_id, annotation_id, ai_assistant)
|
| 1425 |
+
elif annotation_type == Annotation_Type.SELECT:
|
| 1426 |
+
return self.generate_select(instance_id, annotation_id, ai_assistant)
|
| 1427 |
+
elif annotation_type == Annotation_Type.SLIDER:
|
| 1428 |
+
return self.generate_slider(instance_id, annotation_id, ai_assistant)
|
| 1429 |
+
elif annotation_type == Annotation_Type.SPAN:
|
| 1430 |
+
return self.generate_span(instance_id, annotation_id, ai_assistant)
|
| 1431 |
+
elif annotation_type == Annotation_Type.TEXTBOX:
|
| 1432 |
+
return self.generate_textbox(instance_id, annotation_id, ai_assistant)
|
| 1433 |
+
elif annotation_type == Annotation_Type.IMAGE_ANNOTATION:
|
| 1434 |
+
return self.generate_image_annotation(instance_id, annotation_id, ai_assistant)
|
| 1435 |
+
elif annotation_type == Annotation_Type.VIDEO_ANNOTATION:
|
| 1436 |
+
return self.generate_video_annotation(instance_id, annotation_id, ai_assistant)
|
| 1437 |
+
else:
|
| 1438 |
+
raise ValueError(f"Unknown annotation type: {annotation_type}")
|
| 1439 |
+
|
| 1440 |
+
def get_cache_stats(self) -> Dict[str, int]:
|
| 1441 |
+
"""returns statistics on disk cache and in-progress cache entries."""
|
| 1442 |
+
with self.lock:
|
| 1443 |
+
disk_count = 0
|
| 1444 |
+
if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
|
| 1445 |
+
try:
|
| 1446 |
+
disk_data = self.load_disk_cache_data(self.disk_persistence_path)
|
| 1447 |
+
disk_count = len(disk_data)
|
| 1448 |
+
except:
|
| 1449 |
+
pass
|
| 1450 |
+
|
| 1451 |
+
return {
|
| 1452 |
+
'disk_cache_enabled': self.disk_cache_enabled,
|
| 1453 |
+
'cached_items_disk': disk_count,
|
| 1454 |
+
'in_progress_items': len(self.in_progress)
|
| 1455 |
+
}
|
| 1456 |
+
|
| 1457 |
+
def clear_cache(self):
|
| 1458 |
+
"""clears disk cache and cancels any ongoing generation."""
|
| 1459 |
+
with self.lock:
|
| 1460 |
+
for future in self.in_progress.values():
|
| 1461 |
+
future.cancel()
|
| 1462 |
+
self.in_progress.clear()
|
| 1463 |
+
|
| 1464 |
+
if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
|
| 1465 |
+
try:
|
| 1466 |
+
os.remove(self.disk_persistence_path)
|
| 1467 |
+
logger.info("Disk cache file removed")
|
| 1468 |
+
except Exception as e:
|
| 1469 |
+
logger.error(f"Error removing disk cache file: {e}")
|
| 1470 |
+
logger.info("Cache cleared")
|
| 1471 |
+
|
| 1472 |
+
|
| 1473 |
+
|
potato/ai/ai_endpoint.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""
|
| 3 |
+
Unified AI endpoint interface for various LLM providers.
|
| 4 |
+
|
| 5 |
+
This module provides a common interface for interacting with different LLM providers
|
| 6 |
+
including OpenAI, Anthropic, Hugging Face, Ollama, and VLLM endpoints.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from enum import Enum
|
| 11 |
+
import logging
|
| 12 |
+
from abc import ABC, abstractmethod
|
| 13 |
+
import os
|
| 14 |
+
from typing import Dict, Any, Optional, List, Type, Union
|
| 15 |
+
import json
|
| 16 |
+
from string import Template
|
| 17 |
+
|
| 18 |
+
from pydantic import BaseModel
|
| 19 |
+
|
| 20 |
+
from .ai_prompt import get_ai_prompt
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Annotation_Type(Enum):
|
| 26 |
+
RADIO = "radio"
|
| 27 |
+
LIKERT = "likert"
|
| 28 |
+
NUMBER = "number"
|
| 29 |
+
TEXTBOX = "text"
|
| 30 |
+
MULTISELECT = "multiselect"
|
| 31 |
+
SPAN = "span"
|
| 32 |
+
SELECT = "select"
|
| 33 |
+
SLIDER = "slider"
|
| 34 |
+
IMAGE_ANNOTATION = "image_annotation"
|
| 35 |
+
VIDEO_ANNOTATION = "video_annotation"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class ImageData:
|
| 40 |
+
"""Data structure for image input to visual AI endpoints."""
|
| 41 |
+
source: str # 'url' | 'base64'
|
| 42 |
+
data: str # The URL or base64-encoded image data
|
| 43 |
+
width: Optional[int] = None
|
| 44 |
+
height: Optional[int] = None
|
| 45 |
+
mime_type: Optional[str] = None # e.g., 'image/jpeg', 'image/png'
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class VisualAnnotationInput:
|
| 50 |
+
"""Input data structure for visual annotation AI assistance."""
|
| 51 |
+
ai_assistant: str # 'detection', 'classification', 'hint', 'pre_annotate', etc.
|
| 52 |
+
annotation_type: str # 'image_annotation' | 'video_annotation'
|
| 53 |
+
task_type: str # Specific task: 'detection', 'classification', 'scene_detection', etc.
|
| 54 |
+
image_data: Union[ImageData, List[ImageData]] # Single image or list of frames
|
| 55 |
+
description: str # Task description from annotation scheme
|
| 56 |
+
labels: Optional[List[str]] = None # Available labels for the task
|
| 57 |
+
video_metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # fps, duration for video
|
| 58 |
+
region: Optional[Dict[str, float]] = None # Selected region for classification (x, y, width, height)
|
| 59 |
+
confidence_threshold: float = 0.5 # Minimum confidence for detections
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class AnnotationInput:
|
| 64 |
+
ai_assistant: str
|
| 65 |
+
annotation_type: Annotation_Type
|
| 66 |
+
text: str
|
| 67 |
+
description: str
|
| 68 |
+
min_label: Optional[str] = ""
|
| 69 |
+
max_label: Optional[str] = ""
|
| 70 |
+
size: Optional[int] = -1
|
| 71 |
+
labels: Optional[List[str]] = None
|
| 72 |
+
min_value: Optional[int] = -1
|
| 73 |
+
max_value: Optional[int] = -1
|
| 74 |
+
step: Optional[int] = -1
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class ModelCapabilities:
|
| 79 |
+
"""
|
| 80 |
+
Declares what operations an AI endpoint can perform.
|
| 81 |
+
|
| 82 |
+
This dataclass is used to define the capabilities of different AI endpoints,
|
| 83 |
+
enabling the system to automatically filter AI assistant buttons and validate
|
| 84 |
+
requests based on what each model can actually do.
|
| 85 |
+
|
| 86 |
+
Attributes:
|
| 87 |
+
text_generation: Can generate text (hints, rationales, descriptions)
|
| 88 |
+
vision_input: Can process images as input
|
| 89 |
+
bounding_box_output: Can output precise coordinate detections
|
| 90 |
+
text_classification: Can classify text into categories
|
| 91 |
+
image_classification: Can classify images into categories
|
| 92 |
+
rationale_generation: Can generate explanations/rationales for labels
|
| 93 |
+
keyword_extraction: Can extract keywords from text (not applicable to images)
|
| 94 |
+
"""
|
| 95 |
+
text_generation: bool = False
|
| 96 |
+
vision_input: bool = False
|
| 97 |
+
bounding_box_output: bool = False
|
| 98 |
+
text_classification: bool = False
|
| 99 |
+
image_classification: bool = False
|
| 100 |
+
rationale_generation: bool = False
|
| 101 |
+
keyword_extraction: bool = False
|
| 102 |
+
|
| 103 |
+
def supports_assistant(self, assistant_type: str, has_image_input: bool = False) -> bool:
|
| 104 |
+
"""
|
| 105 |
+
Check if model supports a specific AI assistant type.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
assistant_type: The type of AI assistant ('hint', 'keyword', 'rationale',
|
| 109 |
+
'detection', 'pre_annotate', 'classification')
|
| 110 |
+
has_image_input: Whether the current content is an image
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
True if the model supports this assistant type for the given input type
|
| 114 |
+
"""
|
| 115 |
+
if assistant_type == "hint":
|
| 116 |
+
# Hints require text generation; for images, also need vision
|
| 117 |
+
if has_image_input:
|
| 118 |
+
return self.text_generation and self.vision_input
|
| 119 |
+
return self.text_generation
|
| 120 |
+
|
| 121 |
+
elif assistant_type == "keyword":
|
| 122 |
+
# Keywords require keyword extraction AND text input (not images)
|
| 123 |
+
# Keyword highlighting doesn't make sense for images
|
| 124 |
+
return self.keyword_extraction and not has_image_input
|
| 125 |
+
|
| 126 |
+
elif assistant_type == "rationale":
|
| 127 |
+
# Rationales require rationale generation; for images, also need vision
|
| 128 |
+
if has_image_input:
|
| 129 |
+
return self.rationale_generation and self.vision_input
|
| 130 |
+
return self.rationale_generation
|
| 131 |
+
|
| 132 |
+
elif assistant_type in ("detection", "detect", "pre_annotate"):
|
| 133 |
+
# Detection requires vision and bounding box output
|
| 134 |
+
return self.bounding_box_output and self.vision_input
|
| 135 |
+
|
| 136 |
+
elif assistant_type == "classification":
|
| 137 |
+
# Classification depends on input type
|
| 138 |
+
if has_image_input:
|
| 139 |
+
return self.image_classification and self.vision_input
|
| 140 |
+
return self.text_classification
|
| 141 |
+
|
| 142 |
+
# Unknown assistant type - default to False for safety
|
| 143 |
+
return False
|
| 144 |
+
|
| 145 |
+
def get_supported_assistants(self, has_image_input: bool = False) -> List[str]:
|
| 146 |
+
"""
|
| 147 |
+
Get list of assistant types supported for the given input type.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
has_image_input: Whether the current content is an image
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
List of supported assistant type names
|
| 154 |
+
"""
|
| 155 |
+
all_types = ["hint", "keyword", "rationale", "detection", "pre_annotate", "classification"]
|
| 156 |
+
return [t for t in all_types if self.supports_assistant(t, has_image_input)]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class AIEndpointError(Exception):
|
| 160 |
+
"""Base exception for AI endpoint errors."""
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class AIEndpointConfigError(AIEndpointError):
|
| 165 |
+
"""Exception raised for configuration errors."""
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AIEndpointRequestError(AIEndpointError):
|
| 170 |
+
"""Exception raised for request/API errors."""
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class BaseAIEndpoint(ABC):
|
| 175 |
+
"""
|
| 176 |
+
Abstract base class for AI endpoints.
|
| 177 |
+
|
| 178 |
+
All AI endpoint implementations should inherit from this class
|
| 179 |
+
and implement the required methods.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, config: Dict[str, Any]):
|
| 183 |
+
"""
|
| 184 |
+
Initialize the AI endpoint with configuration.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
config: Configuration dictionary containing endpoint-specific settings
|
| 188 |
+
"""
|
| 189 |
+
self.config = config
|
| 190 |
+
self.description = config.get("description", "")
|
| 191 |
+
self.annotation_type = config.get("annotation_type", "")
|
| 192 |
+
self.ai_config = config.get("ai_config", {})
|
| 193 |
+
|
| 194 |
+
# Model configuration
|
| 195 |
+
self.model = self.ai_config.get("model", self._get_default_model())
|
| 196 |
+
self.max_tokens = self.ai_config.get("max_tokens", 100)
|
| 197 |
+
self.temperature = self.ai_config.get("temperature", 0.1)
|
| 198 |
+
|
| 199 |
+
# prompt
|
| 200 |
+
self.prompts = get_ai_prompt()
|
| 201 |
+
|
| 202 |
+
# Initialize the client
|
| 203 |
+
self._initialize_client()
|
| 204 |
+
|
| 205 |
+
@abstractmethod
|
| 206 |
+
def _initialize_client(self) -> None:
|
| 207 |
+
"""Initialize the client for the specific AI provider."""
|
| 208 |
+
pass
|
| 209 |
+
|
| 210 |
+
@abstractmethod
|
| 211 |
+
def _get_default_model(self) -> str:
|
| 212 |
+
"""Get the default model name for this provider."""
|
| 213 |
+
pass
|
| 214 |
+
|
| 215 |
+
@abstractmethod
|
| 216 |
+
def query(self, prompt: str, output_format: Type[BaseModel]):
|
| 217 |
+
"""
|
| 218 |
+
Send a query to the AI model and return the response.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
prompt: The prompt to send to the model
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
The model's response as a string
|
| 225 |
+
|
| 226 |
+
Raises:
|
| 227 |
+
AIEndpointRequestError: If the request fails
|
| 228 |
+
"""
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
def chat_query_with_image(
|
| 232 |
+
self,
|
| 233 |
+
messages: List[Dict[str, Any]],
|
| 234 |
+
images: Optional[List["ImageData"]] = None,
|
| 235 |
+
) -> str:
|
| 236 |
+
"""
|
| 237 |
+
Send a multi-turn chat with interleaved images to the AI model.
|
| 238 |
+
|
| 239 |
+
Messages may contain content blocks (text + image) instead of plain strings.
|
| 240 |
+
Used by the live agent runner for vision-based agent loops.
|
| 241 |
+
|
| 242 |
+
Default implementation raises NotImplementedError — only vision-capable
|
| 243 |
+
endpoints should override this.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
messages: List of message dicts. 'content' may be a string or a list
|
| 247 |
+
of content blocks (e.g., {"type": "text", "text": "..."} or
|
| 248 |
+
{"type": "image", "source": {...}}).
|
| 249 |
+
images: Optional list of ImageData to include (alternative to inline images).
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
The model's response as a plain text string.
|
| 253 |
+
|
| 254 |
+
Raises:
|
| 255 |
+
NotImplementedError: If the endpoint doesn't support vision.
|
| 256 |
+
"""
|
| 257 |
+
raise NotImplementedError(
|
| 258 |
+
f"{self.__class__.__name__} does not support chat_query_with_image. "
|
| 259 |
+
f"Use a vision-capable endpoint (e.g., anthropic_vision)."
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
def chat_query(self, messages: List[Dict[str, str]]) -> str:
|
| 263 |
+
"""
|
| 264 |
+
Send a multi-turn chat conversation to the AI model.
|
| 265 |
+
|
| 266 |
+
Default implementation flattens messages into a single prompt and calls query().
|
| 267 |
+
Subclasses should override with native multi-turn support.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
messages: List of message dicts with 'role' and 'content' keys.
|
| 271 |
+
Roles: 'system', 'user', 'assistant'
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
The model's response as a plain text string.
|
| 275 |
+
"""
|
| 276 |
+
# Flatten messages into a single prompt
|
| 277 |
+
parts = []
|
| 278 |
+
for msg in messages:
|
| 279 |
+
role = msg.get("role", "user")
|
| 280 |
+
content = msg.get("content", "")
|
| 281 |
+
if role == "system":
|
| 282 |
+
parts.append(f"System: {content}")
|
| 283 |
+
elif role == "assistant":
|
| 284 |
+
parts.append(f"Assistant: {content}")
|
| 285 |
+
else:
|
| 286 |
+
parts.append(f"User: {content}")
|
| 287 |
+
prompt = "\n\n".join(parts) + "\n\nAssistant:"
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
result = self.query(prompt)
|
| 291 |
+
# query() may return parsed JSON or a string; ensure we return a string
|
| 292 |
+
if isinstance(result, dict):
|
| 293 |
+
return result.get("response", result.get("content", str(result)))
|
| 294 |
+
return str(result)
|
| 295 |
+
except Exception as e:
|
| 296 |
+
raise AIEndpointRequestError(f"Chat query failed: {e}")
|
| 297 |
+
|
| 298 |
+
def parseStringToJson(self, response_content: str) -> str:
|
| 299 |
+
"""
|
| 300 |
+
Parse structured output from any LLM response, with robust fallbacks.
|
| 301 |
+
|
| 302 |
+
Handles common issues across all endpoint types (ollama, vllm, openai, etc.):
|
| 303 |
+
1. Clean JSON responses -> direct parse
|
| 304 |
+
2. JSON wrapped in markdown code blocks (```json ... ```)
|
| 305 |
+
3. JSON embedded in surrounding prose text
|
| 306 |
+
4. Truncated JSON (max_tokens exceeded) -> extract complete key-value pairs
|
| 307 |
+
5. Plain text with no JSON structure -> return as {"response": text}
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
response_content: Raw response string from the LLM
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
Parsed dict or the raw string if JSON extraction succeeds
|
| 314 |
+
|
| 315 |
+
Raises:
|
| 316 |
+
ValueError: Only if response is completely empty
|
| 317 |
+
"""
|
| 318 |
+
import re
|
| 319 |
+
|
| 320 |
+
# Handle empty or None content
|
| 321 |
+
if not response_content:
|
| 322 |
+
raise ValueError("Empty response content received from AI endpoint")
|
| 323 |
+
|
| 324 |
+
# If it's already a dict, return it
|
| 325 |
+
if isinstance(response_content, dict):
|
| 326 |
+
return response_content
|
| 327 |
+
|
| 328 |
+
# Convert to string if needed
|
| 329 |
+
content_str = str(response_content).strip()
|
| 330 |
+
if not content_str:
|
| 331 |
+
raise ValueError("Empty response content received from AI endpoint")
|
| 332 |
+
|
| 333 |
+
# Strategy 0: Strip thinking/reasoning blocks that wrap the actual output
|
| 334 |
+
# Many models (qwen3, deepseek, etc.) produce <think>...</think> blocks
|
| 335 |
+
cleaned = content_str
|
| 336 |
+
for tag in ['think', 'thinking', 'thought', 'inner_monologue']:
|
| 337 |
+
cleaned = re.sub(
|
| 338 |
+
rf'<{tag}>[\s\S]*?</{tag}>\s*',
|
| 339 |
+
'', cleaned, flags=re.IGNORECASE
|
| 340 |
+
).strip()
|
| 341 |
+
if cleaned and cleaned != content_str:
|
| 342 |
+
content_str = cleaned
|
| 343 |
+
|
| 344 |
+
# Strategy 1: Try direct JSON parse
|
| 345 |
+
try:
|
| 346 |
+
return json.loads(content_str)
|
| 347 |
+
except json.JSONDecodeError:
|
| 348 |
+
pass
|
| 349 |
+
|
| 350 |
+
# Strategy 2: Extract from markdown code blocks
|
| 351 |
+
for pattern in [
|
| 352 |
+
r'```json\s*([\s\S]*?)\s*```',
|
| 353 |
+
r'```\s*([\s\S]*?)\s*```',
|
| 354 |
+
]:
|
| 355 |
+
match = re.search(pattern, content_str)
|
| 356 |
+
if match:
|
| 357 |
+
try:
|
| 358 |
+
return json.loads(match.group(1).strip())
|
| 359 |
+
except json.JSONDecodeError:
|
| 360 |
+
pass
|
| 361 |
+
|
| 362 |
+
# Strategy 3: Extract from tool call format
|
| 363 |
+
# Some models return: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
| 364 |
+
# or function_call blocks
|
| 365 |
+
for pattern in [
|
| 366 |
+
r'<tool_call>\s*([\s\S]*?)\s*</tool_call>',
|
| 367 |
+
r'<function_call>\s*([\s\S]*?)\s*</function_call>',
|
| 368 |
+
r'<output>\s*([\s\S]*?)\s*</output>',
|
| 369 |
+
r'<result>\s*([\s\S]*?)\s*</result>',
|
| 370 |
+
r'<answer>\s*([\s\S]*?)\s*</answer>',
|
| 371 |
+
]:
|
| 372 |
+
match = re.search(pattern, content_str, re.IGNORECASE)
|
| 373 |
+
if match:
|
| 374 |
+
try:
|
| 375 |
+
parsed = json.loads(match.group(1).strip())
|
| 376 |
+
# If it's a tool call wrapper, extract the arguments
|
| 377 |
+
if isinstance(parsed, dict) and 'arguments' in parsed:
|
| 378 |
+
return parsed['arguments']
|
| 379 |
+
return parsed
|
| 380 |
+
except json.JSONDecodeError:
|
| 381 |
+
pass
|
| 382 |
+
|
| 383 |
+
# Strategy 4: Find a JSON object anywhere in the text
|
| 384 |
+
# Greedy match for the outermost { ... }
|
| 385 |
+
match = re.search(r'\{[\s\S]*\}', content_str)
|
| 386 |
+
if match:
|
| 387 |
+
try:
|
| 388 |
+
return json.loads(match.group(0))
|
| 389 |
+
except json.JSONDecodeError:
|
| 390 |
+
pass
|
| 391 |
+
|
| 392 |
+
# Strategy 4: Salvage truncated JSON
|
| 393 |
+
# Extract complete "key": "value" and "key": number pairs
|
| 394 |
+
salvaged = self._salvage_key_value_pairs(content_str)
|
| 395 |
+
if salvaged:
|
| 396 |
+
logger.warning(
|
| 397 |
+
f"Salvaged {len(salvaged)} fields from truncated/malformed response"
|
| 398 |
+
)
|
| 399 |
+
return salvaged
|
| 400 |
+
|
| 401 |
+
# Strategy 6: Parse XML-style output
|
| 402 |
+
# Some models (especially larger ones) produce XML like:
|
| 403 |
+
# <label>joy</label><confidence>90</confidence>
|
| 404 |
+
# or <response><label>joy</label></response>
|
| 405 |
+
xml_result = self._parse_xml_to_dict(content_str)
|
| 406 |
+
if xml_result:
|
| 407 |
+
logger.info(
|
| 408 |
+
f"Parsed {len(xml_result)} fields from XML-style response"
|
| 409 |
+
)
|
| 410 |
+
return xml_result
|
| 411 |
+
|
| 412 |
+
# Strategy 7: Return raw text wrapped in a dict
|
| 413 |
+
logger.warning(
|
| 414 |
+
f"Could not parse JSON or XML from response ({len(content_str)} chars), "
|
| 415 |
+
f"returning as raw text"
|
| 416 |
+
)
|
| 417 |
+
return {"response": content_str}
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def _salvage_key_value_pairs(text: str) -> Optional[dict]:
|
| 421 |
+
"""Extract key-value pairs from truncated or malformed JSON.
|
| 422 |
+
|
| 423 |
+
Handles cases where max_tokens cuts off a response mid-field, e.g.:
|
| 424 |
+
{"label": "joy", "confidence": 90, "reasoning": "The text expres...
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Dict of extracted key-value pairs, or None if nothing found.
|
| 428 |
+
"""
|
| 429 |
+
import re
|
| 430 |
+
result = {}
|
| 431 |
+
|
| 432 |
+
# Extract "key": "value" pairs (string values)
|
| 433 |
+
for match in re.finditer(r'"(\w+)"\s*:\s*"([^"]*)"', text):
|
| 434 |
+
result[match.group(1)] = match.group(2)
|
| 435 |
+
|
| 436 |
+
# Extract "key": number pairs
|
| 437 |
+
for match in re.finditer(r'"(\w+)"\s*:\s*(-?\d+(?:\.\d+)?)\b', text):
|
| 438 |
+
key = match.group(1)
|
| 439 |
+
if key not in result:
|
| 440 |
+
try:
|
| 441 |
+
val = float(match.group(2))
|
| 442 |
+
result[key] = int(val) if val == int(val) else val
|
| 443 |
+
except ValueError:
|
| 444 |
+
pass
|
| 445 |
+
|
| 446 |
+
# Extract "key": true/false/null
|
| 447 |
+
for match in re.finditer(r'"(\w+)"\s*:\s*(true|false|null)\b', text):
|
| 448 |
+
key = match.group(1)
|
| 449 |
+
if key not in result:
|
| 450 |
+
val_str = match.group(2)
|
| 451 |
+
result[key] = (
|
| 452 |
+
True if val_str == 'true'
|
| 453 |
+
else False if val_str == 'false'
|
| 454 |
+
else None
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
return result if result else None
|
| 458 |
+
|
| 459 |
+
@staticmethod
|
| 460 |
+
def _parse_xml_to_dict(text: str) -> Optional[dict]:
|
| 461 |
+
"""Extract key-value pairs from XML-style LLM output.
|
| 462 |
+
|
| 463 |
+
Handles patterns like:
|
| 464 |
+
- <label>joy</label><confidence>90</confidence>
|
| 465 |
+
- <response><label>joy</label></response>
|
| 466 |
+
- Mixed XML with text: "The emotion is <label>joy</label>"
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
Dict of tag->content pairs, or None if no XML tags found.
|
| 470 |
+
"""
|
| 471 |
+
import re
|
| 472 |
+
result = {}
|
| 473 |
+
|
| 474 |
+
# Find all <tag>content</tag> pairs (non-nested simple tags)
|
| 475 |
+
for match in re.finditer(
|
| 476 |
+
r'<(\w+)>([^<]*)</\1>', text, re.IGNORECASE
|
| 477 |
+
):
|
| 478 |
+
tag = match.group(1).lower()
|
| 479 |
+
value = match.group(2).strip()
|
| 480 |
+
|
| 481 |
+
# Skip wrapper tags that contain other tags
|
| 482 |
+
if tag in ('response', 'output', 'result', 'answer', 'root'):
|
| 483 |
+
continue
|
| 484 |
+
|
| 485 |
+
# Try to parse numeric values
|
| 486 |
+
try:
|
| 487 |
+
if '.' in value:
|
| 488 |
+
result[tag] = float(value)
|
| 489 |
+
else:
|
| 490 |
+
result[tag] = int(value)
|
| 491 |
+
except ValueError:
|
| 492 |
+
# Boolean
|
| 493 |
+
if value.lower() in ('true', 'false'):
|
| 494 |
+
result[tag] = value.lower() == 'true'
|
| 495 |
+
else:
|
| 496 |
+
result[tag] = value
|
| 497 |
+
|
| 498 |
+
return result if result else None
|
| 499 |
+
|
| 500 |
+
def get_ai(self, data: AnnotationInput, output_format) -> str:
|
| 501 |
+
"""
|
| 502 |
+
Get a hint for annotating the given text.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
text: The text to get a hint for
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
A helpful hint for annotation
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
try:
|
| 512 |
+
# Check if annotation type exists (comparing string against enum values)
|
| 513 |
+
valid_types = [e.value for e in Annotation_Type]
|
| 514 |
+
if data.annotation_type not in valid_types:
|
| 515 |
+
logger.warning(f"Annotation type '{data.annotation_type}' not found")
|
| 516 |
+
return "Unable to generate suggestion - annotation type not configured"
|
| 517 |
+
|
| 518 |
+
# Check if ai_assistant exists
|
| 519 |
+
ai_prompt = get_ai_prompt()
|
| 520 |
+
if data.ai_assistant not in ai_prompt[data.annotation_type]:
|
| 521 |
+
logger.warning(f"'ai_assistant' not found for {data.annotation_type}")
|
| 522 |
+
return "Unable to generate suggestion - prompt not configured"
|
| 523 |
+
|
| 524 |
+
template_str = self.prompts.get(data.annotation_type).get(data.ai_assistant).get("prompt")
|
| 525 |
+
template = Template(template_str)
|
| 526 |
+
prompt = template.substitute(
|
| 527 |
+
text=data.text,
|
| 528 |
+
description=data.description,
|
| 529 |
+
min_label=data.min_label,
|
| 530 |
+
max_label=data.max_label,
|
| 531 |
+
size=data.size,
|
| 532 |
+
labels=data.labels,
|
| 533 |
+
min_value=data.min_value,
|
| 534 |
+
max_value=data.max_value,
|
| 535 |
+
step=data.step
|
| 536 |
+
)
|
| 537 |
+
return self.query(prompt, output_format)
|
| 538 |
+
except Exception as e:
|
| 539 |
+
logger.error(f"[get_ai] AnnotationInput: {data}")
|
| 540 |
+
logger.error(f"[get_ai] Error for {data.annotation_type}/{data.ai_assistant}: {type(e).__name__}: {e}")
|
| 541 |
+
import traceback
|
| 542 |
+
logger.error(f"[get_ai] Traceback:\n{traceback.format_exc()}")
|
| 543 |
+
return "Unable to generate hint at this time."
|
| 544 |
+
|
| 545 |
+
def health_check(self) -> bool:
|
| 546 |
+
"""
|
| 547 |
+
Check if the AI endpoint is healthy and accessible.
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
True if the endpoint is healthy, False otherwise
|
| 551 |
+
"""
|
| 552 |
+
try:
|
| 553 |
+
# Simple test query
|
| 554 |
+
test_response = self.query("Hello")
|
| 555 |
+
return bool(test_response and test_response.strip())
|
| 556 |
+
except Exception as e:
|
| 557 |
+
logger.error(f"Health check failed: {e}")
|
| 558 |
+
return False
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class AIEndpointFactory:
|
| 562 |
+
"""
|
| 563 |
+
Factory class for creating AI endpoint instances.
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
_endpoints = { }
|
| 567 |
+
|
| 568 |
+
@classmethod
|
| 569 |
+
def register_endpoint(cls, endpoint_type: str, endpoint_class: type):
|
| 570 |
+
"""Register a new endpoint type."""
|
| 571 |
+
cls._endpoints[endpoint_type] = endpoint_class
|
| 572 |
+
|
| 573 |
+
@classmethod
|
| 574 |
+
def create_endpoint(cls, config: Dict[str, Any]) -> Optional[BaseAIEndpoint]:
|
| 575 |
+
"""
|
| 576 |
+
Create an AI endpoint instance based on configuration.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
config: Configuration dictionary containing ai_support settings
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
An AI endpoint instance or None if AI support is disabled
|
| 583 |
+
|
| 584 |
+
Raises:
|
| 585 |
+
AIEndpointConfigError: If the configuration is invalid
|
| 586 |
+
"""
|
| 587 |
+
if not config.get("ai_support", {}).get("enabled", False):
|
| 588 |
+
return None
|
| 589 |
+
|
| 590 |
+
ai_support = config["ai_support"]
|
| 591 |
+
endpoint_type = ai_support.get("endpoint_type")
|
| 592 |
+
|
| 593 |
+
if not endpoint_type:
|
| 594 |
+
raise AIEndpointConfigError("endpoint_type is required when ai_support is enabled")
|
| 595 |
+
|
| 596 |
+
if endpoint_type not in cls._endpoints:
|
| 597 |
+
raise AIEndpointConfigError(f"Unknown endpoint type: {endpoint_type}")
|
| 598 |
+
|
| 599 |
+
# Prepare endpoint configuration
|
| 600 |
+
endpoint_config = {
|
| 601 |
+
"ai_config": ai_support.get("ai_config", {})
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
try:
|
| 605 |
+
endpoint_class = cls._endpoints[endpoint_type]
|
| 606 |
+
return endpoint_class(endpoint_config)
|
| 607 |
+
except Exception as e:
|
| 608 |
+
raise AIEndpointConfigError(f"Failed to create {endpoint_type} endpoint: {e}")
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# Legacy function for backward compatibility
|
| 612 |
+
def get_ai_endpoint(config: dict):
|
| 613 |
+
"""
|
| 614 |
+
Get an AI endpoint instance (legacy function).
|
| 615 |
+
|
| 616 |
+
This function is maintained for backward compatibility.
|
| 617 |
+
New code should use AIEndpointFactory.create_endpoint().
|
| 618 |
+
"""
|
| 619 |
+
return AIEndpointFactory.create_endpoint(config)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
# Register built-in endpoints
|
| 623 |
+
try:
|
| 624 |
+
from .ollama_endpoint import OllamaEndpoint
|
| 625 |
+
AIEndpointFactory.register_endpoint("ollama", OllamaEndpoint)
|
| 626 |
+
except ImportError:
|
| 627 |
+
logger.debug("Ollama endpoint not available")
|
| 628 |
+
|
| 629 |
+
try:
|
| 630 |
+
from .openai_endpoint import OpenAIEndpoint
|
| 631 |
+
AIEndpointFactory.register_endpoint("openai", OpenAIEndpoint)
|
| 632 |
+
except ImportError:
|
| 633 |
+
logger.debug("OpenAI endpoint not available")
|
| 634 |
+
|
| 635 |
+
try:
|
| 636 |
+
from .huggingface_endpoint import HuggingfaceEndpoint
|
| 637 |
+
AIEndpointFactory.register_endpoint("huggingface", HuggingfaceEndpoint)
|
| 638 |
+
except ImportError:
|
| 639 |
+
logger.debug("Hugging Face endpoint not available")
|
| 640 |
+
|
| 641 |
+
try:
|
| 642 |
+
from .gemini_endpoint import GeminiEndpoint
|
| 643 |
+
AIEndpointFactory.register_endpoint("gemini", GeminiEndpoint)
|
| 644 |
+
except ImportError:
|
| 645 |
+
logger.debug("Gemini endpoint not available")
|
| 646 |
+
|
| 647 |
+
try:
|
| 648 |
+
from .anthropic_endpoint import AnthropicEndpoint
|
| 649 |
+
AIEndpointFactory.register_endpoint("anthropic", AnthropicEndpoint)
|
| 650 |
+
except ImportError:
|
| 651 |
+
logger.debug("Anthropic endpoint not available")
|
| 652 |
+
|
| 653 |
+
try:
|
| 654 |
+
from .vllm_endpoint import VLLMEndpoint
|
| 655 |
+
AIEndpointFactory.register_endpoint("vllm", VLLMEndpoint)
|
| 656 |
+
except ImportError:
|
| 657 |
+
logger.debug("VLLM endpoint not available")
|
| 658 |
+
|
| 659 |
+
# Register visual AI endpoints
|
| 660 |
+
try:
|
| 661 |
+
from .yolo_endpoint import YOLOEndpoint
|
| 662 |
+
AIEndpointFactory.register_endpoint("yolo", YOLOEndpoint)
|
| 663 |
+
except ImportError:
|
| 664 |
+
logger.debug("YOLO endpoint not available (ultralytics not installed)")
|
| 665 |
+
|
| 666 |
+
try:
|
| 667 |
+
from .ollama_vision_endpoint import OllamaVisionEndpoint
|
| 668 |
+
AIEndpointFactory.register_endpoint("ollama_vision", OllamaVisionEndpoint)
|
| 669 |
+
except ImportError:
|
| 670 |
+
logger.debug("Ollama Vision endpoint not available")
|
| 671 |
+
|
| 672 |
+
try:
|
| 673 |
+
from .openai_vision_endpoint import OpenAIVisionEndpoint
|
| 674 |
+
AIEndpointFactory.register_endpoint("openai_vision", OpenAIVisionEndpoint)
|
| 675 |
+
except ImportError:
|
| 676 |
+
logger.debug("OpenAI Vision endpoint not available")
|
| 677 |
+
|
| 678 |
+
try:
|
| 679 |
+
from .anthropic_vision_endpoint import AnthropicVisionEndpoint
|
| 680 |
+
AIEndpointFactory.register_endpoint("anthropic_vision", AnthropicVisionEndpoint)
|
| 681 |
+
except ImportError:
|
| 682 |
+
logger.debug("Anthropic Vision endpoint not available")
|
| 683 |
+
|
| 684 |
+
try:
|
| 685 |
+
from .openrouter_endpoint import OpenRouterEndpoint
|
| 686 |
+
AIEndpointFactory.register_endpoint("openrouter", OpenRouterEndpoint)
|
| 687 |
+
except ImportError:
|
| 688 |
+
logger.debug("OpenRouter endpoint not available")
|
potato/ai/ai_help_wrapper.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import render_template_string
|
| 2 |
+
from typing import Optional, Dict, Any, List
|
| 3 |
+
from potato.ai.ai_cache import get_ai_cache_manager, _is_image_url, _get_instance_text
|
| 4 |
+
from potato.ai.ai_prompt import get_ai_prompt
|
| 5 |
+
from potato.server_utils.config_module import config
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
# Global instance
|
| 11 |
+
DYNAMICAIHELP = None
|
| 12 |
+
|
| 13 |
+
def init_dynamic_ai_help():
|
| 14 |
+
import logging
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
logger.info(f"[init_dynamic_ai_help] Called. ai_support.enabled={config.get('ai_support', {}).get('enabled', False)}")
|
| 17 |
+
|
| 18 |
+
if not config["ai_support"]["enabled"]:
|
| 19 |
+
logger.info("[init_dynamic_ai_help] AI support disabled, returning")
|
| 20 |
+
return
|
| 21 |
+
global DYNAMICAIHELP
|
| 22 |
+
if DYNAMICAIHELP is None:
|
| 23 |
+
DYNAMICAIHELP = DynamicAIHelp()
|
| 24 |
+
logger.info(f"[init_dynamic_ai_help] Created DYNAMICAIHELP instance: {id(DYNAMICAIHELP)}")
|
| 25 |
+
else:
|
| 26 |
+
logger.info(f"[init_dynamic_ai_help] DYNAMICAIHELP already exists: {id(DYNAMICAIHELP)}")
|
| 27 |
+
|
| 28 |
+
return DYNAMICAIHELP
|
| 29 |
+
|
| 30 |
+
def get_dynamic_ai_help():
|
| 31 |
+
global DYNAMICAIHELP
|
| 32 |
+
return DYNAMICAIHELP
|
| 33 |
+
|
| 34 |
+
class DynamicAIHelp:
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self.template = """
|
| 37 |
+
{% if ai_assistant %}
|
| 38 |
+
{{ ai_assistant | safe }}
|
| 39 |
+
{% elif error_message %}
|
| 40 |
+
<span class="error">{{ error_message }}</span>
|
| 41 |
+
{% endif %}
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def get_empty_wrapper(self):
|
| 45 |
+
return f'<div class="ai-help none"><div class="tooltip"></div></div>'
|
| 46 |
+
|
| 47 |
+
def generate_ai_assistant(self, ai_prompts, annotation_type, ai_assistant):
|
| 48 |
+
str_html = f'<div class="{ai_assistant} ai-assistant-containter">'
|
| 49 |
+
img_url = ai_prompts[annotation_type].get(ai_assistant).get("img")
|
| 50 |
+
if img_url:
|
| 51 |
+
# Use empty alt since the button already has a text label
|
| 52 |
+
str_html += f'<span class="ai-assistant-img"><img src="{img_url}" alt=""></span>'
|
| 53 |
+
name = ai_prompts[annotation_type].get(ai_assistant).get("name", ai_assistant.capitalize())
|
| 54 |
+
str_html += f'<span>{name}</span>'
|
| 55 |
+
str_html += "</div>"
|
| 56 |
+
return str_html
|
| 57 |
+
|
| 58 |
+
def _filter_assistants_by_capability(
|
| 59 |
+
self, ai_cache_manager, assistant_keys: List[str], is_image_content: bool
|
| 60 |
+
) -> List[str]:
|
| 61 |
+
"""
|
| 62 |
+
Filter assistant types based on model capabilities.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
ai_cache_manager: The AI cache manager instance
|
| 66 |
+
assistant_keys: List of assistant type keys to filter
|
| 67 |
+
is_image_content: Whether the current content is an image
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Filtered list of assistant keys that the model supports
|
| 71 |
+
"""
|
| 72 |
+
# Get capabilities from the cache manager
|
| 73 |
+
capabilities = ai_cache_manager.get_endpoint_capabilities(for_image=is_image_content)
|
| 74 |
+
|
| 75 |
+
filtered_keys = []
|
| 76 |
+
for key in assistant_keys:
|
| 77 |
+
if capabilities.supports_assistant(key, is_image_content):
|
| 78 |
+
filtered_keys.append(key)
|
| 79 |
+
else:
|
| 80 |
+
logger.debug(
|
| 81 |
+
f"[get_ai_help_data] Skipping '{key}' button - "
|
| 82 |
+
f"not supported for {'image' if is_image_content else 'text'} content"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return filtered_keys
|
| 86 |
+
|
| 87 |
+
def get_ai_help_data(self, instance: int, annotation_id: int, annotation_type: str) -> Dict[str, Any]:
|
| 88 |
+
"""Get current AI help configuration with the new prompt structure"""
|
| 89 |
+
try:
|
| 90 |
+
context = {
|
| 91 |
+
'ai_assistant': None,
|
| 92 |
+
'error_message': None,
|
| 93 |
+
}
|
| 94 |
+
ai_prompts = get_ai_prompt()
|
| 95 |
+
logger.debug(f"[get_ai_help_data] ai_prompts keys: {list(ai_prompts.keys()) if ai_prompts else 'None'}")
|
| 96 |
+
|
| 97 |
+
if not ai_prompts:
|
| 98 |
+
context["error_message"] = f'No AI prompt configured'
|
| 99 |
+
logger.debug("[get_ai_help_data] No AI prompts configured")
|
| 100 |
+
return context
|
| 101 |
+
elif annotation_type not in ai_prompts:
|
| 102 |
+
context["error_message"] = f'annotation type {annotation_type} does not exist in ai_prompts'
|
| 103 |
+
logger.debug(f"[get_ai_help_data] annotation type {annotation_type} not in prompts")
|
| 104 |
+
return context
|
| 105 |
+
|
| 106 |
+
ai_cache_manager = get_ai_cache_manager()
|
| 107 |
+
logger.debug(f"[get_ai_help_data] ai_cache_manager: {ai_cache_manager is not None}")
|
| 108 |
+
|
| 109 |
+
if ai_cache_manager is None:
|
| 110 |
+
context["error_message"] = "AI cache manager not initialized"
|
| 111 |
+
logger.debug("[get_ai_help_data] AI cache manager is None")
|
| 112 |
+
return context
|
| 113 |
+
|
| 114 |
+
ai_assistant_html_parts = []
|
| 115 |
+
|
| 116 |
+
# Determine if content is an image for capability-based filtering
|
| 117 |
+
is_image_content = False
|
| 118 |
+
try:
|
| 119 |
+
text = _get_instance_text(instance)
|
| 120 |
+
is_image_content = _is_image_url(text)
|
| 121 |
+
if is_image_content:
|
| 122 |
+
logger.debug(f"[get_ai_help_data] Content is an image URL")
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.debug(f"[get_ai_help_data] Could not determine if content is image: {e}")
|
| 125 |
+
|
| 126 |
+
# Check if user specified specific assistant types
|
| 127 |
+
special_include_types = ai_cache_manager.get_special_include(instance, annotation_id)
|
| 128 |
+
logger.debug(f"[get_ai_help_data] special_include_types: {special_include_types}")
|
| 129 |
+
|
| 130 |
+
if special_include_types:
|
| 131 |
+
# Generate HTML for specific included keys
|
| 132 |
+
logger.debug(f"[get_ai_help_data] Using special include types: {special_include_types}")
|
| 133 |
+
# Filter by capability
|
| 134 |
+
valid_keys = [k for k in special_include_types if k in ai_prompts[annotation_type]]
|
| 135 |
+
filtered_keys = self._filter_assistants_by_capability(
|
| 136 |
+
ai_cache_manager, valid_keys, is_image_content
|
| 137 |
+
)
|
| 138 |
+
for key in filtered_keys:
|
| 139 |
+
ai_assistant_html_parts.append(self.generate_ai_assistant(ai_prompts, annotation_type, key))
|
| 140 |
+
|
| 141 |
+
elif ai_cache_manager.get_include_all():
|
| 142 |
+
# Generate HTML for all keys in the annotation type
|
| 143 |
+
all_keys = list(ai_prompts[annotation_type].keys())
|
| 144 |
+
logger.debug(f"[get_ai_help_data] include_all=True, available keys: {all_keys}")
|
| 145 |
+
# Filter by capability
|
| 146 |
+
filtered_keys = self._filter_assistants_by_capability(
|
| 147 |
+
ai_cache_manager, all_keys, is_image_content
|
| 148 |
+
)
|
| 149 |
+
logger.debug(f"[get_ai_help_data] After capability filter: {filtered_keys}")
|
| 150 |
+
for key in filtered_keys:
|
| 151 |
+
ai_assistant_html_parts.append(self.generate_ai_assistant(ai_prompts, annotation_type, key))
|
| 152 |
+
else:
|
| 153 |
+
logger.debug("[get_ai_help_data] No special includes and include_all=False")
|
| 154 |
+
|
| 155 |
+
# Combine all HTML parts
|
| 156 |
+
ai_assistant_html = '<span>|</span>'.join(ai_assistant_html_parts) if ai_assistant_html_parts else None
|
| 157 |
+
logger.debug(f"[get_ai_help_data] ai_assistant_html_parts count: {len(ai_assistant_html_parts)}")
|
| 158 |
+
|
| 159 |
+
if ai_assistant_html:
|
| 160 |
+
context['ai_assistant'] = ai_assistant_html
|
| 161 |
+
|
| 162 |
+
logger.debug(f"[get_ai_help_data] Final context: ai_assistant={'set' if context['ai_assistant'] else 'None'}, error={'set' if context['error_message'] else 'None'}")
|
| 163 |
+
return context
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"[get_ai_help_data] Exception: {e}", exc_info=True)
|
| 166 |
+
return {
|
| 167 |
+
'ai_assistant': None,
|
| 168 |
+
'error_message': f'Error loading AI help: {str(e)}',
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
def render(self, instance: int, annotation_id: int, annotation_type) -> str:
|
| 172 |
+
"""Render AI help HTML with current data"""
|
| 173 |
+
context = self.get_ai_help_data(instance, annotation_id, annotation_type)
|
| 174 |
+
context.update({
|
| 175 |
+
'instance': instance,
|
| 176 |
+
'annotation_id': annotation_id
|
| 177 |
+
})
|
| 178 |
+
return render_template_string(self.template, **context)
|
| 179 |
+
|
| 180 |
+
def generate_ai_help_html(instance: int, annotation_id: int, annotation_type: str) -> Optional[str]:
|
| 181 |
+
"""
|
| 182 |
+
Generates dynamic AI help HTML using template rendering.
|
| 183 |
+
Now works with the new prompt structure: {annotation_type: {prompt: ..., outputformat: ...}}
|
| 184 |
+
"""
|
| 185 |
+
import logging
|
| 186 |
+
logger = logging.getLogger(__name__)
|
| 187 |
+
|
| 188 |
+
if DYNAMICAIHELP is None:
|
| 189 |
+
logger.debug("[generate_ai_help_html] DYNAMICAIHELP is None - AI support not enabled")
|
| 190 |
+
return "" # AI support not enabled
|
| 191 |
+
|
| 192 |
+
result = DYNAMICAIHELP.render(instance, annotation_id, annotation_type)
|
| 193 |
+
logger.debug(f"[generate_ai_help_html] Rendered result: '{result[:100] if result else 'empty'}...'")
|
| 194 |
+
return result
|
| 195 |
+
|
| 196 |
+
def get_ai_wrapper():
|
| 197 |
+
import logging
|
| 198 |
+
logger = logging.getLogger(__name__)
|
| 199 |
+
helper = get_dynamic_ai_help()
|
| 200 |
+
logger.debug(f"[get_ai_wrapper] DYNAMICAIHELP is {'set' if helper else 'None'}")
|
| 201 |
+
result = helper.get_empty_wrapper() if helper else ""
|
| 202 |
+
logger.debug(f"[get_ai_wrapper] Returning: '{result[:50] if result else 'empty'}...'")
|
| 203 |
+
return result
|
potato/ai/ai_prompt.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional, Type
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from potato.server_utils.config_module import config
|
| 8 |
+
ANNOTATIONS = None
|
| 9 |
+
|
| 10 |
+
class ModelManager:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.models_module = None
|
| 13 |
+
|
| 14 |
+
def load_models_module(self):
|
| 15 |
+
"""Load the models module if not already loaded"""
|
| 16 |
+
if self.models_module is None:
|
| 17 |
+
# absolute pathing
|
| 18 |
+
module_path = config.get("ai_support").get("model_module")
|
| 19 |
+
if module_path:
|
| 20 |
+
file_path = Path(module_path)
|
| 21 |
+
if not file_path.exists():
|
| 22 |
+
raise FileNotFoundError(f"Model module file not found: {file_path}")
|
| 23 |
+
module_name = file_path.stem
|
| 24 |
+
|
| 25 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 26 |
+
self.models_module = importlib.util.module_from_spec(spec)
|
| 27 |
+
spec.loader.exec_module(self.models_module)
|
| 28 |
+
|
| 29 |
+
else:
|
| 30 |
+
default_path = Path(__file__).resolve().parent / "prompt" / "models_module.py"
|
| 31 |
+
|
| 32 |
+
if not default_path.exists():
|
| 33 |
+
raise FileNotFoundError(f"Default model module file not found: {default_path}")
|
| 34 |
+
|
| 35 |
+
module_name = default_path.stem
|
| 36 |
+
|
| 37 |
+
spec = importlib.util.spec_from_file_location(module_name, default_path)
|
| 38 |
+
self.models_module = importlib.util.module_from_spec(spec)
|
| 39 |
+
spec.loader.exec_module(self.models_module)
|
| 40 |
+
|
| 41 |
+
return self.models_module
|
| 42 |
+
|
| 43 |
+
def get_model_class_by_name(self, name: str) -> Optional[Type[BaseModel]]:
|
| 44 |
+
"""
|
| 45 |
+
Return a Pydantic model class based on the provided name.
|
| 46 |
+
"""
|
| 47 |
+
models_module = self.load_models_module()
|
| 48 |
+
return models_module.CLASS_REGISTRY.get(name)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def init_ai_prompt(config):
|
| 52 |
+
global ANNOTATIONS
|
| 53 |
+
if not config["ai_support"]["enabled"]:
|
| 54 |
+
return
|
| 55 |
+
try:
|
| 56 |
+
annotation_paths = config.get("ai_support", {}).get("annotation_path")
|
| 57 |
+
|
| 58 |
+
ANNOTATIONS = {}
|
| 59 |
+
|
| 60 |
+
if annotation_paths:
|
| 61 |
+
# Load files from specified paths
|
| 62 |
+
for key, path in annotation_paths.items():
|
| 63 |
+
if path and os.path.exists(path):
|
| 64 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 65 |
+
ANNOTATIONS[key] = json.load(f)
|
| 66 |
+
else:
|
| 67 |
+
raise Exception(f"File path for annotations does not exist: {path}")
|
| 68 |
+
else:
|
| 69 |
+
# Load all JSON files from default directory (parent/prompt)
|
| 70 |
+
default_path = Path(__file__).resolve().parent / "prompt"
|
| 71 |
+
|
| 72 |
+
if default_path.exists() and default_path.is_dir():
|
| 73 |
+
# Find all JSON files in the directory
|
| 74 |
+
json_files = list(default_path.glob("*.json"))
|
| 75 |
+
|
| 76 |
+
if not json_files:
|
| 77 |
+
raise Exception(f"No JSON files found in default directory: {default_path}")
|
| 78 |
+
|
| 79 |
+
# Load each JSON file, using filename (without extension) as key
|
| 80 |
+
for file_path in json_files:
|
| 81 |
+
key = file_path.stem
|
| 82 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 83 |
+
ANNOTATIONS[key] = json.load(f)
|
| 84 |
+
else:
|
| 85 |
+
raise Exception(f"Default annotation directory does not exist: {default_path}")
|
| 86 |
+
|
| 87 |
+
except json.JSONDecodeError as e:
|
| 88 |
+
raise ValueError(f"Invalid JSON in annotation file: {e}")
|
| 89 |
+
except Exception as e:
|
| 90 |
+
raise RuntimeError(f"Unexpected error loading AI prompt: {e}")
|
| 91 |
+
|
| 92 |
+
def get_ai_prompt():
|
| 93 |
+
global ANNOTATIONS
|
| 94 |
+
return ANNOTATIONS
|
potato/ai/anthropic_endpoint.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anthropic AI endpoint implementation.
|
| 3 |
+
|
| 4 |
+
This module provides integration with Anthropic's Claude API for LLM inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
import anthropic
|
| 9 |
+
from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError, ModelCapabilities
|
| 10 |
+
|
| 11 |
+
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
| 12 |
+
DEFAULT_HINT_PROMPT = '''
|
| 13 |
+
You are assisting a user with an annotation task.
|
| 14 |
+
The annotation instruction is : {description}
|
| 15 |
+
The annotation task type is: {annotation_type}
|
| 16 |
+
The sentence (or item) to annotate is : {text}
|
| 17 |
+
Your goal is to generate a short, helpful hint that guides the annotator in how to think about the input — **without providing the answer**.
|
| 18 |
+
|
| 19 |
+
The hint should:
|
| 20 |
+
- Highlight key aspects of the input relevant to the task
|
| 21 |
+
- Encourage thoughtful reasoning or observation
|
| 22 |
+
- Point to subtle features (tone, wording, structure, implication) that matter for the annotation
|
| 23 |
+
- Be specific and informative, not vague or generic
|
| 24 |
+
'''
|
| 25 |
+
|
| 26 |
+
DEFAULT_KEYWORD_PROMPT = '''
|
| 27 |
+
You are assisting a user with an annotation task.
|
| 28 |
+
The annotation instruction is : {description}
|
| 29 |
+
The annotation task type is: {annotation_type}
|
| 30 |
+
The sentence (or item) to annotate is : {text}
|
| 31 |
+
Your goal is : Print out just a sequence of keywords, not sentences, in the text that most relate to the task. Do not explain your answer. Do not print out the entire text. If no part of the text relates to the task, print the empty string.
|
| 32 |
+
'''
|
| 33 |
+
|
| 34 |
+
class AnthropicEndpoint(BaseAIEndpoint):
|
| 35 |
+
"""Anthropic Claude endpoint for cloud-based LLM inference."""
|
| 36 |
+
|
| 37 |
+
# Capabilities declaration for text-based Anthropic Claude models
|
| 38 |
+
CAPABILITIES = ModelCapabilities(
|
| 39 |
+
text_generation=True,
|
| 40 |
+
vision_input=False,
|
| 41 |
+
bounding_box_output=False,
|
| 42 |
+
text_classification=True,
|
| 43 |
+
image_classification=False,
|
| 44 |
+
rationale_generation=True,
|
| 45 |
+
keyword_extraction=True,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def _initialize_client(self) -> None:
|
| 49 |
+
"""Initialize the Anthropic client."""
|
| 50 |
+
api_key = self.ai_config.get("api_key", "")
|
| 51 |
+
if not api_key:
|
| 52 |
+
raise AIEndpointRequestError("Anthropic API key is required")
|
| 53 |
+
|
| 54 |
+
# Default timeout of 30 seconds, configurable via ai_config
|
| 55 |
+
timeout = self.ai_config.get("timeout", 30)
|
| 56 |
+
self.client = anthropic.Anthropic(api_key=api_key, timeout=timeout)
|
| 57 |
+
|
| 58 |
+
def _get_default_model(self) -> str:
|
| 59 |
+
"""Get the default Anthropic model."""
|
| 60 |
+
return DEFAULT_MODEL
|
| 61 |
+
|
| 62 |
+
def _get_default_hint_prompt(self) -> str:
|
| 63 |
+
"""Get the default hint prompt for Anthropic."""
|
| 64 |
+
return DEFAULT_HINT_PROMPT
|
| 65 |
+
|
| 66 |
+
def _get_default_keyword_prompt(self) -> str:
|
| 67 |
+
"""Get the default keyword prompt for Anthropic."""
|
| 68 |
+
return DEFAULT_KEYWORD_PROMPT
|
| 69 |
+
|
| 70 |
+
def query(self, prompt: str) -> str:
|
| 71 |
+
"""
|
| 72 |
+
Send a query to Anthropic Claude and return the response.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
prompt: The prompt to send to the model
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
The model's response as a string
|
| 79 |
+
|
| 80 |
+
Raises:
|
| 81 |
+
AIEndpointRequestError: If the request fails
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
response = self.client.messages.create(
|
| 85 |
+
model=self.model,
|
| 86 |
+
max_tokens=self.max_tokens,
|
| 87 |
+
temperature=self.temperature,
|
| 88 |
+
messages=[{"role": "user", "content": prompt}]
|
| 89 |
+
)
|
| 90 |
+
return response.content[0].text
|
| 91 |
+
except Exception as e:
|
| 92 |
+
raise AIEndpointRequestError(f"Anthropic request failed: {e}")
|
| 93 |
+
|
| 94 |
+
def chat_query(self, messages: List[Dict[str, str]]) -> str:
|
| 95 |
+
"""Send a multi-turn chat to Anthropic using native messages API."""
|
| 96 |
+
try:
|
| 97 |
+
# Extract system message if present
|
| 98 |
+
system_text = ""
|
| 99 |
+
chat_messages = []
|
| 100 |
+
for msg in messages:
|
| 101 |
+
if msg["role"] == "system":
|
| 102 |
+
system_text = msg["content"]
|
| 103 |
+
else:
|
| 104 |
+
chat_messages.append({"role": msg["role"], "content": msg["content"]})
|
| 105 |
+
|
| 106 |
+
kwargs = {
|
| 107 |
+
"model": self.model,
|
| 108 |
+
"max_tokens": self.max_tokens,
|
| 109 |
+
"temperature": self.temperature,
|
| 110 |
+
"messages": chat_messages,
|
| 111 |
+
}
|
| 112 |
+
if system_text:
|
| 113 |
+
kwargs["system"] = system_text
|
| 114 |
+
|
| 115 |
+
response = self.client.messages.create(**kwargs)
|
| 116 |
+
return response.content[0].text
|
| 117 |
+
except Exception as e:
|
| 118 |
+
raise AIEndpointRequestError(f"Anthropic chat request failed: {e}")
|
potato/ai/anthropic_vision_endpoint.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anthropic Vision AI Endpoint
|
| 3 |
+
|
| 4 |
+
This module provides integration with Anthropic's Claude models for visual
|
| 5 |
+
analysis using the image content block format.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import base64
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Any, Dict, List, Type, Union
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
|
| 14 |
+
from .ai_endpoint import AIEndpointRequestError, ImageData, ModelCapabilities
|
| 15 |
+
from .visual_ai_endpoint import BaseVisualAIEndpoint
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
DEFAULT_MODEL = "claude-sonnet-4-20250514"
|
| 20 |
+
|
| 21 |
+
# Supported image types for Claude
|
| 22 |
+
SUPPORTED_MEDIA_TYPES = [
|
| 23 |
+
"image/jpeg",
|
| 24 |
+
"image/png",
|
| 25 |
+
"image/gif",
|
| 26 |
+
"image/webp"
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AnthropicVisionEndpoint(BaseVisualAIEndpoint):
|
| 31 |
+
"""
|
| 32 |
+
Anthropic Vision endpoint for Claude models with vision capabilities.
|
| 33 |
+
|
| 34 |
+
Uses the image content block format for multimodal inputs.
|
| 35 |
+
|
| 36 |
+
Configuration options:
|
| 37 |
+
- model: Model to use (default: claude-sonnet-4-20250514)
|
| 38 |
+
- api_key: Anthropic API key (can also use ANTHROPIC_API_KEY env var)
|
| 39 |
+
- max_tokens: Maximum response tokens (default: 1024)
|
| 40 |
+
- temperature: Sampling temperature (default: 0.1)
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
# Capabilities declaration for Anthropic Claude vision models
|
| 44 |
+
# Claude models can understand images and generate detailed reasoning but bboxes are approximate
|
| 45 |
+
CAPABILITIES = ModelCapabilities(
|
| 46 |
+
text_generation=True,
|
| 47 |
+
vision_input=True,
|
| 48 |
+
bounding_box_output=False, # Claude bboxes are approximate, not precise
|
| 49 |
+
text_classification=True,
|
| 50 |
+
image_classification=True,
|
| 51 |
+
rationale_generation=True,
|
| 52 |
+
keyword_extraction=False, # Keywords don't apply to images
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def _initialize_client(self) -> None:
|
| 56 |
+
"""Initialize the Anthropic client."""
|
| 57 |
+
try:
|
| 58 |
+
import anthropic
|
| 59 |
+
except ImportError:
|
| 60 |
+
raise AIEndpointRequestError(
|
| 61 |
+
"anthropic package is required. Install it with: pip install anthropic"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
import os
|
| 65 |
+
|
| 66 |
+
api_key = self.ai_config.get("api_key") or os.environ.get("ANTHROPIC_API_KEY")
|
| 67 |
+
if not api_key:
|
| 68 |
+
raise AIEndpointRequestError(
|
| 69 |
+
"Anthropic API key is required. Set it in config or ANTHROPIC_API_KEY env var."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
timeout = self.ai_config.get("timeout", 60)
|
| 73 |
+
|
| 74 |
+
self.client = anthropic.Anthropic(api_key=api_key, timeout=timeout)
|
| 75 |
+
logger.info(f"Anthropic Vision client initialized with model: {self.model}")
|
| 76 |
+
|
| 77 |
+
def _get_default_model(self) -> str:
|
| 78 |
+
"""Get the default Anthropic model."""
|
| 79 |
+
return DEFAULT_MODEL
|
| 80 |
+
|
| 81 |
+
def query(self, prompt: str, output_format: Type[BaseModel]) -> Any:
|
| 82 |
+
"""
|
| 83 |
+
Standard text query without images.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
prompt: Text prompt
|
| 87 |
+
output_format: Pydantic model for structured output
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Parsed response
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
# Add JSON instruction to prompt
|
| 94 |
+
json_prompt = f"""{prompt}
|
| 95 |
+
|
| 96 |
+
Please respond with valid JSON matching this schema:
|
| 97 |
+
{output_format.model_json_schema()}"""
|
| 98 |
+
|
| 99 |
+
response = self.client.messages.create(
|
| 100 |
+
model=self.model,
|
| 101 |
+
max_tokens=self.max_tokens,
|
| 102 |
+
messages=[{"role": "user", "content": json_prompt}],
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
content = response.content[0].text
|
| 106 |
+
return self.parseStringToJson(content)
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
raise AIEndpointRequestError(f"Anthropic query failed: {e}")
|
| 110 |
+
|
| 111 |
+
def query_with_image(
|
| 112 |
+
self,
|
| 113 |
+
prompt: str,
|
| 114 |
+
image_data: Union[ImageData, List[ImageData]],
|
| 115 |
+
output_format: Type[BaseModel]
|
| 116 |
+
) -> Any:
|
| 117 |
+
"""
|
| 118 |
+
Send a query with image(s) to Claude vision model.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
prompt: Text prompt describing what to analyze
|
| 122 |
+
image_data: Single ImageData or list of ImageData
|
| 123 |
+
output_format: Pydantic model for structured output
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Parsed response according to output_format
|
| 127 |
+
|
| 128 |
+
Raises:
|
| 129 |
+
AIEndpointRequestError: If the request fails
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
# Prepare images
|
| 133 |
+
images = [image_data] if isinstance(image_data, ImageData) else image_data
|
| 134 |
+
|
| 135 |
+
# Build content array with images first, then text
|
| 136 |
+
content = []
|
| 137 |
+
|
| 138 |
+
for img in images:
|
| 139 |
+
image_block = self._build_image_block(img)
|
| 140 |
+
content.append(image_block)
|
| 141 |
+
|
| 142 |
+
# Add JSON instruction to prompt
|
| 143 |
+
json_prompt = f"""{prompt}
|
| 144 |
+
|
| 145 |
+
Please respond with valid JSON matching this schema:
|
| 146 |
+
{output_format.model_json_schema()}
|
| 147 |
+
|
| 148 |
+
Only return the JSON object, no other text."""
|
| 149 |
+
|
| 150 |
+
content.append({"type": "text", "text": json_prompt})
|
| 151 |
+
|
| 152 |
+
# Make request
|
| 153 |
+
response = self.client.messages.create(
|
| 154 |
+
model=self.model,
|
| 155 |
+
max_tokens=self.max_tokens,
|
| 156 |
+
messages=[{"role": "user", "content": content}],
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
response_content = response.content[0].text
|
| 160 |
+
logger.debug(f"Anthropic vision response: {response_content[:500] if response_content else 'empty'}")
|
| 161 |
+
|
| 162 |
+
return self.parseStringToJson(response_content)
|
| 163 |
+
|
| 164 |
+
except AIEndpointRequestError:
|
| 165 |
+
raise
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Anthropic vision query failed: {e}")
|
| 168 |
+
import traceback
|
| 169 |
+
logger.error(traceback.format_exc())
|
| 170 |
+
raise AIEndpointRequestError(f"Anthropic vision query failed: {e}")
|
| 171 |
+
|
| 172 |
+
def chat_query_with_image(
|
| 173 |
+
self,
|
| 174 |
+
messages: List[Dict[str, Any]],
|
| 175 |
+
images: Any = None,
|
| 176 |
+
) -> str:
|
| 177 |
+
"""
|
| 178 |
+
Multi-turn chat with interleaved images for vision-based agent loops.
|
| 179 |
+
|
| 180 |
+
Messages may have 'content' as a string (text only) or a list of
|
| 181 |
+
content blocks (text + image dicts in Anthropic format).
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
messages: List of message dicts with 'role' and 'content'.
|
| 185 |
+
images: Unused (images are inline in messages).
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
The model's response as a plain text string.
|
| 189 |
+
"""
|
| 190 |
+
try:
|
| 191 |
+
system = ""
|
| 192 |
+
api_messages = []
|
| 193 |
+
|
| 194 |
+
for msg in messages:
|
| 195 |
+
if msg["role"] == "system":
|
| 196 |
+
system = msg["content"] if isinstance(msg["content"], str) else str(msg["content"])
|
| 197 |
+
else:
|
| 198 |
+
api_messages.append({
|
| 199 |
+
"role": msg["role"],
|
| 200 |
+
"content": msg["content"],
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
kwargs = {
|
| 204 |
+
"model": self.model,
|
| 205 |
+
"max_tokens": self.max_tokens,
|
| 206 |
+
"temperature": self.temperature,
|
| 207 |
+
"messages": api_messages,
|
| 208 |
+
}
|
| 209 |
+
if system:
|
| 210 |
+
kwargs["system"] = system
|
| 211 |
+
|
| 212 |
+
response = self.client.messages.create(**kwargs)
|
| 213 |
+
return response.content[0].text
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"Anthropic vision chat query failed: {e}")
|
| 217 |
+
raise AIEndpointRequestError(f"Anthropic vision chat query failed: {e}")
|
| 218 |
+
|
| 219 |
+
def _build_image_block(self, image_data: ImageData) -> Dict[str, Any]:
|
| 220 |
+
"""
|
| 221 |
+
Build image content block for Anthropic API.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
image_data: ImageData object
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Dict with type: "image" and source content
|
| 228 |
+
"""
|
| 229 |
+
if image_data.source == "url":
|
| 230 |
+
# Claude supports URL sources directly
|
| 231 |
+
return {
|
| 232 |
+
"type": "image",
|
| 233 |
+
"source": {
|
| 234 |
+
"type": "url",
|
| 235 |
+
"url": image_data.data
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
elif image_data.source == "base64":
|
| 240 |
+
# Determine media type
|
| 241 |
+
media_type = image_data.mime_type or "image/jpeg"
|
| 242 |
+
|
| 243 |
+
# Validate media type
|
| 244 |
+
if media_type not in SUPPORTED_MEDIA_TYPES:
|
| 245 |
+
logger.warning(f"Media type {media_type} may not be supported. Using image/jpeg.")
|
| 246 |
+
media_type = "image/jpeg"
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
"type": "image",
|
| 250 |
+
"source": {
|
| 251 |
+
"type": "base64",
|
| 252 |
+
"media_type": media_type,
|
| 253 |
+
"data": image_data.data
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
else:
|
| 258 |
+
raise AIEndpointRequestError(f"Unknown image source: {image_data.source}")
|
| 259 |
+
|
| 260 |
+
def analyze_image(
|
| 261 |
+
self,
|
| 262 |
+
image_path_or_url: str,
|
| 263 |
+
prompt: str,
|
| 264 |
+
output_format: Type[BaseModel] = None
|
| 265 |
+
) -> Any:
|
| 266 |
+
"""
|
| 267 |
+
Convenience method for analyzing a single image.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
image_path_or_url: Path to image file or URL
|
| 271 |
+
prompt: Analysis prompt
|
| 272 |
+
output_format: Optional output format model
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Analysis result
|
| 276 |
+
"""
|
| 277 |
+
# Prepare image data
|
| 278 |
+
if image_path_or_url.startswith(("http://", "https://")):
|
| 279 |
+
# Claude can use URLs directly
|
| 280 |
+
image_data = self.create_url_image_data(image_path_or_url)
|
| 281 |
+
else:
|
| 282 |
+
image_data = self.encode_image_to_base64(image_path_or_url)
|
| 283 |
+
|
| 284 |
+
# Use a generic format if not specified
|
| 285 |
+
if output_format is None:
|
| 286 |
+
from .prompt.models_module import GeneralHintFormat
|
| 287 |
+
output_format = GeneralHintFormat
|
| 288 |
+
|
| 289 |
+
return self.query_with_image(prompt, image_data, output_format)
|
| 290 |
+
|
| 291 |
+
def detect_objects(
|
| 292 |
+
self,
|
| 293 |
+
image_path_or_url: str,
|
| 294 |
+
labels: List[str] = None
|
| 295 |
+
) -> Dict[str, Any]:
|
| 296 |
+
"""
|
| 297 |
+
Detect objects in an image and return bounding boxes.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
image_path_or_url: Path to image file or URL
|
| 301 |
+
labels: Optional list of labels to detect
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Dict with detections list
|
| 305 |
+
"""
|
| 306 |
+
from .prompt.models_module import VisualDetectionFormat
|
| 307 |
+
|
| 308 |
+
labels_str = ", ".join(labels) if labels else "all visible objects"
|
| 309 |
+
|
| 310 |
+
prompt = f"""Analyze this image and detect objects. For each object, provide:
|
| 311 |
+
1. The label (from: {labels_str})
|
| 312 |
+
2. A bounding box with normalized coordinates (0-1 range)
|
| 313 |
+
3. Confidence score (0-1)
|
| 314 |
+
|
| 315 |
+
Return a JSON object with this exact structure:
|
| 316 |
+
{{
|
| 317 |
+
"detections": [
|
| 318 |
+
{{
|
| 319 |
+
"label": "object_name",
|
| 320 |
+
"bbox": {{"x": 0.1, "y": 0.2, "width": 0.3, "height": 0.4}},
|
| 321 |
+
"confidence": 0.95
|
| 322 |
+
}}
|
| 323 |
+
]
|
| 324 |
+
}}
|
| 325 |
+
|
| 326 |
+
Important:
|
| 327 |
+
- Coordinates are normalized (0-1) where x,y is the top-left corner
|
| 328 |
+
- x increases left to right, y increases top to bottom
|
| 329 |
+
- width and height are also normalized (0-1)
|
| 330 |
+
- Only include objects you can clearly identify
|
| 331 |
+
- Estimate bounding boxes as accurately as possible"""
|
| 332 |
+
|
| 333 |
+
# Prepare image
|
| 334 |
+
if image_path_or_url.startswith(("http://", "https://")):
|
| 335 |
+
image_data = self.create_url_image_data(image_path_or_url)
|
| 336 |
+
else:
|
| 337 |
+
image_data = self.encode_image_to_base64(image_path_or_url)
|
| 338 |
+
|
| 339 |
+
return self.query_with_image(prompt, image_data, VisualDetectionFormat)
|
| 340 |
+
|
| 341 |
+
def get_annotation_hint(
|
| 342 |
+
self,
|
| 343 |
+
image_path_or_url: str,
|
| 344 |
+
task_description: str,
|
| 345 |
+
labels: List[str]
|
| 346 |
+
) -> Dict[str, Any]:
|
| 347 |
+
"""
|
| 348 |
+
Get a hint for annotating an image without revealing exact locations.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
image_path_or_url: Path to image file or URL
|
| 352 |
+
task_description: Description of the annotation task
|
| 353 |
+
labels: Available labels
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
Dict with hint text and optional suggested label
|
| 357 |
+
"""
|
| 358 |
+
labels_str = ", ".join(labels)
|
| 359 |
+
|
| 360 |
+
prompt = f"""You are helping an annotator with this task: {task_description}
|
| 361 |
+
|
| 362 |
+
Available labels: {labels_str}
|
| 363 |
+
|
| 364 |
+
Provide a helpful hint that guides the annotator without giving away the exact answer.
|
| 365 |
+
The hint should:
|
| 366 |
+
1. Point out relevant features to consider
|
| 367 |
+
2. Suggest what to look for
|
| 368 |
+
3. Not explicitly state the answer or exact locations
|
| 369 |
+
|
| 370 |
+
Return JSON:
|
| 371 |
+
{{
|
| 372 |
+
"hint": "Your helpful hint here",
|
| 373 |
+
"suggested_focus": "What area or aspect to focus on"
|
| 374 |
+
}}"""
|
| 375 |
+
|
| 376 |
+
# Prepare image
|
| 377 |
+
if image_path_or_url.startswith(("http://", "https://")):
|
| 378 |
+
image_data = self.create_url_image_data(image_path_or_url)
|
| 379 |
+
else:
|
| 380 |
+
image_data = self.encode_image_to_base64(image_path_or_url)
|
| 381 |
+
|
| 382 |
+
class HintFormat(BaseModel):
|
| 383 |
+
hint: str
|
| 384 |
+
suggested_focus: str
|
| 385 |
+
|
| 386 |
+
return self.query_with_image(prompt, image_data, HintFormat)
|
| 387 |
+
|
| 388 |
+
def health_check(self) -> bool:
|
| 389 |
+
"""
|
| 390 |
+
Check if the Anthropic API is accessible.
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
True if API is reachable, False otherwise
|
| 394 |
+
"""
|
| 395 |
+
try:
|
| 396 |
+
# Simple test message
|
| 397 |
+
self.client.messages.create(
|
| 398 |
+
model=self.model,
|
| 399 |
+
max_tokens=10,
|
| 400 |
+
messages=[{"role": "user", "content": "Hello"}]
|
| 401 |
+
)
|
| 402 |
+
return True
|
| 403 |
+
except Exception as e:
|
| 404 |
+
logger.error(f"Anthropic health check failed: {e}")
|
| 405 |
+
return False
|
potato/ai/gemini_endpoint.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Google Gemini AI endpoint implementation.
|
| 3 |
+
|
| 4 |
+
This module provides integration with Google's Gemini API for LLM inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from google import genai
|
| 8 |
+
from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError
|
| 9 |
+
|
| 10 |
+
DEFAULT_MODEL = "gemini-2.0-flash-exp"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GeminiEndpoint(BaseAIEndpoint):
|
| 14 |
+
"""Google Gemini endpoint for cloud-based LLM inference."""
|
| 15 |
+
|
| 16 |
+
def _initialize_client(self) -> None:
|
| 17 |
+
"""Initialize the Gemini client."""
|
| 18 |
+
api_key = self.ai_config.get("api_key", "")
|
| 19 |
+
if not api_key:
|
| 20 |
+
raise AIEndpointRequestError("Gemini API key is required")
|
| 21 |
+
|
| 22 |
+
# Default timeout of 30 seconds, configurable via ai_config
|
| 23 |
+
timeout = self.ai_config.get("timeout", 30)
|
| 24 |
+
self.client = genai.Client(
|
| 25 |
+
api_key=api_key,
|
| 26 |
+
http_options={'timeout': timeout}
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def _get_default_model(self) -> str:
|
| 30 |
+
"""Get the default Gemini model."""
|
| 31 |
+
return DEFAULT_MODEL
|
| 32 |
+
|
| 33 |
+
def query(self, prompt: str, prompt_format: dict) -> str:
|
| 34 |
+
"""
|
| 35 |
+
Send a query to Gemini and return the response.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
prompt: The prompt to send to the model
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
The model's response as a string
|
| 42 |
+
|
| 43 |
+
Raises:
|
| 44 |
+
AIEndpointRequestError: If the request fails
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
response = self.client.models.generate_content(
|
| 48 |
+
model=self.model,
|
| 49 |
+
contents=prompt,
|
| 50 |
+
generation_config={
|
| 51 |
+
'max_output_tokens': self.max_tokens,
|
| 52 |
+
'temperature': self.temperature,
|
| 53 |
+
'response_schema': prompt_format.model_json_schema(),
|
| 54 |
+
}
|
| 55 |
+
)
|
| 56 |
+
return response.text
|
| 57 |
+
except Exception as e:
|
| 58 |
+
raise AIEndpointRequestError(f"Gemini request failed: {e}")
|
potato/ai/huggingface_endpoint.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face AI endpoint implementation.
|
| 3 |
+
|
| 4 |
+
This module provides integration with Hugging Face's Inference API for LLM inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from huggingface_hub import InferenceClient
|
| 8 |
+
from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError
|
| 9 |
+
|
| 10 |
+
DEFAULT_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
|
| 11 |
+
|
| 12 |
+
class HuggingfaceEndpoint(BaseAIEndpoint):
|
| 13 |
+
"""Hugging Face endpoint for cloud-based LLM inference."""
|
| 14 |
+
|
| 15 |
+
def _initialize_client(self) -> None:
|
| 16 |
+
"""Initialize the Hugging Face client."""
|
| 17 |
+
api_key = self.ai_config.get("api_key", "")
|
| 18 |
+
if not api_key:
|
| 19 |
+
raise AIEndpointRequestError("Hugging Face API key is required")
|
| 20 |
+
|
| 21 |
+
# Default timeout of 30 seconds, configurable via ai_config
|
| 22 |
+
timeout = self.ai_config.get("timeout", 30)
|
| 23 |
+
self.client = InferenceClient(
|
| 24 |
+
model=self.model,
|
| 25 |
+
token=api_key,
|
| 26 |
+
timeout=timeout
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def _get_default_model(self) -> str:
|
| 30 |
+
"""Get the default Hugging Face model."""
|
| 31 |
+
return DEFAULT_MODEL
|
| 32 |
+
|
| 33 |
+
def query(self, prompt: str, output_format: dict) -> str:
|
| 34 |
+
"""
|
| 35 |
+
Send a query to Hugging Face and return the response.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
prompt: The prompt to send to the model
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
The model's response as a string
|
| 42 |
+
|
| 43 |
+
Raises:
|
| 44 |
+
AIEndpointRequestError: If the request fails
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
response = self.client.chat_completion(
|
| 48 |
+
messages=[{"role": "user", "content": prompt}],
|
| 49 |
+
max_tokens=self.max_tokens,
|
| 50 |
+
temperature=self.temperature,
|
| 51 |
+
response_format= {
|
| 52 |
+
"type": "json_schema",
|
| 53 |
+
"json_schema": {
|
| 54 |
+
"name": "output_format",
|
| 55 |
+
"schema": output_format.model_json_schema(),
|
| 56 |
+
"strict": True,
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
return response.choices[0].message.content
|
| 61 |
+
except Exception as e:
|
| 62 |
+
raise AIEndpointRequestError(f"Hugging Face request failed: {e}")
|
potato/ai/icl_labeler.py
ADDED
|
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
In-Context Learning (ICL) Labeler Module
|
| 3 |
+
|
| 4 |
+
This module provides AI-assisted labeling using high-confidence human annotations
|
| 5 |
+
as in-context examples to prompt an LLM to label remaining data.
|
| 6 |
+
|
| 7 |
+
Key features:
|
| 8 |
+
- Identifies high-confidence examples where annotators agree
|
| 9 |
+
- Uses examples as in-context demonstrations for LLM labeling
|
| 10 |
+
- Tracks LLM confidence scores on predictions
|
| 11 |
+
- Routes subset of LLM labels to humans for verification
|
| 12 |
+
- Calculates and reports LLM accuracy based on verification
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import random
|
| 19 |
+
import threading
|
| 20 |
+
import time
|
| 21 |
+
from collections import Counter, defaultdict
|
| 22 |
+
from dataclasses import dataclass, field, asdict
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
from typing import Dict, List, Optional, Any, Tuple, Set
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class HighConfidenceExample:
|
| 31 |
+
"""A human-annotated example suitable for in-context learning."""
|
| 32 |
+
instance_id: str
|
| 33 |
+
text: str
|
| 34 |
+
schema_name: str
|
| 35 |
+
label: str
|
| 36 |
+
agreement_score: float # Proportion of annotators who chose this label
|
| 37 |
+
annotator_count: int
|
| 38 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 39 |
+
|
| 40 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 41 |
+
"""Serialize to dictionary."""
|
| 42 |
+
return {
|
| 43 |
+
'instance_id': self.instance_id,
|
| 44 |
+
'text': self.text,
|
| 45 |
+
'schema_name': self.schema_name,
|
| 46 |
+
'label': self.label,
|
| 47 |
+
'agreement_score': self.agreement_score,
|
| 48 |
+
'annotator_count': self.annotator_count,
|
| 49 |
+
'timestamp': self.timestamp.isoformat()
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'HighConfidenceExample':
|
| 54 |
+
"""Deserialize from dictionary."""
|
| 55 |
+
timestamp = data.get('timestamp')
|
| 56 |
+
if isinstance(timestamp, str):
|
| 57 |
+
timestamp = datetime.fromisoformat(timestamp)
|
| 58 |
+
elif timestamp is None:
|
| 59 |
+
timestamp = datetime.now()
|
| 60 |
+
|
| 61 |
+
return cls(
|
| 62 |
+
instance_id=data['instance_id'],
|
| 63 |
+
text=data['text'],
|
| 64 |
+
schema_name=data['schema_name'],
|
| 65 |
+
label=data['label'],
|
| 66 |
+
agreement_score=data['agreement_score'],
|
| 67 |
+
annotator_count=data['annotator_count'],
|
| 68 |
+
timestamp=timestamp
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class ICLPrediction:
|
| 74 |
+
"""Record of an LLM prediction using in-context learning."""
|
| 75 |
+
instance_id: str
|
| 76 |
+
schema_name: str
|
| 77 |
+
predicted_label: str
|
| 78 |
+
confidence_score: float # 0.0-1.0
|
| 79 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 80 |
+
|
| 81 |
+
# In-context examples used
|
| 82 |
+
example_instance_ids: List[str] = field(default_factory=list)
|
| 83 |
+
|
| 84 |
+
# Verification tracking
|
| 85 |
+
verification_status: str = 'pending' # 'pending', 'verified_correct', 'verified_incorrect'
|
| 86 |
+
verified_by: Optional[str] = None
|
| 87 |
+
verified_at: Optional[datetime] = None
|
| 88 |
+
human_label: Optional[str] = None # Human's label if verified
|
| 89 |
+
|
| 90 |
+
# LLM metadata
|
| 91 |
+
model_name: str = ""
|
| 92 |
+
reasoning: str = ""
|
| 93 |
+
|
| 94 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 95 |
+
"""Serialize to dictionary."""
|
| 96 |
+
return {
|
| 97 |
+
'instance_id': self.instance_id,
|
| 98 |
+
'schema_name': self.schema_name,
|
| 99 |
+
'predicted_label': self.predicted_label,
|
| 100 |
+
'confidence_score': self.confidence_score,
|
| 101 |
+
'timestamp': self.timestamp.isoformat(),
|
| 102 |
+
'example_instance_ids': self.example_instance_ids,
|
| 103 |
+
'verification_status': self.verification_status,
|
| 104 |
+
'verified_by': self.verified_by,
|
| 105 |
+
'verified_at': self.verified_at.isoformat() if self.verified_at else None,
|
| 106 |
+
'human_label': self.human_label,
|
| 107 |
+
'model_name': self.model_name,
|
| 108 |
+
'reasoning': self.reasoning
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'ICLPrediction':
|
| 113 |
+
"""Deserialize from dictionary."""
|
| 114 |
+
timestamp = data.get('timestamp')
|
| 115 |
+
if isinstance(timestamp, str):
|
| 116 |
+
timestamp = datetime.fromisoformat(timestamp)
|
| 117 |
+
elif timestamp is None:
|
| 118 |
+
timestamp = datetime.now()
|
| 119 |
+
|
| 120 |
+
verified_at = data.get('verified_at')
|
| 121 |
+
if isinstance(verified_at, str):
|
| 122 |
+
verified_at = datetime.fromisoformat(verified_at)
|
| 123 |
+
|
| 124 |
+
return cls(
|
| 125 |
+
instance_id=data['instance_id'],
|
| 126 |
+
schema_name=data['schema_name'],
|
| 127 |
+
predicted_label=data['predicted_label'],
|
| 128 |
+
confidence_score=data['confidence_score'],
|
| 129 |
+
timestamp=timestamp,
|
| 130 |
+
example_instance_ids=data.get('example_instance_ids', []),
|
| 131 |
+
verification_status=data.get('verification_status', 'pending'),
|
| 132 |
+
verified_by=data.get('verified_by'),
|
| 133 |
+
verified_at=verified_at,
|
| 134 |
+
human_label=data.get('human_label'),
|
| 135 |
+
model_name=data.get('model_name', ''),
|
| 136 |
+
reasoning=data.get('reasoning', '')
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ICLLabeler:
|
| 141 |
+
"""
|
| 142 |
+
Manages in-context learning based labeling using high-confidence human annotations.
|
| 143 |
+
|
| 144 |
+
Workflow:
|
| 145 |
+
1. Monitors annotation progress for high-confidence examples
|
| 146 |
+
2. Periodically refreshes pool of high-confidence examples
|
| 147 |
+
3. Uses examples to prompt LLM for labeling unlabeled instances
|
| 148 |
+
4. Routes some LLM-labeled instances for human verification (blind)
|
| 149 |
+
5. Tracks accuracy metrics
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
_instance = None
|
| 153 |
+
_lock = threading.RLock()
|
| 154 |
+
|
| 155 |
+
def __new__(cls, *args, **kwargs):
|
| 156 |
+
"""Singleton pattern."""
|
| 157 |
+
if cls._instance is None:
|
| 158 |
+
with cls._lock:
|
| 159 |
+
if cls._instance is None:
|
| 160 |
+
cls._instance = super().__new__(cls)
|
| 161 |
+
cls._instance._initialized = False
|
| 162 |
+
return cls._instance
|
| 163 |
+
|
| 164 |
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
| 165 |
+
"""
|
| 166 |
+
Initialize the ICLLabeler.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
config: Configuration dictionary with settings
|
| 170 |
+
"""
|
| 171 |
+
if self._initialized:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
self.config = config or {}
|
| 175 |
+
self._ai_endpoint = None
|
| 176 |
+
|
| 177 |
+
# Get ICL labeling config
|
| 178 |
+
icl_config = self.config.get('icl_labeling', {})
|
| 179 |
+
|
| 180 |
+
# Example selection config
|
| 181 |
+
example_config = icl_config.get('example_selection', {})
|
| 182 |
+
self.min_agreement_threshold = example_config.get('min_agreement_threshold', 0.8)
|
| 183 |
+
self.min_annotators_per_instance = example_config.get('min_annotators_per_instance', 2)
|
| 184 |
+
self.max_examples_per_schema = example_config.get('max_examples_per_schema', 10)
|
| 185 |
+
self.example_refresh_interval = example_config.get('refresh_interval_seconds', 300)
|
| 186 |
+
|
| 187 |
+
# LLM labeling config
|
| 188 |
+
llm_config = icl_config.get('llm_labeling', {})
|
| 189 |
+
self.batch_size = llm_config.get('batch_size', 20)
|
| 190 |
+
self.trigger_threshold = llm_config.get('trigger_threshold', 5)
|
| 191 |
+
self.confidence_threshold = llm_config.get('confidence_threshold', 0.7)
|
| 192 |
+
self.batch_interval = llm_config.get('batch_interval_seconds', 600)
|
| 193 |
+
|
| 194 |
+
# Limits to prevent labeling entire dataset at once
|
| 195 |
+
# This allows iterative improvement - verify accuracy before labeling more
|
| 196 |
+
self.max_total_labels = llm_config.get('max_total_labels', None) # Max instances to label total
|
| 197 |
+
self.max_unlabeled_ratio = llm_config.get('max_unlabeled_ratio', 0.5) # Max % of unlabeled to label
|
| 198 |
+
self.pause_on_low_accuracy = llm_config.get('pause_on_low_accuracy', True)
|
| 199 |
+
self.min_accuracy_threshold = llm_config.get('min_accuracy_threshold', 0.7) # Pause if accuracy below
|
| 200 |
+
|
| 201 |
+
# Verification config
|
| 202 |
+
verification_config = icl_config.get('verification', {})
|
| 203 |
+
self.verification_enabled = verification_config.get('enabled', True)
|
| 204 |
+
self.verification_sample_rate = verification_config.get('sample_rate', 0.2)
|
| 205 |
+
self.verification_strategy = verification_config.get('selection_strategy', 'low_confidence')
|
| 206 |
+
|
| 207 |
+
# Persistence config
|
| 208 |
+
persistence_config = icl_config.get('persistence', {})
|
| 209 |
+
self.predictions_file = persistence_config.get('predictions_file', 'icl_predictions.json')
|
| 210 |
+
|
| 211 |
+
# State
|
| 212 |
+
self.schema_to_examples: Dict[str, List[HighConfidenceExample]] = {}
|
| 213 |
+
self.predictions: Dict[str, Dict[str, ICLPrediction]] = {} # instance_id -> schema -> prediction
|
| 214 |
+
self.verification_queue: List[Tuple[str, str]] = [] # [(instance_id, schema_name), ...]
|
| 215 |
+
self.labeled_instance_ids: Set[str] = set() # Instances labeled by LLM
|
| 216 |
+
|
| 217 |
+
self.last_example_refresh: Optional[datetime] = None
|
| 218 |
+
self.last_batch_run: Optional[datetime] = None
|
| 219 |
+
|
| 220 |
+
# Background worker
|
| 221 |
+
self._worker_thread: Optional[threading.Thread] = None
|
| 222 |
+
self._stop_worker = threading.Event()
|
| 223 |
+
|
| 224 |
+
self._initialized = True
|
| 225 |
+
logger.info("ICLLabeler initialized")
|
| 226 |
+
|
| 227 |
+
def _get_ai_endpoint(self):
|
| 228 |
+
"""Get or create AI endpoint from config (reuses ai_support config)."""
|
| 229 |
+
if self._ai_endpoint is None:
|
| 230 |
+
from potato.ai.ai_endpoint import AIEndpointFactory
|
| 231 |
+
self._ai_endpoint = AIEndpointFactory.create_endpoint(self.config)
|
| 232 |
+
return self._ai_endpoint
|
| 233 |
+
|
| 234 |
+
def _get_annotation_schemes(self) -> List[Dict[str, Any]]:
|
| 235 |
+
"""Get annotation schemes from config."""
|
| 236 |
+
return self.config.get('annotation_schemes', [])
|
| 237 |
+
|
| 238 |
+
def _get_text_key(self) -> str:
|
| 239 |
+
"""Get the text key from item_properties."""
|
| 240 |
+
return self.config.get('item_properties', {}).get('text_key', 'text')
|
| 241 |
+
|
| 242 |
+
# === High-Confidence Example Collection ===
|
| 243 |
+
|
| 244 |
+
def refresh_high_confidence_examples(self) -> Dict[str, List[HighConfidenceExample]]:
|
| 245 |
+
"""
|
| 246 |
+
Scan annotations and identify high-confidence examples.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Dictionary mapping schema name to list of high-confidence examples
|
| 250 |
+
"""
|
| 251 |
+
from potato.flask_server import get_users, get_user_state, get_item_state_manager
|
| 252 |
+
|
| 253 |
+
with self._lock:
|
| 254 |
+
new_examples: Dict[str, List[HighConfidenceExample]] = defaultdict(list)
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
ism = get_item_state_manager()
|
| 258 |
+
if ism is None:
|
| 259 |
+
logger.warning("ItemStateManager not available")
|
| 260 |
+
return new_examples
|
| 261 |
+
|
| 262 |
+
text_key = self._get_text_key()
|
| 263 |
+
schemas = self._get_annotation_schemes()
|
| 264 |
+
schema_names = [s.get('name') for s in schemas if s.get('name')]
|
| 265 |
+
|
| 266 |
+
# Collect all annotations per instance
|
| 267 |
+
instance_annotations: Dict[str, Dict[str, List[Tuple[str, Any]]]] = defaultdict(
|
| 268 |
+
lambda: defaultdict(list)
|
| 269 |
+
) # instance_id -> schema_name -> [(user_id, value), ...]
|
| 270 |
+
|
| 271 |
+
for username in get_users():
|
| 272 |
+
user_state = get_user_state(username)
|
| 273 |
+
if not user_state:
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
all_annotations = user_state.get_all_annotations()
|
| 277 |
+
for instance_id, instance_data in all_annotations.items():
|
| 278 |
+
if 'labels' not in instance_data:
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
for label, value in instance_data['labels'].items():
|
| 282 |
+
schema_name = label.get_schema() if hasattr(label, 'get_schema') else str(label)
|
| 283 |
+
if schema_name in schema_names:
|
| 284 |
+
instance_annotations[instance_id][schema_name].append((username, value))
|
| 285 |
+
|
| 286 |
+
# Find high-confidence examples
|
| 287 |
+
for instance_id, schema_data in instance_annotations.items():
|
| 288 |
+
for schema_name, annotations in schema_data.items():
|
| 289 |
+
annotator_count = len(annotations)
|
| 290 |
+
|
| 291 |
+
if annotator_count < self.min_annotators_per_instance:
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
# Count votes per label
|
| 295 |
+
label_counts = Counter(value for _, value in annotations)
|
| 296 |
+
most_common_label, most_common_count = label_counts.most_common(1)[0]
|
| 297 |
+
|
| 298 |
+
# Calculate agreement
|
| 299 |
+
agreement_score = most_common_count / annotator_count
|
| 300 |
+
|
| 301 |
+
if agreement_score >= self.min_agreement_threshold:
|
| 302 |
+
# Get instance text
|
| 303 |
+
item = ism.get_item(instance_id)
|
| 304 |
+
instance_data = item.get_data() if item else None
|
| 305 |
+
if instance_data is None:
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
text = instance_data.get(text_key, '')
|
| 309 |
+
if not text:
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
example = HighConfidenceExample(
|
| 313 |
+
instance_id=instance_id,
|
| 314 |
+
text=text,
|
| 315 |
+
schema_name=schema_name,
|
| 316 |
+
label=str(most_common_label),
|
| 317 |
+
agreement_score=agreement_score,
|
| 318 |
+
annotator_count=annotator_count
|
| 319 |
+
)
|
| 320 |
+
new_examples[schema_name].append(example)
|
| 321 |
+
|
| 322 |
+
# Select examples using coverage-based selection (CoverICL-inspired)
|
| 323 |
+
# or fall back to agreement-score sorting
|
| 324 |
+
for schema_name in new_examples:
|
| 325 |
+
candidates = new_examples[schema_name]
|
| 326 |
+
if len(candidates) > self.max_examples_per_schema:
|
| 327 |
+
selected = self._select_diverse_examples(
|
| 328 |
+
candidates, self.max_examples_per_schema
|
| 329 |
+
)
|
| 330 |
+
new_examples[schema_name] = selected
|
| 331 |
+
else:
|
| 332 |
+
new_examples[schema_name].sort(
|
| 333 |
+
key=lambda x: x.agreement_score, reverse=True
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.schema_to_examples = dict(new_examples)
|
| 337 |
+
self.last_example_refresh = datetime.now()
|
| 338 |
+
|
| 339 |
+
total_examples = sum(len(examples) for examples in new_examples.values())
|
| 340 |
+
logger.info(f"Refreshed high-confidence examples: {total_examples} examples across {len(new_examples)} schemas")
|
| 341 |
+
|
| 342 |
+
except Exception as e:
|
| 343 |
+
logger.error(f"Error refreshing examples: {e}")
|
| 344 |
+
|
| 345 |
+
return self.schema_to_examples
|
| 346 |
+
|
| 347 |
+
def get_examples_for_schema(self, schema_name: str) -> List[HighConfidenceExample]:
|
| 348 |
+
"""Get high-confidence examples for a specific schema."""
|
| 349 |
+
return self.schema_to_examples.get(schema_name, [])
|
| 350 |
+
|
| 351 |
+
def has_enough_examples(self, schema_name: str) -> bool:
|
| 352 |
+
"""Check if we have enough examples to start labeling."""
|
| 353 |
+
return len(self.get_examples_for_schema(schema_name)) >= self.trigger_threshold
|
| 354 |
+
|
| 355 |
+
def _select_diverse_examples(
|
| 356 |
+
self,
|
| 357 |
+
candidates: List[HighConfidenceExample],
|
| 358 |
+
k: int,
|
| 359 |
+
) -> List[HighConfidenceExample]:
|
| 360 |
+
"""CoverICL-inspired coverage-based example selection.
|
| 361 |
+
|
| 362 |
+
Uses greedy facility location to select examples that maximize
|
| 363 |
+
coverage of the instance embedding space, ensuring diverse and
|
| 364 |
+
representative ICL demonstrations.
|
| 365 |
+
|
| 366 |
+
Inspired by Mavromatis et al. (2024) CoverICL. Falls back to
|
| 367 |
+
agreement-score sorting if vectorization fails.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
candidates: Pool of high-confidence examples
|
| 371 |
+
k: Number of examples to select
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
List of selected examples maximizing coverage
|
| 375 |
+
"""
|
| 376 |
+
if len(candidates) <= k:
|
| 377 |
+
return candidates
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 381 |
+
from sklearn.metrics.pairwise import cosine_distances
|
| 382 |
+
|
| 383 |
+
texts = [c.text for c in candidates]
|
| 384 |
+
vectorizer = TfidfVectorizer(max_features=5000)
|
| 385 |
+
features = vectorizer.fit_transform(texts).toarray()
|
| 386 |
+
|
| 387 |
+
# Greedy facility location: iteratively pick the candidate that
|
| 388 |
+
# maximizes the minimum distance to already-selected examples
|
| 389 |
+
n = len(candidates)
|
| 390 |
+
selected_indices = []
|
| 391 |
+
|
| 392 |
+
# Start with the candidate that has highest agreement score
|
| 393 |
+
# (quality-weighted seed)
|
| 394 |
+
agreements = [c.agreement_score for c in candidates]
|
| 395 |
+
first = int(max(range(n), key=lambda i: agreements[i]))
|
| 396 |
+
selected_indices.append(first)
|
| 397 |
+
|
| 398 |
+
dist_matrix = cosine_distances(features)
|
| 399 |
+
|
| 400 |
+
for _ in range(k - 1):
|
| 401 |
+
# For each unselected candidate, compute min distance to selected set
|
| 402 |
+
best_idx = -1
|
| 403 |
+
best_score = -1.0
|
| 404 |
+
|
| 405 |
+
for i in range(n):
|
| 406 |
+
if i in selected_indices:
|
| 407 |
+
continue
|
| 408 |
+
min_dist = min(dist_matrix[i][j] for j in selected_indices)
|
| 409 |
+
# Weight by agreement score for quality-aware selection
|
| 410 |
+
score = min_dist * candidates[i].agreement_score
|
| 411 |
+
if score > best_score:
|
| 412 |
+
best_score = score
|
| 413 |
+
best_idx = i
|
| 414 |
+
|
| 415 |
+
if best_idx >= 0:
|
| 416 |
+
selected_indices.append(best_idx)
|
| 417 |
+
else:
|
| 418 |
+
break
|
| 419 |
+
|
| 420 |
+
selected = [candidates[i] for i in selected_indices]
|
| 421 |
+
logger.debug(f"CoverICL selection: {len(selected)} diverse examples from {n} candidates")
|
| 422 |
+
return selected
|
| 423 |
+
|
| 424 |
+
except Exception as e:
|
| 425 |
+
logger.warning(f"Coverage-based selection failed, using agreement sorting: {e}")
|
| 426 |
+
candidates.sort(key=lambda x: x.agreement_score, reverse=True)
|
| 427 |
+
return candidates[:k]
|
| 428 |
+
|
| 429 |
+
# === LLM Labeling ===
|
| 430 |
+
|
| 431 |
+
def label_instance(
|
| 432 |
+
self,
|
| 433 |
+
instance_id: str,
|
| 434 |
+
schema_name: str,
|
| 435 |
+
instance_text: str
|
| 436 |
+
) -> Optional[ICLPrediction]:
|
| 437 |
+
"""
|
| 438 |
+
Label a single instance using in-context learning.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
instance_id: The instance to label
|
| 442 |
+
schema_name: The annotation schema to use
|
| 443 |
+
instance_text: The text to label
|
| 444 |
+
|
| 445 |
+
Returns:
|
| 446 |
+
ICLPrediction if successful, None otherwise
|
| 447 |
+
"""
|
| 448 |
+
from potato.ai.icl_prompt_builder import ICLPromptBuilder
|
| 449 |
+
|
| 450 |
+
examples = self.get_examples_for_schema(schema_name)
|
| 451 |
+
if not examples:
|
| 452 |
+
logger.warning(f"No examples available for schema {schema_name}")
|
| 453 |
+
return None
|
| 454 |
+
|
| 455 |
+
# Get schema info
|
| 456 |
+
schemas = self._get_annotation_schemes()
|
| 457 |
+
schema_info = next((s for s in schemas if s.get('name') == schema_name), None)
|
| 458 |
+
if not schema_info:
|
| 459 |
+
logger.warning(f"Schema {schema_name} not found in config")
|
| 460 |
+
return None
|
| 461 |
+
|
| 462 |
+
endpoint = self._get_ai_endpoint()
|
| 463 |
+
if endpoint is None:
|
| 464 |
+
logger.warning("AI endpoint not available")
|
| 465 |
+
return None
|
| 466 |
+
|
| 467 |
+
try:
|
| 468 |
+
# Build prompt
|
| 469 |
+
prompt_builder = ICLPromptBuilder()
|
| 470 |
+
prompt = prompt_builder.build_prompt(
|
| 471 |
+
schema=schema_info,
|
| 472 |
+
examples=examples,
|
| 473 |
+
target_text=instance_text
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Query LLM
|
| 477 |
+
from pydantic import BaseModel
|
| 478 |
+
|
| 479 |
+
class ICLResponse(BaseModel):
|
| 480 |
+
label: str
|
| 481 |
+
confidence: float
|
| 482 |
+
reasoning: str = ""
|
| 483 |
+
|
| 484 |
+
response = endpoint.query(prompt, ICLResponse)
|
| 485 |
+
|
| 486 |
+
# Parse response
|
| 487 |
+
if isinstance(response, str):
|
| 488 |
+
response_data = json.loads(response)
|
| 489 |
+
elif hasattr(response, 'model_dump'):
|
| 490 |
+
response_data = response.model_dump()
|
| 491 |
+
else:
|
| 492 |
+
response_data = response
|
| 493 |
+
|
| 494 |
+
predicted_label = response_data.get('label', '')
|
| 495 |
+
confidence = float(response_data.get('confidence', 0.5))
|
| 496 |
+
reasoning = response_data.get('reasoning', '')
|
| 497 |
+
|
| 498 |
+
# Validate label against schema
|
| 499 |
+
valid_labels = self._get_valid_labels(schema_info)
|
| 500 |
+
if valid_labels and predicted_label not in valid_labels:
|
| 501 |
+
# Try fuzzy matching
|
| 502 |
+
predicted_label = self._fuzzy_match_label(predicted_label, valid_labels)
|
| 503 |
+
if predicted_label is None:
|
| 504 |
+
logger.warning(f"LLM returned invalid label for {instance_id}")
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
# Create prediction
|
| 508 |
+
prediction = ICLPrediction(
|
| 509 |
+
instance_id=instance_id,
|
| 510 |
+
schema_name=schema_name,
|
| 511 |
+
predicted_label=predicted_label,
|
| 512 |
+
confidence_score=min(1.0, max(0.0, confidence)),
|
| 513 |
+
example_instance_ids=[e.instance_id for e in examples],
|
| 514 |
+
model_name=endpoint.model if hasattr(endpoint, 'model') else '',
|
| 515 |
+
reasoning=reasoning
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Store prediction
|
| 519 |
+
with self._lock:
|
| 520 |
+
if instance_id not in self.predictions:
|
| 521 |
+
self.predictions[instance_id] = {}
|
| 522 |
+
self.predictions[instance_id][schema_name] = prediction
|
| 523 |
+
self.labeled_instance_ids.add(instance_id)
|
| 524 |
+
|
| 525 |
+
# Maybe add to verification queue
|
| 526 |
+
if self.verification_enabled and random.random() < self.verification_sample_rate:
|
| 527 |
+
self.verification_queue.append((instance_id, schema_name))
|
| 528 |
+
|
| 529 |
+
logger.debug(f"Labeled {instance_id} with {predicted_label} (confidence: {confidence:.2f})")
|
| 530 |
+
return prediction
|
| 531 |
+
|
| 532 |
+
except Exception as e:
|
| 533 |
+
logger.error(f"Error labeling instance {instance_id}: {e}")
|
| 534 |
+
return None
|
| 535 |
+
|
| 536 |
+
def _get_valid_labels(self, schema_info: Dict[str, Any]) -> List[str]:
|
| 537 |
+
"""Extract valid labels from schema info."""
|
| 538 |
+
labels = schema_info.get('labels', [])
|
| 539 |
+
valid_labels = []
|
| 540 |
+
for label in labels:
|
| 541 |
+
if isinstance(label, str):
|
| 542 |
+
valid_labels.append(label)
|
| 543 |
+
elif isinstance(label, dict):
|
| 544 |
+
valid_labels.append(label.get('name', str(label)))
|
| 545 |
+
return valid_labels
|
| 546 |
+
|
| 547 |
+
def _fuzzy_match_label(self, predicted: str, valid_labels: List[str]) -> Optional[str]:
|
| 548 |
+
"""Try to match predicted label to a valid label."""
|
| 549 |
+
predicted_lower = predicted.lower().strip()
|
| 550 |
+
for label in valid_labels:
|
| 551 |
+
if label.lower().strip() == predicted_lower:
|
| 552 |
+
return label
|
| 553 |
+
return None
|
| 554 |
+
|
| 555 |
+
def should_pause_labeling(self) -> Tuple[bool, str]:
|
| 556 |
+
"""
|
| 557 |
+
Check if labeling should be paused based on limits and accuracy.
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
Tuple of (should_pause, reason)
|
| 561 |
+
"""
|
| 562 |
+
# Check if max total labels reached
|
| 563 |
+
if self.max_total_labels is not None:
|
| 564 |
+
current_count = len(self.labeled_instance_ids)
|
| 565 |
+
if current_count >= self.max_total_labels:
|
| 566 |
+
return True, f"Reached max_total_labels limit ({self.max_total_labels})"
|
| 567 |
+
|
| 568 |
+
# Check accuracy threshold
|
| 569 |
+
if self.pause_on_low_accuracy:
|
| 570 |
+
metrics = self.get_accuracy_metrics()
|
| 571 |
+
total_verified = metrics.get('total_verified', 0)
|
| 572 |
+
accuracy = metrics.get('accuracy')
|
| 573 |
+
|
| 574 |
+
# Only check accuracy if we have enough verifications
|
| 575 |
+
min_verifications = 10
|
| 576 |
+
if total_verified >= min_verifications and accuracy is not None:
|
| 577 |
+
if accuracy < self.min_accuracy_threshold:
|
| 578 |
+
return True, f"Accuracy ({accuracy:.1%}) below threshold ({self.min_accuracy_threshold:.1%})"
|
| 579 |
+
|
| 580 |
+
return False, ""
|
| 581 |
+
|
| 582 |
+
def get_remaining_label_capacity(self) -> int:
|
| 583 |
+
"""
|
| 584 |
+
Get how many more instances can be labeled.
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
Number of instances that can still be labeled, or -1 for unlimited
|
| 588 |
+
"""
|
| 589 |
+
from potato.item_state_management import get_item_state_manager
|
| 590 |
+
from potato.user_state_management import get_user_state_manager
|
| 591 |
+
|
| 592 |
+
try:
|
| 593 |
+
ism = get_item_state_manager()
|
| 594 |
+
except ValueError:
|
| 595 |
+
# ISM not initialized yet
|
| 596 |
+
return 0
|
| 597 |
+
if ism is None:
|
| 598 |
+
return 0
|
| 599 |
+
|
| 600 |
+
# Count unlabeled instances (not labeled by humans or LLM)
|
| 601 |
+
unlabeled_count = 0
|
| 602 |
+
usm = get_user_state_manager()
|
| 603 |
+
for instance_id in ism.instance_id_ordering:
|
| 604 |
+
if instance_id in self.labeled_instance_ids:
|
| 605 |
+
continue
|
| 606 |
+
|
| 607 |
+
has_human_annotation = False
|
| 608 |
+
for username in usm.get_all_users():
|
| 609 |
+
user_state = usm.get_user_state(username)
|
| 610 |
+
if user_state:
|
| 611 |
+
all_annotations = user_state.get_all_annotations()
|
| 612 |
+
if instance_id in all_annotations:
|
| 613 |
+
has_human_annotation = True
|
| 614 |
+
break
|
| 615 |
+
|
| 616 |
+
if not has_human_annotation:
|
| 617 |
+
unlabeled_count += 1
|
| 618 |
+
|
| 619 |
+
current_llm_labels = len(self.labeled_instance_ids)
|
| 620 |
+
|
| 621 |
+
# Calculate max based on ratio
|
| 622 |
+
max_from_ratio = int(unlabeled_count * self.max_unlabeled_ratio)
|
| 623 |
+
|
| 624 |
+
# Calculate max based on total limit
|
| 625 |
+
if self.max_total_labels is not None:
|
| 626 |
+
max_from_total = self.max_total_labels - current_llm_labels
|
| 627 |
+
return min(max_from_ratio, max_from_total)
|
| 628 |
+
|
| 629 |
+
return max_from_ratio
|
| 630 |
+
|
| 631 |
+
def batch_label_instances(self, schema_name: str) -> List[ICLPrediction]:
|
| 632 |
+
"""
|
| 633 |
+
Label multiple unlabeled instances for a schema.
|
| 634 |
+
|
| 635 |
+
Respects configured limits to prevent labeling entire dataset at once.
|
| 636 |
+
|
| 637 |
+
Returns:
|
| 638 |
+
List of successful predictions
|
| 639 |
+
"""
|
| 640 |
+
from potato.flask_server import get_item_state_manager, get_users, get_user_state
|
| 641 |
+
|
| 642 |
+
# Check if we should pause labeling
|
| 643 |
+
should_pause, reason = self.should_pause_labeling()
|
| 644 |
+
if should_pause:
|
| 645 |
+
logger.info(f"Labeling paused: {reason}")
|
| 646 |
+
return []
|
| 647 |
+
|
| 648 |
+
if not self.has_enough_examples(schema_name):
|
| 649 |
+
logger.info(f"Not enough examples for schema {schema_name}")
|
| 650 |
+
return []
|
| 651 |
+
|
| 652 |
+
ism = get_item_state_manager()
|
| 653 |
+
if ism is None:
|
| 654 |
+
return []
|
| 655 |
+
|
| 656 |
+
# Check remaining capacity
|
| 657 |
+
remaining_capacity = self.get_remaining_label_capacity()
|
| 658 |
+
if remaining_capacity <= 0:
|
| 659 |
+
logger.info("No remaining label capacity")
|
| 660 |
+
return []
|
| 661 |
+
|
| 662 |
+
# Limit batch size to remaining capacity
|
| 663 |
+
effective_batch_size = min(self.batch_size, remaining_capacity)
|
| 664 |
+
|
| 665 |
+
text_key = self._get_text_key()
|
| 666 |
+
predictions = []
|
| 667 |
+
|
| 668 |
+
# Find unlabeled instances
|
| 669 |
+
unlabeled_ids = []
|
| 670 |
+
for instance_id in ism.instance_id_ordering:
|
| 671 |
+
# Skip if already labeled by LLM
|
| 672 |
+
if instance_id in self.labeled_instance_ids:
|
| 673 |
+
continue
|
| 674 |
+
|
| 675 |
+
# Skip if already annotated by humans
|
| 676 |
+
has_human_annotation = False
|
| 677 |
+
for username in get_users():
|
| 678 |
+
user_state = get_user_state(username)
|
| 679 |
+
if user_state:
|
| 680 |
+
all_annotations = user_state.get_all_annotations()
|
| 681 |
+
if instance_id in all_annotations:
|
| 682 |
+
has_human_annotation = True
|
| 683 |
+
break
|
| 684 |
+
|
| 685 |
+
if not has_human_annotation:
|
| 686 |
+
unlabeled_ids.append(instance_id)
|
| 687 |
+
|
| 688 |
+
if len(unlabeled_ids) >= effective_batch_size:
|
| 689 |
+
break
|
| 690 |
+
|
| 691 |
+
# Label instances
|
| 692 |
+
for instance_id in unlabeled_ids:
|
| 693 |
+
item = ism.get_item(instance_id)
|
| 694 |
+
instance_data = item.get_data() if item else None
|
| 695 |
+
if instance_data is None:
|
| 696 |
+
continue
|
| 697 |
+
|
| 698 |
+
text = instance_data.get(text_key, '')
|
| 699 |
+
if not text:
|
| 700 |
+
continue
|
| 701 |
+
|
| 702 |
+
prediction = self.label_instance(instance_id, schema_name, text)
|
| 703 |
+
if prediction and prediction.confidence_score >= self.confidence_threshold:
|
| 704 |
+
predictions.append(prediction)
|
| 705 |
+
|
| 706 |
+
self.last_batch_run = datetime.now()
|
| 707 |
+
logger.info(f"Batch labeled {len(predictions)} instances for schema {schema_name}")
|
| 708 |
+
|
| 709 |
+
return predictions
|
| 710 |
+
|
| 711 |
+
# === Verification Workflow ===
|
| 712 |
+
|
| 713 |
+
def get_pending_verifications(self, count: int = 1) -> List[Tuple[str, str]]:
|
| 714 |
+
"""
|
| 715 |
+
Get instances pending human verification.
|
| 716 |
+
|
| 717 |
+
Args:
|
| 718 |
+
count: Number of verification tasks to return
|
| 719 |
+
|
| 720 |
+
Returns:
|
| 721 |
+
List of (instance_id, schema_name) tuples
|
| 722 |
+
"""
|
| 723 |
+
with self._lock:
|
| 724 |
+
if self.verification_strategy == 'low_confidence':
|
| 725 |
+
# Sort by confidence ascending
|
| 726 |
+
pending = [
|
| 727 |
+
(inst_id, schema)
|
| 728 |
+
for inst_id, schema in self.verification_queue
|
| 729 |
+
if (inst_id in self.predictions and
|
| 730 |
+
schema in self.predictions[inst_id] and
|
| 731 |
+
self.predictions[inst_id][schema].verification_status == 'pending')
|
| 732 |
+
]
|
| 733 |
+
pending.sort(
|
| 734 |
+
key=lambda x: self.predictions[x[0]][x[1]].confidence_score
|
| 735 |
+
)
|
| 736 |
+
return pending[:count]
|
| 737 |
+
|
| 738 |
+
elif self.verification_strategy == 'random':
|
| 739 |
+
pending = [
|
| 740 |
+
(inst_id, schema)
|
| 741 |
+
for inst_id, schema in self.verification_queue
|
| 742 |
+
if (inst_id in self.predictions and
|
| 743 |
+
schema in self.predictions[inst_id] and
|
| 744 |
+
self.predictions[inst_id][schema].verification_status == 'pending')
|
| 745 |
+
]
|
| 746 |
+
random.shuffle(pending)
|
| 747 |
+
return pending[:count]
|
| 748 |
+
|
| 749 |
+
else: # mixed
|
| 750 |
+
pending = [
|
| 751 |
+
(inst_id, schema)
|
| 752 |
+
for inst_id, schema in self.verification_queue
|
| 753 |
+
if (inst_id in self.predictions and
|
| 754 |
+
schema in self.predictions[inst_id] and
|
| 755 |
+
self.predictions[inst_id][schema].verification_status == 'pending')
|
| 756 |
+
]
|
| 757 |
+
# 50% low confidence, 50% random
|
| 758 |
+
pending.sort(
|
| 759 |
+
key=lambda x: self.predictions[x[0]][x[1]].confidence_score
|
| 760 |
+
)
|
| 761 |
+
half = count // 2
|
| 762 |
+
low_conf = pending[:half]
|
| 763 |
+
rest = pending[half:]
|
| 764 |
+
random.shuffle(rest)
|
| 765 |
+
return low_conf + rest[:count - half]
|
| 766 |
+
|
| 767 |
+
def record_verification(
|
| 768 |
+
self,
|
| 769 |
+
instance_id: str,
|
| 770 |
+
schema_name: str,
|
| 771 |
+
human_label: str,
|
| 772 |
+
verified_by: str
|
| 773 |
+
) -> bool:
|
| 774 |
+
"""
|
| 775 |
+
Record human verification of an LLM prediction.
|
| 776 |
+
|
| 777 |
+
Args:
|
| 778 |
+
instance_id: The verified instance
|
| 779 |
+
schema_name: The schema verified
|
| 780 |
+
human_label: The human's label
|
| 781 |
+
verified_by: Username of verifier
|
| 782 |
+
|
| 783 |
+
Returns:
|
| 784 |
+
True if verification recorded successfully
|
| 785 |
+
"""
|
| 786 |
+
with self._lock:
|
| 787 |
+
if instance_id not in self.predictions:
|
| 788 |
+
logger.warning(f"No prediction found for instance {instance_id}")
|
| 789 |
+
return False
|
| 790 |
+
|
| 791 |
+
if schema_name not in self.predictions[instance_id]:
|
| 792 |
+
logger.warning(f"No prediction found for schema {schema_name}")
|
| 793 |
+
return False
|
| 794 |
+
|
| 795 |
+
prediction = self.predictions[instance_id][schema_name]
|
| 796 |
+
prediction.human_label = human_label
|
| 797 |
+
prediction.verified_by = verified_by
|
| 798 |
+
prediction.verified_at = datetime.now()
|
| 799 |
+
|
| 800 |
+
if prediction.predicted_label == human_label:
|
| 801 |
+
prediction.verification_status = 'verified_correct'
|
| 802 |
+
else:
|
| 803 |
+
prediction.verification_status = 'verified_incorrect'
|
| 804 |
+
|
| 805 |
+
# Remove from verification queue
|
| 806 |
+
try:
|
| 807 |
+
self.verification_queue.remove((instance_id, schema_name))
|
| 808 |
+
except ValueError:
|
| 809 |
+
pass
|
| 810 |
+
|
| 811 |
+
logger.info(
|
| 812 |
+
f"Verification recorded for {instance_id}: "
|
| 813 |
+
f"predicted={prediction.predicted_label}, human={human_label}, "
|
| 814 |
+
f"status={prediction.verification_status}"
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
return True
|
| 818 |
+
|
| 819 |
+
# === Accuracy Tracking ===
|
| 820 |
+
|
| 821 |
+
def get_accuracy_metrics(self, schema_name: Optional[str] = None) -> Dict[str, Any]:
|
| 822 |
+
"""
|
| 823 |
+
Calculate accuracy metrics from verified predictions.
|
| 824 |
+
|
| 825 |
+
Args:
|
| 826 |
+
schema_name: Optional schema to filter by
|
| 827 |
+
|
| 828 |
+
Returns:
|
| 829 |
+
Dictionary with accuracy metrics
|
| 830 |
+
"""
|
| 831 |
+
with self._lock:
|
| 832 |
+
verified_correct = 0
|
| 833 |
+
verified_incorrect = 0
|
| 834 |
+
pending = 0
|
| 835 |
+
total_predictions = 0
|
| 836 |
+
|
| 837 |
+
confidence_correct = []
|
| 838 |
+
confidence_incorrect = []
|
| 839 |
+
|
| 840 |
+
for inst_id, schemas in self.predictions.items():
|
| 841 |
+
for s_name, prediction in schemas.items():
|
| 842 |
+
if schema_name and s_name != schema_name:
|
| 843 |
+
continue
|
| 844 |
+
|
| 845 |
+
total_predictions += 1
|
| 846 |
+
|
| 847 |
+
if prediction.verification_status == 'verified_correct':
|
| 848 |
+
verified_correct += 1
|
| 849 |
+
confidence_correct.append(prediction.confidence_score)
|
| 850 |
+
elif prediction.verification_status == 'verified_incorrect':
|
| 851 |
+
verified_incorrect += 1
|
| 852 |
+
confidence_incorrect.append(prediction.confidence_score)
|
| 853 |
+
else:
|
| 854 |
+
pending += 1
|
| 855 |
+
|
| 856 |
+
total_verified = verified_correct + verified_incorrect
|
| 857 |
+
accuracy = verified_correct / total_verified if total_verified > 0 else None
|
| 858 |
+
|
| 859 |
+
avg_confidence_correct = (
|
| 860 |
+
sum(confidence_correct) / len(confidence_correct)
|
| 861 |
+
if confidence_correct else None
|
| 862 |
+
)
|
| 863 |
+
avg_confidence_incorrect = (
|
| 864 |
+
sum(confidence_incorrect) / len(confidence_incorrect)
|
| 865 |
+
if confidence_incorrect else None
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
return {
|
| 869 |
+
'total_predictions': total_predictions,
|
| 870 |
+
'verified_correct': verified_correct,
|
| 871 |
+
'verified_incorrect': verified_incorrect,
|
| 872 |
+
'pending_verification': pending,
|
| 873 |
+
'total_verified': total_verified,
|
| 874 |
+
'accuracy': accuracy,
|
| 875 |
+
'avg_confidence_correct': avg_confidence_correct,
|
| 876 |
+
'avg_confidence_incorrect': avg_confidence_incorrect,
|
| 877 |
+
'schema_name': schema_name
|
| 878 |
+
}
|
| 879 |
+
|
| 880 |
+
def get_status(self) -> Dict[str, Any]:
|
| 881 |
+
"""Get overall ICL labeler status."""
|
| 882 |
+
with self._lock:
|
| 883 |
+
total_examples = sum(len(ex) for ex in self.schema_to_examples.values())
|
| 884 |
+
examples_by_schema = {
|
| 885 |
+
schema: len(examples)
|
| 886 |
+
for schema, examples in self.schema_to_examples.items()
|
| 887 |
+
}
|
| 888 |
+
|
| 889 |
+
# Check labeling status
|
| 890 |
+
should_pause, pause_reason = self.should_pause_labeling()
|
| 891 |
+
remaining_capacity = self.get_remaining_label_capacity()
|
| 892 |
+
|
| 893 |
+
return {
|
| 894 |
+
'enabled': self.config.get('icl_labeling', {}).get('enabled', False),
|
| 895 |
+
'total_examples': total_examples,
|
| 896 |
+
'examples_by_schema': examples_by_schema,
|
| 897 |
+
'total_predictions': sum(
|
| 898 |
+
len(schemas) for schemas in self.predictions.values()
|
| 899 |
+
),
|
| 900 |
+
'labeled_instances': len(self.labeled_instance_ids),
|
| 901 |
+
'verification_queue_size': len(self.verification_queue),
|
| 902 |
+
'last_example_refresh': (
|
| 903 |
+
self.last_example_refresh.isoformat()
|
| 904 |
+
if self.last_example_refresh else None
|
| 905 |
+
),
|
| 906 |
+
'last_batch_run': (
|
| 907 |
+
self.last_batch_run.isoformat()
|
| 908 |
+
if self.last_batch_run else None
|
| 909 |
+
),
|
| 910 |
+
'worker_running': (
|
| 911 |
+
self._worker_thread is not None and
|
| 912 |
+
self._worker_thread.is_alive()
|
| 913 |
+
),
|
| 914 |
+
'accuracy_metrics': self.get_accuracy_metrics(),
|
| 915 |
+
# Labeling limits status
|
| 916 |
+
'labeling_paused': should_pause,
|
| 917 |
+
'pause_reason': pause_reason,
|
| 918 |
+
'remaining_label_capacity': remaining_capacity,
|
| 919 |
+
'max_total_labels': self.max_total_labels,
|
| 920 |
+
'max_unlabeled_ratio': self.max_unlabeled_ratio,
|
| 921 |
+
'min_accuracy_threshold': self.min_accuracy_threshold
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
# === Background Worker ===
|
| 925 |
+
|
| 926 |
+
def start_background_worker(self) -> None:
|
| 927 |
+
"""Start the background worker thread."""
|
| 928 |
+
if self._worker_thread is not None and self._worker_thread.is_alive():
|
| 929 |
+
logger.warning("Background worker already running")
|
| 930 |
+
return
|
| 931 |
+
|
| 932 |
+
self._stop_worker.clear()
|
| 933 |
+
self._worker_thread = threading.Thread(
|
| 934 |
+
target=self._worker_loop,
|
| 935 |
+
name="ICLLabelerWorker",
|
| 936 |
+
daemon=True
|
| 937 |
+
)
|
| 938 |
+
self._worker_thread.start()
|
| 939 |
+
logger.info("Started ICL labeler background worker")
|
| 940 |
+
|
| 941 |
+
def stop_background_worker(self) -> None:
|
| 942 |
+
"""Stop the background worker thread."""
|
| 943 |
+
if self._worker_thread is None:
|
| 944 |
+
return
|
| 945 |
+
|
| 946 |
+
self._stop_worker.set()
|
| 947 |
+
self._worker_thread.join(timeout=5.0)
|
| 948 |
+
self._worker_thread = None
|
| 949 |
+
logger.info("Stopped ICL labeler background worker")
|
| 950 |
+
|
| 951 |
+
def _worker_loop(self) -> None:
|
| 952 |
+
"""Main loop for the background worker."""
|
| 953 |
+
logger.info(
|
| 954 |
+
f"ICL background worker started, "
|
| 955 |
+
f"example_refresh={self.example_refresh_interval}s, "
|
| 956 |
+
f"batch_interval={self.batch_interval}s"
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
last_example_refresh = 0
|
| 960 |
+
last_batch = 0
|
| 961 |
+
|
| 962 |
+
while not self._stop_worker.is_set():
|
| 963 |
+
try:
|
| 964 |
+
current_time = time.time()
|
| 965 |
+
|
| 966 |
+
# Refresh examples periodically
|
| 967 |
+
if current_time - last_example_refresh >= self.example_refresh_interval:
|
| 968 |
+
self.refresh_high_confidence_examples()
|
| 969 |
+
last_example_refresh = current_time
|
| 970 |
+
|
| 971 |
+
# Run batch labeling periodically
|
| 972 |
+
if current_time - last_batch >= self.batch_interval:
|
| 973 |
+
schemas = self._get_annotation_schemes()
|
| 974 |
+
for schema in schemas:
|
| 975 |
+
schema_name = schema.get('name')
|
| 976 |
+
if schema_name and self.has_enough_examples(schema_name):
|
| 977 |
+
predictions = self.batch_label_instances(schema_name)
|
| 978 |
+
if predictions:
|
| 979 |
+
self.save_state()
|
| 980 |
+
last_batch = current_time
|
| 981 |
+
|
| 982 |
+
except Exception as e:
|
| 983 |
+
logger.error(f"ICL background worker error: {e}")
|
| 984 |
+
|
| 985 |
+
# Wait for next interval or stop signal
|
| 986 |
+
self._stop_worker.wait(min(self.example_refresh_interval, self.batch_interval) / 2)
|
| 987 |
+
|
| 988 |
+
# === Persistence ===
|
| 989 |
+
|
| 990 |
+
def save_state(self) -> None:
|
| 991 |
+
"""Save current state to disk."""
|
| 992 |
+
task_dir = self.config.get('output_annotation_dir', '')
|
| 993 |
+
if not task_dir:
|
| 994 |
+
return
|
| 995 |
+
|
| 996 |
+
filepath = os.path.join(task_dir, self.predictions_file)
|
| 997 |
+
|
| 998 |
+
try:
|
| 999 |
+
with self._lock:
|
| 1000 |
+
state = {
|
| 1001 |
+
'predictions': {
|
| 1002 |
+
inst_id: {
|
| 1003 |
+
schema: pred.to_dict()
|
| 1004 |
+
for schema, pred in schemas.items()
|
| 1005 |
+
}
|
| 1006 |
+
for inst_id, schemas in self.predictions.items()
|
| 1007 |
+
},
|
| 1008 |
+
'examples': {
|
| 1009 |
+
schema: [ex.to_dict() for ex in examples]
|
| 1010 |
+
for schema, examples in self.schema_to_examples.items()
|
| 1011 |
+
},
|
| 1012 |
+
'verification_queue': self.verification_queue,
|
| 1013 |
+
'labeled_instance_ids': list(self.labeled_instance_ids),
|
| 1014 |
+
'last_example_refresh': (
|
| 1015 |
+
self.last_example_refresh.isoformat()
|
| 1016 |
+
if self.last_example_refresh else None
|
| 1017 |
+
),
|
| 1018 |
+
'last_batch_run': (
|
| 1019 |
+
self.last_batch_run.isoformat()
|
| 1020 |
+
if self.last_batch_run else None
|
| 1021 |
+
)
|
| 1022 |
+
}
|
| 1023 |
+
|
| 1024 |
+
# Atomic write
|
| 1025 |
+
temp_path = filepath + '.tmp'
|
| 1026 |
+
with open(temp_path, 'w') as f:
|
| 1027 |
+
json.dump(state, f, indent=2)
|
| 1028 |
+
os.replace(temp_path, filepath)
|
| 1029 |
+
|
| 1030 |
+
logger.debug(f"Saved ICL state to {filepath}")
|
| 1031 |
+
|
| 1032 |
+
except Exception as e:
|
| 1033 |
+
logger.error(f"Error saving ICL state: {e}")
|
| 1034 |
+
|
| 1035 |
+
def load_state(self) -> None:
|
| 1036 |
+
"""Load state from disk."""
|
| 1037 |
+
task_dir = self.config.get('output_annotation_dir', '')
|
| 1038 |
+
if not task_dir:
|
| 1039 |
+
return
|
| 1040 |
+
|
| 1041 |
+
filepath = os.path.join(task_dir, self.predictions_file)
|
| 1042 |
+
|
| 1043 |
+
if not os.path.exists(filepath):
|
| 1044 |
+
return
|
| 1045 |
+
|
| 1046 |
+
try:
|
| 1047 |
+
with open(filepath, 'r') as f:
|
| 1048 |
+
state = json.load(f)
|
| 1049 |
+
|
| 1050 |
+
with self._lock:
|
| 1051 |
+
# Load predictions
|
| 1052 |
+
self.predictions = {}
|
| 1053 |
+
for inst_id, schemas in state.get('predictions', {}).items():
|
| 1054 |
+
self.predictions[inst_id] = {
|
| 1055 |
+
schema: ICLPrediction.from_dict(pred_data)
|
| 1056 |
+
for schema, pred_data in schemas.items()
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
# Load examples
|
| 1060 |
+
self.schema_to_examples = {}
|
| 1061 |
+
for schema, examples in state.get('examples', {}).items():
|
| 1062 |
+
self.schema_to_examples[schema] = [
|
| 1063 |
+
HighConfidenceExample.from_dict(ex) for ex in examples
|
| 1064 |
+
]
|
| 1065 |
+
|
| 1066 |
+
# Load other state
|
| 1067 |
+
self.verification_queue = [
|
| 1068 |
+
tuple(item) for item in state.get('verification_queue', [])
|
| 1069 |
+
]
|
| 1070 |
+
self.labeled_instance_ids = set(state.get('labeled_instance_ids', []))
|
| 1071 |
+
|
| 1072 |
+
if state.get('last_example_refresh'):
|
| 1073 |
+
self.last_example_refresh = datetime.fromisoformat(
|
| 1074 |
+
state['last_example_refresh']
|
| 1075 |
+
)
|
| 1076 |
+
if state.get('last_batch_run'):
|
| 1077 |
+
self.last_batch_run = datetime.fromisoformat(
|
| 1078 |
+
state['last_batch_run']
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
logger.info(f"Loaded ICL state from {filepath}")
|
| 1082 |
+
|
| 1083 |
+
except Exception as e:
|
| 1084 |
+
logger.error(f"Error loading ICL state: {e}")
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
# Module-level singleton access
|
| 1088 |
+
_icl_labeler: Optional[ICLLabeler] = None
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
def init_icl_labeler(config: Dict[str, Any]) -> ICLLabeler:
|
| 1092 |
+
"""Initialize the global ICL labeler."""
|
| 1093 |
+
global _icl_labeler
|
| 1094 |
+
_icl_labeler = ICLLabeler(config)
|
| 1095 |
+
_icl_labeler.load_state()
|
| 1096 |
+
return _icl_labeler
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
def get_icl_labeler() -> Optional[ICLLabeler]:
|
| 1100 |
+
"""Get the global ICL labeler instance."""
|
| 1101 |
+
return _icl_labeler
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
def clear_icl_labeler() -> None:
|
| 1105 |
+
"""Clear the global ICL labeler (for testing)."""
|
| 1106 |
+
global _icl_labeler
|
| 1107 |
+
if _icl_labeler:
|
| 1108 |
+
_icl_labeler.stop_background_worker()
|
| 1109 |
+
_icl_labeler = None
|
| 1110 |
+
ICLLabeler._instance = None
|
potato/ai/icl_prompt_builder.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
In-Context Learning Prompt Builder
|
| 3 |
+
|
| 4 |
+
This module builds effective prompts for in-context learning based labeling.
|
| 5 |
+
It formats high-confidence examples and target instances into prompts that
|
| 6 |
+
elicit accurate label predictions with confidence scores.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import re
|
| 12 |
+
from typing import Dict, List, Any, Tuple, Optional, TYPE_CHECKING
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from potato.ai.icl_labeler import HighConfidenceExample
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ICLPromptBuilder:
|
| 21 |
+
"""
|
| 22 |
+
Builds effective prompts for in-context learning.
|
| 23 |
+
|
| 24 |
+
The prompt structure:
|
| 25 |
+
1. System instructions explaining the task
|
| 26 |
+
2. Schema description and available labels
|
| 27 |
+
3. High-confidence examples with their labels
|
| 28 |
+
4. Target text to label
|
| 29 |
+
5. Output format instructions (JSON with label, confidence, reasoning)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, max_example_length: int = 500, max_target_length: int = 1000):
|
| 33 |
+
"""
|
| 34 |
+
Initialize the prompt builder.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
max_example_length: Maximum characters per example text
|
| 38 |
+
max_target_length: Maximum characters for target text
|
| 39 |
+
"""
|
| 40 |
+
self.max_example_length = max_example_length
|
| 41 |
+
self.max_target_length = max_target_length
|
| 42 |
+
|
| 43 |
+
def build_prompt(
|
| 44 |
+
self,
|
| 45 |
+
schema: Dict[str, Any],
|
| 46 |
+
examples: List['HighConfidenceExample'],
|
| 47 |
+
target_text: str
|
| 48 |
+
) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Build a complete ICL prompt.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
schema: Annotation schema dictionary with name, description, labels
|
| 54 |
+
examples: List of high-confidence examples
|
| 55 |
+
target_text: The text to be labeled
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Complete prompt string
|
| 59 |
+
"""
|
| 60 |
+
parts = []
|
| 61 |
+
|
| 62 |
+
# System instructions
|
| 63 |
+
parts.append(self._build_system_prompt(schema))
|
| 64 |
+
|
| 65 |
+
# Examples section
|
| 66 |
+
if examples:
|
| 67 |
+
parts.append("\n## Examples\n")
|
| 68 |
+
parts.append("Here are examples of correctly labeled texts:\n")
|
| 69 |
+
for i, example in enumerate(examples, 1):
|
| 70 |
+
parts.append(self._format_example(example, i))
|
| 71 |
+
|
| 72 |
+
# Target text section
|
| 73 |
+
parts.append("\n## Your Task\n")
|
| 74 |
+
parts.append("Now label the following text:\n")
|
| 75 |
+
parts.append(f'Text: "{self._truncate_text(target_text, self.max_target_length)}"\n')
|
| 76 |
+
|
| 77 |
+
# Output format instructions
|
| 78 |
+
parts.append(self._build_output_instructions(schema))
|
| 79 |
+
|
| 80 |
+
return "\n".join(parts)
|
| 81 |
+
|
| 82 |
+
def _build_system_prompt(self, schema: Dict[str, Any]) -> str:
|
| 83 |
+
"""Build the system/instruction portion of the prompt."""
|
| 84 |
+
schema_name = schema.get('name', 'unknown')
|
| 85 |
+
description = schema.get('description', 'Label the text according to the schema.')
|
| 86 |
+
labels = self._get_labels_from_schema(schema)
|
| 87 |
+
annotation_type = schema.get('annotation_type', 'radio')
|
| 88 |
+
|
| 89 |
+
prompt = f"""You are an expert annotation assistant. Your task is to label text according to a specific annotation schema.
|
| 90 |
+
|
| 91 |
+
## Schema: {schema_name}
|
| 92 |
+
|
| 93 |
+
**Description:** {description}
|
| 94 |
+
|
| 95 |
+
**Available Labels:** {', '.join(labels)}
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
# Add type-specific instructions
|
| 99 |
+
if annotation_type == 'radio':
|
| 100 |
+
prompt += "\n**Task Type:** Single-choice classification. Select exactly ONE label.\n"
|
| 101 |
+
elif annotation_type == 'multiselect':
|
| 102 |
+
prompt += "\n**Task Type:** Multi-label classification. Select ALL applicable labels.\n"
|
| 103 |
+
elif annotation_type == 'likert':
|
| 104 |
+
prompt += "\n**Task Type:** Rating scale. Choose the most appropriate rating.\n"
|
| 105 |
+
|
| 106 |
+
return prompt
|
| 107 |
+
|
| 108 |
+
def _format_example(self, example: 'HighConfidenceExample', index: int) -> str:
|
| 109 |
+
"""Format a single example for the prompt."""
|
| 110 |
+
truncated_text = self._truncate_text(example.text, self.max_example_length)
|
| 111 |
+
|
| 112 |
+
return f"""
|
| 113 |
+
### Example {index}
|
| 114 |
+
Text: "{truncated_text}"
|
| 115 |
+
Label: **{example.label}**
|
| 116 |
+
(Agreement: {example.agreement_score:.0%} from {example.annotator_count} annotators)
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def _build_output_instructions(self, schema: Dict[str, Any]) -> str:
|
| 120 |
+
"""Build instructions for the expected output format."""
|
| 121 |
+
labels = self._get_labels_from_schema(schema)
|
| 122 |
+
labels_json = json.dumps(labels)
|
| 123 |
+
|
| 124 |
+
return f"""
|
| 125 |
+
## Output Format
|
| 126 |
+
|
| 127 |
+
Respond with a JSON object containing:
|
| 128 |
+
- `label`: Your chosen label (must be one of: {labels_json})
|
| 129 |
+
- `confidence`: Your confidence score from 0.0 to 1.0
|
| 130 |
+
- 1.0 = Absolutely certain
|
| 131 |
+
- 0.7-0.9 = High confidence
|
| 132 |
+
- 0.5-0.7 = Moderate confidence
|
| 133 |
+
- 0.3-0.5 = Low confidence
|
| 134 |
+
- 0.0-0.3 = Very uncertain
|
| 135 |
+
- `reasoning`: Brief explanation for your choice (1-2 sentences)
|
| 136 |
+
|
| 137 |
+
**Important:**
|
| 138 |
+
- Only use labels from the provided list
|
| 139 |
+
- Be honest about your confidence - reflect your actual certainty, not a fixed value
|
| 140 |
+
- Base your decision on the examples and schema description
|
| 141 |
+
- Use the full 0.0–1.0 range: reserve 0.9+ for near-certain cases, use 0.4–0.6 when genuinely unsure
|
| 142 |
+
|
| 143 |
+
Example response (the confidence value here is illustrative only — yours should reflect actual certainty):
|
| 144 |
+
```json
|
| 145 |
+
{{"label": "example_label", "confidence": 0.72, "reasoning": "The text shows clear indicators of..."}}
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Now provide your response as JSON:
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def _get_labels_from_schema(self, schema: Dict[str, Any]) -> List[str]:
|
| 152 |
+
"""Extract label names from schema definition."""
|
| 153 |
+
labels = schema.get('labels', [])
|
| 154 |
+
result = []
|
| 155 |
+
for label in labels:
|
| 156 |
+
if isinstance(label, str):
|
| 157 |
+
result.append(label)
|
| 158 |
+
elif isinstance(label, dict):
|
| 159 |
+
result.append(label.get('name', str(label)))
|
| 160 |
+
return result
|
| 161 |
+
|
| 162 |
+
def _truncate_text(self, text: str, max_length: int) -> str:
|
| 163 |
+
"""Truncate text to max length, preserving word boundaries."""
|
| 164 |
+
if len(text) <= max_length:
|
| 165 |
+
return text
|
| 166 |
+
|
| 167 |
+
truncated = text[:max_length]
|
| 168 |
+
# Try to break at word boundary
|
| 169 |
+
last_space = truncated.rfind(' ')
|
| 170 |
+
if last_space > max_length * 0.8:
|
| 171 |
+
truncated = truncated[:last_space]
|
| 172 |
+
|
| 173 |
+
return truncated + "..."
|
| 174 |
+
|
| 175 |
+
def parse_response(
|
| 176 |
+
self,
|
| 177 |
+
response: str,
|
| 178 |
+
schema: Dict[str, Any]
|
| 179 |
+
) -> Tuple[Optional[str], float, str]:
|
| 180 |
+
"""
|
| 181 |
+
Parse the LLM response to extract label, confidence, and reasoning.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
response: Raw response from LLM
|
| 185 |
+
schema: Schema for validation
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Tuple of (label, confidence, reasoning) or (None, 0.0, "") on failure
|
| 189 |
+
"""
|
| 190 |
+
try:
|
| 191 |
+
# Try to parse as JSON directly
|
| 192 |
+
data = self._extract_json(response)
|
| 193 |
+
if data:
|
| 194 |
+
label = data.get('label', '')
|
| 195 |
+
confidence = float(data.get('confidence', 0.5))
|
| 196 |
+
reasoning = data.get('reasoning', '')
|
| 197 |
+
|
| 198 |
+
# Validate label
|
| 199 |
+
valid_labels = self._get_labels_from_schema(schema)
|
| 200 |
+
if label in valid_labels:
|
| 201 |
+
return label, min(1.0, max(0.0, confidence)), reasoning
|
| 202 |
+
|
| 203 |
+
# Try fuzzy matching
|
| 204 |
+
matched = self._fuzzy_match_label(label, valid_labels)
|
| 205 |
+
if matched:
|
| 206 |
+
return matched, min(1.0, max(0.0, confidence)), reasoning
|
| 207 |
+
|
| 208 |
+
# Fallback: try to extract label from text
|
| 209 |
+
return self._extract_label_from_text(response, schema)
|
| 210 |
+
|
| 211 |
+
except Exception as e:
|
| 212 |
+
logger.warning(f"Error parsing response: {e}")
|
| 213 |
+
return None, 0.0, ""
|
| 214 |
+
|
| 215 |
+
def _extract_json(self, text: str) -> Optional[Dict[str, Any]]:
|
| 216 |
+
"""Extract JSON from text, handling markdown code blocks."""
|
| 217 |
+
# Try direct parse
|
| 218 |
+
try:
|
| 219 |
+
return json.loads(text)
|
| 220 |
+
except json.JSONDecodeError:
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
# Try to find JSON in code blocks
|
| 224 |
+
json_patterns = [
|
| 225 |
+
r'```json\s*(.*?)\s*```',
|
| 226 |
+
r'```\s*(.*?)\s*```',
|
| 227 |
+
r'\{[^{}]*"label"[^{}]*\}'
|
| 228 |
+
]
|
| 229 |
+
|
| 230 |
+
for pattern in json_patterns:
|
| 231 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 232 |
+
if match:
|
| 233 |
+
try:
|
| 234 |
+
json_str = match.group(1) if match.lastindex else match.group(0)
|
| 235 |
+
return json.loads(json_str)
|
| 236 |
+
except (json.JSONDecodeError, IndexError):
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
def _fuzzy_match_label(self, label: str, valid_labels: List[str]) -> Optional[str]:
|
| 242 |
+
"""Try to match label with case-insensitive comparison."""
|
| 243 |
+
label_lower = label.lower().strip()
|
| 244 |
+
for valid in valid_labels:
|
| 245 |
+
if valid.lower().strip() == label_lower:
|
| 246 |
+
return valid
|
| 247 |
+
return None
|
| 248 |
+
|
| 249 |
+
def _extract_label_from_text(
|
| 250 |
+
self,
|
| 251 |
+
text: str,
|
| 252 |
+
schema: Dict[str, Any]
|
| 253 |
+
) -> Tuple[Optional[str], float, str]:
|
| 254 |
+
"""Fallback: try to extract label directly from text."""
|
| 255 |
+
valid_labels = self._get_labels_from_schema(schema)
|
| 256 |
+
text_lower = text.lower()
|
| 257 |
+
|
| 258 |
+
for label in valid_labels:
|
| 259 |
+
# Look for label mentioned in text
|
| 260 |
+
if label.lower() in text_lower:
|
| 261 |
+
return label, 0.5, "Extracted from response text (low confidence)"
|
| 262 |
+
|
| 263 |
+
return None, 0.0, ""
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class MultiSelectPromptBuilder(ICLPromptBuilder):
|
| 267 |
+
"""
|
| 268 |
+
Specialized prompt builder for multi-select (multi-label) tasks.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
def _build_output_instructions(self, schema: Dict[str, Any]) -> str:
|
| 272 |
+
"""Build output instructions for multi-select."""
|
| 273 |
+
labels = self._get_labels_from_schema(schema)
|
| 274 |
+
labels_json = json.dumps(labels)
|
| 275 |
+
|
| 276 |
+
return f"""
|
| 277 |
+
## Output Format
|
| 278 |
+
|
| 279 |
+
Respond with a JSON object containing:
|
| 280 |
+
- `labels`: Array of selected labels (from: {labels_json})
|
| 281 |
+
- `confidence`: Your overall confidence score from 0.0 to 1.0 — reflect actual certainty, not a fixed value
|
| 282 |
+
- `reasoning`: Brief explanation for your choices
|
| 283 |
+
|
| 284 |
+
Example response (the confidence value here is illustrative only — yours should reflect actual certainty):
|
| 285 |
+
```json
|
| 286 |
+
{{"labels": ["label1", "label2"], "confidence": 0.65, "reasoning": "The text exhibits both..."}}
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
Now provide your response as JSON:
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
def parse_response(
|
| 293 |
+
self,
|
| 294 |
+
response: str,
|
| 295 |
+
schema: Dict[str, Any]
|
| 296 |
+
) -> Tuple[Optional[List[str]], float, str]:
|
| 297 |
+
"""Parse multi-select response."""
|
| 298 |
+
try:
|
| 299 |
+
data = self._extract_json(response)
|
| 300 |
+
if data:
|
| 301 |
+
labels = data.get('labels', [])
|
| 302 |
+
confidence = float(data.get('confidence', 0.5))
|
| 303 |
+
reasoning = data.get('reasoning', '')
|
| 304 |
+
|
| 305 |
+
valid_labels = self._get_labels_from_schema(schema)
|
| 306 |
+
validated = [l for l in labels if l in valid_labels]
|
| 307 |
+
|
| 308 |
+
if validated:
|
| 309 |
+
return validated, min(1.0, max(0.0, confidence)), reasoning
|
| 310 |
+
|
| 311 |
+
return None, 0.0, ""
|
| 312 |
+
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logger.warning(f"Error parsing multi-select response: {e}")
|
| 315 |
+
return None, 0.0, ""
|
potato/ai/judge.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM-as-Judge service for human-alignment.
|
| 3 |
+
|
| 4 |
+
Produces a judge verdict (label + confidence + reasoning) for an annotation
|
| 5 |
+
instance, given the schema (labels + description + an editable rubric) and,
|
| 6 |
+
optionally, few-shot examples drawn from high-agreement human labels. The
|
| 7 |
+
verdicts are compared against human labels elsewhere
|
| 8 |
+
(``potato/server_utils/judge_alignment.py``) to measure and calibrate
|
| 9 |
+
human↔judge agreement (Cohen's κ).
|
| 10 |
+
|
| 11 |
+
This deliberately does NOT reuse ``ICLLabeler`` as the judge — ICL auto-labels
|
| 12 |
+
from inter-annotator agreement, which would leak the gold labels we are trying
|
| 13 |
+
to measure the judge against. We only borrow ICLLabeler's *example selection*
|
| 14 |
+
for few-shot calibration, and we always exclude the instance being judged from
|
| 15 |
+
its own example set.
|
| 16 |
+
|
| 17 |
+
The judge call goes through the same ``AIEndpointFactory`` / ``BaseAIEndpoint``
|
| 18 |
+
machinery as every other AI feature (mirrors ``icl_labeler.label_instance``),
|
| 19 |
+
so it works with any configured provider.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import hashlib
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from typing import Any, Dict, List, Optional
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class JudgePrediction:
|
| 33 |
+
"""A single LLM-judge verdict for one instance + schema."""
|
| 34 |
+
|
| 35 |
+
instance_id: str
|
| 36 |
+
schema_name: str
|
| 37 |
+
predicted_label: str
|
| 38 |
+
confidence: float # 0.0–1.0
|
| 39 |
+
reasoning: str = ""
|
| 40 |
+
model_name: str = ""
|
| 41 |
+
prompt_version: str = ""
|
| 42 |
+
examples_used: List[str] = field(default_factory=list)
|
| 43 |
+
|
| 44 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 45 |
+
return {
|
| 46 |
+
"instance_id": self.instance_id,
|
| 47 |
+
"schema_name": self.schema_name,
|
| 48 |
+
"predicted_label": self.predicted_label,
|
| 49 |
+
"confidence": self.confidence,
|
| 50 |
+
"reasoning": self.reasoning,
|
| 51 |
+
"model_name": self.model_name,
|
| 52 |
+
"prompt_version": self.prompt_version,
|
| 53 |
+
"examples_used": self.examples_used,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def from_dict(cls, data: Dict[str, Any]) -> "JudgePrediction":
|
| 58 |
+
return cls(
|
| 59 |
+
instance_id=data["instance_id"],
|
| 60 |
+
schema_name=data["schema_name"],
|
| 61 |
+
predicted_label=data.get("predicted_label", ""),
|
| 62 |
+
confidence=float(data.get("confidence", 0.0)),
|
| 63 |
+
reasoning=data.get("reasoning", ""),
|
| 64 |
+
model_name=data.get("model_name", ""),
|
| 65 |
+
prompt_version=data.get("prompt_version", ""),
|
| 66 |
+
examples_used=data.get("examples_used", []),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def extract_labels(schema_info: Dict[str, Any]) -> List[str]:
|
| 71 |
+
"""Return the allowed label names for a categorical schema.
|
| 72 |
+
|
| 73 |
+
Supports ``radio``/``select``/``multiselect`` (``labels`` list of
|
| 74 |
+
str|dict) and ``likert`` (1..size). Returns ``[]`` for unsupported types.
|
| 75 |
+
"""
|
| 76 |
+
atype = schema_info.get("annotation_type", "")
|
| 77 |
+
if atype == "likert":
|
| 78 |
+
size = int(schema_info.get("size", 5))
|
| 79 |
+
return [str(i) for i in range(1, size + 1)]
|
| 80 |
+
labels = schema_info.get("labels", [])
|
| 81 |
+
out = []
|
| 82 |
+
for lab in labels:
|
| 83 |
+
if isinstance(lab, dict):
|
| 84 |
+
out.append(str(lab.get("name", "")))
|
| 85 |
+
else:
|
| 86 |
+
out.append(str(lab))
|
| 87 |
+
return [x for x in out if x]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def compute_prompt_version(rubric: str, schema_name: str, few_shot: bool) -> str:
|
| 91 |
+
"""Stable short hash identifying this judge configuration.
|
| 92 |
+
|
| 93 |
+
Editing the rubric (or toggling few-shot) yields a new version so the admin
|
| 94 |
+
report can track κ across prompt versions.
|
| 95 |
+
"""
|
| 96 |
+
basis = f"{schema_name}␟{int(bool(few_shot))}␟{rubric or ''}"
|
| 97 |
+
return "v_" + hashlib.sha1(basis.encode("utf-8")).hexdigest()[:10]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class JudgeService:
|
| 101 |
+
"""Builds judge prompts and queries the configured AI endpoint."""
|
| 102 |
+
|
| 103 |
+
def __init__(self, config: Dict[str, Any]):
|
| 104 |
+
self.config = config or {}
|
| 105 |
+
self.judge_config = self.config.get("judge_alignment", {}) or {}
|
| 106 |
+
self._endpoint = None
|
| 107 |
+
self._endpoint_initialized = False
|
| 108 |
+
|
| 109 |
+
# ----- endpoint -------------------------------------------------------
|
| 110 |
+
|
| 111 |
+
def _get_endpoint(self):
|
| 112 |
+
if not self._endpoint_initialized:
|
| 113 |
+
self._endpoint_initialized = True
|
| 114 |
+
try:
|
| 115 |
+
from potato.ai.ai_endpoint import AIEndpointFactory
|
| 116 |
+
# The judge endpoint config lives under judge_alignment, but
|
| 117 |
+
# fall back to the task's ai_support so a single endpoint can
|
| 118 |
+
# serve both. Shape mirrors what AIEndpointFactory expects.
|
| 119 |
+
ai_support = self.judge_config.get("ai_support") or self.config.get("ai_support")
|
| 120 |
+
if not ai_support:
|
| 121 |
+
logger.warning("Judge: no ai_support / judge_alignment.ai_support configured")
|
| 122 |
+
return None
|
| 123 |
+
self._endpoint = AIEndpointFactory.create_endpoint({"ai_support": ai_support})
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"Judge: failed to create endpoint: {e}")
|
| 126 |
+
self._endpoint = None
|
| 127 |
+
return self._endpoint
|
| 128 |
+
|
| 129 |
+
# ----- prompt ---------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
def _schema_judge_config(self, schema_name: str) -> Dict[str, Any]:
|
| 132 |
+
per_schema = self.judge_config.get("schemas", {}) or {}
|
| 133 |
+
return per_schema.get(schema_name, {}) or {}
|
| 134 |
+
|
| 135 |
+
def get_rubric(self, schema_info: Dict[str, Any]) -> str:
|
| 136 |
+
"""Editable rubric for a schema; falls back to its description."""
|
| 137 |
+
sc = self._schema_judge_config(schema_info.get("name", ""))
|
| 138 |
+
return sc.get("rubric") or schema_info.get("description", "") or ""
|
| 139 |
+
|
| 140 |
+
def build_prompt(
|
| 141 |
+
self,
|
| 142 |
+
schema_info: Dict[str, Any],
|
| 143 |
+
instance_text: str,
|
| 144 |
+
few_shot_examples: Optional[List[Dict[str, str]]] = None,
|
| 145 |
+
) -> str:
|
| 146 |
+
"""Compose the judge prompt.
|
| 147 |
+
|
| 148 |
+
few_shot_examples: list of {"text": ..., "label": ...} gold exemplars
|
| 149 |
+
(already excluding the target instance).
|
| 150 |
+
"""
|
| 151 |
+
labels = extract_labels(schema_info)
|
| 152 |
+
rubric = self.get_rubric(schema_info)
|
| 153 |
+
parts = [
|
| 154 |
+
"You are an expert evaluator acting as an impartial judge.",
|
| 155 |
+
"Assign exactly one label to the item below, following the rubric.",
|
| 156 |
+
"",
|
| 157 |
+
f"Task: {schema_info.get('description', '')}".rstrip(),
|
| 158 |
+
f"Rubric: {rubric}".rstrip(),
|
| 159 |
+
"Allowed labels: " + ", ".join(labels) if labels else "",
|
| 160 |
+
]
|
| 161 |
+
if few_shot_examples:
|
| 162 |
+
parts.append("\nExamples (item → correct label):")
|
| 163 |
+
for ex in few_shot_examples:
|
| 164 |
+
parts.append(f"- {_truncate(ex.get('text', ''))} → {ex.get('label', '')}")
|
| 165 |
+
parts.append("\nItem to judge:")
|
| 166 |
+
parts.append(_truncate(instance_text, 4000))
|
| 167 |
+
parts.append(
|
| 168 |
+
'\nRespond as JSON: {"label": <one of the allowed labels>, '
|
| 169 |
+
'"confidence": <0.0-1.0>, "reasoning": <one sentence>}.'
|
| 170 |
+
)
|
| 171 |
+
return "\n".join(p for p in parts if p != "")
|
| 172 |
+
|
| 173 |
+
# ----- judging --------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
def judge_instance(
|
| 176 |
+
self,
|
| 177 |
+
instance_id: str,
|
| 178 |
+
schema_info: Dict[str, Any],
|
| 179 |
+
instance_text: str,
|
| 180 |
+
few_shot_examples: Optional[List[Dict[str, str]]] = None,
|
| 181 |
+
prompt_version: Optional[str] = None,
|
| 182 |
+
) -> Optional[JudgePrediction]:
|
| 183 |
+
"""Query the judge for one instance. Returns None on failure."""
|
| 184 |
+
endpoint = self._get_endpoint()
|
| 185 |
+
if endpoint is None:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
schema_name = schema_info.get("name", "")
|
| 189 |
+
valid_labels = extract_labels(schema_info)
|
| 190 |
+
rubric = self.get_rubric(schema_info)
|
| 191 |
+
if prompt_version is None:
|
| 192 |
+
prompt_version = compute_prompt_version(
|
| 193 |
+
rubric, schema_name, bool(few_shot_examples)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
prompt = self.build_prompt(schema_info, instance_text, few_shot_examples)
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
from pydantic import BaseModel
|
| 200 |
+
|
| 201 |
+
class JudgeVerdict(BaseModel):
|
| 202 |
+
label: str
|
| 203 |
+
confidence: float = 0.5
|
| 204 |
+
reasoning: str = ""
|
| 205 |
+
|
| 206 |
+
response = endpoint.query(prompt, JudgeVerdict)
|
| 207 |
+
if isinstance(response, str):
|
| 208 |
+
data = json.loads(response)
|
| 209 |
+
elif hasattr(response, "model_dump"):
|
| 210 |
+
data = response.model_dump()
|
| 211 |
+
elif hasattr(response, "dict"):
|
| 212 |
+
data = response.dict()
|
| 213 |
+
else:
|
| 214 |
+
data = response or {}
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"Judge: query/parse failed for {instance_id}/{schema_name}: {e}")
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
predicted = str(data.get("label", "")).strip()
|
| 220 |
+
try:
|
| 221 |
+
confidence = float(data.get("confidence", 0.5))
|
| 222 |
+
except (TypeError, ValueError):
|
| 223 |
+
confidence = 0.5
|
| 224 |
+
confidence = min(1.0, max(0.0, confidence))
|
| 225 |
+
reasoning = str(data.get("reasoning", ""))
|
| 226 |
+
|
| 227 |
+
if valid_labels and predicted not in valid_labels:
|
| 228 |
+
matched = _fuzzy_match_label(predicted, valid_labels)
|
| 229 |
+
if matched is None:
|
| 230 |
+
logger.warning(
|
| 231 |
+
f"Judge: invalid label '{predicted}' for {instance_id}/{schema_name}"
|
| 232 |
+
)
|
| 233 |
+
return None
|
| 234 |
+
predicted = matched
|
| 235 |
+
|
| 236 |
+
return JudgePrediction(
|
| 237 |
+
instance_id=instance_id,
|
| 238 |
+
schema_name=schema_name,
|
| 239 |
+
predicted_label=predicted,
|
| 240 |
+
confidence=confidence,
|
| 241 |
+
reasoning=reasoning,
|
| 242 |
+
model_name=getattr(endpoint, "model", ""),
|
| 243 |
+
prompt_version=prompt_version,
|
| 244 |
+
examples_used=[e.get("id", "") for e in (few_shot_examples or []) if e.get("id")],
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _truncate(text: str, limit: int = 300) -> str:
|
| 249 |
+
text = str(text or "")
|
| 250 |
+
return text if len(text) <= limit else text[:limit] + "…"
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _fuzzy_match_label(predicted: str, valid_labels: List[str]) -> Optional[str]:
|
| 254 |
+
"""Case-insensitive / prefix match a model label to an allowed label."""
|
| 255 |
+
if not predicted:
|
| 256 |
+
return None
|
| 257 |
+
low = predicted.lower().strip()
|
| 258 |
+
for lab in valid_labels:
|
| 259 |
+
if lab.lower() == low:
|
| 260 |
+
return lab
|
| 261 |
+
for lab in valid_labels:
|
| 262 |
+
ll = lab.lower()
|
| 263 |
+
if ll in low or low in ll:
|
| 264 |
+
return lab
|
| 265 |
+
return None
|
potato/ai/llm_active_learning.py
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Integration for Active Learning
|
| 3 |
+
|
| 4 |
+
This module provides LLM-based active learning capabilities using VLLM endpoints.
|
| 5 |
+
It implements confidence-based instance selection and prediction using large language
|
| 6 |
+
models, with support for multiple confidence elicitation methods:
|
| 7 |
+
|
| 8 |
+
- **logprobs**: Extract token-level log probabilities from VLLM/OpenAI-compatible
|
| 9 |
+
endpoints for calibrated confidence scores.
|
| 10 |
+
- **verbalized**: Ask the LLM to self-report confidence on a 1-10 scale (default).
|
| 11 |
+
- **consistency**: Query the same instance N times with temperature > 0 and use
|
| 12 |
+
agreement rate as confidence (works with any endpoint).
|
| 13 |
+
|
| 14 |
+
References:
|
| 15 |
+
Tian et al. (2023) "Just Ask for Calibration: Strategies for Eliciting
|
| 16 |
+
Calibrated Confidence Scores from Language Models Fine-Tuned with Human
|
| 17 |
+
Feedback." EMNLP 2023.
|
| 18 |
+
|
| 19 |
+
Xiong et al. (2024) "Can LLMs Express Their Uncertainty? An Empirical
|
| 20 |
+
Evaluation of Confidence Elicitation in LLMs." ICLR 2024.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import logging
|
| 24 |
+
import math
|
| 25 |
+
import time
|
| 26 |
+
import json
|
| 27 |
+
import requests
|
| 28 |
+
from collections import Counter
|
| 29 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 30 |
+
from dataclasses import dataclass, field
|
| 31 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
from potato.active_learning_manager import TrainingMetrics
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _loads_lenient(content: str):
|
| 38 |
+
"""json.loads that tolerates markdown code fences and surrounding prose.
|
| 39 |
+
|
| 40 |
+
Many models (e.g. Gemma on vLLM) wrap JSON in ```json ... ``` fences even
|
| 41 |
+
when response_format=json_object is requested, which breaks a naive
|
| 42 |
+
json.loads(). Strip fences and, failing that, extract the first {...} block.
|
| 43 |
+
"""
|
| 44 |
+
import re
|
| 45 |
+
if content is None:
|
| 46 |
+
raise json.JSONDecodeError("empty content", "", 0)
|
| 47 |
+
s = content.strip()
|
| 48 |
+
# Strip a leading ```json / ``` fence and trailing ```
|
| 49 |
+
s = re.sub(r"^```(?:json|JSON)?\s*", "", s)
|
| 50 |
+
s = re.sub(r"\s*```$", "", s).strip()
|
| 51 |
+
try:
|
| 52 |
+
return json.loads(s)
|
| 53 |
+
except json.JSONDecodeError:
|
| 54 |
+
# Fall back to the first balanced-looking {...} object in the text.
|
| 55 |
+
m = re.search(r"\{.*\}", s, re.DOTALL)
|
| 56 |
+
if m:
|
| 57 |
+
return json.loads(m.group(0))
|
| 58 |
+
raise
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class LLMPrediction:
|
| 63 |
+
"""Result of an LLM prediction."""
|
| 64 |
+
instance_id: str
|
| 65 |
+
predicted_label: str
|
| 66 |
+
confidence_score: float
|
| 67 |
+
raw_response: str
|
| 68 |
+
error_message: Optional[str] = None
|
| 69 |
+
confidence_method: str = "verbalized"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class LLMConfig:
|
| 74 |
+
"""Configuration for LLM integration."""
|
| 75 |
+
endpoint_url: str
|
| 76 |
+
model_name: str
|
| 77 |
+
max_tokens: int = 512
|
| 78 |
+
temperature: float = 0.1
|
| 79 |
+
timeout: int = 30
|
| 80 |
+
batch_size: int = 10
|
| 81 |
+
retry_attempts: int = 3
|
| 82 |
+
retry_delay: float = 1.0
|
| 83 |
+
max_instances_per_request: int = 5
|
| 84 |
+
confidence_method: str = "verbalized" # logprobs | verbalized | consistency
|
| 85 |
+
consistency_samples: int = 3
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class LLMActiveLearning:
|
| 89 |
+
"""
|
| 90 |
+
LLM-based active learning implementation.
|
| 91 |
+
|
| 92 |
+
This class provides methods for:
|
| 93 |
+
- Querying LLMs for predictions and confidence scores
|
| 94 |
+
- Batch processing of instances
|
| 95 |
+
- Error handling and retry logic
|
| 96 |
+
- Integration with the active learning pipeline
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, config: LLMConfig):
|
| 100 |
+
self.config = config
|
| 101 |
+
self.logger = logging.getLogger(__name__)
|
| 102 |
+
self.session = requests.Session()
|
| 103 |
+
|
| 104 |
+
# Configure session
|
| 105 |
+
self.session.timeout = config.timeout
|
| 106 |
+
|
| 107 |
+
# Test connection on initialization
|
| 108 |
+
self._test_connection()
|
| 109 |
+
|
| 110 |
+
def _test_connection(self):
|
| 111 |
+
"""Test the connection to the LLM endpoint."""
|
| 112 |
+
try:
|
| 113 |
+
test_payload = {
|
| 114 |
+
"model": self.config.model_name,
|
| 115 |
+
"messages": [{"role": "user", "content": "Hello"}],
|
| 116 |
+
"max_tokens": 10,
|
| 117 |
+
"temperature": 0.1
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
response = self.session.post(
|
| 121 |
+
self.config.endpoint_url,
|
| 122 |
+
json=test_payload,
|
| 123 |
+
timeout=5
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if response.status_code == 200:
|
| 127 |
+
self.logger.info(f"Successfully connected to LLM endpoint: {self.config.endpoint_url}")
|
| 128 |
+
else:
|
| 129 |
+
self.logger.warning(f"LLM endpoint returned status {response.status_code}: {response.text}")
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
self.logger.error(f"Failed to connect to LLM endpoint: {e}")
|
| 133 |
+
# Don't raise - allow fallback to traditional methods
|
| 134 |
+
|
| 135 |
+
def predict_instances(self, instances: List[Dict[str, Any]],
|
| 136 |
+
annotation_instructions: str,
|
| 137 |
+
schema_name: str,
|
| 138 |
+
label_options: List[str]) -> List[LLMPrediction]:
|
| 139 |
+
"""
|
| 140 |
+
Predict labels and confidence scores for instances using LLM.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
instances: List of instances to predict
|
| 144 |
+
annotation_instructions: Instructions for the annotation task
|
| 145 |
+
schema_name: Name of the annotation schema
|
| 146 |
+
label_options: Available label options
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
List of LLM predictions with confidence scores
|
| 150 |
+
"""
|
| 151 |
+
if not instances:
|
| 152 |
+
return []
|
| 153 |
+
|
| 154 |
+
self.logger.info(f"Starting LLM prediction for {len(instances)} instances")
|
| 155 |
+
|
| 156 |
+
# Create prompts for each instance
|
| 157 |
+
prompts = self._create_prompts(instances, annotation_instructions, schema_name, label_options)
|
| 158 |
+
|
| 159 |
+
# Process in batches
|
| 160 |
+
all_predictions = []
|
| 161 |
+
|
| 162 |
+
for i in range(0, len(prompts), self.config.batch_size):
|
| 163 |
+
batch_prompts = prompts[i:i + self.config.batch_size]
|
| 164 |
+
batch_instances = instances[i:i + self.config.batch_size]
|
| 165 |
+
|
| 166 |
+
batch_predictions = self._process_batch(batch_prompts, batch_instances)
|
| 167 |
+
all_predictions.extend(batch_predictions)
|
| 168 |
+
|
| 169 |
+
# Small delay between batches to avoid overwhelming the endpoint
|
| 170 |
+
if i + self.config.batch_size < len(prompts):
|
| 171 |
+
time.sleep(0.1)
|
| 172 |
+
|
| 173 |
+
self.logger.info(f"Completed LLM prediction for {len(all_predictions)} instances")
|
| 174 |
+
return all_predictions
|
| 175 |
+
|
| 176 |
+
def _create_prompts(self, instances: List[Dict[str, Any]],
|
| 177 |
+
annotation_instructions: str,
|
| 178 |
+
schema_name: str,
|
| 179 |
+
label_options: List[str]) -> List[str]:
|
| 180 |
+
"""Create prompts for LLM prediction."""
|
| 181 |
+
prompts = []
|
| 182 |
+
|
| 183 |
+
# Create the base prompt template
|
| 184 |
+
base_prompt = self._create_base_prompt(annotation_instructions, schema_name, label_options)
|
| 185 |
+
|
| 186 |
+
for instance in instances:
|
| 187 |
+
# Extract text content
|
| 188 |
+
text_content = self._extract_text_content(instance)
|
| 189 |
+
|
| 190 |
+
# Create instance-specific prompt
|
| 191 |
+
prompt = f"{base_prompt}\n\nText to annotate:\n{text_content}\n\nPlease provide your prediction and confidence score."
|
| 192 |
+
|
| 193 |
+
prompts.append(prompt)
|
| 194 |
+
|
| 195 |
+
return prompts
|
| 196 |
+
|
| 197 |
+
def _create_base_prompt(self, annotation_instructions: str,
|
| 198 |
+
schema_name: str,
|
| 199 |
+
label_options: List[str]) -> str:
|
| 200 |
+
"""Create the base prompt for LLM prediction."""
|
| 201 |
+
prompt = f"""You are an expert annotator for a text classification task.
|
| 202 |
+
|
| 203 |
+
Task: {annotation_instructions}
|
| 204 |
+
|
| 205 |
+
Schema: {schema_name}
|
| 206 |
+
|
| 207 |
+
Available labels: {', '.join(label_options)}
|
| 208 |
+
|
| 209 |
+
For each text, please:
|
| 210 |
+
1. Analyze the text carefully
|
| 211 |
+
2. Choose the most appropriate label from the available options
|
| 212 |
+
3. Provide a confidence score from 1 to 10 (where 1 = very uncertain, 10 = very confident)
|
| 213 |
+
|
| 214 |
+
Please respond in the following JSON format:
|
| 215 |
+
{{
|
| 216 |
+
"label": "chosen_label",
|
| 217 |
+
"confidence": confidence_score,
|
| 218 |
+
"reasoning": "brief explanation of your choice"
|
| 219 |
+
}}
|
| 220 |
+
|
| 221 |
+
Example response:
|
| 222 |
+
{{
|
| 223 |
+
"label": "{label_options[0] if label_options else 'example'}",
|
| 224 |
+
"confidence": 8,
|
| 225 |
+
"reasoning": "The text clearly expresses positive sentiment based on the language used."
|
| 226 |
+
}}"""
|
| 227 |
+
|
| 228 |
+
return prompt
|
| 229 |
+
|
| 230 |
+
def _extract_text_content(self, instance: Dict[str, Any]) -> str:
|
| 231 |
+
"""Extract text content from an instance."""
|
| 232 |
+
# Try common text field names
|
| 233 |
+
text_fields = ['text', 'content', 'message', 'sentence', 'document']
|
| 234 |
+
|
| 235 |
+
for field in text_fields:
|
| 236 |
+
if field in instance:
|
| 237 |
+
content = instance[field]
|
| 238 |
+
if isinstance(content, str):
|
| 239 |
+
return content
|
| 240 |
+
elif isinstance(content, dict):
|
| 241 |
+
# Handle nested text fields
|
| 242 |
+
for nested_field in text_fields:
|
| 243 |
+
if nested_field in content:
|
| 244 |
+
return str(content[nested_field])
|
| 245 |
+
|
| 246 |
+
# Fallback: convert the entire instance to string
|
| 247 |
+
return str(instance)
|
| 248 |
+
|
| 249 |
+
def _process_batch(self, prompts: List[str], instances: List[Dict[str, Any]]) -> List[LLMPrediction]:
|
| 250 |
+
"""Process a batch of prompts."""
|
| 251 |
+
predictions = []
|
| 252 |
+
|
| 253 |
+
# Use ThreadPoolExecutor for parallel processing within the batch
|
| 254 |
+
with ThreadPoolExecutor(max_workers=min(len(prompts), 5)) as executor:
|
| 255 |
+
future_to_index = {
|
| 256 |
+
executor.submit(self._predict_single, prompt, instances[i]): i
|
| 257 |
+
for i, prompt in enumerate(prompts)
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
for future in as_completed(future_to_index):
|
| 261 |
+
index = future_to_index[future]
|
| 262 |
+
try:
|
| 263 |
+
prediction = future.result()
|
| 264 |
+
predictions.append(prediction)
|
| 265 |
+
except Exception as e:
|
| 266 |
+
self.logger.error(f"Error processing instance {index}: {e}")
|
| 267 |
+
# Create error prediction
|
| 268 |
+
error_prediction = LLMPrediction(
|
| 269 |
+
instance_id=instances[index].get('id', f'instance_{index}'),
|
| 270 |
+
predicted_label='',
|
| 271 |
+
confidence_score=0.1,
|
| 272 |
+
raw_response='',
|
| 273 |
+
error_message=str(e)
|
| 274 |
+
)
|
| 275 |
+
predictions.append(error_prediction)
|
| 276 |
+
|
| 277 |
+
return predictions
|
| 278 |
+
|
| 279 |
+
def _predict_single(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
|
| 280 |
+
"""Make a single prediction using the LLM.
|
| 281 |
+
|
| 282 |
+
Dispatches to the appropriate confidence method:
|
| 283 |
+
- logprobs: Extract token-level log probabilities
|
| 284 |
+
- consistency: Query N times, use agreement rate
|
| 285 |
+
- verbalized (default): Parse self-reported confidence from JSON
|
| 286 |
+
"""
|
| 287 |
+
method = self.config.confidence_method
|
| 288 |
+
|
| 289 |
+
if method == "consistency":
|
| 290 |
+
return self._predict_consistency(prompt, instance)
|
| 291 |
+
elif method == "logprobs":
|
| 292 |
+
return self._predict_with_logprobs(prompt, instance)
|
| 293 |
+
else:
|
| 294 |
+
return self._predict_verbalized(prompt, instance)
|
| 295 |
+
|
| 296 |
+
def _predict_verbalized(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
|
| 297 |
+
"""Original verbalized confidence method (1-10 scale)."""
|
| 298 |
+
instance_id = instance.get('id', 'unknown')
|
| 299 |
+
|
| 300 |
+
for attempt in range(self.config.retry_attempts):
|
| 301 |
+
try:
|
| 302 |
+
payload = {
|
| 303 |
+
"model": self.config.model_name,
|
| 304 |
+
"messages": [
|
| 305 |
+
{"role": "system", "content": "You are a helpful assistant that provides structured JSON responses."},
|
| 306 |
+
{"role": "user", "content": prompt}
|
| 307 |
+
],
|
| 308 |
+
"max_tokens": self.config.max_tokens,
|
| 309 |
+
"temperature": self.config.temperature,
|
| 310 |
+
"response_format": {"type": "json_object"}
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
response = self.session.post(
|
| 314 |
+
self.config.endpoint_url,
|
| 315 |
+
json=payload,
|
| 316 |
+
timeout=self.config.timeout
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
if response.status_code == 200:
|
| 320 |
+
result = response.json()
|
| 321 |
+
|
| 322 |
+
if 'choices' in result and len(result['choices']) > 0:
|
| 323 |
+
content = result['choices'][0]['message']['content']
|
| 324 |
+
|
| 325 |
+
try:
|
| 326 |
+
parsed_response = _loads_lenient(content)
|
| 327 |
+
predicted_label = parsed_response.get('label', '')
|
| 328 |
+
confidence_score = parsed_response.get('confidence', 1)
|
| 329 |
+
|
| 330 |
+
if not isinstance(confidence_score, (int, float)):
|
| 331 |
+
confidence_score = 1
|
| 332 |
+
else:
|
| 333 |
+
confidence_score = max(1, min(10, confidence_score)) / 10.0
|
| 334 |
+
|
| 335 |
+
return LLMPrediction(
|
| 336 |
+
instance_id=instance_id,
|
| 337 |
+
predicted_label=predicted_label,
|
| 338 |
+
confidence_score=confidence_score,
|
| 339 |
+
raw_response=content,
|
| 340 |
+
confidence_method="verbalized"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
except json.JSONDecodeError as e:
|
| 344 |
+
self.logger.warning(f"Failed to parse JSON response for instance {instance_id}: {e}")
|
| 345 |
+
return self._extract_from_raw_response(content, instance_id)
|
| 346 |
+
|
| 347 |
+
else:
|
| 348 |
+
raise Exception(f"Invalid response format: {result}")
|
| 349 |
+
|
| 350 |
+
else:
|
| 351 |
+
raise Exception(f"HTTP {response.status_code}: {response.text}")
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
self.logger.warning(f"Attempt {attempt + 1} failed for instance {instance_id}: {e}")
|
| 355 |
+
|
| 356 |
+
if attempt < self.config.retry_attempts - 1:
|
| 357 |
+
time.sleep(self.config.retry_delay * (attempt + 1))
|
| 358 |
+
else:
|
| 359 |
+
return LLMPrediction(
|
| 360 |
+
instance_id=instance_id,
|
| 361 |
+
predicted_label='',
|
| 362 |
+
confidence_score=0.1,
|
| 363 |
+
raw_response='',
|
| 364 |
+
error_message=f"All attempts failed: {e}",
|
| 365 |
+
confidence_method="verbalized"
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
return LLMPrediction(
|
| 369 |
+
instance_id=instance_id,
|
| 370 |
+
predicted_label='',
|
| 371 |
+
confidence_score=0.1,
|
| 372 |
+
raw_response='',
|
| 373 |
+
error_message="Unknown error",
|
| 374 |
+
confidence_method="verbalized"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
def _predict_with_logprobs(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
|
| 378 |
+
"""Extract confidence from token-level log probabilities.
|
| 379 |
+
|
| 380 |
+
Requests logprobs=True from VLLM/OpenAI-compatible endpoints and
|
| 381 |
+
computes confidence as exp(mean_logprob) over the label tokens.
|
| 382 |
+
Falls back to verbalized confidence if logprobs unavailable.
|
| 383 |
+
"""
|
| 384 |
+
instance_id = instance.get('id', 'unknown')
|
| 385 |
+
|
| 386 |
+
for attempt in range(self.config.retry_attempts):
|
| 387 |
+
try:
|
| 388 |
+
payload = {
|
| 389 |
+
"model": self.config.model_name,
|
| 390 |
+
"messages": [
|
| 391 |
+
{"role": "system", "content": "You are a helpful assistant that provides structured JSON responses."},
|
| 392 |
+
{"role": "user", "content": prompt}
|
| 393 |
+
],
|
| 394 |
+
"max_tokens": self.config.max_tokens,
|
| 395 |
+
"temperature": self.config.temperature,
|
| 396 |
+
"response_format": {"type": "json_object"},
|
| 397 |
+
"logprobs": True,
|
| 398 |
+
"top_logprobs": 5,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
response = self.session.post(
|
| 402 |
+
self.config.endpoint_url,
|
| 403 |
+
json=payload,
|
| 404 |
+
timeout=self.config.timeout
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if response.status_code == 200:
|
| 408 |
+
result = response.json()
|
| 409 |
+
|
| 410 |
+
if 'choices' not in result or len(result['choices']) == 0:
|
| 411 |
+
raise Exception(f"Invalid response format: {result}")
|
| 412 |
+
|
| 413 |
+
choice = result['choices'][0]
|
| 414 |
+
content = choice['message']['content']
|
| 415 |
+
|
| 416 |
+
# Parse label from JSON content
|
| 417 |
+
try:
|
| 418 |
+
parsed = _loads_lenient(content)
|
| 419 |
+
except json.JSONDecodeError:
|
| 420 |
+
return self._extract_from_raw_response(content, instance_id)
|
| 421 |
+
|
| 422 |
+
predicted_label = parsed.get('label', '')
|
| 423 |
+
|
| 424 |
+
# Try to extract logprobs
|
| 425 |
+
logprobs_data = choice.get('logprobs', {})
|
| 426 |
+
token_logprobs = logprobs_data.get('content', [])
|
| 427 |
+
|
| 428 |
+
if token_logprobs:
|
| 429 |
+
# Compute mean logprob across all tokens
|
| 430 |
+
log_probs = [
|
| 431 |
+
t['logprob'] for t in token_logprobs
|
| 432 |
+
if 'logprob' in t and t['logprob'] is not None
|
| 433 |
+
]
|
| 434 |
+
if log_probs:
|
| 435 |
+
mean_logprob = sum(log_probs) / len(log_probs)
|
| 436 |
+
confidence_score = min(1.0, max(0.0, math.exp(mean_logprob)))
|
| 437 |
+
else:
|
| 438 |
+
# No valid logprobs, fall back to verbalized
|
| 439 |
+
confidence_score = parsed.get('confidence', 5)
|
| 440 |
+
if isinstance(confidence_score, (int, float)):
|
| 441 |
+
confidence_score = max(1, min(10, confidence_score)) / 10.0
|
| 442 |
+
else:
|
| 443 |
+
confidence_score = 0.5
|
| 444 |
+
else:
|
| 445 |
+
# Endpoint didn't return logprobs, fall back to verbalized
|
| 446 |
+
self.logger.debug(f"No logprobs returned for {instance_id}, using verbalized")
|
| 447 |
+
confidence_score = parsed.get('confidence', 5)
|
| 448 |
+
if isinstance(confidence_score, (int, float)):
|
| 449 |
+
confidence_score = max(1, min(10, confidence_score)) / 10.0
|
| 450 |
+
else:
|
| 451 |
+
confidence_score = 0.5
|
| 452 |
+
|
| 453 |
+
return LLMPrediction(
|
| 454 |
+
instance_id=instance_id,
|
| 455 |
+
predicted_label=predicted_label,
|
| 456 |
+
confidence_score=confidence_score,
|
| 457 |
+
raw_response=content,
|
| 458 |
+
confidence_method="logprobs" if token_logprobs else "verbalized"
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
else:
|
| 462 |
+
raise Exception(f"HTTP {response.status_code}: {response.text}")
|
| 463 |
+
|
| 464 |
+
except Exception as e:
|
| 465 |
+
self.logger.warning(f"Logprobs attempt {attempt + 1} failed for {instance_id}: {e}")
|
| 466 |
+
if attempt < self.config.retry_attempts - 1:
|
| 467 |
+
time.sleep(self.config.retry_delay * (attempt + 1))
|
| 468 |
+
else:
|
| 469 |
+
return LLMPrediction(
|
| 470 |
+
instance_id=instance_id,
|
| 471 |
+
predicted_label='',
|
| 472 |
+
confidence_score=0.1,
|
| 473 |
+
raw_response='',
|
| 474 |
+
error_message=f"All logprob attempts failed: {e}",
|
| 475 |
+
confidence_method="logprobs"
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
return LLMPrediction(
|
| 479 |
+
instance_id=instance_id,
|
| 480 |
+
predicted_label='',
|
| 481 |
+
confidence_score=0.1,
|
| 482 |
+
raw_response='',
|
| 483 |
+
error_message="Unknown error",
|
| 484 |
+
confidence_method="logprobs"
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
def _predict_consistency(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
|
| 488 |
+
"""Consistency-based confidence: query N times, use agreement rate.
|
| 489 |
+
|
| 490 |
+
Works with any endpoint (including Anthropic, Ollama) that doesn't
|
| 491 |
+
support logprobs. Confidence = fraction of samples that agree on
|
| 492 |
+
the most common label.
|
| 493 |
+
"""
|
| 494 |
+
instance_id = instance.get('id', 'unknown')
|
| 495 |
+
n_samples = self.config.consistency_samples
|
| 496 |
+
|
| 497 |
+
labels = []
|
| 498 |
+
raw_responses = []
|
| 499 |
+
|
| 500 |
+
for _ in range(n_samples):
|
| 501 |
+
try:
|
| 502 |
+
payload = {
|
| 503 |
+
"model": self.config.model_name,
|
| 504 |
+
"messages": [
|
| 505 |
+
{"role": "system", "content": "You are a helpful assistant that provides structured JSON responses."},
|
| 506 |
+
{"role": "user", "content": prompt}
|
| 507 |
+
],
|
| 508 |
+
"max_tokens": self.config.max_tokens,
|
| 509 |
+
"temperature": max(0.5, self.config.temperature), # Need some randomness
|
| 510 |
+
"response_format": {"type": "json_object"}
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
response = self.session.post(
|
| 514 |
+
self.config.endpoint_url,
|
| 515 |
+
json=payload,
|
| 516 |
+
timeout=self.config.timeout
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
if response.status_code == 200:
|
| 520 |
+
result = response.json()
|
| 521 |
+
if 'choices' in result and len(result['choices']) > 0:
|
| 522 |
+
content = result['choices'][0]['message']['content']
|
| 523 |
+
raw_responses.append(content)
|
| 524 |
+
try:
|
| 525 |
+
parsed = _loads_lenient(content)
|
| 526 |
+
labels.append(parsed.get('label', ''))
|
| 527 |
+
except json.JSONDecodeError:
|
| 528 |
+
pass
|
| 529 |
+
|
| 530 |
+
except Exception as e:
|
| 531 |
+
self.logger.debug(f"Consistency sample failed for {instance_id}: {e}")
|
| 532 |
+
|
| 533 |
+
if not labels:
|
| 534 |
+
return LLMPrediction(
|
| 535 |
+
instance_id=instance_id,
|
| 536 |
+
predicted_label='',
|
| 537 |
+
confidence_score=0.1,
|
| 538 |
+
raw_response='',
|
| 539 |
+
error_message="All consistency samples failed",
|
| 540 |
+
confidence_method="consistency"
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Most common label
|
| 544 |
+
label_counts = Counter(labels)
|
| 545 |
+
predicted_label, count = label_counts.most_common(1)[0]
|
| 546 |
+
confidence_score = count / len(labels)
|
| 547 |
+
|
| 548 |
+
return LLMPrediction(
|
| 549 |
+
instance_id=instance_id,
|
| 550 |
+
predicted_label=predicted_label,
|
| 551 |
+
confidence_score=confidence_score,
|
| 552 |
+
raw_response=raw_responses[0] if raw_responses else '',
|
| 553 |
+
confidence_method="consistency"
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
def _extract_from_raw_response(self, raw_response: str, instance_id: str) -> LLMPrediction:
|
| 557 |
+
"""Extract prediction from raw response when JSON parsing fails."""
|
| 558 |
+
try:
|
| 559 |
+
# Try to find label and confidence in the raw text
|
| 560 |
+
lines = raw_response.lower().split('\n')
|
| 561 |
+
|
| 562 |
+
predicted_label = ''
|
| 563 |
+
confidence_score = 0.1
|
| 564 |
+
|
| 565 |
+
for line in lines:
|
| 566 |
+
if 'label' in line and ':' in line:
|
| 567 |
+
label_part = line.split(':', 1)[1].strip().strip('"\'')
|
| 568 |
+
if label_part:
|
| 569 |
+
predicted_label = label_part
|
| 570 |
+
|
| 571 |
+
if 'confidence' in line and ':' in line:
|
| 572 |
+
conf_part = line.split(':', 1)[1].strip()
|
| 573 |
+
try:
|
| 574 |
+
conf_value = float(conf_part)
|
| 575 |
+
confidence_score = max(0.1, min(1.0, conf_value / 10.0))
|
| 576 |
+
except ValueError:
|
| 577 |
+
pass
|
| 578 |
+
|
| 579 |
+
return LLMPrediction(
|
| 580 |
+
instance_id=instance_id,
|
| 581 |
+
predicted_label=predicted_label,
|
| 582 |
+
confidence_score=confidence_score,
|
| 583 |
+
raw_response=raw_response
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
except Exception as e:
|
| 587 |
+
self.logger.error(f"Failed to extract from raw response for instance {instance_id}: {e}")
|
| 588 |
+
return LLMPrediction(
|
| 589 |
+
instance_id=instance_id,
|
| 590 |
+
predicted_label='',
|
| 591 |
+
confidence_score=0.1,
|
| 592 |
+
raw_response=raw_response,
|
| 593 |
+
error_message=f"Failed to extract prediction: {e}"
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
def calculate_confidence_distribution(self, predictions: List[LLMPrediction]) -> Dict[str, float]:
|
| 597 |
+
"""Calculate confidence score distribution from predictions."""
|
| 598 |
+
if not predictions:
|
| 599 |
+
return {}
|
| 600 |
+
|
| 601 |
+
# Filter out predictions with errors
|
| 602 |
+
valid_predictions = [p for p in predictions if p.error_message is None]
|
| 603 |
+
|
| 604 |
+
if not valid_predictions:
|
| 605 |
+
return {}
|
| 606 |
+
|
| 607 |
+
confidence_scores = [p.confidence_score for p in valid_predictions]
|
| 608 |
+
|
| 609 |
+
# Create histogram bins
|
| 610 |
+
bins = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
|
| 611 |
+
hist, _ = np.histogram(confidence_scores, bins=bins)
|
| 612 |
+
|
| 613 |
+
# Convert to percentages
|
| 614 |
+
total = len(confidence_scores)
|
| 615 |
+
distribution = {}
|
| 616 |
+
for i, count in enumerate(hist):
|
| 617 |
+
bin_label = f"{bins[i]:.1f}-{bins[i+1]:.1f}"
|
| 618 |
+
distribution[bin_label] = (count / total) * 100 if total > 0 else 0
|
| 619 |
+
|
| 620 |
+
return distribution
|
| 621 |
+
|
| 622 |
+
def get_prediction_stats(self, predictions: List[LLMPrediction]) -> Dict[str, Any]:
|
| 623 |
+
"""Get statistics about the predictions."""
|
| 624 |
+
if not predictions:
|
| 625 |
+
return {
|
| 626 |
+
"total_predictions": 0,
|
| 627 |
+
"successful_predictions": 0,
|
| 628 |
+
"error_rate": 0.0,
|
| 629 |
+
"average_confidence": 0.0,
|
| 630 |
+
"confidence_distribution": {}
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
total = len(predictions)
|
| 634 |
+
successful = len([p for p in predictions if p.error_message is None])
|
| 635 |
+
error_rate = (total - successful) / total if total > 0 else 0.0
|
| 636 |
+
|
| 637 |
+
valid_predictions = [p for p in predictions if p.error_message is None]
|
| 638 |
+
average_confidence = np.mean([p.confidence_score for p in valid_predictions]) if valid_predictions else 0.0
|
| 639 |
+
|
| 640 |
+
confidence_distribution = self.calculate_confidence_distribution(predictions)
|
| 641 |
+
|
| 642 |
+
return {
|
| 643 |
+
"total_predictions": total,
|
| 644 |
+
"successful_predictions": successful,
|
| 645 |
+
"error_rate": error_rate,
|
| 646 |
+
"average_confidence": average_confidence,
|
| 647 |
+
"confidence_distribution": confidence_distribution
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
class MockLLMActiveLearning(LLMActiveLearning):
|
| 652 |
+
"""
|
| 653 |
+
Mock LLM implementation for testing and development.
|
| 654 |
+
|
| 655 |
+
This class provides realistic mock responses for testing active learning
|
| 656 |
+
without requiring an actual LLM endpoint.
|
| 657 |
+
"""
|
| 658 |
+
|
| 659 |
+
def __init__(self, config: LLMConfig):
|
| 660 |
+
super().__init__(config)
|
| 661 |
+
self.logger.info("Using Mock LLM for active learning")
|
| 662 |
+
|
| 663 |
+
# Mock response patterns
|
| 664 |
+
self._mock_responses = [
|
| 665 |
+
{"label": "positive", "confidence": 8, "reasoning": "Clear positive sentiment"},
|
| 666 |
+
{"label": "negative", "confidence": 7, "reasoning": "Negative tone detected"},
|
| 667 |
+
{"label": "neutral", "confidence": 6, "reasoning": "Balanced perspective"},
|
| 668 |
+
{"label": "positive", "confidence": 9, "reasoning": "Very positive language"},
|
| 669 |
+
{"label": "negative", "confidence": 5, "reasoning": "Somewhat negative"},
|
| 670 |
+
{"label": "neutral", "confidence": 4, "reasoning": "Mixed signals"},
|
| 671 |
+
{"label": "positive", "confidence": 3, "reasoning": "Uncertain positive"},
|
| 672 |
+
{"label": "negative", "confidence": 8, "reasoning": "Clearly negative"},
|
| 673 |
+
{"label": "neutral", "confidence": 7, "reasoning": "Neutral stance"},
|
| 674 |
+
{"label": "positive", "confidence": 6, "reasoning": "Moderately positive"}
|
| 675 |
+
]
|
| 676 |
+
self._response_index = 0
|
| 677 |
+
|
| 678 |
+
def _test_connection(self):
|
| 679 |
+
"""Mock connection test."""
|
| 680 |
+
self.logger.info("Mock LLM connection test successful")
|
| 681 |
+
|
| 682 |
+
def _predict_single(self, prompt: str, instance: Dict[str, Any]) -> LLMPrediction:
|
| 683 |
+
"""Make a mock prediction."""
|
| 684 |
+
instance_id = instance.get('id', 'unknown')
|
| 685 |
+
|
| 686 |
+
# Simulate processing time
|
| 687 |
+
time.sleep(0.1)
|
| 688 |
+
|
| 689 |
+
# Get next mock response
|
| 690 |
+
mock_response = self._mock_responses[self._response_index % len(self._mock_responses)]
|
| 691 |
+
self._response_index += 1
|
| 692 |
+
|
| 693 |
+
# Add some randomness to confidence scores
|
| 694 |
+
confidence_variation = np.random.normal(0, 0.1)
|
| 695 |
+
confidence_score = max(0.1, min(1.0, (mock_response['confidence'] / 10.0) + confidence_variation))
|
| 696 |
+
|
| 697 |
+
return LLMPrediction(
|
| 698 |
+
instance_id=instance_id,
|
| 699 |
+
predicted_label=mock_response['label'],
|
| 700 |
+
confidence_score=confidence_score,
|
| 701 |
+
raw_response=json.dumps(mock_response)
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def create_llm_active_learning(config: Dict[str, Any]) -> LLMActiveLearning:
|
| 706 |
+
"""
|
| 707 |
+
Factory function to create LLM active learning instance.
|
| 708 |
+
|
| 709 |
+
Args:
|
| 710 |
+
config: LLM configuration dictionary
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
LLMActiveLearning: Configured LLM active learning instance
|
| 714 |
+
"""
|
| 715 |
+
llm_config = LLMConfig(
|
| 716 |
+
endpoint_url=config.get('endpoint_url', ''),
|
| 717 |
+
model_name=config.get('model_name', ''),
|
| 718 |
+
max_tokens=config.get('max_tokens', 512),
|
| 719 |
+
temperature=config.get('temperature', 0.1),
|
| 720 |
+
timeout=config.get('timeout', 30),
|
| 721 |
+
batch_size=config.get('batch_size', 10),
|
| 722 |
+
retry_attempts=config.get('retry_attempts', 3),
|
| 723 |
+
retry_delay=config.get('retry_delay', 1.0),
|
| 724 |
+
max_instances_per_request=config.get('max_instances_per_request', 5),
|
| 725 |
+
confidence_method=config.get('confidence_method', 'verbalized'),
|
| 726 |
+
consistency_samples=config.get('consistency_samples', 3),
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
# Use mock implementation for testing or when endpoint is not available
|
| 730 |
+
if config.get('use_mock', False) or not llm_config.endpoint_url:
|
| 731 |
+
return MockLLMActiveLearning(llm_config)
|
| 732 |
+
else:
|
| 733 |
+
return LLMActiveLearning(llm_config)
|
potato/ai/ollama_endpoint.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ollama AI endpoint implementation.
|
| 3 |
+
|
| 4 |
+
This module provides integration with Ollama for local LLM inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from typing import Dict, List, Optional, Type
|
| 9 |
+
import ollama
|
| 10 |
+
from pydantic import BaseModel
|
| 11 |
+
from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError, ModelCapabilities
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
DEFAULT_MODEL = "llama3.2"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OllamaEndpoint(BaseAIEndpoint):
|
| 19 |
+
"""Ollama endpoint for local LLM inference."""
|
| 20 |
+
|
| 21 |
+
# Capabilities declaration for text-based Ollama models
|
| 22 |
+
CAPABILITIES = ModelCapabilities(
|
| 23 |
+
text_generation=True,
|
| 24 |
+
vision_input=False,
|
| 25 |
+
bounding_box_output=False,
|
| 26 |
+
text_classification=True,
|
| 27 |
+
image_classification=False,
|
| 28 |
+
rationale_generation=True,
|
| 29 |
+
keyword_extraction=True,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def _initialize_client(self) -> None:
|
| 33 |
+
"""Initialize the Ollama client."""
|
| 34 |
+
# Default timeout of 60 seconds for local inference (can be slower)
|
| 35 |
+
timeout = self.ai_config.get("timeout", 60)
|
| 36 |
+
host = self.ai_config.get("base_url", "http://localhost:11434")
|
| 37 |
+
|
| 38 |
+
# Create client with timeout
|
| 39 |
+
self.client = ollama.Client(host=host, timeout=timeout)
|
| 40 |
+
|
| 41 |
+
# Check if Ollama is available
|
| 42 |
+
try:
|
| 43 |
+
self.client.list()
|
| 44 |
+
except Exception as e:
|
| 45 |
+
raise AIEndpointRequestError(f"Failed to connect to Ollama: {e}")
|
| 46 |
+
|
| 47 |
+
def _get_default_model(self) -> str:
|
| 48 |
+
"""Get the default Ollama model."""
|
| 49 |
+
return DEFAULT_MODEL
|
| 50 |
+
|
| 51 |
+
def query(self, prompt: str, output_format: Type[BaseModel]) -> str:
|
| 52 |
+
"""
|
| 53 |
+
Send a query to Ollama and return the response.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
prompt: The prompt to send to the model
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
The model's response as a string
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
AIEndpointRequestError: If the request fails
|
| 63 |
+
"""
|
| 64 |
+
import logging
|
| 65 |
+
logger = logging.getLogger(__name__)
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
logger.debug(f"[Ollama] Querying model: {self.model}")
|
| 69 |
+
logger.debug(f"[Ollama] Prompt (first 200 chars): {prompt[:200]}...")
|
| 70 |
+
|
| 71 |
+
options = {
|
| 72 |
+
'temperature': self.temperature,
|
| 73 |
+
'num_predict': self.max_tokens
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# Think mode: configurable via ai_config['think'], defaults to False
|
| 77 |
+
# for fast structured output (thinking wastes tokens on JSON tasks)
|
| 78 |
+
think = self.ai_config.get('think', False)
|
| 79 |
+
|
| 80 |
+
response = self.client.chat(
|
| 81 |
+
model=self.model,
|
| 82 |
+
messages=[{'role': 'user', 'content': prompt}],
|
| 83 |
+
options=options,
|
| 84 |
+
format=output_format.model_json_schema(),
|
| 85 |
+
think=think,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Log full response structure for debugging
|
| 89 |
+
logger.debug(f"[Ollama] Response type: {type(response)}")
|
| 90 |
+
logger.debug(f"[Ollama] Full response: {response}")
|
| 91 |
+
|
| 92 |
+
# Get the message object - handle both dict and object access
|
| 93 |
+
message = response.get('message') if hasattr(response, 'get') else getattr(response, 'message', None)
|
| 94 |
+
if message is None:
|
| 95 |
+
raise AIEndpointRequestError("No message in Ollama response")
|
| 96 |
+
|
| 97 |
+
# Get content - handle both dict and object access
|
| 98 |
+
content = message.get('content') if hasattr(message, 'get') else getattr(message, 'content', None)
|
| 99 |
+
|
| 100 |
+
# Some models put response in 'thinking' field - check that too
|
| 101 |
+
if not content and hasattr(message, 'thinking') and message.thinking:
|
| 102 |
+
logger.warning("[Ollama] Content empty but thinking field has data - model may need think=False")
|
| 103 |
+
# Try to extract JSON from thinking field as fallback
|
| 104 |
+
thinking_text = message.thinking
|
| 105 |
+
# Look for JSON in the thinking text
|
| 106 |
+
import re
|
| 107 |
+
json_match = re.search(r'\{[^{}]*\}', thinking_text)
|
| 108 |
+
if json_match:
|
| 109 |
+
content = json_match.group(0)
|
| 110 |
+
logger.debug(f"[Ollama] Extracted JSON from thinking: {content}")
|
| 111 |
+
|
| 112 |
+
logger.debug(f"[Ollama] Content type: {type(content)}")
|
| 113 |
+
logger.debug(f"[Ollama] Content value: {repr(content)[:200] if content else 'EMPTY'}")
|
| 114 |
+
|
| 115 |
+
# If content is already a dict (structured output), return it directly
|
| 116 |
+
if isinstance(content, dict):
|
| 117 |
+
logger.debug("[Ollama] Content is already a dict, returning directly")
|
| 118 |
+
return content
|
| 119 |
+
|
| 120 |
+
# Parse response using the base class's robust parser
|
| 121 |
+
# (handles truncated JSON, markdown blocks, plain text, etc.)
|
| 122 |
+
if content:
|
| 123 |
+
logger.debug(f"[Ollama] Response content (first 500 chars): {str(content)[:500]}")
|
| 124 |
+
return self.parseStringToJson(content)
|
| 125 |
+
else:
|
| 126 |
+
raise AIEndpointRequestError("Empty content from Ollama - try a different model or disable thinking mode")
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"[Ollama] Request failed: {e}")
|
| 129 |
+
raise AIEndpointRequestError(f"Ollama request failed: {e}")
|
| 130 |
+
|
| 131 |
+
def chat_query(self, messages: List[Dict[str, str]]) -> str:
|
| 132 |
+
"""Send a multi-turn chat to Ollama using native chat API."""
|
| 133 |
+
import logging
|
| 134 |
+
logger = logging.getLogger(__name__)
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
options = {
|
| 138 |
+
'temperature': self.temperature,
|
| 139 |
+
'num_predict': self.max_tokens,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
think = self.ai_config.get('think', False)
|
| 143 |
+
response = self.client.chat(
|
| 144 |
+
model=self.model,
|
| 145 |
+
messages=messages,
|
| 146 |
+
options=options,
|
| 147 |
+
think=think,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
message = response.get('message') if hasattr(response, 'get') else getattr(response, 'message', None)
|
| 151 |
+
if message is None:
|
| 152 |
+
raise AIEndpointRequestError("No message in Ollama chat response")
|
| 153 |
+
|
| 154 |
+
content = message.get('content') if hasattr(message, 'get') else getattr(message, 'content', None)
|
| 155 |
+
return content or ""
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"[Ollama] Chat request failed: {e}")
|
| 158 |
+
raise AIEndpointRequestError(f"Ollama chat request failed: {e}")
|
| 159 |
+
|
| 160 |
+
|
potato/ai/ollama_vision_endpoint.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ollama Vision AI Endpoint
|
| 3 |
+
|
| 4 |
+
This module provides integration with Ollama vision models for local
|
| 5 |
+
visual AI inference. Supports LLaVA, Llama 3.2 Vision, BakLLaVA, and Qwen-VL models.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import base64
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Any, Dict, List, Type, Union
|
| 12 |
+
|
| 13 |
+
from pydantic import BaseModel
|
| 14 |
+
|
| 15 |
+
from .ai_endpoint import AIEndpointRequestError, ImageData, ModelCapabilities
|
| 16 |
+
from .visual_ai_endpoint import BaseVisualAIEndpoint
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Default vision model
|
| 21 |
+
DEFAULT_MODEL = "llava:latest"
|
| 22 |
+
|
| 23 |
+
# Models known to support vision (used only for warning suppression, not to restrict usage)
|
| 24 |
+
# Any Ollama model can be used - this list just prevents "may not support vision" warnings
|
| 25 |
+
# for models we know are vision-capable
|
| 26 |
+
VISION_MODELS = [
|
| 27 |
+
# LLaVA family
|
| 28 |
+
"llava",
|
| 29 |
+
"llava-llama3",
|
| 30 |
+
"llava-phi3",
|
| 31 |
+
"bakllava",
|
| 32 |
+
# Llama Vision
|
| 33 |
+
"llama3.2-vision",
|
| 34 |
+
# Qwen Vision-Language
|
| 35 |
+
"qwen2.5-vl",
|
| 36 |
+
"qwen2-vl",
|
| 37 |
+
"qwen3-vl",
|
| 38 |
+
# Other vision models
|
| 39 |
+
"moondream",
|
| 40 |
+
"minicpm-v",
|
| 41 |
+
"gemma3", # Gemma 3 has vision capabilities
|
| 42 |
+
"gemma4", # Gemma 4 family (e.g. gemma4:e4b) supports vision
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class OllamaVisionEndpoint(BaseVisualAIEndpoint):
|
| 47 |
+
"""
|
| 48 |
+
Ollama Vision endpoint for multimodal local inference.
|
| 49 |
+
|
| 50 |
+
Supports vision-capable models like LLaVA, Llama 3.2 Vision, BakLLaVA.
|
| 51 |
+
Images are sent as base64 in the 'images' field.
|
| 52 |
+
|
| 53 |
+
Configuration options:
|
| 54 |
+
- model: Vision model to use (default: llava:latest)
|
| 55 |
+
- base_url: Ollama server URL (default: http://localhost:11434)
|
| 56 |
+
- timeout: Request timeout in seconds (default: 120)
|
| 57 |
+
- max_tokens: Maximum response tokens (default: 500)
|
| 58 |
+
- temperature: Sampling temperature (default: 0.1)
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
# Capabilities declaration for vision-capable Ollama models (LLaVA, Qwen-VL, etc.)
|
| 62 |
+
# Note: VLLMs can generate text about images but cannot do precise bounding box detection
|
| 63 |
+
# Keyword extraction is disabled because it doesn't apply to image content
|
| 64 |
+
CAPABILITIES = ModelCapabilities(
|
| 65 |
+
text_generation=True,
|
| 66 |
+
vision_input=True,
|
| 67 |
+
bounding_box_output=False, # VLLMs are not reliable for precise bbox coordinates
|
| 68 |
+
text_classification=True,
|
| 69 |
+
image_classification=True,
|
| 70 |
+
rationale_generation=True,
|
| 71 |
+
keyword_extraction=False, # Keywords don't apply to images
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def _initialize_client(self) -> None:
|
| 75 |
+
"""Initialize the Ollama client."""
|
| 76 |
+
try:
|
| 77 |
+
import ollama
|
| 78 |
+
except ImportError:
|
| 79 |
+
raise AIEndpointRequestError(
|
| 80 |
+
"ollama package is required. Install it with: pip install ollama"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
timeout = self.ai_config.get("timeout", 120) # Vision models can be slower
|
| 84 |
+
host = self.ai_config.get("base_url", "http://localhost:11434")
|
| 85 |
+
|
| 86 |
+
self.client = ollama.Client(host=host, timeout=timeout)
|
| 87 |
+
|
| 88 |
+
# Verify connection and model availability
|
| 89 |
+
try:
|
| 90 |
+
models = self.client.list()
|
| 91 |
+
logger.info(f"Connected to Ollama at {host}")
|
| 92 |
+
|
| 93 |
+
# Check if the specified model is a known vision model
|
| 94 |
+
# This is just informational - any model can be used
|
| 95 |
+
model_lower = self.model.lower()
|
| 96 |
+
is_known_vision_model = any(vm in model_lower for vm in VISION_MODELS)
|
| 97 |
+
if not is_known_vision_model:
|
| 98 |
+
logger.info(
|
| 99 |
+
f"Model '{self.model}' not in known vision models list. "
|
| 100 |
+
f"This is fine if it supports vision - proceeding anyway."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
raise AIEndpointRequestError(f"Failed to connect to Ollama: {e}")
|
| 105 |
+
|
| 106 |
+
def _get_default_model(self) -> str:
|
| 107 |
+
"""Get the default vision model."""
|
| 108 |
+
return DEFAULT_MODEL
|
| 109 |
+
|
| 110 |
+
def query(self, prompt: str, output_format: Type[BaseModel]) -> Any:
|
| 111 |
+
"""
|
| 112 |
+
Standard text query (falls back to text-only mode).
|
| 113 |
+
|
| 114 |
+
For vision tasks, use query_with_image() instead.
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
options = {
|
| 118 |
+
'temperature': self.temperature,
|
| 119 |
+
'num_predict': self.max_tokens
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
response = self.client.chat(
|
| 123 |
+
model=self.model,
|
| 124 |
+
messages=[{'role': 'user', 'content': prompt}],
|
| 125 |
+
options=options,
|
| 126 |
+
format=output_format.model_json_schema(),
|
| 127 |
+
think=False,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
message = response.get('message') if hasattr(response, 'get') else getattr(response, 'message', None)
|
| 131 |
+
if message is None:
|
| 132 |
+
raise AIEndpointRequestError("No message in Ollama response")
|
| 133 |
+
|
| 134 |
+
content = message.get('content') if hasattr(message, 'get') else getattr(message, 'content', None)
|
| 135 |
+
|
| 136 |
+
if isinstance(content, dict):
|
| 137 |
+
return content
|
| 138 |
+
|
| 139 |
+
if content:
|
| 140 |
+
return self.parseStringToJson(content)
|
| 141 |
+
else:
|
| 142 |
+
raise AIEndpointRequestError("Empty content from Ollama")
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
raise AIEndpointRequestError(f"Ollama query failed: {e}")
|
| 146 |
+
|
| 147 |
+
def query_with_image(
|
| 148 |
+
self,
|
| 149 |
+
prompt: str,
|
| 150 |
+
image_data: Union[ImageData, List[ImageData]],
|
| 151 |
+
output_format: Type[BaseModel]
|
| 152 |
+
) -> Any:
|
| 153 |
+
"""
|
| 154 |
+
Send a query with image(s) to Ollama vision model.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
prompt: Text prompt describing what to analyze
|
| 158 |
+
image_data: Single ImageData or list of ImageData
|
| 159 |
+
output_format: Pydantic model for structured output
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Parsed response according to output_format
|
| 163 |
+
|
| 164 |
+
Raises:
|
| 165 |
+
AIEndpointRequestError: If the request fails
|
| 166 |
+
"""
|
| 167 |
+
try:
|
| 168 |
+
# Prepare images
|
| 169 |
+
images = [image_data] if isinstance(image_data, ImageData) else image_data
|
| 170 |
+
|
| 171 |
+
# Convert to base64 if needed
|
| 172 |
+
image_base64_list = []
|
| 173 |
+
for img in images:
|
| 174 |
+
b64_data = self._get_base64_image(img)
|
| 175 |
+
image_base64_list.append(b64_data)
|
| 176 |
+
|
| 177 |
+
# Build message with images
|
| 178 |
+
options = {
|
| 179 |
+
'temperature': self.temperature,
|
| 180 |
+
'num_predict': self.max_tokens
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
# Ollama expects images as a list of base64 strings
|
| 184 |
+
response = self.client.chat(
|
| 185 |
+
model=self.model,
|
| 186 |
+
messages=[{
|
| 187 |
+
'role': 'user',
|
| 188 |
+
'content': prompt,
|
| 189 |
+
'images': image_base64_list
|
| 190 |
+
}],
|
| 191 |
+
options=options,
|
| 192 |
+
format=output_format.model_json_schema(),
|
| 193 |
+
think=False,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
logger.debug(f"Ollama vision response type: {type(response)}")
|
| 197 |
+
|
| 198 |
+
# Extract content from response
|
| 199 |
+
message = response.get('message') if hasattr(response, 'get') else getattr(response, 'message', None)
|
| 200 |
+
if message is None:
|
| 201 |
+
raise AIEndpointRequestError("No message in Ollama vision response")
|
| 202 |
+
|
| 203 |
+
content = message.get('content') if hasattr(message, 'get') else getattr(message, 'content', None)
|
| 204 |
+
|
| 205 |
+
logger.debug(f"Ollama vision content type: {type(content)}")
|
| 206 |
+
|
| 207 |
+
# Parse response
|
| 208 |
+
if isinstance(content, dict):
|
| 209 |
+
return content
|
| 210 |
+
|
| 211 |
+
if content:
|
| 212 |
+
return self.parseStringToJson(content)
|
| 213 |
+
else:
|
| 214 |
+
raise AIEndpointRequestError("Empty content from Ollama vision model")
|
| 215 |
+
|
| 216 |
+
except AIEndpointRequestError:
|
| 217 |
+
raise
|
| 218 |
+
except Exception as e:
|
| 219 |
+
logger.error(f"Ollama vision query failed: {e}")
|
| 220 |
+
import traceback
|
| 221 |
+
logger.error(traceback.format_exc())
|
| 222 |
+
raise AIEndpointRequestError(f"Ollama vision query failed: {e}")
|
| 223 |
+
|
| 224 |
+
def _get_base64_image(self, image_data: ImageData) -> str:
|
| 225 |
+
"""
|
| 226 |
+
Get base64-encoded image data.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
image_data: ImageData object
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Base64-encoded image string (without data URL prefix)
|
| 233 |
+
"""
|
| 234 |
+
if image_data.source == "base64":
|
| 235 |
+
# Already base64, just return the data
|
| 236 |
+
return image_data.data
|
| 237 |
+
|
| 238 |
+
elif image_data.source == "url":
|
| 239 |
+
# Download and convert to base64
|
| 240 |
+
downloaded = self.download_image_to_base64(image_data.data)
|
| 241 |
+
return downloaded.data
|
| 242 |
+
|
| 243 |
+
else:
|
| 244 |
+
raise AIEndpointRequestError(f"Unknown image source: {image_data.source}")
|
| 245 |
+
|
| 246 |
+
def analyze_image(
|
| 247 |
+
self,
|
| 248 |
+
image_path_or_url: str,
|
| 249 |
+
prompt: str,
|
| 250 |
+
output_format: Type[BaseModel] = None
|
| 251 |
+
) -> Any:
|
| 252 |
+
"""
|
| 253 |
+
Convenience method for analyzing a single image.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
image_path_or_url: Path to image file or URL
|
| 257 |
+
prompt: Analysis prompt
|
| 258 |
+
output_format: Optional output format model
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
Analysis result
|
| 262 |
+
"""
|
| 263 |
+
# Prepare image data
|
| 264 |
+
if image_path_or_url.startswith(("http://", "https://")):
|
| 265 |
+
image_data = self.download_image_to_base64(image_path_or_url)
|
| 266 |
+
else:
|
| 267 |
+
image_data = self.encode_image_to_base64(image_path_or_url)
|
| 268 |
+
|
| 269 |
+
# Use a generic format if not specified
|
| 270 |
+
if output_format is None:
|
| 271 |
+
from .prompt.models_module import GeneralHintFormat
|
| 272 |
+
output_format = GeneralHintFormat
|
| 273 |
+
|
| 274 |
+
return self.query_with_image(prompt, image_data, output_format)
|
| 275 |
+
|
| 276 |
+
def describe_image(self, image_path_or_url: str) -> str:
|
| 277 |
+
"""
|
| 278 |
+
Get a natural language description of an image.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
image_path_or_url: Path to image file or URL
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Text description of the image
|
| 285 |
+
"""
|
| 286 |
+
# Use a simple model that returns text
|
| 287 |
+
class DescriptionFormat(BaseModel):
|
| 288 |
+
description: str
|
| 289 |
+
|
| 290 |
+
result = self.analyze_image(
|
| 291 |
+
image_path_or_url,
|
| 292 |
+
"Describe this image in detail. What objects, people, or scenes do you see?",
|
| 293 |
+
DescriptionFormat
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if isinstance(result, dict) and "description" in result:
|
| 297 |
+
return result["description"]
|
| 298 |
+
return str(result)
|
| 299 |
+
|
| 300 |
+
def health_check(self) -> bool:
|
| 301 |
+
"""
|
| 302 |
+
Check if the Ollama vision model is available.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
True if model is ready, False otherwise
|
| 306 |
+
"""
|
| 307 |
+
try:
|
| 308 |
+
# Try to list models
|
| 309 |
+
self.client.list()
|
| 310 |
+
return True
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.error(f"Ollama vision health check failed: {e}")
|
| 313 |
+
return False
|
potato/ai/openai_endpoint.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI AI endpoint implementation.
|
| 3 |
+
|
| 4 |
+
This module provides integration with OpenAI's API for LLM inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from typing import Dict, List
|
| 9 |
+
from openai import OpenAI
|
| 10 |
+
from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError, ModelCapabilities
|
| 11 |
+
|
| 12 |
+
DEFAULT_MODEL = "gpt-4o-mini"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class OpenAIEndpoint(BaseAIEndpoint):
|
| 16 |
+
"""OpenAI endpoint for cloud-based LLM inference."""
|
| 17 |
+
|
| 18 |
+
# Capabilities declaration for text-based OpenAI models
|
| 19 |
+
CAPABILITIES = ModelCapabilities(
|
| 20 |
+
text_generation=True,
|
| 21 |
+
vision_input=False,
|
| 22 |
+
bounding_box_output=False,
|
| 23 |
+
text_classification=True,
|
| 24 |
+
image_classification=False,
|
| 25 |
+
rationale_generation=True,
|
| 26 |
+
keyword_extraction=True,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def _initialize_client(self) -> None:
|
| 30 |
+
"""Initialize the OpenAI client."""
|
| 31 |
+
# OpenAI-compatible servers (vLLM, llama.cpp, etc.) ignore the key
|
| 32 |
+
# but the SDK rejects an empty string, so accept a placeholder.
|
| 33 |
+
api_key = self.ai_config.get("api_key") or os.environ.get(
|
| 34 |
+
"OPENAI_API_KEY", ""
|
| 35 |
+
)
|
| 36 |
+
base_url = self.ai_config.get("base_url")
|
| 37 |
+
if not api_key:
|
| 38 |
+
if base_url:
|
| 39 |
+
api_key = "EMPTY" # non-empty placeholder for local servers
|
| 40 |
+
else:
|
| 41 |
+
raise AIEndpointRequestError("OpenAI API key is required")
|
| 42 |
+
|
| 43 |
+
# Default timeout of 30 seconds, configurable via ai_config
|
| 44 |
+
timeout = self.ai_config.get("timeout", 30)
|
| 45 |
+
client_kwargs = {"api_key": api_key, "timeout": timeout}
|
| 46 |
+
# Honor a custom base_url so this endpoint can target any
|
| 47 |
+
# OpenAI-compatible server (previously ignored -> always hit
|
| 48 |
+
# api.openai.com even when a local base_url was configured).
|
| 49 |
+
if base_url:
|
| 50 |
+
client_kwargs["base_url"] = base_url
|
| 51 |
+
self.client = OpenAI(**client_kwargs)
|
| 52 |
+
|
| 53 |
+
def _get_default_model(self) -> str:
|
| 54 |
+
"""Get the default OpenAI model."""
|
| 55 |
+
return DEFAULT_MODEL
|
| 56 |
+
|
| 57 |
+
def query(self, prompt: str, output_format: dict) -> str:
|
| 58 |
+
"""
|
| 59 |
+
Send a query to OpenAI and return the response.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
prompt: The prompt to send to the model
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
The model's response as a string
|
| 66 |
+
|
| 67 |
+
Raises:
|
| 68 |
+
AIEndpointRequestError: If the request fails
|
| 69 |
+
"""
|
| 70 |
+
try:
|
| 71 |
+
response = self.client.chat.completions.create(
|
| 72 |
+
model=self.model,
|
| 73 |
+
messages=[{"role": "user", "content": prompt}],
|
| 74 |
+
max_tokens=self.max_tokens,
|
| 75 |
+
temperature=self.temperature,
|
| 76 |
+
text_format=output_format.model_json_schema(),
|
| 77 |
+
)
|
| 78 |
+
return response.choices[0].message.content
|
| 79 |
+
except Exception as e:
|
| 80 |
+
raise AIEndpointRequestError(f"OpenAI request failed: {e}")
|
| 81 |
+
|
| 82 |
+
def chat_query(self, messages: List[Dict[str, str]]) -> str:
|
| 83 |
+
"""Send a multi-turn chat to OpenAI using native messages API."""
|
| 84 |
+
try:
|
| 85 |
+
response = self.client.chat.completions.create(
|
| 86 |
+
model=self.model,
|
| 87 |
+
messages=messages,
|
| 88 |
+
max_tokens=self.max_tokens,
|
| 89 |
+
temperature=self.temperature,
|
| 90 |
+
)
|
| 91 |
+
return response.choices[0].message.content
|
| 92 |
+
except Exception as e:
|
| 93 |
+
raise AIEndpointRequestError(f"OpenAI chat request failed: {e}")
|
| 94 |
+
|
potato/ai/openai_vision_endpoint.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI Vision AI Endpoint
|
| 3 |
+
|
| 4 |
+
This module provides integration with OpenAI's vision models (GPT-4o, GPT-4o-mini)
|
| 5 |
+
for visual analysis and annotation assistance.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import base64
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Any, Dict, List, Type, Union
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
|
| 14 |
+
from .ai_endpoint import AIEndpointRequestError, ImageData, ModelCapabilities
|
| 15 |
+
from .visual_ai_endpoint import BaseVisualAIEndpoint
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
DEFAULT_MODEL = "gpt-4o"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class OpenAIVisionEndpoint(BaseVisualAIEndpoint):
|
| 23 |
+
"""
|
| 24 |
+
OpenAI Vision endpoint for GPT-4o and GPT-4o-mini vision capabilities.
|
| 25 |
+
|
| 26 |
+
Supports both URL and base64 image inputs using the image_url content type.
|
| 27 |
+
|
| 28 |
+
Configuration options:
|
| 29 |
+
- model: Model to use (gpt-4o, gpt-4o-mini) (default: gpt-4o)
|
| 30 |
+
- api_key: OpenAI API key (can also use OPENAI_API_KEY env var)
|
| 31 |
+
- max_tokens: Maximum response tokens (default: 1000)
|
| 32 |
+
- temperature: Sampling temperature (default: 0.1)
|
| 33 |
+
- detail: Image detail level - 'low', 'high', or 'auto' (default: auto)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# Capabilities declaration for OpenAI vision models (GPT-4o, GPT-4o-mini)
|
| 37 |
+
# These models can understand images and generate text but bounding boxes are approximate
|
| 38 |
+
CAPABILITIES = ModelCapabilities(
|
| 39 |
+
text_generation=True,
|
| 40 |
+
vision_input=True,
|
| 41 |
+
bounding_box_output=False, # GPT-4V bboxes are approximate, not precise
|
| 42 |
+
text_classification=True,
|
| 43 |
+
image_classification=True,
|
| 44 |
+
rationale_generation=True,
|
| 45 |
+
keyword_extraction=False, # Keywords don't apply to images
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def _initialize_client(self) -> None:
|
| 49 |
+
"""Initialize the OpenAI client."""
|
| 50 |
+
try:
|
| 51 |
+
import openai
|
| 52 |
+
except ImportError:
|
| 53 |
+
raise AIEndpointRequestError(
|
| 54 |
+
"openai package is required. Install it with: pip install openai"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
import os
|
| 58 |
+
|
| 59 |
+
api_key = self.ai_config.get("api_key") or os.environ.get("OPENAI_API_KEY")
|
| 60 |
+
if not api_key:
|
| 61 |
+
raise AIEndpointRequestError(
|
| 62 |
+
"OpenAI API key is required. Set it in config or OPENAI_API_KEY env var."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
timeout = self.ai_config.get("timeout", 60)
|
| 66 |
+
self.detail = self.ai_config.get("detail", "auto")
|
| 67 |
+
|
| 68 |
+
self.client = openai.OpenAI(api_key=api_key, timeout=timeout)
|
| 69 |
+
logger.info(f"OpenAI Vision client initialized with model: {self.model}")
|
| 70 |
+
|
| 71 |
+
def _get_default_model(self) -> str:
|
| 72 |
+
"""Get the default OpenAI vision model."""
|
| 73 |
+
return DEFAULT_MODEL
|
| 74 |
+
|
| 75 |
+
def query(self, prompt: str, output_format: Type[BaseModel]) -> Any:
|
| 76 |
+
"""
|
| 77 |
+
Standard text query without images.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
prompt: Text prompt
|
| 81 |
+
output_format: Pydantic model for structured output
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Parsed response
|
| 85 |
+
"""
|
| 86 |
+
try:
|
| 87 |
+
response = self.client.chat.completions.create(
|
| 88 |
+
model=self.model,
|
| 89 |
+
messages=[{"role": "user", "content": prompt}],
|
| 90 |
+
max_tokens=self.max_tokens,
|
| 91 |
+
temperature=self.temperature,
|
| 92 |
+
response_format={"type": "json_object"},
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
content = response.choices[0].message.content
|
| 96 |
+
return self.parseStringToJson(content)
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
raise AIEndpointRequestError(f"OpenAI query failed: {e}")
|
| 100 |
+
|
| 101 |
+
def query_with_image(
|
| 102 |
+
self,
|
| 103 |
+
prompt: str,
|
| 104 |
+
image_data: Union[ImageData, List[ImageData]],
|
| 105 |
+
output_format: Type[BaseModel]
|
| 106 |
+
) -> Any:
|
| 107 |
+
"""
|
| 108 |
+
Send a query with image(s) to OpenAI vision model.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
prompt: Text prompt describing what to analyze
|
| 112 |
+
image_data: Single ImageData or list of ImageData
|
| 113 |
+
output_format: Pydantic model for structured output
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Parsed response according to output_format
|
| 117 |
+
|
| 118 |
+
Raises:
|
| 119 |
+
AIEndpointRequestError: If the request fails
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
# Prepare images
|
| 123 |
+
images = [image_data] if isinstance(image_data, ImageData) else image_data
|
| 124 |
+
|
| 125 |
+
# Build content array with text and images
|
| 126 |
+
content = [{"type": "text", "text": prompt}]
|
| 127 |
+
|
| 128 |
+
for img in images:
|
| 129 |
+
image_content = self._build_image_content(img)
|
| 130 |
+
content.append(image_content)
|
| 131 |
+
|
| 132 |
+
# Make request
|
| 133 |
+
response = self.client.chat.completions.create(
|
| 134 |
+
model=self.model,
|
| 135 |
+
messages=[{"role": "user", "content": content}],
|
| 136 |
+
max_tokens=self.max_tokens,
|
| 137 |
+
temperature=self.temperature,
|
| 138 |
+
response_format={"type": "json_object"},
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
response_content = response.choices[0].message.content
|
| 142 |
+
logger.debug(f"OpenAI vision response: {response_content[:500] if response_content else 'empty'}")
|
| 143 |
+
|
| 144 |
+
return self.parseStringToJson(response_content)
|
| 145 |
+
|
| 146 |
+
except AIEndpointRequestError:
|
| 147 |
+
raise
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"OpenAI vision query failed: {e}")
|
| 150 |
+
import traceback
|
| 151 |
+
logger.error(traceback.format_exc())
|
| 152 |
+
raise AIEndpointRequestError(f"OpenAI vision query failed: {e}")
|
| 153 |
+
|
| 154 |
+
def _build_image_content(self, image_data: ImageData) -> Dict[str, Any]:
|
| 155 |
+
"""
|
| 156 |
+
Build image content block for OpenAI API.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
image_data: ImageData object
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Dict with type: "image_url" and image_url content
|
| 163 |
+
"""
|
| 164 |
+
if image_data.source == "url":
|
| 165 |
+
# Direct URL reference
|
| 166 |
+
return {
|
| 167 |
+
"type": "image_url",
|
| 168 |
+
"image_url": {
|
| 169 |
+
"url": image_data.data,
|
| 170 |
+
"detail": self.detail
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
elif image_data.source == "base64":
|
| 175 |
+
# Data URL format
|
| 176 |
+
mime_type = image_data.mime_type or "image/jpeg"
|
| 177 |
+
data_url = f"data:{mime_type};base64,{image_data.data}"
|
| 178 |
+
|
| 179 |
+
return {
|
| 180 |
+
"type": "image_url",
|
| 181 |
+
"image_url": {
|
| 182 |
+
"url": data_url,
|
| 183 |
+
"detail": self.detail
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
else:
|
| 188 |
+
raise AIEndpointRequestError(f"Unknown image source: {image_data.source}")
|
| 189 |
+
|
| 190 |
+
def analyze_image(
|
| 191 |
+
self,
|
| 192 |
+
image_path_or_url: str,
|
| 193 |
+
prompt: str,
|
| 194 |
+
output_format: Type[BaseModel] = None
|
| 195 |
+
) -> Any:
|
| 196 |
+
"""
|
| 197 |
+
Convenience method for analyzing a single image.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
image_path_or_url: Path to image file or URL
|
| 201 |
+
prompt: Analysis prompt
|
| 202 |
+
output_format: Optional output format model
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Analysis result
|
| 206 |
+
"""
|
| 207 |
+
# Prepare image data - use URL directly if possible
|
| 208 |
+
if image_path_or_url.startswith(("http://", "https://")):
|
| 209 |
+
image_data = self.create_url_image_data(image_path_or_url)
|
| 210 |
+
else:
|
| 211 |
+
image_data = self.encode_image_to_base64(image_path_or_url)
|
| 212 |
+
|
| 213 |
+
# Use a generic format if not specified
|
| 214 |
+
if output_format is None:
|
| 215 |
+
from .prompt.models_module import GeneralHintFormat
|
| 216 |
+
output_format = GeneralHintFormat
|
| 217 |
+
|
| 218 |
+
return self.query_with_image(prompt, image_data, output_format)
|
| 219 |
+
|
| 220 |
+
def detect_objects(
|
| 221 |
+
self,
|
| 222 |
+
image_path_or_url: str,
|
| 223 |
+
labels: List[str] = None
|
| 224 |
+
) -> Dict[str, Any]:
|
| 225 |
+
"""
|
| 226 |
+
Detect objects in an image and return bounding boxes.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
image_path_or_url: Path to image file or URL
|
| 230 |
+
labels: Optional list of labels to detect
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Dict with detections list
|
| 234 |
+
"""
|
| 235 |
+
from .prompt.models_module import VisualDetectionFormat
|
| 236 |
+
|
| 237 |
+
labels_str = ", ".join(labels) if labels else "all visible objects"
|
| 238 |
+
|
| 239 |
+
prompt = f"""Analyze this image and detect objects. For each object, provide:
|
| 240 |
+
1. The label (from: {labels_str})
|
| 241 |
+
2. A bounding box with normalized coordinates (0-1 range)
|
| 242 |
+
3. Confidence score (0-1)
|
| 243 |
+
|
| 244 |
+
Return JSON with this structure:
|
| 245 |
+
{{
|
| 246 |
+
"detections": [
|
| 247 |
+
{{
|
| 248 |
+
"label": "object_name",
|
| 249 |
+
"bbox": {{"x": 0.1, "y": 0.2, "width": 0.3, "height": 0.4}},
|
| 250 |
+
"confidence": 0.95
|
| 251 |
+
}}
|
| 252 |
+
]
|
| 253 |
+
}}
|
| 254 |
+
|
| 255 |
+
Coordinates are normalized (0-1) where x,y is the top-left corner.
|
| 256 |
+
Only include objects you can clearly identify with confidence > 0.5."""
|
| 257 |
+
|
| 258 |
+
# Prepare image
|
| 259 |
+
if image_path_or_url.startswith(("http://", "https://")):
|
| 260 |
+
image_data = self.create_url_image_data(image_path_or_url)
|
| 261 |
+
else:
|
| 262 |
+
image_data = self.encode_image_to_base64(image_path_or_url)
|
| 263 |
+
|
| 264 |
+
return self.query_with_image(prompt, image_data, VisualDetectionFormat)
|
| 265 |
+
|
| 266 |
+
def describe_region(
|
| 267 |
+
self,
|
| 268 |
+
image_path_or_url: str,
|
| 269 |
+
region: Dict[str, float],
|
| 270 |
+
labels: List[str] = None
|
| 271 |
+
) -> Dict[str, Any]:
|
| 272 |
+
"""
|
| 273 |
+
Describe or classify a specific region in an image.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
image_path_or_url: Path to image file or URL
|
| 277 |
+
region: Dict with x, y, width, height (normalized 0-1)
|
| 278 |
+
labels: Optional list of possible labels
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Classification result with suggested label and confidence
|
| 282 |
+
"""
|
| 283 |
+
labels_str = ", ".join(labels) if labels else "any appropriate category"
|
| 284 |
+
|
| 285 |
+
prompt = f"""Look at the region marked in this image:
|
| 286 |
+
- Region: x={region['x']:.2f}, y={region['y']:.2f}, width={region['width']:.2f}, height={region['height']:.2f}
|
| 287 |
+
(Coordinates are normalized 0-1, where 0,0 is top-left)
|
| 288 |
+
|
| 289 |
+
Classify what you see in this region from these options: {labels_str}
|
| 290 |
+
|
| 291 |
+
Return JSON:
|
| 292 |
+
{{
|
| 293 |
+
"suggested_label": "label_name",
|
| 294 |
+
"confidence": 0.85,
|
| 295 |
+
"reasoning": "Brief explanation"
|
| 296 |
+
}}"""
|
| 297 |
+
|
| 298 |
+
# Prepare image
|
| 299 |
+
if image_path_or_url.startswith(("http://", "https://")):
|
| 300 |
+
image_data = self.create_url_image_data(image_path_or_url)
|
| 301 |
+
else:
|
| 302 |
+
image_data = self.encode_image_to_base64(image_path_or_url)
|
| 303 |
+
|
| 304 |
+
class RegionClassificationFormat(BaseModel):
|
| 305 |
+
suggested_label: str
|
| 306 |
+
confidence: float
|
| 307 |
+
reasoning: str
|
| 308 |
+
|
| 309 |
+
return self.query_with_image(prompt, image_data, RegionClassificationFormat)
|
| 310 |
+
|
| 311 |
+
def health_check(self) -> bool:
|
| 312 |
+
"""
|
| 313 |
+
Check if the OpenAI API is accessible.
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
True if API is reachable, False otherwise
|
| 317 |
+
"""
|
| 318 |
+
try:
|
| 319 |
+
# Simple models list check
|
| 320 |
+
self.client.models.list()
|
| 321 |
+
return True
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.error(f"OpenAI health check failed: {e}")
|
| 324 |
+
return False
|
potato/ai/openrouter_endpoint.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenRouter endpoint implementation.
|
| 3 |
+
This module provides integration with OpenRouter's API for LLM inference.
|
| 4 |
+
"""
|
| 5 |
+
import requests
|
| 6 |
+
from .ai_endpoint import BaseAIEndpoint, AIEndpointRequestError
|
| 7 |
+
|
| 8 |
+
DEFAULT_MODEL = "openai/gpt-4o-mini"
|
| 9 |
+
|
| 10 |
+
class OpenRouterEndpoint(BaseAIEndpoint):
|
| 11 |
+
"""OpenRouter endpoint for cloud-based LLM inference."""
|
| 12 |
+
|
| 13 |
+
# Models that support structured output
|
| 14 |
+
STRUCTURED_OUTPUT_MODELS = {
|
| 15 |
+
"openai/gpt-4o",
|
| 16 |
+
"openai/gpt-4o-mini",
|
| 17 |
+
"openai/gpt-4-turbo",
|
| 18 |
+
"anthropic/claude-3-5-sonnet",
|
| 19 |
+
"deepseek/deepseek-r1:free"
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
def _initialize_client(self) -> None:
|
| 23 |
+
"""Initialize the OpenAI client."""
|
| 24 |
+
api_key = self.ai_config.get("api_key", "")
|
| 25 |
+
if not api_key:
|
| 26 |
+
raise AIEndpointRequestError("OpenRouter API key is required")
|
| 27 |
+
|
| 28 |
+
def _get_default_model(self) -> str:
|
| 29 |
+
"""Get the default OpenAI model."""
|
| 30 |
+
return DEFAULT_MODEL
|
| 31 |
+
|
| 32 |
+
def supports_structured_output(self) -> bool:
|
| 33 |
+
"""Check if the current model supports structured output."""
|
| 34 |
+
model = self.model or DEFAULT_MODEL
|
| 35 |
+
return any(model.startswith(prefix.split('/')[0]) or model in self.STRUCTURED_OUTPUT_MODELS
|
| 36 |
+
for prefix in self.STRUCTURED_OUTPUT_MODELS)
|
| 37 |
+
|
| 38 |
+
def query(self, prompt: str, output_format: dict) -> str:
|
| 39 |
+
"""
|
| 40 |
+
Send a query to OpenRouter and return the response.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
prompt: The prompt to send to the model (as messages list or string)
|
| 44 |
+
output_format: Pydantic model for structured output
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
The model's response as a string
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
AIEndpointRequestError: If the request fails
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
url = "https://openrouter.ai/api/v1/chat/completions"
|
| 54 |
+
headers = {
|
| 55 |
+
"Authorization": f"Bearer {self.ai_config.get('api_key')}",
|
| 56 |
+
"Content-Type": "application/json"
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
messages = [{"role": "user", "content": prompt}]
|
| 60 |
+
schema = output_format.model_json_schema()
|
| 61 |
+
|
| 62 |
+
body = {
|
| 63 |
+
"model": self.model or DEFAULT_MODEL,
|
| 64 |
+
"max_tokens": self.max_tokens,
|
| 65 |
+
"temperature": self.temperature,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Handle structured output based on model support
|
| 69 |
+
if self.supports_structured_output():
|
| 70 |
+
body["messages"] = messages
|
| 71 |
+
body["response_format"] = {
|
| 72 |
+
"type": "json_schema",
|
| 73 |
+
"json_schema": {
|
| 74 |
+
"name": "response",
|
| 75 |
+
"schema": schema,
|
| 76 |
+
"strict": True
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
else:
|
| 80 |
+
# If model does not support structured format, just send raw prompt
|
| 81 |
+
body["messages"] = messages
|
| 82 |
+
|
| 83 |
+
r = requests.post(url, headers=headers, json=body)
|
| 84 |
+
|
| 85 |
+
if r.status_code >= 400:
|
| 86 |
+
raise AIEndpointRequestError(f"OpenRouter error {r.status_code}: {r.text}")
|
| 87 |
+
|
| 88 |
+
data = r.json()
|
| 89 |
+
if self.supports_structured_output():
|
| 90 |
+
return self.parseStringToJson(data["choices"][0]["message"]["content"])
|
| 91 |
+
return data["choices"][0]["message"]["content"]
|
| 92 |
+
except Exception as e:
|
| 93 |
+
raise AIEndpointRequestError(f"OpenRouter request failed: {e}")
|
potato/ai/prompt/image_annotation.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"detect": {
|
| 3 |
+
"prompt": "TASK: Detect objects in this image that match the specified labels.\n\nDescription: ${description}\nLabels to detect: ${labels}\nMinimum confidence: ${confidence_threshold}\n\nFor each detected object, provide:\n1. The label (must be from the provided labels list)\n2. Bounding box with normalized coordinates (0-1 range)\n3. Confidence score (0-1)\n\nCoordinate system:\n- x,y is the top-left corner of the box\n- x increases left to right (0 = left edge, 1 = right edge)\n- y increases top to bottom (0 = top edge, 1 = bottom edge)\n- width and height are also normalized (0-1)\n\nReturn JSON:\n{\n \"detections\": [\n {\n \"label\": \"object_name\",\n \"bbox\": {\"x\": 0.1, \"y\": 0.2, \"width\": 0.3, \"height\": 0.4},\n \"confidence\": 0.95\n }\n ]\n}\n\nOnly include objects you can clearly identify with confidence >= ${confidence_threshold}.",
|
| 4 |
+
"output_format": "visual_detection",
|
| 5 |
+
"img": "/static/ai_assistant_img/detect.svg",
|
| 6 |
+
"name": "Detect"
|
| 7 |
+
},
|
| 8 |
+
"detection": {
|
| 9 |
+
"prompt": "TASK: Detect objects in this image that match the specified labels.\n\nDescription: ${description}\nLabels to detect: ${labels}\nMinimum confidence: ${confidence_threshold}\n\nFor each detected object, provide:\n1. The label (must be from the provided labels list)\n2. Bounding box with normalized coordinates (0-1 range)\n3. Confidence score (0-1)\n\nCoordinate system:\n- x,y is the top-left corner of the box\n- x increases left to right (0 = left edge, 1 = right edge)\n- y increases top to bottom (0 = top edge, 1 = bottom edge)\n- width and height are also normalized (0-1)\n\nReturn JSON:\n{\n \"detections\": [\n {\n \"label\": \"object_name\",\n \"bbox\": {\"x\": 0.1, \"y\": 0.2, \"width\": 0.3, \"height\": 0.4},\n \"confidence\": 0.95\n }\n ]\n}\n\nOnly include objects you can clearly identify with confidence >= ${confidence_threshold}.",
|
| 10 |
+
"output_format": "visual_detection",
|
| 11 |
+
"img": "/static/ai_assistant_img/detect.svg",
|
| 12 |
+
"name": "Detect"
|
| 13 |
+
},
|
| 14 |
+
"pre_annotate": {
|
| 15 |
+
"prompt": "TASK: Pre-annotate this image by detecting all objects that match the available labels.\n\nDescription: ${description}\nAvailable labels: ${labels}\n\nDetect ALL instances of objects matching these labels. Be thorough - it's better to include uncertain detections (the human annotator will review them).\n\nFor each detection provide:\n1. Label from the available list\n2. Bounding box (normalized 0-1 coordinates)\n3. Confidence score\n\nReturn JSON:\n{\n \"detections\": [\n {\n \"label\": \"label_name\",\n \"bbox\": {\"x\": 0.0, \"y\": 0.0, \"width\": 0.0, \"height\": 0.0},\n \"confidence\": 0.0\n }\n ]\n}\n\nInclude all potential detections, even with lower confidence. The annotator will verify.",
|
| 16 |
+
"output_format": "visual_detection",
|
| 17 |
+
"img": "/static/ai_assistant_img/auto.svg",
|
| 18 |
+
"name": "Auto"
|
| 19 |
+
},
|
| 20 |
+
"classification": {
|
| 21 |
+
"prompt": "TASK: Classify the specified region in this image.\n\nDescription: ${description}\nRegion: ${region}\nAvailable labels: ${labels}\n\nLook at the indicated region and determine which label best describes what you see there.\n\nReturn JSON:\n{\n \"suggested_label\": \"label_name\",\n \"confidence\": 0.85,\n \"reasoning\": \"Brief explanation of why this label fits\"\n}",
|
| 22 |
+
"output_format": "visual_classification",
|
| 23 |
+
"img": "/static/ai_assistant_img/classify.svg",
|
| 24 |
+
"name": "Classify"
|
| 25 |
+
},
|
| 26 |
+
"hint": {
|
| 27 |
+
"prompt": "TASK: Provide a helpful hint for annotating this image WITHOUT revealing exact answers.\n\nAnnotation task: ${description}\nAvailable labels: ${labels}\n\nProvide guidance that helps the annotator without giving away:\n- Exact object locations\n- Specific label assignments\n\nGood hints:\n- Point out relevant visual features\n- Suggest areas that deserve attention\n- Note potential challenges or ambiguities\n- Remind about edge cases\n\nReturn JSON:\n{\n \"hint\": \"Your helpful guidance here\",\n \"suggestive_choice\": \"optional_focus_area\"\n}",
|
| 28 |
+
"output_format": "default_hint",
|
| 29 |
+
"img": "/static/ai_assistant_img/blub.svg",
|
| 30 |
+
"name": "Hint"
|
| 31 |
+
},
|
| 32 |
+
"keyword": {
|
| 33 |
+
"prompt": "TASK: Identify visual keywords/features associated with each label in this image.\n\nAnnotation task: ${description}\nAvailable labels: ${labels}\n\nFor each label, identify visual cues or features that would indicate its presence.\n\nReturn JSON:\n{\n \"label_keywords\": [\n {\n \"label\": \"label_name\",\n \"keywords\": [\"visual_feature_1\", \"visual_feature_2\"]\n }\n ]\n}",
|
| 34 |
+
"output_format": "default_keyword",
|
| 35 |
+
"img": "/static/ai_assistant_img/highlight.svg",
|
| 36 |
+
"name": "Keywords"
|
| 37 |
+
},
|
| 38 |
+
"rationale": {
|
| 39 |
+
"prompt": "TASK: Provide rationale for how each label might apply to this image.\n\nAnnotation task: ${description}\nAvailable labels: ${labels}\n\nFor each label, explain what evidence in the image supports or contradicts its application.\n\nReturn JSON:\n{\n \"rationales\": [\n {\n \"label\": \"label_name\",\n \"reasoning\": \"Explanation of visual evidence for/against this label\"\n }\n ]\n}",
|
| 40 |
+
"output_format": "default_rationale",
|
| 41 |
+
"img": "/static/ai_assistant_img/question.svg",
|
| 42 |
+
"name": "Rationale"
|
| 43 |
+
}
|
| 44 |
+
}
|
potato/ai/prompt/likert.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hint": {
|
| 3 |
+
"prompt": "TASK: Generate annotation guidance for Likert scale task.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Scale: ${min_label} (1) to ${max_label} (${size})\n- Scale points: 1, 2, 3, ..., ${size}\n\nINSTRUCTIONS:\n1. Analyze the text for features relevant to the annotation task\n2. Generate a helpful hint that guides thinking WITHOUT revealing the answer\n3. Suggest a scale position based on your analysis\n\nHINT REQUIREMENTS:\n- Focus on specific textual evidence (word choice, tone, structure)\n- Point out subtle indicators the annotator should notice\n- Be concrete and actionable, not generic\n- Guide analytical thinking without bias toward any scale position",
|
| 4 |
+
"output_format": "default_hint",
|
| 5 |
+
"img": "/static/ai_assistant_img/blub.svg"
|
| 6 |
+
},
|
| 7 |
+
"keyword": {
|
| 8 |
+
"prompt": "TASK: Extract key words/phrases that guide Likert scale annotation decisions.\n\nINPUT DETAILS:\n- Text: \"${text}\"\n- Annotation task: ${description}\n- Scale: ${min_label} (1) to ${max_label} (${size})\n- Scale range: 1, 2, 3, ..., ${size}\n\nOBJECTIVE: Identify 3-5 most significant words/phrases that directly indicate scale positioning. Output as JSON array format.\n\nSELECTION CRITERIA:\n- Words that signal intensity/degree (extremely, slightly, moderately)\n- Sentiment markers (positive/negative indicators)\n- Qualifying language (hedges, certainty markers)\n- Context-specific terminology relevant to the annotation task\n- Structural indicators (but, however, although)\n\nPRIORITIZE:\n1. Words with clear scale implications\n2. Phrases that distinguish between scale levels\n3. Contextual clues that affect interpretation\n4. Intensity modifiers and qualifiers",
|
| 9 |
+
"output_format": "default_keyword",
|
| 10 |
+
"img": "/static/ai_assistant_img/highlight.svg"
|
| 11 |
+
},
|
| 12 |
+
"rationale": {
|
| 13 |
+
"name": "Rationale",
|
| 14 |
+
"prompt": "TASK: Generate rationales for different positions on the Likert scale.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Scale: ${min_label} (1) to ${max_label} (${size})\n\nINSTRUCTIONS:\nProvide rationales for different scale positions (low, middle, high). Explain what evidence in the text could support each position. Be balanced and objective.\n\nOUTPUT FORMAT:\nReturn a JSON object with \"rationales\" array containing objects with \"label\" (scale position descriptor) and \"reasoning\" fields.\n\nEXAMPLE OUTPUT:\n{\"rationales\": [{\"label\": \"low (1-2)\", \"reasoning\": \"The negative tone and words like 'disappointing' suggest a low rating\"}, {\"label\": \"middle (3)\", \"reasoning\": \"Mixed signals with both positive and negative elements\"}, {\"label\": \"high (4-5)\", \"reasoning\": \"Strong positive language like 'excellent' supports a high rating\"}]}",
|
| 15 |
+
"output_format": "default_rationale",
|
| 16 |
+
"img": "/static/ai_assistant_img/question.svg"
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
potato/ai/prompt/models_module.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Type, Union, Dict, List
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class GeneralHintFormat(BaseModel):
|
| 6 |
+
hint: str
|
| 7 |
+
suggestive_choice: Union[str, int]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LabelKeywords(BaseModel):
|
| 11 |
+
"""Keywords/phrases associated with a specific label."""
|
| 12 |
+
label: str
|
| 13 |
+
keywords: List[str]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GeneralKeywordFormat(BaseModel):
|
| 17 |
+
"""Simplified keyword format: list of label -> keywords mappings.
|
| 18 |
+
|
| 19 |
+
Example output:
|
| 20 |
+
{
|
| 21 |
+
"label_keywords": [
|
| 22 |
+
{"label": "positive", "keywords": ["great", "love it", "excellent"]},
|
| 23 |
+
{"label": "negative", "keywords": ["terrible", "awful"]}
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
"""
|
| 27 |
+
label_keywords: List[LabelKeywords]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class GeneralRandomFormat(BaseModel):
|
| 31 |
+
"""Deprecated: Use GeneralRationaleFormat instead."""
|
| 32 |
+
random: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class LabelRationale(BaseModel):
|
| 36 |
+
"""Rationale/reasoning for why a specific label might apply."""
|
| 37 |
+
label: str
|
| 38 |
+
reasoning: str
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class GeneralRationaleFormat(BaseModel):
|
| 42 |
+
"""Rationale format: explanations for how each label might apply to the text.
|
| 43 |
+
|
| 44 |
+
Example output:
|
| 45 |
+
{
|
| 46 |
+
"rationales": [
|
| 47 |
+
{"label": "positive", "reasoning": "The phrase 'excellent quality' suggests satisfaction"},
|
| 48 |
+
{"label": "negative", "reasoning": "The mention of 'delayed shipping' indicates frustration"}
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
"""
|
| 52 |
+
rationales: List[LabelRationale]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ============================================================================
|
| 56 |
+
# Visual Annotation Output Formats
|
| 57 |
+
# ============================================================================
|
| 58 |
+
|
| 59 |
+
class BoundingBox(BaseModel):
|
| 60 |
+
"""Normalized bounding box coordinates (0-1 range).
|
| 61 |
+
|
| 62 |
+
x, y: top-left corner position
|
| 63 |
+
width, height: box dimensions
|
| 64 |
+
All values are normalized to image dimensions (0-1).
|
| 65 |
+
"""
|
| 66 |
+
x: float
|
| 67 |
+
y: float
|
| 68 |
+
width: float
|
| 69 |
+
height: float
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Detection(BaseModel):
|
| 73 |
+
"""Single object detection result.
|
| 74 |
+
|
| 75 |
+
Example:
|
| 76 |
+
{
|
| 77 |
+
"label": "person",
|
| 78 |
+
"bbox": {"x": 0.1, "y": 0.2, "width": 0.3, "height": 0.5},
|
| 79 |
+
"confidence": 0.95
|
| 80 |
+
}
|
| 81 |
+
"""
|
| 82 |
+
label: str
|
| 83 |
+
bbox: BoundingBox
|
| 84 |
+
confidence: float
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class VisualDetectionFormat(BaseModel):
|
| 88 |
+
"""Object detection results for an image.
|
| 89 |
+
|
| 90 |
+
Example output:
|
| 91 |
+
{
|
| 92 |
+
"detections": [
|
| 93 |
+
{"label": "car", "bbox": {"x": 0.1, "y": 0.2, "width": 0.3, "height": 0.2}, "confidence": 0.92},
|
| 94 |
+
{"label": "person", "bbox": {"x": 0.5, "y": 0.3, "width": 0.1, "height": 0.4}, "confidence": 0.87}
|
| 95 |
+
]
|
| 96 |
+
}
|
| 97 |
+
"""
|
| 98 |
+
detections: List[Detection]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class VisualClassificationFormat(BaseModel):
|
| 102 |
+
"""Classification result for an image or region.
|
| 103 |
+
|
| 104 |
+
Example output:
|
| 105 |
+
{
|
| 106 |
+
"suggested_label": "cat",
|
| 107 |
+
"confidence": 0.89,
|
| 108 |
+
"reasoning": "The image shows a feline with pointed ears and whiskers"
|
| 109 |
+
}
|
| 110 |
+
"""
|
| 111 |
+
suggested_label: str
|
| 112 |
+
confidence: float
|
| 113 |
+
reasoning: Optional[str] = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class VideoSegment(BaseModel):
|
| 117 |
+
"""Temporal segment in a video.
|
| 118 |
+
|
| 119 |
+
Times are in seconds.
|
| 120 |
+
"""
|
| 121 |
+
start_time: float
|
| 122 |
+
end_time: float
|
| 123 |
+
suggested_label: str
|
| 124 |
+
confidence: float
|
| 125 |
+
description: Optional[str] = None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class VideoSceneDetectionFormat(BaseModel):
|
| 129 |
+
"""Scene/segment detection results for a video.
|
| 130 |
+
|
| 131 |
+
Example output:
|
| 132 |
+
{
|
| 133 |
+
"segments": [
|
| 134 |
+
{"start_time": 0.0, "end_time": 5.5, "suggested_label": "intro", "confidence": 0.9},
|
| 135 |
+
{"start_time": 5.5, "end_time": 15.0, "suggested_label": "action", "confidence": 0.85}
|
| 136 |
+
]
|
| 137 |
+
}
|
| 138 |
+
"""
|
| 139 |
+
segments: List[VideoSegment]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class VideoKeyframe(BaseModel):
|
| 143 |
+
"""Keyframe annotation for a video.
|
| 144 |
+
|
| 145 |
+
timestamp: Time in seconds
|
| 146 |
+
"""
|
| 147 |
+
timestamp: float
|
| 148 |
+
suggested_label: str
|
| 149 |
+
confidence: float
|
| 150 |
+
reason: Optional[str] = None
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class VideoKeyframeDetectionFormat(BaseModel):
|
| 154 |
+
"""Keyframe detection results for a video.
|
| 155 |
+
|
| 156 |
+
Example output:
|
| 157 |
+
{
|
| 158 |
+
"keyframes": [
|
| 159 |
+
{"timestamp": 2.5, "suggested_label": "scene_change", "confidence": 0.95, "reason": "Major visual transition"},
|
| 160 |
+
{"timestamp": 8.0, "suggested_label": "action_peak", "confidence": 0.82, "reason": "Key moment in action"}
|
| 161 |
+
]
|
| 162 |
+
}
|
| 163 |
+
"""
|
| 164 |
+
keyframes: List[VideoKeyframe]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class TrackPosition(BaseModel):
|
| 168 |
+
"""Object position in a single frame for tracking."""
|
| 169 |
+
frame_index: int
|
| 170 |
+
bbox: BoundingBox
|
| 171 |
+
confidence: float
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ObjectTrack(BaseModel):
|
| 175 |
+
"""Tracked object across multiple frames."""
|
| 176 |
+
track_id: int
|
| 177 |
+
label: str
|
| 178 |
+
positions: List[TrackPosition]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class VideoTrackingSuggestionFormat(BaseModel):
|
| 182 |
+
"""Object tracking suggestions for a video.
|
| 183 |
+
|
| 184 |
+
Example output:
|
| 185 |
+
{
|
| 186 |
+
"tracks": [
|
| 187 |
+
{
|
| 188 |
+
"track_id": 1,
|
| 189 |
+
"label": "person",
|
| 190 |
+
"positions": [
|
| 191 |
+
{"frame_index": 0, "bbox": {"x": 0.1, "y": 0.2, "width": 0.15, "height": 0.3}, "confidence": 0.9},
|
| 192 |
+
{"frame_index": 1, "bbox": {"x": 0.12, "y": 0.22, "width": 0.15, "height": 0.3}, "confidence": 0.88}
|
| 193 |
+
]
|
| 194 |
+
}
|
| 195 |
+
]
|
| 196 |
+
}
|
| 197 |
+
"""
|
| 198 |
+
tracks: List[ObjectTrack]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class FrameDetections(BaseModel):
|
| 202 |
+
"""Detections for a single video frame."""
|
| 203 |
+
frame_index: int
|
| 204 |
+
detections: List[Detection]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class MultiFrameDetectionFormat(BaseModel):
|
| 208 |
+
"""Detection results across multiple video frames.
|
| 209 |
+
|
| 210 |
+
Used when running detection on sampled video frames.
|
| 211 |
+
"""
|
| 212 |
+
frames: List[FrameDetections]
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ============================================================================
|
| 216 |
+
# Class Registry
|
| 217 |
+
# ============================================================================
|
| 218 |
+
|
| 219 |
+
# ============================================================================
|
| 220 |
+
# Option Highlighting Output Format
|
| 221 |
+
# ============================================================================
|
| 222 |
+
|
| 223 |
+
class OptionHighlightFormat(BaseModel):
|
| 224 |
+
"""LLM response for option highlighting.
|
| 225 |
+
|
| 226 |
+
Used to identify the most likely correct options for a discrete annotation task.
|
| 227 |
+
The highlighted options are shown at full opacity while others are dimmed.
|
| 228 |
+
|
| 229 |
+
Example output:
|
| 230 |
+
{
|
| 231 |
+
"highlighted_options": ["positive", "neutral"],
|
| 232 |
+
"confidence": 0.85
|
| 233 |
+
}
|
| 234 |
+
"""
|
| 235 |
+
highlighted_options: List[str] # Top-k most likely option names/values
|
| 236 |
+
confidence: Optional[float] = None # Optional overall confidence score (0-1)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
CLASS_REGISTRY = {
|
| 240 |
+
# Text annotation formats
|
| 241 |
+
"default_hint": GeneralHintFormat,
|
| 242 |
+
"default_keyword": GeneralKeywordFormat,
|
| 243 |
+
"default_random": GeneralRandomFormat, # Keep for backwards compatibility
|
| 244 |
+
"default_rationale": GeneralRationaleFormat,
|
| 245 |
+
|
| 246 |
+
# Option highlighting format
|
| 247 |
+
"option_highlight": OptionHighlightFormat,
|
| 248 |
+
|
| 249 |
+
# Visual annotation formats - Image
|
| 250 |
+
"visual_detection": VisualDetectionFormat,
|
| 251 |
+
"visual_classification": VisualClassificationFormat,
|
| 252 |
+
|
| 253 |
+
# Visual annotation formats - Video
|
| 254 |
+
"video_scene_detection": VideoSceneDetectionFormat,
|
| 255 |
+
"video_keyframe_detection": VideoKeyframeDetectionFormat,
|
| 256 |
+
"video_tracking_suggestion": VideoTrackingSuggestionFormat,
|
| 257 |
+
"multi_frame_detection": MultiFrameDetectionFormat,
|
| 258 |
+
}
|
potato/ai/prompt/multiselect.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hint": {
|
| 3 |
+
"prompt": "TASK: Generate guidance for multiple label selection.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nINSTRUCTIONS:\n1. Analyze text for features that may correspond to multiple labels\n2. Guide toward identifying ALL applicable categories\n3. Focus on overlapping characteristics and comprehensive analysis\n\nHINT REQUIREMENTS:\n- Identify indicators for each potential label category\n- Point out that multiple selections may be appropriate\n- Guide systematic evaluation of all label options\n- Highlight overlapping or complementary features",
|
| 4 |
+
"output_format": "default_hint",
|
| 5 |
+
"img": "/static/ai_assistant_img/blub.svg"
|
| 6 |
+
},
|
| 7 |
+
"keyword": {
|
| 8 |
+
"prompt": "TASK: Extract words/phrases that indicate multiple applicable labels.\n\nINPUT DETAILS:\n- Text: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nOBJECTIVE: Identify terms that support selection of multiple labels.\n\nSELECTION CRITERIA:\n- Words that indicate multiple categories simultaneously\n- Terms supporting different label aspects\n- Overlapping category indicators\n- Comprehensive feature markers\n- Multi-faceted descriptors",
|
| 9 |
+
"output_format": "default_keyword",
|
| 10 |
+
"img": "/static/ai_assistant_img/highlight.svg"
|
| 11 |
+
},
|
| 12 |
+
"rationale": {
|
| 13 |
+
"name": "Rationale",
|
| 14 |
+
"prompt": "TASK: Generate rationales explaining why each label might apply to this text.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nCRITICAL REQUIREMENT:\nYou MUST provide a rationale for EVERY label listed above. Your output must have exactly one entry per label.\n\nINSTRUCTIONS:\nFor EACH available label (ALL of them), provide a brief rationale explaining what evidence in the text could support selecting that label. Since multiple labels can be selected, focus on independent evidence for each label. If a label doesn't apply, explain why.\n\nOUTPUT FORMAT:\nReturn a JSON object with \"rationales\" array containing one object per label, each with \"label\" and \"reasoning\" fields.\n\nEXAMPLE OUTPUT:\n{\"rationales\": [{\"label\": \"category1\", \"reasoning\": \"The text mentions X which relates to this category\"}, {\"label\": \"category2\", \"reasoning\": \"The phrase Y suggests this also applies\"}, {\"label\": \"category3\", \"reasoning\": \"No direct evidence for this label in the text\"}]}",
|
| 15 |
+
"output_format": "default_rationale",
|
| 16 |
+
"img": "/static/ai_assistant_img/question.svg"
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
potato/ai/prompt/number.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hint": {
|
| 3 |
+
"prompt": "TASK: Generate guidance for numerical value extraction/estimation.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n\nINSTRUCTIONS:\n1. Analyze the text for numerical clues or quantifiable elements\n2. Guide the annotator toward identifying the correct numerical value\n3. Focus on mathematical, statistical, or countable aspects\n\nHINT REQUIREMENTS:\n- Identify numerical indicators, quantities, or measurable elements\n- Point out calculation methods or counting strategies\n- Highlight context that affects numerical interpretation\n- Guide toward systematic analysis approach",
|
| 4 |
+
"output_format": "default_hint",
|
| 5 |
+
"img": "/static/ai_assistant_img/blub.svg"
|
| 6 |
+
},
|
| 7 |
+
"keyword": {
|
| 8 |
+
"prompt": "TASK: Extract words/phrases that contain or indicate numerical information.\n\nINPUT DETAILS:\n- Text: \"${text}\"\n- Annotation task: ${description}\n\nOBJECTIVE: Identify words/phrases that directly relate to the numerical answer.\n\nSELECTION CRITERIA:\n- Explicit numbers, quantities, or measurements\n- Words indicating amount, frequency, or degree\n- Mathematical or statistical terminology\n- Comparative language (more, less, double, half)\n- Time references, percentages, ratios",
|
| 9 |
+
"output_format": "default_keyword",
|
| 10 |
+
"img": "/static/ai_assistant_img/highlight.svg"
|
| 11 |
+
},
|
| 12 |
+
"rationale": {
|
| 13 |
+
"name": "Rationale",
|
| 14 |
+
"prompt": "TASK: Generate rationales for different approaches to determining the numerical answer.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n\nINSTRUCTIONS:\nProvide different rationales or approaches for determining the numerical value. Consider different interpretations or calculation methods if applicable.\n\nOUTPUT FORMAT:\nReturn a JSON object with \"rationales\" array containing objects with \"label\" (approach name) and \"reasoning\" fields.\n\nEXAMPLE OUTPUT:\n{\"rationales\": [{\"label\": \"literal count\", \"reasoning\": \"Counting explicit mentions gives X\"}, {\"label\": \"inclusive interpretation\", \"reasoning\": \"Including implicit references increases the count to Y\"}]}",
|
| 15 |
+
"output_format": "default_rationale",
|
| 16 |
+
"img": "/static/ai_assistant_img/question.svg"
|
| 17 |
+
}
|
| 18 |
+
}
|
potato/ai/prompt/option_highlight.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"option_highlight": {
|
| 3 |
+
"prompt": "TASK: Identify the most likely correct options for an annotation task.\n\nINPUT DETAILS:\n- Content to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available options: ${labels}\n- Number of options to highlight: ${top_k}\n\nINSTRUCTIONS:\n1. Analyze the content carefully in the context of the annotation task\n2. Consider which ${top_k} options are most likely to be correct based on:\n - Direct evidence in the content\n - Contextual clues and tone\n - Domain knowledge relevant to the task\n3. Select exactly ${top_k} options (or fewer if fewer options exist)\n\nIMPORTANT:\n- Return ONLY the option names/values exactly as they appear in the available options list\n- Do not modify, paraphrase, or abbreviate the option names\n- If you're uncertain, choose options that seem most plausible given the content\n\nOUTPUT FORMAT:\nReturn a JSON object with:\n- \"highlighted_options\": array of ${top_k} option names from the available options\n- \"confidence\": your confidence score from 0.0 to 1.0\n\nEXAMPLE (if options are: positive, negative, neutral):\n{\"highlighted_options\": [\"positive\", \"neutral\"], \"confidence\": 0.75}",
|
| 4 |
+
"output_format": "option_highlight",
|
| 5 |
+
"img": "/static/ai_assistant_img/highlight.svg",
|
| 6 |
+
"name": "Option Highlight"
|
| 7 |
+
}
|
| 8 |
+
}
|
potato/ai/prompt/radio.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hint": {
|
| 3 |
+
"prompt": "TASK: Generate annotation guidance for single-choice selection.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nINSTRUCTIONS:\n1. Analyze the text for features that distinguish between the available labels\n2. Generate a helpful hint that guides classification WITHOUT revealing the answer\n3. Focus on decision-making criteria between options\n\nHINT REQUIREMENTS:\n- Identify key textual indicators that differentiate between label options\n- Point out distinguishing features (style, content, context)\n- Guide analytical thinking toward the classification criteria\n- Be specific about what to examine, not generic",
|
| 4 |
+
"output_format": "default_hint",
|
| 5 |
+
"img": "/static/ai_assistant_img/blub.svg"
|
| 6 |
+
},
|
| 7 |
+
"keyword": {
|
| 8 |
+
"prompt": "TASK: Identify words or short phrases in the text that relate to each label.\n\nINPUT DETAILS:\n- Text: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nINSTRUCTIONS:\nFor each label, find words or short phrases from the text that indicate or relate to that label. Only include words/phrases that actually appear in the text. Return 1-5 keywords per label. If no words relate to a label, return an empty list for that label.\n\nEXAMPLE OUTPUT FORMAT:\n{\"label_keywords\": [{\"label\": \"positive\", \"keywords\": [\"great\", \"love it\"]}, {\"label\": \"negative\", \"keywords\": [\"terrible\"]}]}",
|
| 9 |
+
"output_format": "default_keyword",
|
| 10 |
+
"img": "/static/ai_assistant_img/highlight.svg"
|
| 11 |
+
},
|
| 12 |
+
"rationale": {
|
| 13 |
+
"name": "Rationale",
|
| 14 |
+
"prompt": "TASK: Generate rationales explaining why each label might apply to this text.\n\nINPUT DETAILS:\n- Text to annotate: \"${text}\"\n- Annotation task: ${description}\n- Available labels: ${labels}\n\nCRITICAL REQUIREMENT:\nYou MUST provide a rationale for EVERY label listed above. Your output must have exactly one entry per label.\n\nINSTRUCTIONS:\nFor EACH available label (ALL of them), provide a brief rationale explaining what evidence in the text could support choosing that label. Be balanced and objective - present the case for each label fairly. If a label doesn't strongly apply, explain what would need to be present.\n\nOUTPUT FORMAT:\nReturn a JSON object with \"rationales\" array containing one object per label, each with \"label\" and \"reasoning\" fields.\n\nEXAMPLE (if labels are: positive, negative, neutral, mixed):\n{\"rationales\": [{\"label\": \"positive\", \"reasoning\": \"The phrase 'excellent service' suggests satisfaction\"}, {\"label\": \"negative\", \"reasoning\": \"No clear negative indicators present\"}, {\"label\": \"neutral\", \"reasoning\": \"The tone is more emotional than neutral\"}, {\"label\": \"mixed\", \"reasoning\": \"Would need both positive and negative elements\"}]}",
|
| 15 |
+
"output_format": "default_rationale",
|
| 16 |
+
"img": "/static/ai_assistant_img/question.svg"
|
| 17 |
+
}
|
| 18 |
+
}
|