from __future__ import annotations import json import logging import os from typing import Any, Dict, Union import requests from tqdm import tqdm import time from concurrent.futures import ThreadPoolExecutor import threading from builtins import open from potato.server_utils.config_module import config logger = logging.getLogger(__name__) from potato.item_state_management import get_item_state_manager from potato.ai.ai_endpoint import ( AIEndpointFactory, Annotation_Type, AnnotationInput, ImageData, VisualAnnotationInput, ModelCapabilities, ) from potato.ai.ollama_endpoint import OllamaEndpoint from potato.ai.openrouter_endpoint import OpenRouterEndpoint from potato.ai.ai_prompt import ModelManager, get_ai_prompt AICACHEMANAGER = None def _get_scheme_field(annotation_id: int, field: str, default=None): """Safely get a field from an annotation scheme with a clear error message.""" schemes = config.get("annotation_schemes", []) if annotation_id >= len(schemes): raise ValueError( f"AI cache: annotation_id {annotation_id} out of range " f"(only {len(schemes)} scheme(s) configured)" ) scheme = schemes[annotation_id] if default is not None: return scheme.get(field, default) if field not in scheme: scheme_name = scheme.get("name", f"index {annotation_id}") scheme_type = scheme.get("annotation_type", "unknown") raise ValueError( f"AI cache: annotation scheme '{scheme_name}' (type '{scheme_type}') " f"missing required field '{field}'" ) return scheme[field] def _get_instance_text(instance_id: int) -> str: """Get the text content from an instance using the configured text_key.""" item = get_item_state_manager().items()[instance_id] item_data = item.get_data() # Get the configured text_key text_key = config.get("item_properties", {}).get("text_key", "text") # Try the configured text_key first if text_key in item_data: return item_data[text_key] # Fall back to common keys for key in ['text', 'content', 'message']: if key in item_data: return item_data[key] # Last resort: return any string value for value in item_data.values(): if isinstance(value, str): return value return str(item_data) def _is_image_url(text: str) -> bool: """Check if text appears to be an image URL.""" if not isinstance(text, str): return False text_lower = text.lower() # Check for image extensions image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp'] if any(ext in text_lower for ext in image_extensions): return True # Check for common image hosting services image_hosts = ['unsplash.com', 'imgur.com', 'flickr.com', 'picsum.photos'] if any(host in text_lower for host in image_hosts): return True # Check if URL starts with http and might be an image if text_lower.startswith(('http://', 'https://')) and 'image' in text_lower: return True return False def _get_image_data_from_url(url: str) -> ImageData: """Download image from URL and return as ImageData. Includes SSRF protection to prevent fetching from private/internal IPs. """ import base64 import ipaddress import socket from urllib.parse import urlparse # SSRF protection: validate URL scheme and resolve hostname try: parsed = urlparse(url) if parsed.scheme not in ('http', 'https'): logger.warning(f"Blocked non-HTTP image URL: {url[:100]}") return None hostname = parsed.hostname if hostname: addr_info = socket.getaddrinfo(hostname, None) for info in addr_info: ip_str = info[4][0] try: ip = ipaddress.ip_address(ip_str) if ip.is_private or ip.is_loopback or ip.is_link_local: logger.warning( f"Blocked image URL resolving to private IP: " f"{hostname} -> {ip_str}" ) return None except ValueError: pass except Exception as e: logger.warning(f"Failed to validate image URL {url[:100]}: {e}") return None try: response = requests.get(url, timeout=30) response.raise_for_status() b64_data = base64.b64encode(response.content).decode('utf-8') # Determine mime type from content-type header or URL content_type = response.headers.get('content-type', 'image/jpeg') return ImageData(source='base64', data=b64_data, mime_type=content_type) except Exception as e: logger.error(f"Failed to download image from {url}: {e}") return None def init_ai_cache_manager(): global AICACHEMANAGER if AICACHEMANAGER is None: AICACHEMANAGER = AiCacheManager() return AICACHEMANAGER def get_ai_cache_manager(): """Get the AI cache manager instance. Returns None if not initialized (AI support disabled).""" global AICACHEMANAGER return AICACHEMANAGER def clear_ai_cache_manager(): """Clear the AI cache manager singleton. Used for testing.""" global AICACHEMANAGER AICACHEMANAGER = None class AiCacheManager: def __init__(self): ai_support = config["ai_support"] if not ai_support["enabled"]: return cache_config = ai_support.get("cache_config", {}) ai_config = ai_support.get("ai_config", {}) include = ai_config.get("include") or {} special_include = include.get("special_include", None) self.include_all = include.get("all", False) self.special_includes = {} self.model_manager = ModelManager() self.model_manager.load_models_module() if special_include: for page_key, page_value in special_include.items(): # Convert string keys to integers for easier lookup page_index = int(page_key) self.special_includes[page_index] = {} for annotation_id, annotation_types in page_value.items(): annotation_id_int = int(annotation_id) self.special_includes[page_index][annotation_id_int] = annotation_types # Disk cache configuration. # F-028: tolerate a partial/absent ai_cache config (e.g. AI support # enabled for ICL with no disk_cache block) instead of crashing boot # with KeyError: 'disk_cache'. disk_cache_cfg = cache_config.get("disk_cache", {}) if isinstance(cache_config, dict) else {} self.disk_cache_enabled = disk_cache_cfg.get("enabled", False) disk_cache_path = disk_cache_cfg.get("path") if self.disk_cache_enabled and not disk_cache_path: raise Exception("You have enable disk cache, but you did not specific the path!") self.disk_persistence_path = disk_cache_path # Validate cache path stays within task directory if self.disk_persistence_path: task_dir = os.path.abspath(config.get("task_dir", ".")) cache_abs = os.path.abspath( os.path.join(task_dir, self.disk_persistence_path) if not os.path.isabs(self.disk_persistence_path) else self.disk_persistence_path ) if not cache_abs.startswith(task_dir + os.sep) and cache_abs != task_dir: raise ValueError( f"Cache path '{self.disk_persistence_path}' resolves to " f"'{cache_abs}' which is outside the task directory " f"'{task_dir}'. Path traversal is not allowed." ) # Prefetch configuration — clamp to sane ranges. # F-028: default to no prefetch when the prefetch block is absent # (e.g. cache_config: {enabled: false}) instead of KeyError on boot. prefetch_cfg = cache_config.get("prefetch", {}) if isinstance(cache_config, dict) else {} self.warm_up_page_count = max(0, min(int(prefetch_cfg.get("warm_up_page_count", 0)), 10000)) self.prefetch_page_count_on_next = max(0, min(int(prefetch_cfg.get("on_next", 0)), 10000)) self.prefetch_page_count_on_prev = max(0, min(int(prefetch_cfg.get("on_prev", 0)), 10000)) # Option highlighting configuration option_highlighting = ai_support.get("option_highlighting", {}) self.option_highlighting_enabled = option_highlighting.get("enabled", False) self.option_highlighting_top_k = option_highlighting.get("top_k", 3) self.option_highlighting_dim_opacity = option_highlighting.get("dim_opacity", 0.4) self.option_highlighting_auto_apply = option_highlighting.get("auto_apply", True) self.option_highlighting_schemas = option_highlighting.get("schemas", None) # None means all # Prefetch count for option highlighting — clamp to sane range self.option_highlighting_prefetch_count = max(0, min( int(option_highlighting.get("prefetch_count", 20)), 10000 )) # Threading self.in_progress = {} self.lock = threading.RLock() self.executor = ThreadPoolExecutor(max_workers=20) AIEndpointFactory.register_endpoint("ollama", OllamaEndpoint) AIEndpointFactory.register_endpoint("open_router", OpenRouterEndpoint) # Register visual AI endpoints try: from potato.ai.yolo_endpoint import YOLOEndpoint AIEndpointFactory.register_endpoint("yolo", YOLOEndpoint) except ImportError: logger.debug("YOLO endpoint not available (ultralytics not installed)") try: from potato.ai.ollama_vision_endpoint import OllamaVisionEndpoint AIEndpointFactory.register_endpoint("ollama_vision", OllamaVisionEndpoint) except ImportError: logger.debug("Ollama Vision endpoint not available") try: from potato.ai.openai_vision_endpoint import OpenAIVisionEndpoint AIEndpointFactory.register_endpoint("openai_vision", OpenAIVisionEndpoint) except ImportError: logger.debug("OpenAI Vision endpoint not available") try: from potato.ai.anthropic_vision_endpoint import AnthropicVisionEndpoint AIEndpointFactory.register_endpoint("anthropic_vision", AnthropicVisionEndpoint) except ImportError: logger.debug("Anthropic Vision endpoint not available") # Degrade gracefully if the AI backend (e.g. a local Ollama/vLLM server) # is unreachable at boot: log a warning and serve the task with AI # support disabled rather than aborting server startup. try: self.ai_endpoint = AIEndpointFactory.create_endpoint(config) except Exception as e: logger.warning( "AI endpoint unavailable at startup (%s). Continuing with AI " "support disabled. Check that your AI backend is running.", e ) self.ai_endpoint = None # Create visual endpoint if different from main endpoint self.visual_endpoint = None visual_endpoint_type = config.get("ai_support", {}).get("visual_endpoint_type") if visual_endpoint_type and visual_endpoint_type != config.get("ai_support", {}).get("endpoint_type"): visual_config = { "ai_support": { "enabled": True, "endpoint_type": visual_endpoint_type, "ai_config": config.get("ai_support", {}).get("visual_ai_config", config.get("ai_support", {}).get("ai_config", {})) } } try: self.visual_endpoint = AIEndpointFactory.create_endpoint(visual_config) except Exception as e: logger.warning( "Visual AI endpoint unavailable at startup (%s). Continuing " "without visual AI support.", e ) self.visual_endpoint = None annotation_scheme = config.get("annotation_schemes") self.annotations = [] for scheme in annotation_scheme: self.annotations.append(scheme) # Check if main endpoint supports vision self.endpoint_supports_vision = hasattr(self.ai_endpoint, 'query_with_image') logger.info(f"AI endpoint supports vision: {self.endpoint_supports_vision}") # Initialize cache if self.disk_cache_enabled: self.load_cache_from_disk() self.start_warmup() def _validate_assistant_compatibility( self, instance_id: int, annotation_id: int, ai_assistant: str ) -> tuple: """ Validate that the AI assistant is compatible with the input type and model capabilities. Args: instance_id: The instance/item index annotation_id: The annotation scheme index ai_assistant: Type of assistance ('hint', 'keyword', 'rationale', 'detection', etc.) Returns: Tuple of (is_valid: bool, error_message: str) If valid, error_message is empty string. """ try: text = _get_instance_text(instance_id) is_image = _is_image_url(text) # Determine which endpoint to use if is_image and self.visual_endpoint: endpoint = self.visual_endpoint elif is_image and self.endpoint_supports_vision: endpoint = self.ai_endpoint else: endpoint = self.ai_endpoint # Get capabilities from endpoint capabilities = getattr(endpoint, 'CAPABILITIES', None) if capabilities is None: # No capabilities declared - allow all (backward compatibility) logger.debug(f"Endpoint {type(endpoint).__name__} has no CAPABILITIES, allowing {ai_assistant}") return True, "" # Check if the assistant type is supported if not capabilities.supports_assistant(ai_assistant, is_image): input_type = "image" if is_image else "text" return False, ( f"Model {type(endpoint).__name__} does not support '{ai_assistant}' " f"for {input_type} content" ) return True, "" except Exception as e: logger.warning(f"Error validating assistant compatibility: {e}") # On validation error, allow the request (fail open for now) return True, "" def get_endpoint_capabilities(self, for_image: bool = False) -> ModelCapabilities: """ Get the capabilities of the appropriate endpoint for the given input type. Args: for_image: Whether the input is an image Returns: ModelCapabilities instance, or a default permissive one if not declared """ if for_image and self.visual_endpoint: endpoint = self.visual_endpoint elif for_image and self.endpoint_supports_vision: endpoint = self.ai_endpoint else: endpoint = self.ai_endpoint capabilities = getattr(endpoint, 'CAPABILITIES', None) if capabilities is None: # Return permissive defaults for backward compatibility return ModelCapabilities( text_generation=True, vision_input=for_image, bounding_box_output=False, text_classification=True, image_classification=for_image, rationale_generation=True, keyword_extraction=not for_image, ) return capabilities def _get_ai_with_vision_support(self, text: str, prompt: str, output_format) -> str: """ Get AI response, using vision if text is an image URL and endpoint supports it. """ # Check if we should use vision if self.endpoint_supports_vision and _is_image_url(text): logger.debug(f"Using vision query for image URL: {text[:50]}...") image_data = _get_image_data_from_url(text) if image_data: try: return self.ai_endpoint.query_with_image(prompt, image_data, output_format) except Exception as e: logger.error(f"Vision query failed: {e}") # Fall back to text query # Fall back to regular text query return self.ai_endpoint.query(prompt, output_format) def start_warmup(self): self.start_prefetch(0, self.warm_up_page_count) # Also prefetch option highlights if enabled if self.option_highlighting_enabled: self.start_option_highlight_prefetch(0, self.warm_up_page_count) total = len(self.in_progress) desc = "Preloading the AI" progress_bar = tqdm(total=total, desc=desc, unit="item") def count_completed(): return total - len(self.in_progress) prev_done = 0 while self.in_progress: current_done = count_completed() progress_bar.update(current_done - prev_done) prev_done = current_done time.sleep(0.2) final_done = count_completed() if final_done > prev_done: progress_bar.update(final_done - prev_done) progress_bar.close() def load_disk_cache_data(self, file_path: str) -> Dict[str, Any]: """loads the cache JSON from disk and returns a dictionary of stringified keys to values.""" try: with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: logger.error(f"Error loading disk cache: {e}") return {} def load_cache_from_disk(self): """Initializes disk cache file if it doesn't exist.""" if not self.disk_cache_enabled or not self.disk_persistence_path: return if os.path.exists(self.disk_persistence_path): data = self.load_disk_cache_data(self.disk_persistence_path) logger.info(f"Disk cache initialized with {len(data)} items") else: try: # Create parent directory if it doesn't exist os.makedirs(os.path.dirname(self.disk_persistence_path), exist_ok=True) with open(self.disk_persistence_path, 'w', encoding='utf-8') as file: json.dump({}, file) logger.info(f"Initialized empty disk cache at {self.disk_persistence_path}") except Exception as e: logger.error(f"Failed to create disk cache: {e}") def save_cache_to_disk(self, key, value): """saves a single key-value pair to disk cache using atomic write.""" if not self.disk_cache_enabled or not self.disk_persistence_path: return try: os.makedirs(os.path.dirname(self.disk_persistence_path), exist_ok=True) # Load existing disk data first existing_disk_data = {} if os.path.exists(self.disk_persistence_path): existing_disk_data = self.load_disk_cache_data(self.disk_persistence_path) # Add the new key-value pair existing_disk_data[str(key)] = value temp_path = self.disk_persistence_path + ".tmp" with open(temp_path, 'w', encoding='utf-8') as f: json.dump(existing_disk_data, f, indent=2, ensure_ascii=False) os.rename(temp_path, self.disk_persistence_path) except Exception as e: logger.error(f"Error saving cache to disk: {e}") def add_to_cache(self, key, value): """inserts a key-value into the disk cache.""" with self.lock: if self.disk_cache_enabled: self.save_cache_to_disk(key, value) def get_from_cache(self, key): """Tries to retrieve the item from disk cache.""" with self.lock: # Try disk cache if self.disk_cache_enabled and self.disk_persistence_path and os.path.exists(self.disk_persistence_path): try: disk_data = self.load_disk_cache_data(self.disk_persistence_path) key_str = str(key) if key_str in disk_data: return disk_data[key_str] except Exception as e: logger.error(f"Error reading from disk: {e}") return None def generate_likert(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str: from string import Template annotation_type = _get_scheme_field(annotation_id, "annotation_type") description = _get_scheme_field(annotation_id, "description") text = _get_instance_text(instance_id) min_label = _get_scheme_field(annotation_id, "min_label") max_label = _get_scheme_field(annotation_id, "max_label") size = _get_scheme_field(annotation_id, "size") ai_prompt = get_ai_prompt() output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format")) # Check if we should use vision endpoint for image-based content if self.endpoint_supports_vision and _is_image_url(text): logger.debug(f"Using vision for likert {ai_assistant} on image: {text[:50]}...") image_data = _get_image_data_from_url(text) if image_data: # Build vision-specific prompts based on ai_assistant type if ai_assistant == "hint": prompt = f"""Look at this image and help with the following annotation task: Task: {description} Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size}) Please analyze the image and suggest an appropriate rating with a brief explanation. Respond in JSON format: {{"hint": "", "suggestive_choice": ""}}""" elif ai_assistant == "rationale": prompt = f"""Look at this image and explain the reasoning for different rating choices: Task: {description} Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size}) For each possible rating, explain what visual evidence in the image would support that rating. Respond in JSON format: {{"rationales": [{{"label": "", "reasoning": ""}}]}}""" elif ai_assistant == "keyword": prompt = f"""Look at this image and identify visual features relevant to the rating task: Task: {description} Rating scale: {size} points, from "{min_label}" (1) to "{max_label}" ({size}) Identify key visual elements that would influence the rating. Respond in JSON format: {{"keywords": ["", ""]}}""" else: prompt = f"Analyze this image for: {description}" try: return self.ai_endpoint.query_with_image(prompt, image_data, output_format) except Exception as e: logger.error(f"Vision query failed for likert {ai_assistant}: {e}") # Fall back to standard text-based generation data = AnnotationInput( ai_assistant=ai_assistant, annotation_type=annotation_type, text=text, description=description, min_label=min_label, max_label=max_label, size=size ) res = self.ai_endpoint.get_ai(data, output_format) return res def generate_multiselect(self, instance_id: int, annotation_id: int, ai_assistant: str) -> str: annotation_type = _get_scheme_field(annotation_id, "annotation_type") description = _get_scheme_field(annotation_id, "description") labels = _get_scheme_field(annotation_id, "labels") text = _get_instance_text(instance_id) ai_prompt = get_ai_prompt() output_format = self.model_manager.get_model_class_by_name(ai_prompt[annotation_type].get(ai_assistant).get("output_format")) # Check if we should use vision endpoint for image-based content if self.endpoint_supports_vision and _is_image_url(text): logger.debug(f"Using vision for multiselect {ai_assistant} on image: {text[:50]}...") image_data = _get_image_data_from_url(text) if image_data: # Format labels for the prompt label_names = [l.get('name', l) if isinstance(l, dict) else l for l in labels] labels_str = ', '.join(f'"{name}"' for name in label_names) # Build vision-specific prompts based on ai_assistant type if ai_assistant == "hint": prompt = f"""Look at this image and help with the following annotation task: Task: {description} Available options (select all that apply): {labels_str} Please analyze the image and suggest which options apply. Respond in JSON format: {{"hint": "", "suggestive_choices": ["", ""]}}""" elif ai_assistant == "rationale": prompt = f"""Look at this image and explain the reasoning for each option: Task: {description} Available options: {labels_str} For each option, explain what visual evidence supports or contradicts it. Respond in JSON format: {{"rationales": [{{"label": "