import time from typing import Any, Dict, List, Optional from loguru import logger from pydantic import BaseModel from surreal_commands import CommandInput, CommandOutput, command from open_notebook.database.repository import ensure_record_id from open_notebook.domain.notebook import Source from open_notebook.domain.transformation import Transformation try: from open_notebook.graphs.source import source_graph except ImportError as e: logger.error(f"Failed to import source_graph: {e}") raise ValueError("source_graph not available") def full_model_dump(model): if isinstance(model, BaseModel): return model.model_dump() elif isinstance(model, dict): return {k: full_model_dump(v) for k, v in model.items()} elif isinstance(model, list): return [full_model_dump(item) for item in model] else: return model class SourceProcessingInput(CommandInput): source_id: str content_state: Dict[str, Any] notebook_ids: List[str] transformations: List[str] embed: bool class SourceProcessingOutput(CommandOutput): success: bool source_id: str embedded_chunks: int = 0 insights_created: int = 0 processing_time: float error_message: Optional[str] = None @command( "process_source", app="open_notebook", retry={ "max_attempts": 5, "wait_strategy": "exponential_jitter", "wait_min": 1, "wait_max": 30, "retry_on": [RuntimeError], }, ) async def process_source_command( input_data: SourceProcessingInput, ) -> SourceProcessingOutput: """ Process source content using the source_graph workflow """ start_time = time.time() try: logger.info(f"Starting source processing for source: {input_data.source_id}") logger.info(f"Notebook IDs: {input_data.notebook_ids}") logger.info(f"Transformations: {input_data.transformations}") logger.info(f"Embed: {input_data.embed}") # 1. Load transformation objects from IDs transformations = [] for trans_id in input_data.transformations: logger.info(f"Loading transformation: {trans_id}") transformation = await Transformation.get(trans_id) if not transformation: raise ValueError(f"Transformation '{trans_id}' not found") transformations.append(transformation) logger.info(f"Loaded {len(transformations)} transformations") # 2. Get existing source record to update its command field source = await Source.get(input_data.source_id) if not source: raise ValueError(f"Source '{input_data.source_id}' not found") # Update source with command reference source.command = ( ensure_record_id(input_data.execution_context.command_id) if input_data.execution_context else None ) await source.save() logger.info(f"Updated source {source.id} with command reference") # 3. Process source with all notebooks logger.info(f"Processing source with {len(input_data.notebook_ids)} notebooks") # Execute source_graph with all notebooks result = await source_graph.ainvoke( { # type: ignore[arg-type] "content_state": input_data.content_state, "notebook_ids": input_data.notebook_ids, # Use notebook_ids (plural) as expected by SourceState "apply_transformations": transformations, "embed": input_data.embed, "source_id": input_data.source_id, # Add the source_id to the state } ) processed_source = result["source"] # 4. Gather processing results (notebook associations handled by source_graph) embedded_chunks = ( await processed_source.get_embedded_chunks() if input_data.embed else 0 ) insights_list = await processed_source.get_insights() insights_created = len(insights_list) processing_time = time.time() - start_time logger.info( f"Successfully processed source: {processed_source.id} in {processing_time:.2f}s" ) logger.info( f"Created {insights_created} insights and {embedded_chunks} embedded chunks" ) return SourceProcessingOutput( success=True, source_id=str(processed_source.id), embedded_chunks=embedded_chunks, insights_created=insights_created, processing_time=processing_time, ) except RuntimeError as e: # Transaction conflicts should be retried by surreal-commands logger.warning(f"Transaction conflict, will retry: {e}") raise except Exception as e: # Other errors are permanent failures processing_time = time.time() - start_time logger.error(f"Source processing failed: {e}") return SourceProcessingOutput( success=False, source_id=input_data.source_id, processing_time=processing_time, error_message=str(e), )