#!/usr/bin/env uv run python """ Script to generate embeddings for existing messages that don't already have embeddings. # Note: When generating embeddings for messages, we need to consider two limits defined in the settings: # 1. MAX_EMBEDDING_TOKENS: This is the maximum number of tokens that can be included in a single message for which an embedding is generated. # If a message exceeds this limit, it will be chunked into multiple embeddings. # 2. MAX_EMBEDDING_TOKENS_PER_REQUEST: This is the maximum total number of tokens that can be included in a single request to the embedding provider. # If the total number of tokens across all messages in a batch exceeds this limit, the batch will need to be split into multiple batches. Usage: python scripts/generate_message_embeddings.py [--workspace-name WORKSPACE] [--session-name SESSION] [--peer-name PEER] """ import argparse import asyncio import os import sys # Add the project root to the path project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, project_root) import tiktoken # noqa: E402 from sqlalchemy import select # noqa: E402 from sqlalchemy.ext.asyncio import AsyncSession # noqa: E402 from src import models # noqa: E402 from src.dependencies import tracked_db # noqa: E402 from src.embedding_client import embedding_client # noqa: E402 async def get_messages_without_embeddings( db: AsyncSession, workspace_name: str | None = None, session_name: str | None = None, peer_name: str | None = None, ) -> list[models.Message]: """ Get all messages that don't have embeddings yet. Args: db: Database session workspace_name: Optional workspace name filter session_name: Optional session name filter peer_name: Optional peer name filter Returns: List of messages without embeddings """ # Query messages that don't have embeddings stmt = ( select(models.Message) .outerjoin( models.MessageEmbedding, models.Message.public_id == models.MessageEmbedding.message_id, ) .where(models.MessageEmbedding.message_id.is_(None)) # No embedding exists .order_by(models.Message.id) ) # Apply filters if provided if workspace_name: stmt = stmt.where(models.Message.workspace_name == workspace_name) if session_name: stmt = stmt.where(models.Message.session_name == session_name) if peer_name: stmt = stmt.where(models.Message.peer_name == peer_name) result = await db.execute(stmt) return list(result.scalars().all()) async def create_embeddings_for_messages( db: AsyncSession, messages: list[models.Message], ) -> int: """ Create embeddings for a batch of messages. Args: db: Database session messages: List of messages to create embeddings for Returns: Number of embeddings created """ if not messages: return 0 # Initialize tiktoken encoding (same as used in MessageCreate schema) encoding = tiktoken.get_encoding("o200k_base") # Prepare data for batch embedding with proper token encoding id_resource_dict = { message.public_id: ( message.content, encoding.encode(message.content), # Properly encode the content ) for message in messages } # Generate embeddings embedding_dict = await embedding_client.batch_embed(id_resource_dict) # Create MessageEmbedding objects embedding_objects: list[models.MessageEmbedding] = [] embeddings_created = 0 for message in messages: embeddings = embedding_dict.get(message.public_id, []) for embedding in embeddings: embedding_obj = models.MessageEmbedding( content=message.content, embedding=embedding, message_id=message.public_id, workspace_name=message.workspace_name, session_name=message.session_name, peer_name=message.peer_name, ) embedding_objects.append(embedding_obj) embeddings_created += 1 # Add to database if embedding_objects: db.add_all(embedding_objects) await db.commit() return embeddings_created async def main() -> None: parser = argparse.ArgumentParser( description="Generate embeddings for messages that don't already have them", ) parser.add_argument( "--batch-size", type=int, default=50, help="Number of messages to process in each batch (default: 50)", ) parser.add_argument( "--workspace-name", help="Only process messages from this workspace", ) parser.add_argument( "--session-name", help="Only process messages from this session", ) parser.add_argument( "--peer-name", help="Only process messages from this peer", ) args = parser.parse_args() print("Generating embeddings for messages...") if args.workspace_name: print(f" Filtering by workspace: {args.workspace_name}") else: print(" Processing all workspaces") if args.session_name: print(f" Filtering by session: {args.session_name}") if args.peer_name: print(f" Filtering by peer: {args.peer_name}") # Use tracked_db context manager for proper database session handling async with tracked_db("generate_embeddings") as db: try: # Get messages without embeddings print("Finding messages without embeddings...") messages = await get_messages_without_embeddings( db, args.workspace_name, args.session_name, args.peer_name ) if not messages: print("No messages found that need embeddings.") return print(f"Found {len(messages)} messages without embeddings.") # Process in batches batch_size = args.batch_size total_embeddings = 0 for i in range(0, len(messages), batch_size): batch = messages[i : i + batch_size] batch_num = (i // batch_size) + 1 total_batches = (len(messages) + batch_size - 1) // batch_size print( f"Processing batch {batch_num}/{total_batches} ({len(batch)} messages)..." ) embeddings_created = await create_embeddings_for_messages(db, batch) total_embeddings += embeddings_created print( f" Created {embeddings_created} embeddings for batch {batch_num}" ) print( f"\nCompleted! Created {total_embeddings} embeddings for {len(messages)} messages." ) except Exception as e: print(f"Error: {e}") sys.exit(1) if __name__ == "__main__": asyncio.run(main())