Spaces:
Sleeping
Sleeping
| import time | |
| from typing import Any, Dict, List, Optional | |
| from loguru import logger | |
| from pydantic import BaseModel | |
| from surreal_commands import CommandInput, CommandOutput, command | |
| from open_notebook.database.repository import ensure_record_id | |
| from open_notebook.domain.notebook import Source | |
| from open_notebook.domain.transformation import Transformation | |
| try: | |
| from open_notebook.graphs.source import source_graph | |
| except ImportError as e: | |
| logger.error(f"Failed to import source_graph: {e}") | |
| raise ValueError("source_graph not available") | |
| def full_model_dump(model): | |
| if isinstance(model, BaseModel): | |
| return model.model_dump() | |
| elif isinstance(model, dict): | |
| return {k: full_model_dump(v) for k, v in model.items()} | |
| elif isinstance(model, list): | |
| return [full_model_dump(item) for item in model] | |
| else: | |
| return model | |
| class SourceProcessingInput(CommandInput): | |
| source_id: str | |
| content_state: Dict[str, Any] | |
| notebook_ids: List[str] | |
| transformations: List[str] | |
| embed: bool | |
| class SourceProcessingOutput(CommandOutput): | |
| success: bool | |
| source_id: str | |
| embedded_chunks: int = 0 | |
| insights_created: int = 0 | |
| processing_time: float | |
| error_message: Optional[str] = None | |
| async def process_source_command( | |
| input_data: SourceProcessingInput, | |
| ) -> SourceProcessingOutput: | |
| """ | |
| Process source content using the source_graph workflow | |
| """ | |
| start_time = time.time() | |
| try: | |
| logger.info(f"Starting source processing for source: {input_data.source_id}") | |
| logger.info(f"Notebook IDs: {input_data.notebook_ids}") | |
| logger.info(f"Transformations: {input_data.transformations}") | |
| logger.info(f"Embed: {input_data.embed}") | |
| # 1. Load transformation objects from IDs | |
| transformations = [] | |
| for trans_id in input_data.transformations: | |
| logger.info(f"Loading transformation: {trans_id}") | |
| transformation = await Transformation.get(trans_id) | |
| if not transformation: | |
| raise ValueError(f"Transformation '{trans_id}' not found") | |
| transformations.append(transformation) | |
| logger.info(f"Loaded {len(transformations)} transformations") | |
| # 2. Get existing source record to update its command field | |
| source = await Source.get(input_data.source_id) | |
| if not source: | |
| raise ValueError(f"Source '{input_data.source_id}' not found") | |
| # Update source with command reference | |
| source.command = ( | |
| ensure_record_id(input_data.execution_context.command_id) | |
| if input_data.execution_context | |
| else None | |
| ) | |
| await source.save() | |
| logger.info(f"Updated source {source.id} with command reference") | |
| # 3. Process source with all notebooks | |
| logger.info(f"Processing source with {len(input_data.notebook_ids)} notebooks") | |
| # Execute source_graph with all notebooks | |
| result = await source_graph.ainvoke( | |
| { # type: ignore[arg-type] | |
| "content_state": input_data.content_state, | |
| "notebook_ids": input_data.notebook_ids, # Use notebook_ids (plural) as expected by SourceState | |
| "apply_transformations": transformations, | |
| "embed": input_data.embed, | |
| "source_id": input_data.source_id, # Add the source_id to the state | |
| } | |
| ) | |
| processed_source = result["source"] | |
| # 4. Gather processing results (notebook associations handled by source_graph) | |
| embedded_chunks = ( | |
| await processed_source.get_embedded_chunks() if input_data.embed else 0 | |
| ) | |
| insights_list = await processed_source.get_insights() | |
| insights_created = len(insights_list) | |
| processing_time = time.time() - start_time | |
| logger.info( | |
| f"Successfully processed source: {processed_source.id} in {processing_time:.2f}s" | |
| ) | |
| logger.info( | |
| f"Created {insights_created} insights and {embedded_chunks} embedded chunks" | |
| ) | |
| return SourceProcessingOutput( | |
| success=True, | |
| source_id=str(processed_source.id), | |
| embedded_chunks=embedded_chunks, | |
| insights_created=insights_created, | |
| processing_time=processing_time, | |
| ) | |
| except RuntimeError as e: | |
| # Transaction conflicts should be retried by surreal-commands | |
| logger.warning(f"Transaction conflict, will retry: {e}") | |
| raise | |
| except Exception as e: | |
| # Other errors are permanent failures | |
| processing_time = time.time() - start_time | |
| logger.error(f"Source processing failed: {e}") | |
| return SourceProcessingOutput( | |
| success=False, | |
| source_id=input_data.source_id, | |
| processing_time=processing_time, | |
| error_message=str(e), | |
| ) | |