codebook / potato /ai /ai_cache.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
64.3 kB
from __future__ import annotations
import json
import logging
import os
from typing import Any, Dict, Union
import requests
from tqdm import tqdm
import time
from concurrent.futures import ThreadPoolExecutor
import threading
from builtins import open
from potato.server_utils.config_module import config
logger = logging.getLogger(__name__)
from potato.item_state_management import get_item_state_manager
from potato.ai.ai_endpoint import (
AIEndpointFactory,
Annotation_Type,
AnnotationInput,
ImageData,
VisualAnnotationInput,
ModelCapabilities,
)
from potato.ai.ollama_endpoint import OllamaEndpoint
from potato.ai.openrouter_endpoint import OpenRouterEndpoint
from potato.ai.ai_prompt import ModelManager, get_ai_prompt
AICACHEMANAGER = None
def _get_scheme_field(annotation_id: int, field: str, default=None):
"""Safely get a field from an annotation scheme with a clear error message."""
schemes = config.get("annotation_schemes", [])
if annotation_id >= len(schemes):
raise ValueError(
f"AI cache: annotation_id {annotation_id} out of range "
f"(only {len(schemes)} scheme(s) configured)"
)
scheme = schemes[annotation_id]
if default is not None:
return scheme.get(field, default)
if field not in scheme:
scheme_name = scheme.get("name", f"index {annotation_id}")
scheme_type = scheme.get("annotation_type", "unknown")
raise ValueError(
f"AI cache: annotation scheme '{scheme_name}' (type '{scheme_type}') "
f"missing required field '{field}'"
)
return scheme[field]
def _get_instance_text(instance_id: int) -> str:
"""Get the text content from an instance using the configured text_key."""
item = get_item_state_manager().items()[instance_id]
item_data = item.get_data()
# Get the configured text_key
text_key = config.get("item_properties", {}).get("text_key", "text")
# Try the configured text_key first
if text_key in item_data:
return item_data[text_key]
# Fall back to common keys
for key in ['text', 'content', 'message']:
if key in item_data:
return item_data[key]
# Last resort: return any string value
for value in item_data.values():
if isinstance(value, str):
return value
return str(item_data)
def _is_image_url(text: str) -> bool:
"""Check if text appears to be an image URL."""
if not isinstance(text, str):
return False
text_lower = text.lower()
# Check for image extensions
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp']
if any(ext in text_lower for ext in image_extensions):
return True
# Check for common image hosting services
image_hosts = ['unsplash.com', 'imgur.com', 'flickr.com', 'picsum.photos']
if any(host in text_lower for host in image_hosts):
return True
# Check if URL starts with http and might be an image
if text_lower.startswith(('http://', 'https://')) and 'image' in text_lower:
return True
return False
def _get_image_data_from_url(url: str) -> ImageData:
"""Download image from URL and return as ImageData.
Includes SSRF protection to prevent fetching from private/internal IPs.
"""
import base64
import ipaddress
import socket
from urllib.parse import urlparse
# SSRF protection: validate URL scheme and resolve hostname
try:
parsed = urlparse(url)
if parsed.scheme not in ('http', 'https'):
logger.warning(f"Blocked non-HTTP image URL: {url[:100]}")
return None
hostname = parsed.hostname
if hostname:
addr_info = socket.getaddrinfo(hostname, None)
for info in addr_info:
ip_str = info[4][0]
try:
ip = ipaddress.ip_address(ip_str)
if ip.is_private or ip.is_loopback or ip.is_link_local:
logger.warning(
f"Blocked image URL resolving to private IP: "
f"{hostname} -> {ip_str}"
)
return None
except ValueError:
pass
except Exception as e:
logger.warning(f"Failed to validate image URL {url[:100]}: {e}")
return None
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
b64_data = base64.b64encode(response.content).decode('utf-8')
# Determine mime type from content-type header or URL
content_type = response.headers.get('content-type', 'image/jpeg')
return ImageData(source='base64', data=b64_data, mime_type=content_type)
except Exception as e:
logger.error(f"Failed to download image from {url}: {e}")
return None
def init_ai_cache_manager():
global AICACHEMANAGER
if AICACHEMANAGER is None:
AICACHEMANAGER = AiCacheManager()
return AICACHEMANAGER
def get_ai_cache_manager():
"""Get the AI cache manager instance. Returns None if not initialized (AI support disabled)."""
global AICACHEMANAGER
return AICACHEMANAGER
def clear_ai_cache_manager():
"""Clear the AI cache manager singleton. Used for testing."""
global AICACHEMANAGER
AICACHEMANAGER = None
class AiCacheManager:
def __init__(self):
ai_support = config["ai_support"]
if not ai_support["enabled"]:
return
cache_config = ai_support.get("cache_config", {})
ai_config = ai_support.get("ai_config", {})
include = ai_config.get("include") or {}
special_include = include.get("special_include", None)
self.include_all = include.get("all", False)
self.special_includes = {}
self.model_manager = ModelManager()
self.model_manager.load_models_module()
if special_include:
for page_key, page_value in special_include.items():
# Convert string keys to integers for easier lookup
page_index = int(page_key)
self.special_includes[page_index] = {}
for annotation_id, annotation_types in page_value.items():
annotation_id_int = int(annotation_id)
self.special_includes[page_index][annotation_id_int] = annotation_types
# Disk cache configuration.
# F-028: tolerate a partial/absent ai_cache config (e.g. AI support
# enabled for ICL with no disk_cache block) instead of crashing boot
# with KeyError: 'disk_cache'.
disk_cache_cfg = cache_config.get("disk_cache", {}) if isinstance(cache_config, dict) else {}
self.disk_cache_enabled = disk_cache_cfg.get("enabled", False)
disk_cache_path = disk_cache_cfg.get("path")
if self.disk_cache_enabled and not disk_cache_path:
raise Exception("You have enable disk cache, but you did not specific the path!")
self.disk_persistence_path = disk_cache_path
# Validate cache path stays within task directory
if self.disk_persistence_path:
task_dir = os.path.abspath(config.get("task_dir", "."))
cache_abs = os.path.abspath(
os.path.join(task_dir, self.disk_persistence_path)
if not os.path.isabs(self.disk_persistence_path)
else self.disk_persistence_path
)
if not cache_abs.startswith(task_dir + os.sep) and cache_abs != task_dir:
raise ValueError(
f"Cache path '{self.disk_persistence_path}' resolves to "
f"'{cache_abs}' which is outside the task directory "
f"'{task_dir}'. Path traversal is not allowed."
)
# Prefetch configuration — clamp to sane ranges.
# F-028: default to no prefetch when the prefetch block is absent
# (e.g. cache_config: {enabled: false}) instead of KeyError on boot.
prefetch_cfg = cache_config.get("prefetch", {}) if isinstance(cache_config, dict) else {}
self.warm_up_page_count = max(0, min(int(prefetch_cfg.get("warm_up_page_count", 0)), 10000))
self.prefetch_page_count_on_next = max(0, min(int(prefetch_cfg.get("on_next", 0)), 10000))
self.prefetch_page_count_on_prev = max(0, min(int(prefetch_cfg.get("on_prev", 0)), 10000))
# Option highlighting configuration
option_highlighting = ai_support.get("option_highlighting", {})
self.option_highlighting_enabled = option_highlighting.get("enabled", False)
self.option_highlighting_top_k = option_highlighting.get("top_k", 3)
self.option_highlighting_dim_opacity = option_highlighting.get("dim_opacity", 0.4)
self.option_highlighting_auto_apply = option_highlighting.get("auto_apply", True)
self.option_highlighting_schemas = option_highlighting.get("schemas", None) # None means all
# Prefetch count for option highlighting — clamp to sane range
self.option_highlighting_prefetch_count = max(0, min(
int(option_highlighting.get("prefetch_count", 20)), 10000
))
# Threading
self.in_progress = {}
self.lock = threading.RLock()
self.executor = ThreadPoolExecutor(max_workers=20)
AIEndpointFactory.register_endpoint("ollama", OllamaEndpoint)
AIEndpointFactory.register_endpoint("open_router", OpenRouterEndpoint)
# Register visual AI endpoints
try:
from potato.ai.yolo_endpoint import YOLOEndpoint
AIEndpointFactory.register_endpoint("yolo", YOLOEndpoint)
except ImportError:
logger.debug("YOLO endpoint not available (ultralytics not installed)")
try:
from potato.ai.ollama_vision_endpoint import OllamaVisionEndpoint
AIEndpointFactory.register_endpoint("ollama_vision", OllamaVisionEndpoint)
except ImportError:
logger.debug("Ollama Vision endpoint not available")
try:
from potato.ai.openai_vision_endpoint import OpenAIVisionEndpoint
AIEndpointFactory.register_endpoint("openai_vision", OpenAIVisionEndpoint)
except ImportError:
logger.debug("OpenAI Vision endpoint not available")
try:
from potato.ai.anthropic_vision_endpoint import AnthropicVisionEndpoint
AIEndpointFactory.register_endpoint("anthropic_vision", AnthropicVisionEndpoint)
except ImportError:
logger.debug("Anthropic Vision endpoint not available")
# Degrade gracefully if the AI backend (e.g. a local Ollama/vLLM server)
# is unreachable at boot: log a warning and serve the task with AI
# support disabled rather than aborting server startup.
try:
self.ai_endpoint = AIEndpointFactory.create_endpoint(config)
except Exception as e:
logger.warning(
"AI endpoint unavailable at startup (%s). Continuing with AI "
"support disabled. Check that your AI backend is running.", e
)
self.ai_endpoint = None
# Create visual endpoint if different from main endpoint
self.visual_endpoint = None
visual_endpoint_type = config.get("ai_support", {}).get("visual_endpoint_type")
if visual_endpoint_type and visual_endpoint_type != config.get("ai_support", {}).get("endpoint_type"):
visual_config = {
"ai_support": {
"enabled": True,
"endpoint_type": visual_endpoint_type,
"ai_config": config.get("ai_support", {}).get("visual_ai_config", config.get("ai_support", {}).get("ai_config", {}))
}
}
try:
self.visual_endpoint = AIEndpointFactory.create_endpoint(visual_config)
except Exception as e:
logger.warning(
"Visual AI endpoint unavailable at startup (%s). Continuing "
"without visual AI support.", e
)
self.visual_endpoint = None
annotation_scheme = config.get("annotation_schemes")
self.annotations = []
for scheme in annotation_scheme:
self.annotations.append(scheme)
# Check if main endpoint supports vision
self.endpoint_supports_vision = hasattr(self.ai_endpoint, 'query_with_image')
logger.info(f"AI endpoint supports vision: {self.endpoint_supports_vision}")
# Initialize cache
if self.disk_cache_enabled:
self.load_cache_from_disk()
self.start_warmup()
def _validate_assistant_compatibility(
self, instance_id: int, annotation_id: int, ai_assistant: str
) -> tuple:
"""
Validate that the AI assistant is compatible with the input type and model capabilities.
Args:
instance_id: The instance/item index
annotation_id: The annotation scheme index
ai_assistant: Type of assistance ('hint', 'keyword', 'rationale', 'detection', etc.)
Returns:
Tuple of (is_valid: bool, error_message: str)
If valid, error_message is empty string.
"""
try:
text = _get_instance_text(instance_id)
is_image = _is_image_url(text)
# Determine which endpoint to use
if is_image and self.visual_endpoint:
endpoint = self.visual_endpoint
elif is_image and self.endpoint_supports_vision:
endpoint = self.ai_endpoint
else:
endpoint = self.ai_endpoint
# Get capabilities from endpoint
capabilities = getattr(endpoint, 'CAPABILITIES', None)
if capabilities is None:
# No capabilities declared - allow all (backward compatibility)
logger.debug(f"Endpoint {type(endpoint).__name__} has no CAPABILITIES, allowing {ai_assistant}")
return True, ""
# Check if the assistant type is supported
if not capabilities.supports_assistant(ai_assistant, is_image):
input_type = "image" if is_image else "text"
return False, (
f"Model {type(endpoint).__name__} does not support '{ai_assistant}' "
f"for {input_type} content"
)
return True, ""
except Exception as e:
logger.warning(f"Error validating assistant compatibility: {e}")
# On validation error, allow the request (fail open for now)
return True, ""
def get_endpoint_capabilities(self, for_image: bool = False) -> ModelCapabilities:
"""
Get the capabilities of the appropriate endpoint for the given input type.
Args:
for_image: Whether the input is an image
Returns:
ModelCapabilities instance, or a default permissive one if not declared
"""
if for_image and self.visual_endpoint:
endpoint = self.visual_endpoint
elif for_image and self.endpoint_supports_vision:
endpoint = self.ai_endpoint
else:
endpoint = self.ai_endpoint
capabilities = getattr(endpoint, 'CAPABILITIES', None)
if capabilities is None:
# Return permissive defaults for backward compatibility
return ModelCapabilities(
text_generation=True,
vision_input=for_image,
bounding_box_output=False,
text_classification=True,
image_classification=for_image,
rationale_generation=True,
keyword_extraction=not for_image,
)
return capabilities
def _get_ai_with_vision_support(self, text: str, prompt: str, output_format) -> str:
"""
Get AI response, using vision if text is an image URL and endpoint supports it.
"""
# Check if we should use vision
if self.endpoint_supports_vision and _is_image_url(text):
logger.debug(f"Using vision query for image URL: {text[:50]}...")
image_data = _get_image_data_from_url(text)
if image_data:
try:
return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
except Exception as e:
logger.error(f"Vision query failed: {e}")
# Fall back to text query
# Fall back to regular text query
return self.ai_endpoint.query(prompt, output_format)
def start_warmup(self):
self.start_prefetch(0, self.warm_up_page_count)
# Also prefetch option highlights if enabled
if self.option_highlighting_enabled:
self.start_option_highlight_prefetch(0, self.warm_up_page_count)
total = len(self.in_progress)
desc = "Preloading the AI"
progress_bar = tqdm(total=total, desc=desc, unit="item")
def count_completed():
return total - len(self.in_progress)
prev_done = 0
while self.in_progress:
current_done = count_completed()
progress_bar.update(current_done - prev_done)
prev_done = current_done
time.sleep(0.2)
final_done = count_completed()
if final_done > prev_done:
progress_bar.update(final_done - prev_done)
progress_bar.close()
def load_disk_cache_data(self, file_path: str) -> Dict[str, Any]:
"""loads the cache JSON from disk and returns a dictionary of stringified keys to values."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading disk cache: {e}")
return {}
def load_cache_from_disk(self):
"""Initializes disk cache file if it doesn't exist."""
if not self.disk_cache_enabled or not self.disk_persistence_path:
return
if os.path.exists(self.disk_persistence_path):
data = self.load_disk_cache_data(self.disk_persistence_path)
logger.info(f"Disk cache initialized with {len(data)} items")
else:
try:
# Create parent directory if it doesn't exist
os.makedirs(os.path.dirname(self.disk_persistence_path), exist_ok=True)
with open(self.disk_persistence_path, 'w', encoding='utf-8') as file:
json.dump({}, file)
logger.info(f"Initialized empty disk cache at {self.disk_persistence_path}")
except Exception as e:
logger.error(f"Failed to create disk cache: {e}")
def save_cache_to_disk(self, key, value):
"""saves a single key-value pair to disk cache using atomic write."""
if not self.disk_cache_enabled or not self.disk_persistence_path:
return
try:
os.makedirs(os.path.dirname(self.disk_persistence_path), exist_ok=True)
# Load existing disk data first
existing_disk_data = {}
if os.path.exists(self.disk_persistence_path):
existing_disk_data = self.load_disk_cache_data(self.disk_persistence_path)
# Add the new key-value pair
existing_disk_data[str(key)] = value
temp_path = self.disk_persistence_path + ".tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(existing_disk_data, f, indent=2, ensure_ascii=False)
os.rename(temp_path, self.disk_persistence_path)
except Exception as e:
logger.error(f"Error saving cache to disk: {e}")
def add_to_cache(self, key, value):
"""inserts a key-value into the disk cache."""
with self.lock:
if self.disk_cache_enabled:
self.save_cache_to_disk(key, value)
def get_from_cache(self, key):
"""Tries to retrieve the item from disk cache."""
with self.lock:
# Try disk cache
if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
try:
disk_data = self.load_disk_cache_data(self.disk_persistence_path)
key_str = str(key)
if key_str in disk_data:
return disk_data[key_str]
except Exception as e:
logger.error(f"Error reading from disk: {e}")
return None
def generate_likert(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
from string import Template
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description")
text = _get_instance_text(instance_id)
min_label = _get_scheme_field(annotation_id, "min_label")
max_label = _get_scheme_field(annotation_id, "max_label")
size = _get_scheme_field(annotation_id, "size")
ai_prompt = get_ai_prompt()
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
# Check if we should use vision endpoint for image-based content
if self.endpoint_supports_vision and _is_image_url(text):
logger.debug(f"Using vision for likert {ai_assistant} on image: {text[:50]}...")
image_data = _get_image_data_from_url(text)
if image_data:
# Build vision-specific prompts based on ai_assistant type
if ai_assistant == "hint":
prompt = f"""Look at this image and help with the following annotation task:
Task: {description}
Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
Please analyze the image and suggest an appropriate rating with a brief explanation.
Respond in JSON format: {{"hint": "<explanation>", "suggestive_choice": "<rating label>"}}"""
elif ai_assistant == "rationale":
prompt = f"""Look at this image and explain the reasoning for different rating choices:
Task: {description}
Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
For each possible rating, explain what visual evidence in the image would support that rating.
Respond in JSON format: {{"rationales": [{{"label": "<rating>", "reasoning": "<explanation>"}}]}}"""
elif ai_assistant == "keyword":
prompt = f"""Look at this image and identify visual features relevant to the rating task:
Task: {description}
Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size})
Identify key visual elements that would influence the rating.
Respond in JSON format: {{"keywords": ["<visual_feature_1>", "<visual_feature_2>"]}}"""
else:
prompt = f"Analyze this image for: {description}"
try:
return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
except Exception as e:
logger.error(f"Vision query failed for likert {ai_assistant}: {e}")
# Fall back to standard text-based generation
data = AnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
text=text,
description=description,
min_label=min_label,
max_label=max_label,
size=size
)
res = self.ai_endpoint.get_ai(data, output_format)
return res
def generate_multiselect(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description")
labels = _get_scheme_field(annotation_id, "labels")
text = _get_instance_text(instance_id)
ai_prompt = get_ai_prompt()
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
# Check if we should use vision endpoint for image-based content
if self.endpoint_supports_vision and _is_image_url(text):
logger.debug(f"Using vision for multiselect {ai_assistant} on image: {text[:50]}...")
image_data = _get_image_data_from_url(text)
if image_data:
# Format labels for the prompt
label_names = [l.get('name', l) if isinstance(l, dict) else l for l in labels]
labels_str = ', '.join(f'"{name}"' for name in label_names)
# Build vision-specific prompts based on ai_assistant type
if ai_assistant == "hint":
prompt = f"""Look at this image and help with the following annotation task:
Task: {description}
Available options (select all that apply): {labels_str}
Please analyze the image and suggest which options apply.
Respond in JSON format: {{"hint": "<explanation>", "suggestive_choices": ["<option1>", "<option2>"]}}"""
elif ai_assistant == "rationale":
prompt = f"""Look at this image and explain the reasoning for each option:
Task: {description}
Available options: {labels_str}
For each option, explain what visual evidence supports or contradicts it.
Respond in JSON format: {{"rationales": [{{"label": "<option>", "reasoning": "<explanation>"}}]}}"""
elif ai_assistant == "keyword":
prompt = f"""Look at this image and identify visual features for each option:
Task: {description}
Available options: {labels_str}
For each option, identify visual cues that indicate its presence.
Respond in JSON format: {{"label_keywords": [{{"label": "<option>", "keywords": ["<feature1>", "<feature2>"]}}]}}"""
else:
prompt = f"Analyze this image for: {description}. Options: {labels_str}"
try:
return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
except Exception as e:
logger.error(f"Vision query failed for multiselect {ai_assistant}: {e}")
# Fall back to standard text-based generation
data = AnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
text=text,
description=description,
labels=labels
)
res = self.ai_endpoint.get_ai(data, output_format)
return res
def generate_radio(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description")
text = _get_instance_text(instance_id)
labels = _get_scheme_field(annotation_id, "labels")
ai_prompt = get_ai_prompt()
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
# Check if we should use vision endpoint for image-based content
if self.endpoint_supports_vision and _is_image_url(text):
logger.debug(f"Using vision for radio {ai_assistant} on image: {text[:50]}...")
image_data = _get_image_data_from_url(text)
if image_data:
# Format labels for the prompt
label_names = [l.get('name', l) if isinstance(l, dict) else l for l in labels]
labels_str = ', '.join(f'"{name}"' for name in label_names)
# Build vision-specific prompts based on ai_assistant type
if ai_assistant == "hint":
prompt = f"""Look at this image and help with the following annotation task:
Task: {description}
Available options: {labels_str}
Please analyze the image and suggest the most appropriate option.
Respond in JSON format: {{"hint": "<explanation>", "suggestive_choice": "<selected option>"}}"""
elif ai_assistant == "rationale":
prompt = f"""Look at this image and explain the reasoning for each option:
Task: {description}
Available options: {labels_str}
For each option, explain what visual evidence in the image supports or contradicts it.
Respond in JSON format: {{"rationales": [{{"label": "<option>", "reasoning": "<explanation>"}}]}}"""
elif ai_assistant == "keyword":
prompt = f"""Look at this image and identify visual features for each option:
Task: {description}
Available options: {labels_str}
For each option, identify visual cues that would indicate its presence.
Respond in JSON format: {{"label_keywords": [{{"label": "<option>", "keywords": ["<feature1>", "<feature2>"]}}]}}"""
else:
prompt = f"Analyze this image for: {description}. Options: {labels_str}"
try:
return self.ai_endpoint.query_with_image(prompt, image_data, output_format)
except Exception as e:
logger.error(f"Vision query failed for radio {ai_assistant}: {e}")
# Fall back to standard text-based generation
data = AnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
text=text,
description=description,
labels=labels
)
res = self.ai_endpoint.get_ai(data, output_format)
return res
def generate_number(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description")
text = _get_instance_text(instance_id)
data = AnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
text=text,
description=description,
)
ai_prompt = get_ai_prompt();
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
res = self.ai_endpoint.get_ai(data, output_format)
return res
def generate_select(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description")
labels = _get_scheme_field(annotation_id, "labels")
text = _get_instance_text(instance_id)
data = AnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
text=text,
description=description,
labels=labels
)
ai_prompt = get_ai_prompt();
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
res = self.ai_endpoint.get_ai(data, output_format)
return res
def generate_slider(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description")
min_value = _get_scheme_field(annotation_id, "min_value")
max_value = _get_scheme_field(annotation_id, "max_value")
step = _get_scheme_field(annotation_id, "step", default=1)
text = _get_instance_text(instance_id)
data = AnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
text=text,
description=description,
min_value=min_value,
max_value=max_value,
step=step
)
ai_prompt = get_ai_prompt();
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
res = self.ai_endpoint.get_ai(data, output_format)
return res
def generate_span(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description")
labels = _get_scheme_field(annotation_id, "labels")
text = _get_instance_text(instance_id)
data = AnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
text=text,
description=description,
labels=labels
)
ai_prompt = get_ai_prompt();
logger.debug(f"Generating span annotation with labels: {labels}")
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
res = self.ai_endpoint.get_ai(data, output_format)
return res
def generate_textbox(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
logger.debug(f"Generating textbox for annotation_id: {annotation_id}")
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description")
text = _get_instance_text(instance_id)
data = AnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
text=text,
description=description,
)
ai_prompt = get_ai_prompt();
output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format"))
res = self.ai_endpoint.get_ai(data, output_format)
return res
def generate_image_annotation(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
"""Generate AI assistance for image annotation tasks.
Args:
instance_id: The instance/item index
annotation_id: The annotation scheme index
ai_assistant: Type of assistance ('detection', 'classification', 'hint', 'pre_annotate', etc.)
Returns:
Dict with AI suggestions (detections, classifications, hints, etc.)
"""
logger.debug(f"Generating image annotation for instance={instance_id}, annotation={annotation_id}, assistant={ai_assistant}")
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description", default="")
labels = _get_scheme_field(annotation_id, "labels", default=[])
# Extract label names if labels are dicts
if labels and isinstance(labels[0], dict):
labels = [l.get("name", str(l)) for l in labels]
# Get image URL from item data
item_data = get_item_state_manager().items()[instance_id].get_data()
image_url = self._extract_image_url(item_data)
if not image_url:
return {"error": "No image URL found in instance data"}
# Determine which endpoint to use
endpoint = self._get_visual_endpoint()
if not endpoint:
return {"error": "No visual AI endpoint configured"}
# Check if endpoint supports visual queries
if not hasattr(endpoint, 'query_with_image'):
# Fall back to text-based hint
return self._generate_text_hint_for_visual(instance_id, annotation_id, ai_assistant)
# Prepare image data
image_data = self._prepare_image_data(image_url)
# Get confidence threshold from config
confidence_threshold = _get_scheme_field(annotation_id, "ai_support", default={}).get(
"confidence_threshold", 0.5
)
# Build VisualAnnotationInput
data = VisualAnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
task_type=ai_assistant, # detection, classification, hint, etc.
image_data=image_data,
description=description,
labels=labels,
confidence_threshold=confidence_threshold
)
# Get output format from prompt config
ai_prompt = get_ai_prompt()
prompt_config = ai_prompt.get(annotation_type, {}).get(ai_assistant, {})
output_format_name = prompt_config.get("output_format", "visual_detection")
output_format = self.model_manager.get_model_class_by_name(output_format_name)
# Query the visual endpoint
result = endpoint.get_visual_ai(data, output_format)
return result
def generate_video_annotation(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
"""Generate AI assistance for video annotation tasks.
Args:
instance_id: The instance/item index
annotation_id: The annotation scheme index
ai_assistant: Type of assistance ('scene_detection', 'frame_classification', etc.)
Returns:
Dict with AI suggestions (segments, keyframes, etc.)
"""
logger.debug(f"Generating video annotation for instance={instance_id}, annotation={annotation_id}, assistant={ai_assistant}")
annotation_type = _get_scheme_field(annotation_id, "annotation_type")
description = _get_scheme_field(annotation_id, "description", default="")
labels = _get_scheme_field(annotation_id, "labels", default=[])
# Extract label names if labels are dicts
if labels and isinstance(labels[0], dict):
labels = [l.get("name", str(l)) for l in labels]
# Get video URL from item data
item_data = get_item_state_manager().items()[instance_id].get_data()
video_url = self._extract_video_url(item_data)
if not video_url:
return {"error": "No video URL found in instance data"}
# Determine which endpoint to use
endpoint = self._get_visual_endpoint()
if not endpoint:
return {"error": "No visual AI endpoint configured"}
# Check if endpoint supports visual queries
if not hasattr(endpoint, 'query_with_image'):
return self._generate_text_hint_for_visual(instance_id, annotation_id, ai_assistant)
# Extract video frames
try:
frames = endpoint.extract_video_frames(video_url)
video_metadata = endpoint.get_video_metadata(video_url)
except Exception as e:
logger.error(f"Failed to extract video frames: {e}")
return {"error": f"Failed to process video: {str(e)}"}
# Build VisualAnnotationInput
data = VisualAnnotationInput(
ai_assistant=ai_assistant,
annotation_type=annotation_type,
task_type=ai_assistant,
image_data=frames, # List of frame images
description=description,
labels=labels,
video_metadata=video_metadata
)
# Get output format
ai_prompt = get_ai_prompt()
prompt_config = ai_prompt.get(annotation_type, {}).get(ai_assistant, {})
output_format_name = prompt_config.get("output_format", "video_scene_detection")
output_format = self.model_manager.get_model_class_by_name(output_format_name)
# Query the visual endpoint
result = endpoint.get_visual_ai(data, output_format)
return result
def _get_visual_endpoint(self):
"""Get the appropriate endpoint for visual tasks."""
# Use dedicated visual endpoint if configured
if self.visual_endpoint:
return self.visual_endpoint
# Check if main endpoint supports vision
if hasattr(self.ai_endpoint, 'query_with_image'):
return self.ai_endpoint
# Try to find a visual endpoint from registered types
visual_types = ['yolo', 'ollama_vision', 'openai_vision', 'anthropic_vision']
for vtype in visual_types:
if vtype in AIEndpointFactory._endpoints:
try:
visual_config = {
"ai_support": {
"enabled": True,
"endpoint_type": vtype,
"ai_config": config.get("ai_support", {}).get("ai_config", {})
}
}
return AIEndpointFactory.create_endpoint(visual_config)
except Exception as e:
logger.debug(f"Could not create {vtype} endpoint: {e}")
continue
return None
def _extract_image_url(self, item_data: Dict) -> str:
"""Extract image URL from item data.
Looks for common field names that might contain image URLs.
"""
# Common field names for images
image_fields = ['image', 'image_url', 'img', 'img_url', 'url', 'path', 'file', 'src']
for field in image_fields:
if field in item_data:
value = item_data[field]
if isinstance(value, str) and (
value.startswith(('http://', 'https://', '/')) or
value.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp'))
):
return value
# Check 'text' field for URL (common in simple configs)
if 'text' in item_data:
text = item_data['text']
if isinstance(text, str) and (
text.startswith(('http://', 'https://')) and
any(ext in text.lower() for ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp'])
):
return text
return None
def _extract_video_url(self, item_data: Dict) -> str:
"""Extract video URL from item data."""
# Common field names for videos
video_fields = ['video', 'video_url', 'url', 'path', 'file', 'src', 'media']
for field in video_fields:
if field in item_data:
value = item_data[field]
if isinstance(value, str) and (
value.startswith(('http://', 'https://', '/')) or
value.endswith(('.mp4', '.webm', '.ogg', '.avi', '.mov'))
):
return value
# Check 'text' field for URL
if 'text' in item_data:
text = item_data['text']
if isinstance(text, str) and (
text.startswith(('http://', 'https://')) and
any(ext in text.lower() for ext in ['.mp4', '.webm', '.ogg', '.avi', '.mov'])
):
return text
return None
def _prepare_image_data(self, image_url: str) -> ImageData:
"""Prepare ImageData from URL or path."""
if image_url.startswith(('http://', 'https://')):
return ImageData(source="url", data=image_url)
else:
# Local file path - encode as base64
from potato.ai.visual_ai_endpoint import BaseVisualAIEndpoint
return BaseVisualAIEndpoint.encode_image_to_base64(image_url)
def _generate_text_hint_for_visual(self, instance_id: int, annotation_id: int, ai_assistant: str) -> Dict:
"""Generate text-based hint when visual endpoint is not available."""
description = config["annotation_schemes"][annotation_id].get("description", "")
labels = config["annotation_schemes"][annotation_id].get("labels", [])
if labels and isinstance(labels[0], dict):
labels = [l.get("name", str(l)) for l in labels]
return {
"hint": f"Review the {'image' if 'image' in config['annotation_schemes'][annotation_id]['annotation_type'] else 'video'} carefully. "
f"Look for: {', '.join(labels) if labels else 'relevant content'}. "
f"Task: {description}",
"suggestive_choice": ""
}
def is_option_highlighting_enabled_for_scheme(self, annotation_id: int) -> bool:
"""Check if option highlighting is enabled for a specific annotation scheme."""
if not self.option_highlighting_enabled:
return False
scheme = config["annotation_schemes"][annotation_id]
annotation_type = scheme.get("annotation_type", "")
scheme_name = scheme.get("name", "")
# Only applicable to discrete option types
discrete_types = ["radio", "multiselect", "likert", "select"]
if annotation_type not in discrete_types:
return False
# Check if schemas filter is set
if self.option_highlighting_schemas is not None:
if scheme_name not in self.option_highlighting_schemas:
return False
return True
def get_option_highlighting_config(self) -> Dict:
"""Get the option highlighting configuration for the frontend."""
return {
"enabled": self.option_highlighting_enabled,
"top_k": self.option_highlighting_top_k,
"dim_opacity": self.option_highlighting_dim_opacity,
"auto_apply": self.option_highlighting_auto_apply,
"schemas": self.option_highlighting_schemas,
"prefetch_count": self.option_highlighting_prefetch_count,
}
def generate_option_highlights(self, instance_id: int, annotation_id: int) -> Dict:
"""Generate option highlighting suggestions for an annotation.
Args:
instance_id: The instance/item index
annotation_id: The annotation scheme index
Returns:
Dict with highlighted options and configuration:
{
"highlighted": ["option1", "option2"],
"top_k": 3,
"confidence": 0.85
}
"""
from string import Template
if not self.is_option_highlighting_enabled_for_scheme(annotation_id):
return {"error": "Option highlighting not enabled for this scheme"}
annotation_type = _get_scheme_field(annotation_id, "annotation_type", default="")
description = _get_scheme_field(annotation_id, "description", default="")
labels = _get_scheme_field(annotation_id, "labels", default=[])
# Extract label names
if labels and isinstance(labels[0], dict):
label_names = [l.get("name", str(l)) for l in labels]
else:
label_names = [str(l) for l in labels]
# For likert scales, generate label names from min/max labels
if annotation_type == "likert":
size = scheme.get("size", 5)
min_label = scheme.get("min_label", "1")
max_label = scheme.get("max_label", str(size))
label_names = [f"{i+1} ({min_label if i == 0 else max_label if i == size-1 else ''})" for i in range(size)]
# Clean up empty parentheses
label_names = [l.replace(" ()", "") for l in label_names]
text = _get_instance_text(instance_id)
top_k = min(self.option_highlighting_top_k, len(label_names))
# Get prompt template
ai_prompt = get_ai_prompt()
prompt_config = ai_prompt.get("option_highlight", {}).get("option_highlight", {})
if not prompt_config:
return {"error": "Option highlight prompt not configured"}
prompt_template = prompt_config.get("prompt", "")
output_format_name = prompt_config.get("output_format", "option_highlight")
output_format = self.model_manager.get_model_class_by_name(output_format_name)
# Build the prompt with clear delimiters to mitigate prompt injection.
# The user content is wrapped in XML-style tags so the LLM can
# distinguish between instructions and untrusted data.
delimited_text = (
f"<user_content>\n{text}\n</user_content>"
)
template = Template(prompt_template)
prompt = template.safe_substitute(
text=delimited_text,
description=description,
labels=", ".join(label_names),
top_k=top_k
)
# Query the AI endpoint
try:
result = self.ai_endpoint.query(prompt, output_format)
logger.debug(f"Option highlight raw result: {result}")
# Parse the result
if isinstance(result, str):
import json as json_module
try:
# Try to parse JSON from the response
result = json_module.loads(result)
except json_module.JSONDecodeError:
# Try to extract JSON from markdown code block
if "```json" in result:
json_start = result.find("```json") + 7
json_end = result.find("```", json_start)
result = json_module.loads(result[json_start:json_end].strip())
elif "```" in result:
json_start = result.find("```") + 3
json_end = result.find("```", json_start)
result = json_module.loads(result[json_start:json_end].strip())
else:
return {"error": f"Could not parse response: {result[:100]}"}
highlighted = result.get("highlighted_options", [])
confidence = result.get("confidence", None)
# Validate highlighted options against available labels
valid_highlighted = [opt for opt in highlighted if opt in label_names]
return {
"highlighted": valid_highlighted[:top_k],
"top_k": top_k,
"confidence": confidence
}
except Exception as e:
logger.error(f"Error generating option highlights: {e}")
return {"error": str(e)}
def get_option_highlights(self, instance_id: int, annotation_id: int) -> Dict:
"""Get option highlights from cache or generate them.
Args:
instance_id: The instance/item index
annotation_id: The annotation scheme index
Returns:
Dict with highlighted options
"""
key = (instance_id, annotation_id, "option_highlight")
# Try cache first
if self.disk_cache_enabled:
cached = self.get_from_cache(key)
if cached is not None:
logger.debug(f"Option highlight cache hit for {key}")
return cached
# Generate
result = self.generate_option_highlights(instance_id, annotation_id)
# Cache if successful
if "error" not in result and self.disk_cache_enabled:
self.add_to_cache(key, result)
return result
def start_option_highlight_prefetch(self, page_id: int, prefetch_amount: int = None):
"""Prefetch option highlights for upcoming items.
Args:
page_id: Current page/instance index
prefetch_amount: Number of items to prefetch (uses config default if None)
"""
if not self.option_highlighting_enabled or not self.disk_cache_enabled:
return
if prefetch_amount is None:
prefetch_amount = self.option_highlighting_prefetch_count
ism = get_item_state_manager()
with self.lock:
# Calculate range
if prefetch_amount >= 0:
start_idx = page_id
end_idx = min(start_idx + prefetch_amount, len(ism.items()))
else:
start_idx = max(page_id + prefetch_amount, 0)
end_idx = page_id
keys = []
for i in range(start_idx, end_idx):
for annotation_id, scheme in enumerate(config["annotation_schemes"]):
if self.is_option_highlighting_enabled_for_scheme(annotation_id):
key = (i, annotation_id, "option_highlight")
# Check if not already cached or in progress
if self.get_from_cache(key) is None and key not in self.in_progress:
keys.append(key)
# Submit prefetch jobs
for key in keys:
instance_id, annotation_id, _ = key
future = self.executor.submit(self.generate_option_highlights, instance_id, annotation_id)
self.in_progress[key] = future
def callback(fut, cache_key=key):
with self.lock:
try:
result = fut.result()
if "error" not in result:
self.add_to_cache(cache_key, result)
except Exception as e:
logger.error(f"Option highlight prefetch failed for {cache_key}: {e}")
self.in_progress.pop(cache_key, None)
future.add_done_callback(callback)
if keys:
logger.debug(f"Started option highlight prefetch for {len(keys)} items")
def get_include_all(self):
return self.include_all
def get_special_include(self, page_number_int, annotation_id_int):
logger.debug(f"get_special_include: page={page_number_int}, annotation_id={annotation_id_int}")
if not self.special_includes.get(page_number_int):
return None
elif not self.special_includes.get(page_number_int).get(annotation_id_int):
return None
return self.special_includes.get(page_number_int).get(annotation_id_int)
def start_prefetch(self, page_id, prefetch_amount):
"""Prefetches a fixed number of upcoming items to warm the cache."""
if not config.get("ai_support", {}).get("enabled") or not self.disk_cache_enabled:
return
ism = get_item_state_manager()
with self.lock:
# Calculate range bounds
if prefetch_amount >= 0:
start_idx = page_id
end_idx = min(start_idx + prefetch_amount, len(ism.items()))
else:
start_idx = max(page_id - prefetch_amount, 0)
end_idx = page_id
logger.debug(f"Prefetch range: start_idx={start_idx}, end_idx={end_idx}")
keys = []
for i in range(start_idx, end_idx):
# Check if this page should be included
if not self.should_include_page(i):
continue
# Process each annotation scheme for this page
for annotation_id, scheme in enumerate(config["annotation_schemes"]):
if not self.should_include_scheme(i, annotation_id):
continue
annotation_type = scheme["annotation_type"]
ai_prompt = get_ai_prompt()
if not ai_prompt[annotation_type]:
raise Exception(f"{annotation_type} is not defined in ai_prompt")
# Generate keys for this page/scheme combination
scheme_keys = self.get_keys_for_scheme(i, annotation_type, annotation_id, ai_prompt)
keys.extend(scheme_keys)
if keys:
self.prefetch(keys)
def should_include_page(self, page_index):
"""Determine if a page should be included based on include_all and special_includes."""
if self.include_all:
return True
return page_index in self.special_includes
def should_include_scheme(self, page_index, annotation_id):
"""Determine if a scheme should be included for a given page."""
if self.include_all:
return True
# Check if page is in special_includes and scheme is specified
if page_index in self.special_includes:
page_includes = self.special_includes[page_index]
# Handle both list and dict formats for page_includes
if isinstance(page_includes, dict):
return annotation_id in page_includes
elif isinstance(page_includes, list):
return annotation_id in page_includes
return False
def get_keys_for_scheme(self, page_index, annotation_type, annotation_id, ai_prompt):
"""Get all keys for a specific page combination."""
keys = []
# Check if this page/annotation has specific overrides in special_includes
if (page_index in self.special_includes and
isinstance(self.special_includes[page_index], dict) and
annotation_id in self.special_includes[page_index]):
# Use special_includes (overrides include_all setting)
specified_keys = self.special_includes[page_index][annotation_id]
for key in specified_keys:
keys.append((page_index, annotation_id, key))
elif self.include_all:
# No specific override, so include all available keys for this annotation type
for key in ai_prompt[annotation_type]:
keys.append((page_index, annotation_id, key))
# If include_all is False and no special_include entry, return empty keys
return keys
def prefetch(self, keys: list):
"""checks if keys are already cached and asynchronously generates missing ones"""
with self.lock:
for key in keys:
if self.get_from_cache(key) is None and key not in self.in_progress:
# i, annotation_id, annotation_type, ai_prompt
instance_id, annotation_id, ai_assistant = key
future = self.executor.submit(self.compute_help, instance_id, annotation_id, ai_assistant)
self.in_progress[key] = future
def callback(fut, cache_key=key):
with self.lock:
try:
result = fut.result()
self.add_to_cache(cache_key, result)
except Exception as e:
logger.error(f"Prefetch failed for key {cache_key}: {e}")
self.in_progress.pop(cache_key, None)
future.add_done_callback(callback)
def get_ai_help(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str:
"""retrieves AI help either from cache, waits for in-progress, or computes on-demand."""
key = (instance_id, annotation_id, ai_assistant)
# Check if caching is enabled for this help type
if not self.disk_cache_enabled:
return self.compute_help(instance_id, annotation_id, ai_assistant)
# Try to get from cache if caching is enabled
cached_value = self.get_from_cache(key)
if cached_value is not None:
logger.debug(f"Cache hit for key: {key}")
return cached_value
with self.lock:
if key in self.in_progress:
future = self.in_progress[key]
else:
future = self.executor.submit(self.compute_help, instance_id, annotation_id, ai_assistant)
self.in_progress[key] = future
try:
result = future.result(timeout=60)
# Don't cache error responses
is_error_response = (
isinstance(result, str) and
(result.startswith("Unable to generate") or
result.startswith("Error:") or
"error" in result.lower()[:50])
)
if self.disk_cache_enabled and not is_error_response:
self.add_to_cache(key, result)
elif is_error_response:
logger.warning(f"Not caching error response for key {key}: {result[:100]}")
with self.lock:
self.in_progress.pop(key, None)
return result
except Exception as e:
logger.error(f"Error computing help for key {key}: {e}")
with self.lock:
self.in_progress.pop(key, None)
return f"Error: {str(e)}"
def compute_help(self, instance_id: int, annotation_id: int, ai_assistant: str):
# Validate that the assistant type is compatible with the model and input
is_valid, error_message = self._validate_assistant_compatibility(
instance_id, annotation_id, ai_assistant
)
if not is_valid:
logger.warning(f"Assistant compatibility check failed: {error_message}")
return {"error": error_message}
annotation_type_str = config["annotation_schemes"][annotation_id]["annotation_type"]
annotation_type = Annotation_Type(annotation_type_str)
if annotation_type == Annotation_Type.LIKERT:
return self.generate_likert(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.RADIO:
return self.generate_radio(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.MULTISELECT:
return self.generate_multiselect(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.NUMBER:
return self.generate_number(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.SELECT:
return self.generate_select(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.SLIDER:
return self.generate_slider(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.SPAN:
return self.generate_span(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.TEXTBOX:
return self.generate_textbox(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.IMAGE_ANNOTATION:
return self.generate_image_annotation(instance_id, annotation_id, ai_assistant)
elif annotation_type == Annotation_Type.VIDEO_ANNOTATION:
return self.generate_video_annotation(instance_id, annotation_id, ai_assistant)
else:
raise ValueError(f"Unknown annotation type: {annotation_type}")
def get_cache_stats(self) -> Dict[str, int]:
"""returns statistics on disk cache and in-progress cache entries."""
with self.lock:
disk_count = 0
if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
try:
disk_data = self.load_disk_cache_data(self.disk_persistence_path)
disk_count = len(disk_data)
except:
pass
return {
'disk_cache_enabled': self.disk_cache_enabled,
'cached_items_disk': disk_count,
'in_progress_items': len(self.in_progress)
}
def clear_cache(self):
"""clears disk cache and cancels any ongoing generation."""
with self.lock:
for future in self.in_progress.values():
future.cancel()
self.in_progress.clear()
if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path):
try:
os.remove(self.disk_persistence_path)
logger.info("Disk cache file removed")
except Exception as e:
logger.error(f"Error removing disk cache file: {e}")
logger.info("Cache cleared")