""" MySQL-backed UserState implementation for Potato annotation platform. This module provides a database-backed implementation of the UserState interface, storing all user state data in MySQL tables for persistence and scalability. """ import logging import threading from typing import Dict, List, Set, Any, Optional, Tuple from collections import defaultdict from potato.user_state_management import UserState from potato.phase import UserPhase from potato.item_state_management import Item, Label, SpanAnnotation from .connection import DatabaseManager logger = logging.getLogger(__name__) class MysqlUserState(UserState): """ MySQL-backed implementation of UserState. This class stores all user state data in MySQL tables, providing persistence and scalability for annotation workflows. Backend feature parity gap: link annotations (SpanLink) and event annotations (EventAnnotation) are only supported by InMemoryUserState. The MySQL schema has no link_annotations or event_annotations tables, and this class does not implement add_link_annotation / add_event_annotation / get_*_annotations. Configurations that use span_link or event_annotation schemas will fail with AttributeError on the MySQL backend. """ def __init__(self, user_id: str, db_manager: DatabaseManager, max_assignments: int = -1): """ Initialize the MySQL user state. Args: user_id: Unique identifier for the user db_manager: Database manager instance max_assignments: Maximum number of assignments for this user """ self.user_id = user_id self.db_manager = db_manager self.max_assignments = max_assignments # Thread-safe cache lock self._cache_lock = threading.Lock() # Ensure user exists in database self._ensure_user_exists() # Cache for performance (protected by _cache_lock) self._instance_ordering_cache = None self._current_phase_cache = None self._current_page_cache = None self._current_instance_index_cache = None def _ensure_user_exists(self): """Create user record if it doesn't exist.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT IGNORE INTO user_states (user_id, current_phase, current_page, current_instance_index, max_assignments) VALUES (%s, %s, %s, %s, %s) """, (self.user_id, 'LOGIN', None, -1, self.max_assignments)) conn.commit() def _invalidate_cache(self): """Invalidate cached data (thread-safe).""" with self._cache_lock: self._instance_ordering_cache = None self._current_phase_cache = None self._current_page_cache = None self._current_instance_index_cache = None def advance_to_phase(self, phase: UserPhase, page: str) -> None: """Advance the user to a new phase and page.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE user_states SET current_phase = %s, current_page = %s WHERE user_id = %s """, (str(phase), page, self.user_id)) conn.commit() self._invalidate_cache() def assign_instance(self, item: Item) -> None: """Assign an instance to the user for annotation.""" instance_id = item.get_id() # Check if already assigned with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT COUNT(*) FROM user_instance_assignments WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id)) result = cursor.fetchone() if result is not None and result[0] > 0: return # Already assigned # Get next assignment order cursor.execute(""" SELECT COALESCE(MAX(assignment_order), -1) + 1 FROM user_instance_assignments WHERE user_id = %s """, (self.user_id,)) result = cursor.fetchone() next_order = result[0] if result is not None else 0 # Insert assignment cursor.execute(""" INSERT INTO user_instance_assignments (user_id, instance_id, assignment_order) VALUES (%s, %s, %s) """, (self.user_id, instance_id, next_order)) # Update current instance index if this is the first assignment cursor.execute(""" SELECT current_instance_index FROM user_states WHERE user_id = %s """, (self.user_id,)) result = cursor.fetchone() current_index = result[0] if result is not None else -1 if current_index == -1: cursor.execute(""" UPDATE user_states SET current_instance_index = 0 WHERE user_id = %s """, (self.user_id,)) conn.commit() self._invalidate_cache() def assign_instance_at_index(self, item: Item, index: int) -> bool: """Insert ``item`` at ``index`` in the user's assignment ordering. Used by quality-control injection (attention checks, gold standards). Returns False if the item is already assigned. Raises IndexError if ``index`` is outside [0, current_assignment_count]. """ instance_id = item.get_id() with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT COUNT(*) FROM user_instance_assignments WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id)) already = cursor.fetchone() if already is not None and already[0] > 0: return False cursor.execute(""" SELECT COUNT(*) FROM user_instance_assignments WHERE user_id = %s """, (self.user_id,)) count_result = cursor.fetchone() current_count = count_result[0] if count_result is not None else 0 if index < 0 or index > current_count: raise IndexError( f"assign_instance_at_index: index {index} out of range " f"[0, {current_count}]" ) # Shift later orders up to make room for the insert. cursor.execute(""" UPDATE user_instance_assignments SET assignment_order = assignment_order + 1 WHERE user_id = %s AND assignment_order >= %s """, (self.user_id, index)) cursor.execute(""" INSERT INTO user_instance_assignments (user_id, instance_id, assignment_order) VALUES (%s, %s, %s) """, (self.user_id, instance_id, index)) # Rebalance the user's cursor. cursor.execute(""" SELECT current_instance_index FROM user_states WHERE user_id = %s """, (self.user_id,)) current_index_result = cursor.fetchone() current_index = current_index_result[0] if current_index_result is not None else -1 if current_index == -1: new_index = 0 elif current_index >= index: new_index = current_index + 1 else: new_index = current_index cursor.execute(""" UPDATE user_states SET current_instance_index = %s WHERE user_id = %s """, (new_index, self.user_id)) conn.commit() self._invalidate_cache() return True def unassign_instance(self, instance_id: str) -> bool: """Remove an instance assignment from the user.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT assignment_order FROM user_instance_assignments WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id)) result = cursor.fetchone() if result is None: return False removed_order = result[0] cursor.execute(""" DELETE FROM user_instance_assignments WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id)) cursor.execute(""" UPDATE user_instance_assignments SET assignment_order = assignment_order - 1 WHERE user_id = %s AND assignment_order > %s """, (self.user_id, removed_order)) # get_current_instance_index opens its own connection, so under # READ COMMITTED (MySQL default) it reads the pre-DELETE value, # which is what the index-rebalance math below needs. The COUNT(*) # that follows runs on this outer cursor and sees the DELETE. current_index = self.get_current_instance_index() cursor.execute(""" SELECT COUNT(*) FROM user_instance_assignments WHERE user_id = %s """, (self.user_id,)) count_result = cursor.fetchone() assignment_count = count_result[0] if count_result is not None else 0 if assignment_count == 0: new_index = -1 elif current_index > removed_order: new_index = current_index - 1 elif current_index == removed_order: new_index = min(removed_order, assignment_count - 1) else: new_index = min(current_index, assignment_count - 1) cursor.execute(""" UPDATE user_states SET current_instance_index = %s WHERE user_id = %s """, (new_index, self.user_id)) conn.commit() self._invalidate_cache() return True def get_current_instance(self) -> Optional[Item]: """Get the current instance the user is annotating.""" current_index = self.get_current_instance_index() if current_index < 0: return None instance_ordering = self._get_instance_ordering() if current_index >= len(instance_ordering): return None instance_id = instance_ordering[current_index] from potato.item_state_management import get_item_state_manager return get_item_state_manager().get_item(instance_id) def get_current_instance_index(self) -> int: """Get the current instance index.""" if self._current_instance_index_cache is not None: return self._current_instance_index_cache with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT current_instance_index FROM user_states WHERE user_id = %s """, (self.user_id,)) result = cursor.fetchone() self._current_instance_index_cache = result[0] if result else -1 return self._current_instance_index_cache def get_user_id(self) -> str: """Get the user ID.""" return self.user_id def goto_prev_instance(self) -> bool: """Move to the previous instance.""" current_index = self.get_current_instance_index() if current_index > 0: with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE user_states SET current_instance_index = %s WHERE user_id = %s """, (current_index - 1, self.user_id)) conn.commit() self._invalidate_cache() return True return False def goto_next_instance(self) -> bool: """Move to the next instance.""" current_index = self.get_current_instance_index() instance_ordering = self._get_instance_ordering() if current_index < len(instance_ordering) - 1: with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE user_states SET current_instance_index = %s WHERE user_id = %s """, (current_index + 1, self.user_id)) conn.commit() self._invalidate_cache() return True return False def go_to_index(self, instance_index: int) -> None: """Move to a specific instance index.""" instance_ordering = self._get_instance_ordering() if 0 <= instance_index < len(instance_ordering): with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE user_states SET current_instance_index = %s WHERE user_id = %s """, (instance_index, self.user_id)) conn.commit() self._invalidate_cache() def get_all_annotations(self) -> Dict[str, Dict[str, Any]]: """Get all annotations for this user.""" annotations = {} # Get label annotations with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT instance_id, schema_name, label_name, label_value FROM label_annotations WHERE user_id = %s """, (self.user_id,)) for row in cursor.fetchall(): instance_id, schema_name, label_name, label_value = row if instance_id not in annotations: annotations[instance_id] = {"labels": {}, "spans": {}} if schema_name not in annotations[instance_id]["labels"]: annotations[instance_id]["labels"][schema_name] = {} annotations[instance_id]["labels"][schema_name][label_name] = label_value # Get span annotations cursor.execute(""" SELECT instance_id, schema_name, span_name, span_title, start_pos, end_pos FROM span_annotations WHERE user_id = %s """, (self.user_id,)) for row in cursor.fetchall(): instance_id, schema_name, span_name, span_title, start_pos, end_pos = row if instance_id not in annotations: annotations[instance_id] = {"labels": {}, "spans": {}} if schema_name not in annotations[instance_id]["spans"]: annotations[instance_id]["spans"][schema_name] = {} annotations[instance_id]["spans"][schema_name][span_name] = { "title": span_title, "start": start_pos, "end": end_pos } return annotations def get_label_annotations(self, instance_id: str) -> Dict[Label, Any]: """Get label annotations for a specific instance.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT schema_name, label_name, label_value FROM label_annotations WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id)) annotations = {} for row in cursor.fetchall(): schema_name, label_name, label_value = row label = Label(schema_name, label_name) annotations[label] = label_value return annotations def get_span_annotations(self, instance_id: str) -> Dict[SpanAnnotation, Any]: """Get span annotations for a specific instance.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT schema_name, span_name, span_title, start_pos, end_pos, kb_id, kb_source, kb_label FROM span_annotations WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id)) annotations = {} for row in cursor.fetchall(): schema_name, span_name, span_title, start_pos, end_pos, \ kb_id, kb_source, kb_label = row span = SpanAnnotation( schema_name, span_name, span_title, start_pos, end_pos, kb_id=kb_id, kb_source=kb_source, kb_label=kb_label, ) annotations[span] = True # Span annotations are boolean return annotations def get_current_phase_and_page(self) -> Tuple[UserPhase, Optional[str]]: """Get the current phase and page.""" if self._current_phase_cache is not None and self._current_page_cache is not None: return self._current_phase_cache, self._current_page_cache with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT current_phase, current_page FROM user_states WHERE user_id = %s """, (self.user_id,)) result = cursor.fetchone() if result: phase_str, page = result phase = UserPhase.fromstr(phase_str) self._current_phase_cache = phase self._current_page_cache = page return phase, page else: return UserPhase.LOGIN, None def get_annotation_count(self) -> int: """Get the number of annotated instances.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT COUNT(DISTINCT instance_id) FROM label_annotations WHERE user_id = %s UNION SELECT COUNT(DISTINCT instance_id) FROM span_annotations WHERE user_id = %s """, (self.user_id, self.user_id)) results = cursor.fetchall() return sum(result[0] for result in results) def get_assigned_instance_count(self) -> int: """Get the number of assigned instances.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT COUNT(*) FROM user_instance_assignments WHERE user_id = %s """, (self.user_id,)) result = cursor.fetchone() return result[0] if result is not None else 0 def get_assigned_instance_ids(self) -> Set[str]: """Get the set of assigned instance IDs.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT instance_id FROM user_instance_assignments WHERE user_id = %s ORDER BY assignment_order """, (self.user_id,)) return {row[0] for row in cursor.fetchall()} def add_label_annotation(self, instance_id: str, label: Label, value: Any) -> None: """Add a label annotation.""" phase, page = self.get_current_phase_and_page() if phase == UserPhase.ANNOTATION: # Store in label_annotations table with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO label_annotations (user_id, instance_id, schema_name, label_name, label_value) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE label_value = VALUES(label_value) """, (self.user_id, instance_id, label.get_schema(), label.get_name(), str(value))) conn.commit() else: # Store in phase_annotations table with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO phase_annotations (user_id, phase_name, page_name, schema_name, label_name, label_value) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE label_value = VALUES(label_value) """, (self.user_id, str(phase), page, label.get_schema(), label.get_name(), str(value))) conn.commit() def add_span_annotation(self, instance_id: str, span: SpanAnnotation, value: Any) -> None: """Add a span annotation.""" phase, page = self.get_current_phase_and_page() if phase == UserPhase.ANNOTATION: # Store in span_annotations table with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO span_annotations (user_id, instance_id, schema_name, span_name, span_title, start_pos, end_pos, kb_id, kb_source, kb_label) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE span_title = VALUES(span_title), start_pos = VALUES(start_pos), end_pos = VALUES(end_pos), kb_id = VALUES(kb_id), kb_source = VALUES(kb_source), kb_label = VALUES(kb_label) """, (self.user_id, instance_id, span.get_schema(), span.get_name(), span.get_title(), span.get_start(), span.get_end(), getattr(span, 'kb_id', None), getattr(span, 'kb_source', None), getattr(span, 'kb_label', None))) conn.commit() else: # For non-annotation phases, store in phase_annotations as JSON span_data = { "title": span.get_title(), "start": span.get_start(), "end": span.get_end() } import json with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO phase_annotations (user_id, phase_name, page_name, schema_name, label_name, label_value) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE label_value = VALUES(label_value) """, (self.user_id, str(phase), page, span.get_schema(), span.get_name(), json.dumps(span_data))) conn.commit() def get_annotated_instance_ids(self) -> Set[str]: """Get the set of annotated instance IDs.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT DISTINCT instance_id FROM label_annotations WHERE user_id = %s UNION SELECT DISTINCT instance_id FROM span_annotations WHERE user_id = %s """, (self.user_id, self.user_id)) return {row[0] for row in cursor.fetchall()} def has_annotated(self, instance_id: str) -> bool: """Check if the user has annotated a specific instance.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT COUNT(*) FROM label_annotations WHERE user_id = %s AND instance_id = %s UNION SELECT COUNT(*) FROM span_annotations WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id, self.user_id, instance_id)) results = cursor.fetchall() return any(result[0] > 0 for result in results) def clear_all_annotations(self) -> None: """Clear all annotations for this user.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM label_annotations WHERE user_id = %s", (self.user_id,)) cursor.execute("DELETE FROM span_annotations WHERE user_id = %s", (self.user_id,)) cursor.execute("DELETE FROM phase_annotations WHERE user_id = %s", (self.user_id,)) cursor.execute("DELETE FROM behavioral_data WHERE user_id = %s", (self.user_id,)) cursor.execute("DELETE FROM ai_hints WHERE user_id = %s", (self.user_id,)) conn.commit() def clear_instance_annotations(self, instance_id: str) -> None: """Clear all annotations for one instance.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute( "DELETE FROM label_annotations WHERE user_id = %s AND instance_id = %s", (self.user_id, instance_id), ) cursor.execute( "DELETE FROM span_annotations WHERE user_id = %s AND instance_id = %s", (self.user_id, instance_id), ) cursor.execute( "DELETE FROM behavioral_data WHERE user_id = %s AND instance_id = %s", (self.user_id, instance_id), ) cursor.execute( "DELETE FROM ai_hints WHERE user_id = %s AND instance_id = %s", (self.user_id, instance_id), ) conn.commit() def has_assignments(self) -> bool: """Check if the user has any assignments.""" return self.get_assigned_instance_count() > 0 def has_remaining_assignments(self) -> bool: """Check if the user has remaining assignments.""" from potato.item_state_management import get_item_state_manager has_available_items = get_item_state_manager().has_unlabeled_items_for_user(self) if self.max_assignments >= 0: return self.get_annotation_count() < self.max_assignments and has_available_items return has_available_items def set_max_assignments(self, max_assignments: int) -> None: """Set the maximum number of assignments.""" self.max_assignments = max_assignments with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE user_states SET max_assignments = %s WHERE user_id = %s """, (max_assignments, self.user_id)) conn.commit() def get_max_assignments(self) -> int: """Get the maximum number of assignments.""" return self.max_assignments def hint_exists(self, instance_id: str) -> bool: """Check if a hint exists for an instance.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT COUNT(*) FROM ai_hints WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id)) result = cursor.fetchone() return result is not None and result[0] > 0 def get_hint(self, instance_id: str) -> Optional[str]: """Get the hint for an instance.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT hint_text FROM ai_hints WHERE user_id = %s AND instance_id = %s """, (self.user_id, instance_id)) result = cursor.fetchone() return result[0] if result else None def cache_hint(self, instance_id: str, hint: str) -> None: """Cache a hint for an instance.""" with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO ai_hints (user_id, instance_id, hint_text) VALUES (%s, %s, %s) ON DUPLICATE KEY UPDATE hint_text = VALUES(hint_text) """, (self.user_id, instance_id, hint)) conn.commit() def _get_instance_ordering(self) -> List[str]: """Get the ordered list of assigned instance IDs.""" if self._instance_ordering_cache is not None: return self._instance_ordering_cache with self.db_manager.get_connection() as conn: cursor = conn.cursor() cursor.execute(""" SELECT instance_id FROM user_instance_assignments WHERE user_id = %s ORDER BY assignment_order """, (self.user_id,)) self._instance_ordering_cache = [row[0] for row in cursor.fetchall()] return self._instance_ordering_cache def is_at_end_index(self) -> bool: """Check if the user is at the end of their assignments.""" current_index = self.get_current_instance_index() instance_ordering = self._get_instance_ordering() return current_index == len(instance_ordering) - 1 def go_back(self) -> bool: """Move back to the previous instance.""" return self.goto_prev_instance() def go_forward(self) -> bool: """Move forward to the next instance.""" return self.goto_next_instance() def get_current_instance_id(self) -> Optional[str]: """Get the ID of the current instance.""" current_instance = self.get_current_instance() return current_instance.get_id() if current_instance else None def get_labels(self) -> Dict[str, Dict[str, str]]: """Get all labels (deprecated, use get_all_annotations).""" annotations = self.get_all_annotations() labels = {} for instance_id, data in annotations.items(): labels[instance_id] = data.get("labels", {}) return labels