Spaces:
Paused
Paused
Lưu Quang Vũ
:bug: Update function call format to avoid potential leak issues (#96)
7fdfe87 unverified | import hashlib | |
| import re | |
| from contextlib import contextmanager | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import lmdb | |
| import orjson | |
| from loguru import logger | |
| from ..models import ContentItem, ConversationInStore, Message | |
| from ..utils import g_config | |
| from ..utils.helper import ( | |
| extract_tool_calls, | |
| remove_tool_call_blocks, | |
| ) | |
| from ..utils.singleton import Singleton | |
| def _hash_message(message: Message) -> str: | |
| """ | |
| Generate a stable, canonical hash for a single message. | |
| Strips system hints, thoughts, and tool call blocks to ensure | |
| identical logical content produces the same hash regardless of format. | |
| """ | |
| core_data = { | |
| "role": message.role, | |
| "name": message.name, | |
| "tool_call_id": message.tool_call_id, | |
| } | |
| content = message.content | |
| if not content: | |
| core_data["content"] = None | |
| elif isinstance(content, str): | |
| normalized = content.replace("\r\n", "\n") | |
| normalized = LMDBConversationStore.remove_think_tags(normalized) | |
| normalized = remove_tool_call_blocks(normalized).strip() | |
| core_data["content"] = normalized if normalized else None | |
| elif isinstance(content, list): | |
| text_parts = [] | |
| for item in content: | |
| text_val = "" | |
| if isinstance(item, ContentItem) and item.type == "text": | |
| text_val = item.text or "" | |
| elif isinstance(item, dict) and item.get("type") == "text": | |
| text_val = item.get("text") or "" | |
| if text_val: | |
| text_val = text_val.replace("\r\n", "\n") | |
| text_val = LMDBConversationStore.remove_think_tags(text_val) | |
| text_val = remove_tool_call_blocks(text_val).strip() | |
| if text_val: | |
| text_parts.append(text_val) | |
| elif isinstance(item, ContentItem) and item.type in ("image_url", "file"): | |
| # For non-text items, include their unique markers to distinguish them | |
| if item.type == "image_url": | |
| text_parts.append( | |
| f"[image_url:{item.image_url.get('url') if item.image_url else ''}]" | |
| ) | |
| elif item.type == "file": | |
| text_parts.append( | |
| f"[file:{item.file.get('url') or item.file.get('filename') if item.file else ''}]" | |
| ) | |
| else: | |
| # Fallback for other dict-based content parts | |
| part_type = item.get("type") if isinstance(item, dict) else None | |
| if part_type == "image_url": | |
| url = item.get("image_url", {}).get("url") | |
| text_parts.append(f"[image_url:{url}]") | |
| elif part_type == "file": | |
| url = item.get("file", {}).get("url") or item.get("file", {}).get("filename") | |
| text_parts.append(f"[file:{url}]") | |
| combined_text = "\n".join(text_parts).replace("\r\n", "\n").strip() | |
| core_data["content"] = combined_text if combined_text else None | |
| if message.tool_calls: | |
| calls_data = [] | |
| for tc in message.tool_calls: | |
| args = tc.function.arguments or "{}" | |
| try: | |
| parsed = orjson.loads(args) | |
| canon_args = orjson.dumps(parsed, option=orjson.OPT_SORT_KEYS).decode("utf-8") | |
| except orjson.JSONDecodeError: | |
| canon_args = args | |
| calls_data.append( | |
| { | |
| "name": tc.function.name, | |
| "arguments": canon_args, | |
| } | |
| ) | |
| calls_data.sort(key=lambda x: (x["name"], x["arguments"])) | |
| core_data["tool_calls"] = calls_data | |
| else: | |
| core_data["tool_calls"] = None | |
| message_bytes = orjson.dumps(core_data, option=orjson.OPT_SORT_KEYS) | |
| digest = hashlib.sha256(message_bytes).hexdigest() | |
| return digest | |
| def _hash_conversation(client_id: str, model: str, messages: List[Message]) -> str: | |
| """Generate a hash for a list of messages and model name, tied to a specific client_id.""" | |
| combined_hash = hashlib.sha256() | |
| combined_hash.update((client_id or "").encode("utf-8")) | |
| combined_hash.update((model or "").encode("utf-8")) | |
| for message in messages: | |
| message_hash = _hash_message(message) | |
| combined_hash.update(message_hash.encode("utf-8")) | |
| return combined_hash.hexdigest() | |
| class LMDBConversationStore(metaclass=Singleton): | |
| """LMDB-based storage for Message lists with hash-based key-value operations.""" | |
| HASH_LOOKUP_PREFIX = "hash:" | |
| def __init__( | |
| self, | |
| db_path: Optional[str] = None, | |
| max_db_size: Optional[int] = None, | |
| retention_days: Optional[int] = None, | |
| ): | |
| """ | |
| Initialize LMDB store. | |
| Args: | |
| db_path: Path to LMDB database directory | |
| max_db_size: Maximum database size in bytes (default: 256 MB) | |
| retention_days: Number of days to retain conversations (default: 14, 0 disables cleanup) | |
| """ | |
| if db_path is None: | |
| db_path = g_config.storage.path | |
| if max_db_size is None: | |
| max_db_size = g_config.storage.max_size | |
| if retention_days is None: | |
| retention_days = g_config.storage.retention_days | |
| self.db_path: Path = Path(db_path) | |
| self.max_db_size: int = max_db_size | |
| self.retention_days: int = max(0, int(retention_days)) | |
| self._env: lmdb.Environment | None = None | |
| self._ensure_db_path() | |
| self._init_environment() | |
| def _ensure_db_path(self) -> None: | |
| self.db_path.parent.mkdir(parents=True, exist_ok=True) | |
| def _init_environment(self) -> None: | |
| try: | |
| self._env = lmdb.open( | |
| str(self.db_path), | |
| map_size=self.max_db_size, | |
| max_dbs=3, | |
| writemap=True, | |
| readahead=False, | |
| meminit=False, | |
| ) | |
| logger.info(f"LMDB environment initialized at {self.db_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize LMDB environment: {e}") | |
| raise | |
| def _get_transaction(self, write: bool = False): | |
| if not self._env: | |
| raise RuntimeError("LMDB environment not initialized") | |
| txn: lmdb.Transaction = self._env.begin(write=write) | |
| try: | |
| yield txn | |
| if write: | |
| txn.commit() | |
| except Exception: | |
| if write: | |
| txn.abort() | |
| raise | |
| finally: | |
| pass # Transaction is automatically cleaned up | |
| def store( | |
| self, | |
| conv: ConversationInStore, | |
| custom_key: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Store a conversation model in LMDB. | |
| Args: | |
| conv: Conversation model to store | |
| custom_key: Optional custom key, if not provided, hash will be used | |
| Returns: | |
| str: The key used to store the messages (hash or custom key) | |
| """ | |
| if not conv: | |
| raise ValueError("Messages list cannot be empty") | |
| # Sanitize messages before computing hash and storing to ensure consistency | |
| # with the search (find) logic, which also sanitizes its prefix. | |
| sanitized_messages = self.sanitize_assistant_messages(conv.messages) | |
| conv.messages = sanitized_messages | |
| # Generate hash for the message list | |
| message_hash = _hash_conversation(conv.client_id, conv.model, conv.messages) | |
| storage_key = custom_key or message_hash | |
| now = datetime.now() | |
| if conv.created_at is None: | |
| conv.created_at = now | |
| conv.updated_at = now | |
| value = orjson.dumps(conv.model_dump(mode="json")) | |
| try: | |
| with self._get_transaction(write=True) as txn: | |
| txn.put(storage_key.encode("utf-8"), value, overwrite=True) | |
| txn.put( | |
| f"{self.HASH_LOOKUP_PREFIX}{message_hash}".encode("utf-8"), | |
| storage_key.encode("utf-8"), | |
| ) | |
| logger.debug(f"Stored {len(conv.messages)} messages with key: {storage_key[:12]}") | |
| return storage_key | |
| except Exception as e: | |
| logger.error(f"Failed to store messages with key {storage_key[:12]}: {e}") | |
| raise | |
| def get(self, key: str) -> Optional[ConversationInStore]: | |
| """ | |
| Retrieve conversation data by key. | |
| Args: | |
| key: Storage key (hash or custom key) | |
| Returns: | |
| Conversation or None if not found | |
| """ | |
| try: | |
| with self._get_transaction(write=False) as txn: | |
| data = txn.get(key.encode("utf-8"), default=None) | |
| if not data: | |
| return None | |
| storage_data = orjson.loads(data) # type: ignore | |
| conv = ConversationInStore.model_validate(storage_data) | |
| logger.debug(f"Retrieved {len(conv.messages)} messages with key: {key[:12]}") | |
| return conv | |
| except Exception as e: | |
| logger.error(f"Failed to retrieve messages with key {key[:12]}: {e}") | |
| return None | |
| def find(self, model: str, messages: List[Message]) -> Optional[ConversationInStore]: | |
| """ | |
| Search conversation data by message list. | |
| """ | |
| if not messages: | |
| return None | |
| # --- Find with raw messages --- | |
| if conv := self._find_by_message_list(model, messages): | |
| logger.debug(f"Session found for '{model}' with {len(messages)} raw messages.") | |
| return conv | |
| # --- Find with cleaned messages --- | |
| cleaned_messages = self.sanitize_assistant_messages(messages) | |
| if cleaned_messages != messages: | |
| if conv := self._find_by_message_list(model, cleaned_messages): | |
| logger.debug( | |
| f"Session found for '{model}' with {len(cleaned_messages)} cleaned messages." | |
| ) | |
| return conv | |
| logger.debug(f"No session found for '{model}' with {len(messages)} messages.") | |
| return None | |
| def _find_by_message_list( | |
| self, | |
| model: str, | |
| messages: List[Message], | |
| ) -> Optional[ConversationInStore]: | |
| """Internal find implementation based on a message list.""" | |
| for c in g_config.gemini.clients: | |
| message_hash = _hash_conversation(c.id, model, messages) | |
| key = f"{self.HASH_LOOKUP_PREFIX}{message_hash}" | |
| try: | |
| with self._get_transaction(write=False) as txn: | |
| if mapped := txn.get(key.encode("utf-8")): # type: ignore | |
| return self.get(mapped.decode("utf-8")) # type: ignore | |
| except Exception as e: | |
| logger.error( | |
| f"Failed to retrieve messages by message list for hash {message_hash} and client {c.id}: {e}" | |
| ) | |
| continue | |
| if conv := self.get(message_hash): | |
| return conv | |
| return None | |
| def exists(self, key: str) -> bool: | |
| """ | |
| Check if a key exists in the store. | |
| Args: | |
| key: Storage key to check | |
| Returns: | |
| bool: True if key exists, False otherwise | |
| """ | |
| try: | |
| with self._get_transaction(write=False) as txn: | |
| return txn.get(key.encode("utf-8")) is not None | |
| except Exception as e: | |
| logger.error(f"Failed to check existence of key {key}: {e}") | |
| return False | |
| def delete(self, key: str) -> Optional[ConversationInStore]: | |
| """ | |
| Delete conversation model by key. | |
| Args: | |
| key: Storage key to delete | |
| Returns: | |
| ConversationInStore: The deleted conversation data, or None if not found | |
| """ | |
| try: | |
| with self._get_transaction(write=True) as txn: | |
| # Get data first to clean up hash mapping | |
| data = txn.get(key.encode("utf-8")) | |
| if not data: | |
| return None | |
| storage_data = orjson.loads(data) # type: ignore | |
| conv = ConversationInStore.model_validate(storage_data) | |
| message_hash = _hash_conversation(conv.client_id, conv.model, conv.messages) | |
| # Delete main data | |
| txn.delete(key.encode("utf-8")) | |
| # Clean up hash mapping if it exists | |
| if message_hash and key != message_hash: | |
| txn.delete(f"{self.HASH_LOOKUP_PREFIX}{message_hash}".encode("utf-8")) | |
| logger.debug(f"Deleted messages with key: {key[:12]}") | |
| return conv | |
| except Exception as e: | |
| logger.error(f"Failed to delete messages with key {key[:12]}: {e}") | |
| return None | |
| def keys(self, prefix: str = "", limit: Optional[int] = None) -> List[str]: | |
| """ | |
| List all keys in the store, optionally filtered by prefix. | |
| Args: | |
| prefix: Optional prefix to filter keys | |
| limit: Optional limit on number of keys returned | |
| Returns: | |
| List of keys | |
| """ | |
| keys = [] | |
| try: | |
| with self._get_transaction(write=False) as txn: | |
| cursor = txn.cursor() | |
| cursor.first() | |
| count = 0 | |
| for key, _ in cursor: | |
| key_str = key.decode("utf-8") | |
| # Skip internal hash mappings | |
| if key_str.startswith(self.HASH_LOOKUP_PREFIX): | |
| continue | |
| if not prefix or key_str.startswith(prefix): | |
| keys.append(key_str) | |
| count += 1 | |
| if limit and count >= limit: | |
| break | |
| except Exception as e: | |
| logger.error(f"Failed to list keys: {e}") | |
| return keys | |
| def cleanup_expired(self, retention_days: Optional[int] = None) -> int: | |
| """ | |
| Delete conversations older than the given retention period. | |
| Args: | |
| retention_days: Optional override for retention period in days. | |
| Returns: | |
| Number of conversations removed. | |
| """ | |
| retention_value = ( | |
| self.retention_days if retention_days is None else max(0, int(retention_days)) | |
| ) | |
| if retention_value <= 0: | |
| logger.debug("Retention cleanup skipped because retention is disabled.") | |
| return 0 | |
| cutoff = datetime.now() - timedelta(days=retention_value) | |
| expired_entries: list[tuple[str, ConversationInStore]] = [] | |
| try: | |
| with self._get_transaction(write=False) as txn: | |
| cursor = txn.cursor() | |
| for key_bytes, value_bytes in cursor: | |
| key_str = key_bytes.decode("utf-8") | |
| if key_str.startswith(self.HASH_LOOKUP_PREFIX): | |
| continue | |
| try: | |
| storage_data = orjson.loads(value_bytes) # type: ignore[arg-type] | |
| conv = ConversationInStore.model_validate(storage_data) | |
| except Exception as exc: | |
| logger.warning(f"Failed to decode record for key {key_str}: {exc}") | |
| continue | |
| timestamp = conv.created_at or conv.updated_at | |
| if not timestamp: | |
| continue | |
| if timestamp < cutoff: | |
| expired_entries.append((key_str, conv)) | |
| except Exception as exc: | |
| logger.error(f"Failed to scan LMDB for retention cleanup: {exc}") | |
| raise | |
| if not expired_entries: | |
| return 0 | |
| removed = 0 | |
| try: | |
| with self._get_transaction(write=True) as txn: | |
| for key_str, conv in expired_entries: | |
| key_bytes = key_str.encode("utf-8") | |
| if not txn.delete(key_bytes): | |
| continue | |
| message_hash = _hash_conversation(conv.client_id, conv.model, conv.messages) | |
| if message_hash and key_str != message_hash: | |
| txn.delete(f"{self.HASH_LOOKUP_PREFIX}{message_hash}".encode("utf-8")) | |
| removed += 1 | |
| except Exception as exc: | |
| logger.error(f"Failed to delete expired conversations: {exc}") | |
| raise | |
| if removed: | |
| logger.info( | |
| f"LMDB retention cleanup removed {removed} conversation(s) older than {cutoff.isoformat()}." | |
| ) | |
| return removed | |
| def stats(self) -> Dict[str, Any]: | |
| """ | |
| Get database statistics. | |
| Returns: | |
| Dict with database statistics | |
| """ | |
| if not self._env: | |
| logger.error("LMDB environment not initialized") | |
| return {} | |
| try: | |
| return self._env.stat() | |
| except Exception as e: | |
| logger.error(f"Failed to get database stats: {e}") | |
| return {} | |
| def close(self) -> None: | |
| """Close the LMDB environment.""" | |
| if self._env: | |
| self._env.close() | |
| self._env = None | |
| logger.info("LMDB environment closed") | |
| def __del__(self): | |
| """Cleanup on destruction.""" | |
| self.close() | |
| def remove_think_tags(text: str) -> str: | |
| """ | |
| Remove all <think>...</think> tags and strip whitespace. | |
| """ | |
| if not text: | |
| return text | |
| # Remove all think blocks anywhere in the text | |
| cleaned_content = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL) | |
| return cleaned_content.strip() | |
| def sanitize_assistant_messages(messages: list[Message]) -> list[Message]: | |
| """ | |
| Produce a canonical history where assistant messages are cleaned of | |
| internal markers and tool call blocks are moved to metadata. | |
| """ | |
| cleaned_messages = [] | |
| for msg in messages: | |
| if msg.role == "assistant": | |
| if isinstance(msg.content, str): | |
| text = LMDBConversationStore.remove_think_tags(msg.content) | |
| tool_calls = msg.tool_calls | |
| if not tool_calls: | |
| text, tool_calls = extract_tool_calls(text) | |
| else: | |
| text = remove_tool_call_blocks(text).strip() | |
| normalized_content = text.strip() or None | |
| if normalized_content != msg.content or tool_calls != msg.tool_calls: | |
| cleaned_msg = msg.model_copy( | |
| update={ | |
| "content": normalized_content, | |
| "tool_calls": tool_calls or None, | |
| } | |
| ) | |
| cleaned_messages.append(cleaned_msg) | |
| else: | |
| cleaned_messages.append(msg) | |
| elif isinstance(msg.content, list): | |
| new_content = [] | |
| all_extracted_calls = list(msg.tool_calls or []) | |
| changed = False | |
| for item in msg.content: | |
| if isinstance(item, ContentItem) and item.type == "text" and item.text: | |
| text = LMDBConversationStore.remove_think_tags(item.text) | |
| if not msg.tool_calls: | |
| text, extracted = extract_tool_calls(text) | |
| if extracted: | |
| all_extracted_calls.extend(extracted) | |
| changed = True | |
| else: | |
| text = remove_tool_call_blocks(text).strip() | |
| if text != item.text: | |
| changed = True | |
| item = item.model_copy(update={"text": text.strip() or None}) | |
| new_content.append(item) | |
| if changed: | |
| cleaned_messages.append( | |
| msg.model_copy( | |
| update={ | |
| "content": new_content, | |
| "tool_calls": all_extracted_calls or None, | |
| } | |
| ) | |
| ) | |
| else: | |
| cleaned_messages.append(msg) | |
| else: | |
| cleaned_messages.append(msg) | |
| else: | |
| cleaned_messages.append(msg) | |
| return cleaned_messages | |