Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| from typing import Any, List, Optional | |
| from fastapi import ( | |
| APIRouter, | |
| Depends, | |
| File, | |
| Form, | |
| HTTPException, | |
| Query, | |
| UploadFile, | |
| ) | |
| from fastapi.responses import FileResponse, Response | |
| from loguru import logger | |
| from surreal_commands import execute_command_sync | |
| from api.command_service import CommandService | |
| from api.models import ( | |
| AssetModel, | |
| CreateSourceInsightRequest, | |
| SourceCreate, | |
| SourceInsightResponse, | |
| SourceListResponse, | |
| SourceResponse, | |
| SourceStatusResponse, | |
| SourceUpdate, | |
| ) | |
| from commands.source_commands import SourceProcessingInput | |
| from open_notebook.config import UPLOADS_FOLDER | |
| from open_notebook.database.repository import ensure_record_id, repo_query | |
| from open_notebook.domain.notebook import Notebook, Source | |
| from open_notebook.domain.transformation import Transformation | |
| from open_notebook.exceptions import InvalidInputError | |
| router = APIRouter() | |
| def generate_unique_filename(original_filename: str, upload_folder: str) -> str: | |
| """Generate unique filename like Streamlit app (append counter if file exists).""" | |
| file_path = Path(upload_folder) | |
| file_path.mkdir(parents=True, exist_ok=True) | |
| # Split filename and extension | |
| stem = Path(original_filename).stem | |
| suffix = Path(original_filename).suffix | |
| # Check if file exists and generate unique name | |
| counter = 0 | |
| while True: | |
| if counter == 0: | |
| new_filename = original_filename | |
| else: | |
| new_filename = f"{stem} ({counter}){suffix}" | |
| full_path = file_path / new_filename | |
| if not full_path.exists(): | |
| return str(full_path) | |
| counter += 1 | |
| async def save_uploaded_file(upload_file: UploadFile) -> str: | |
| """Save uploaded file to uploads folder and return file path.""" | |
| if not upload_file.filename: | |
| raise ValueError("No filename provided") | |
| # Generate unique filename | |
| file_path = generate_unique_filename(upload_file.filename, UPLOADS_FOLDER) | |
| try: | |
| # Save file | |
| with open(file_path, "wb") as f: | |
| content = await upload_file.read() | |
| f.write(content) | |
| logger.info(f"Saved uploaded file to: {file_path}") | |
| return file_path | |
| except Exception as e: | |
| logger.error(f"Failed to save uploaded file: {e}") | |
| # Clean up partial file if it exists | |
| if os.path.exists(file_path): | |
| os.unlink(file_path) | |
| raise | |
| def parse_source_form_data( | |
| type: str = Form(...), | |
| notebook_id: Optional[str] = Form(None), | |
| notebooks: Optional[str] = Form(None), # JSON string of notebook IDs | |
| url: Optional[str] = Form(None), | |
| content: Optional[str] = Form(None), | |
| title: Optional[str] = Form(None), | |
| transformations: Optional[str] = Form(None), # JSON string of transformation IDs | |
| embed: str = Form("false"), # Accept as string, convert to bool | |
| delete_source: str = Form("false"), # Accept as string, convert to bool | |
| async_processing: str = Form("false"), # Accept as string, convert to bool | |
| file: Optional[UploadFile] = File(None), | |
| ) -> tuple[SourceCreate, Optional[UploadFile]]: | |
| """Parse form data into SourceCreate model and return upload file separately.""" | |
| import json | |
| # Convert string booleans to actual booleans | |
| def str_to_bool(value: str) -> bool: | |
| return value.lower() in ("true", "1", "yes", "on") | |
| embed_bool = str_to_bool(embed) | |
| delete_source_bool = str_to_bool(delete_source) | |
| async_processing_bool = str_to_bool(async_processing) | |
| # Parse JSON strings | |
| notebooks_list = None | |
| if notebooks: | |
| try: | |
| notebooks_list = json.loads(notebooks) | |
| except json.JSONDecodeError: | |
| logger.error(f"DEBUG - Invalid JSON in notebooks field: {notebooks}") | |
| raise ValueError("Invalid JSON in notebooks field") | |
| transformations_list = [] | |
| if transformations: | |
| try: | |
| transformations_list = json.loads(transformations) | |
| except json.JSONDecodeError: | |
| logger.error( | |
| f"DEBUG - Invalid JSON in transformations field: {transformations}" | |
| ) | |
| raise ValueError("Invalid JSON in transformations field") | |
| # Create SourceCreate instance | |
| try: | |
| source_data = SourceCreate( | |
| type=type, | |
| notebook_id=notebook_id, | |
| notebooks=notebooks_list, | |
| url=url, | |
| content=content, | |
| title=title, | |
| file_path=None, # Will be set later if file is uploaded | |
| transformations=transformations_list, | |
| embed=embed_bool, | |
| delete_source=delete_source_bool, | |
| async_processing=async_processing_bool, | |
| ) | |
| pass # SourceCreate instance created successfully | |
| except Exception as e: | |
| logger.error(f"Failed to create SourceCreate instance: {e}") | |
| raise | |
| return source_data, file | |
| async def get_sources( | |
| notebook_id: Optional[str] = Query(None, description="Filter by notebook ID"), | |
| limit: int = Query(50, ge=1, le=100, description="Number of sources to return (1-100)"), | |
| offset: int = Query(0, ge=0, description="Number of sources to skip"), | |
| sort_by: str = Query("updated", description="Field to sort by (created or updated)"), | |
| sort_order: str = Query("desc", description="Sort order (asc or desc)"), | |
| ): | |
| """Get sources with pagination and sorting support.""" | |
| try: | |
| # Validate sort parameters | |
| if sort_by not in ["created", "updated"]: | |
| raise HTTPException(status_code=400, detail="sort_by must be 'created' or 'updated'") | |
| if sort_order.lower() not in ["asc", "desc"]: | |
| raise HTTPException(status_code=400, detail="sort_order must be 'asc' or 'desc'") | |
| # Build ORDER BY clause | |
| order_clause = f"ORDER BY {sort_by} {sort_order.upper()}" | |
| # Build the query | |
| if notebook_id: | |
| # Verify notebook exists first | |
| notebook = await Notebook.get(notebook_id) | |
| if not notebook: | |
| raise HTTPException(status_code=404, detail="Notebook not found") | |
| # Query sources for specific notebook - include command field | |
| query = f""" | |
| SELECT id, asset, created, title, updated, topics, command, | |
| (SELECT VALUE count() FROM source_insight WHERE source = $parent.id GROUP ALL)[0].count OR 0 AS insights_count, | |
| ((SELECT VALUE id FROM source_embedding WHERE source = $parent.id LIMIT 1)) != NONE AS embedded | |
| FROM (select value in from reference where out=$notebook_id) | |
| {order_clause} | |
| LIMIT $limit START $offset | |
| """ | |
| result = await repo_query( | |
| query, { | |
| "notebook_id": ensure_record_id(notebook_id), | |
| "limit": limit, | |
| "offset": offset | |
| } | |
| ) | |
| else: | |
| # Query all sources - include command field | |
| query = f""" | |
| SELECT id, asset, created, title, updated, topics, command, | |
| (SELECT VALUE count() FROM source_insight WHERE source = $parent.id GROUP ALL)[0].count OR 0 AS insights_count, | |
| ((SELECT VALUE id FROM source_embedding WHERE source = $parent.id LIMIT 1)) != NONE AS embedded | |
| FROM source | |
| {order_clause} | |
| LIMIT $limit START $offset | |
| """ | |
| result = await repo_query(query, {"limit": limit, "offset": offset}) | |
| # Extract command IDs for batch status fetching | |
| command_ids = [] | |
| command_to_source = {} | |
| for row in result: | |
| command = row.get("command") | |
| if command: | |
| command_str = str(command) | |
| command_ids.append(command_str) | |
| command_to_source[command_str] = row["id"] | |
| # Batch fetch command statuses | |
| command_statuses = {} | |
| if command_ids: | |
| try: | |
| # Get status for all commands in batch (if the library supports it) | |
| # If not, we'll fall back to individual calls, but limit concurrent requests | |
| import asyncio | |
| from surreal_commands import get_command_status | |
| async def get_status_safe(command_id: str): | |
| try: | |
| status = await get_command_status(command_id) | |
| return (command_id, status) | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to get status for command {command_id}: {e}" | |
| ) | |
| return (command_id, None) | |
| # Limit concurrent requests to avoid overwhelming the command service | |
| semaphore = asyncio.Semaphore(10) | |
| async def get_status_with_limit(command_id: str): | |
| async with semaphore: | |
| return await get_status_safe(command_id) | |
| # Fetch statuses concurrently but with limit | |
| status_tasks = [get_status_with_limit(cmd_id) for cmd_id in command_ids] | |
| status_results = await asyncio.gather( | |
| *status_tasks, return_exceptions=True | |
| ) | |
| # Process results | |
| for result_item in status_results: | |
| if isinstance(result_item, Exception): | |
| continue | |
| if isinstance(result_item, tuple) and len(result_item) == 2: | |
| cmd_id, status = result_item | |
| command_statuses[cmd_id] = status | |
| except Exception as e: | |
| logger.warning(f"Failed to batch fetch command statuses: {e}") | |
| # Convert result to response model | |
| response_list = [] | |
| for row in result: | |
| command = row.get("command") | |
| command_id = str(command) if command else None | |
| status = None | |
| processing_info = None | |
| # Get status information if command exists | |
| if command_id and command_id in command_statuses: | |
| status_obj = command_statuses[command_id] | |
| if status_obj: | |
| status = status_obj.status | |
| # Extract execution metadata from nested result structure | |
| result_data: dict[str, Any] | None = getattr(status_obj, "result", None) | |
| execution_metadata: dict[str, Any] = result_data.get("execution_metadata", {}) if isinstance(result_data, dict) else {} | |
| processing_info = { | |
| "started_at": execution_metadata.get("started_at"), | |
| "completed_at": execution_metadata.get("completed_at"), | |
| "error": getattr(status_obj, "error_message", None), | |
| } | |
| elif command_id: | |
| # Command exists but status couldn't be fetched | |
| status = "unknown" | |
| response_list.append( | |
| SourceListResponse( | |
| id=row["id"], | |
| title=row.get("title"), | |
| topics=row.get("topics") or [], | |
| asset=AssetModel( | |
| file_path=row["asset"].get("file_path") | |
| if row.get("asset") | |
| else None, | |
| url=row["asset"].get("url") if row.get("asset") else None, | |
| ) | |
| if row.get("asset") | |
| else None, | |
| embedded=row.get("embedded", False), | |
| embedded_chunks=0, # Removed from query - not needed in list view | |
| insights_count=row.get("insights_count", 0), | |
| created=str(row["created"]), | |
| updated=str(row["updated"]), | |
| # Status fields | |
| command_id=command_id, | |
| status=status, | |
| processing_info=processing_info, | |
| ) | |
| ) | |
| return response_list | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error fetching sources: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching sources: {str(e)}") | |
| async def create_source( | |
| form_data: tuple[SourceCreate, Optional[UploadFile]] = Depends( | |
| parse_source_form_data | |
| ), | |
| ): | |
| """Create a new source with support for both JSON and multipart form data.""" | |
| source_data, upload_file = form_data | |
| try: | |
| # Verify all specified notebooks exist (backward compatibility support) | |
| for notebook_id in (source_data.notebooks or []): | |
| notebook = await Notebook.get(notebook_id) | |
| if not notebook: | |
| raise HTTPException( | |
| status_code=404, detail=f"Notebook {notebook_id} not found" | |
| ) | |
| # Handle file upload if provided | |
| file_path = None | |
| if upload_file and source_data.type == "upload": | |
| try: | |
| file_path = await save_uploaded_file(upload_file) | |
| except Exception as e: | |
| logger.error(f"File upload failed: {e}") | |
| raise HTTPException( | |
| status_code=400, detail=f"File upload failed: {str(e)}" | |
| ) | |
| # Prepare content_state for processing | |
| content_state: dict[str, Any] = {} | |
| if source_data.type == "link": | |
| if not source_data.url: | |
| raise HTTPException( | |
| status_code=400, detail="URL is required for link type" | |
| ) | |
| content_state["url"] = source_data.url | |
| elif source_data.type == "upload": | |
| # Use uploaded file path or provided file_path (backward compatibility) | |
| final_file_path = file_path or source_data.file_path | |
| if not final_file_path: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="File upload or file_path is required for upload type", | |
| ) | |
| content_state["file_path"] = final_file_path | |
| content_state["delete_source"] = source_data.delete_source | |
| elif source_data.type == "text": | |
| if not source_data.content: | |
| raise HTTPException( | |
| status_code=400, detail="Content is required for text type" | |
| ) | |
| content_state["content"] = source_data.content | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid source type. Must be link, upload, or text", | |
| ) | |
| # Validate transformations exist | |
| transformation_ids = source_data.transformations or [] | |
| for trans_id in transformation_ids: | |
| transformation = await Transformation.get(trans_id) | |
| if not transformation: | |
| raise HTTPException( | |
| status_code=404, detail=f"Transformation {trans_id} not found" | |
| ) | |
| # Branch based on processing mode | |
| if source_data.async_processing: | |
| # ASYNC PATH: Create source record first, then queue command | |
| logger.info("Using async processing path") | |
| # Create minimal source record - let SurrealDB generate the ID | |
| source = Source( | |
| title=source_data.title or "Processing...", | |
| topics=[], | |
| ) | |
| await source.save() | |
| # Add source to notebooks immediately so it appears in the UI | |
| # The source_graph will skip adding duplicates | |
| for notebook_id in (source_data.notebooks or []): | |
| await source.add_to_notebook(notebook_id) | |
| try: | |
| # Import command modules to ensure they're registered | |
| import commands.source_commands # noqa: F401 | |
| # Submit command for background processing | |
| command_input = SourceProcessingInput( | |
| source_id=str(source.id), | |
| content_state=content_state, | |
| notebook_ids=source_data.notebooks, | |
| transformations=transformation_ids, | |
| embed=source_data.embed, | |
| ) | |
| command_id = await CommandService.submit_command_job( | |
| "open_notebook", # app name | |
| "process_source", # command name | |
| command_input.model_dump(), | |
| ) | |
| logger.info(f"Submitted async processing command: {command_id}") | |
| # Update source with command reference immediately | |
| # command_id already includes 'command:' prefix | |
| source.command = ensure_record_id(command_id) | |
| await source.save() | |
| # Return source with command info | |
| return SourceResponse( | |
| id=source.id or "", | |
| title=source.title, | |
| topics=source.topics or [], | |
| asset=None, # Will be populated after processing | |
| full_text=None, # Will be populated after processing | |
| embedded=False, # Will be updated after processing | |
| embedded_chunks=0, | |
| created=str(source.created), | |
| updated=str(source.updated), | |
| command_id=command_id, | |
| status="new", | |
| processing_info={"async": True, "queued": True}, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to submit async processing command: {e}") | |
| # Clean up source record on command submission failure | |
| try: | |
| await source.delete() | |
| except Exception: | |
| pass | |
| # Clean up uploaded file if we created it | |
| if file_path and upload_file: | |
| try: | |
| os.unlink(file_path) | |
| except Exception: | |
| pass | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to queue processing: {str(e)}" | |
| ) | |
| else: | |
| # SYNC PATH: Execute synchronously using execute_command_sync | |
| logger.info("Using sync processing path") | |
| try: | |
| # Import command modules to ensure they're registered | |
| import commands.source_commands # noqa: F401 | |
| # Create source record - let SurrealDB generate the ID | |
| source = Source( | |
| title=source_data.title or "Processing...", | |
| topics=[], | |
| ) | |
| await source.save() | |
| # Add source to notebooks immediately so it appears in the UI | |
| # The source_graph will skip adding duplicates | |
| for notebook_id in (source_data.notebooks or []): | |
| await source.add_to_notebook(notebook_id) | |
| # Execute command synchronously | |
| command_input = SourceProcessingInput( | |
| source_id=str(source.id), | |
| content_state=content_state, | |
| notebook_ids=source_data.notebooks, | |
| transformations=transformation_ids, | |
| embed=source_data.embed, | |
| ) | |
| result = execute_command_sync( | |
| "open_notebook", # app name | |
| "process_source", # command name | |
| command_input.model_dump(), | |
| timeout=300, # 5 minute timeout for sync processing | |
| ) | |
| if not result.is_success(): | |
| logger.error(f"Sync processing failed: {result.error_message}") | |
| # Clean up source record | |
| try: | |
| await source.delete() | |
| except Exception: | |
| pass | |
| # Clean up uploaded file if we created it | |
| if file_path and upload_file: | |
| try: | |
| os.unlink(file_path) | |
| except Exception: | |
| pass | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Processing failed: {result.error_message}", | |
| ) | |
| # Get the processed source | |
| if not source.id: | |
| raise HTTPException( | |
| status_code=500, detail="Source ID is missing" | |
| ) | |
| processed_source = await Source.get(source.id) | |
| if not processed_source: | |
| raise HTTPException( | |
| status_code=500, detail="Processed source not found" | |
| ) | |
| embedded_chunks = await processed_source.get_embedded_chunks() | |
| return SourceResponse( | |
| id=processed_source.id or "", | |
| title=processed_source.title, | |
| topics=processed_source.topics or [], | |
| asset=AssetModel( | |
| file_path=processed_source.asset.file_path | |
| if processed_source.asset | |
| else None, | |
| url=processed_source.asset.url | |
| if processed_source.asset | |
| else None, | |
| ) | |
| if processed_source.asset | |
| else None, | |
| full_text=processed_source.full_text, | |
| embedded=embedded_chunks > 0, | |
| embedded_chunks=embedded_chunks, | |
| created=str(processed_source.created), | |
| updated=str(processed_source.updated), | |
| # No command_id or status for sync processing (legacy behavior) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Sync processing failed: {e}") | |
| # Clean up uploaded file if we created it | |
| if file_path and upload_file: | |
| try: | |
| os.unlink(file_path) | |
| except Exception: | |
| pass | |
| raise | |
| except HTTPException: | |
| # Clean up uploaded file on HTTP exceptions if we created it | |
| if file_path and upload_file: | |
| try: | |
| os.unlink(file_path) | |
| except Exception: | |
| pass | |
| raise | |
| except InvalidInputError as e: | |
| # Clean up uploaded file on validation errors if we created it | |
| if file_path and upload_file: | |
| try: | |
| os.unlink(file_path) | |
| except Exception: | |
| pass | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Error creating source: {str(e)}") | |
| # Clean up uploaded file on unexpected errors if we created it | |
| if file_path and upload_file: | |
| try: | |
| os.unlink(file_path) | |
| except Exception: | |
| pass | |
| raise HTTPException(status_code=500, detail=f"Error creating source: {str(e)}") | |
| async def create_source_json(source_data: SourceCreate): | |
| """Create a new source using JSON payload (legacy endpoint for backward compatibility).""" | |
| # Convert to form data format and call main endpoint | |
| form_data = (source_data, None) | |
| return await create_source(form_data) | |
| async def _resolve_source_file(source_id: str) -> tuple[str, str]: | |
| source = await Source.get(source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| file_path = source.asset.file_path if source.asset else None | |
| if not file_path: | |
| raise HTTPException(status_code=404, detail="Source has no file to download") | |
| safe_root = os.path.realpath(UPLOADS_FOLDER) | |
| resolved_path = os.path.realpath(file_path) | |
| if not resolved_path.startswith(safe_root): | |
| logger.warning( | |
| f"Blocked download outside uploads directory for source {source_id}: {resolved_path}" | |
| ) | |
| raise HTTPException(status_code=403, detail="Access to file denied") | |
| if not os.path.exists(resolved_path): | |
| raise HTTPException(status_code=404, detail="File not found on server") | |
| filename = os.path.basename(resolved_path) | |
| return resolved_path, filename | |
| def _is_source_file_available(source: Source) -> Optional[bool]: | |
| if not source or not source.asset or not source.asset.file_path: | |
| return None | |
| file_path = source.asset.file_path | |
| safe_root = os.path.realpath(UPLOADS_FOLDER) | |
| resolved_path = os.path.realpath(file_path) | |
| if not resolved_path.startswith(safe_root): | |
| return False | |
| return os.path.exists(resolved_path) | |
| async def get_source(source_id: str): | |
| """Get a specific source by ID.""" | |
| try: | |
| source = await Source.get(source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Get status information if command exists | |
| status = None | |
| processing_info = None | |
| if source.command: | |
| try: | |
| status = await source.get_status() | |
| processing_info = await source.get_processing_progress() | |
| except Exception as e: | |
| logger.warning(f"Failed to get status for source {source_id}: {e}") | |
| status = "unknown" | |
| embedded_chunks = await source.get_embedded_chunks() | |
| # Get associated notebooks | |
| notebooks_query = await repo_query( | |
| "SELECT VALUE out FROM reference WHERE in = $source_id", | |
| {"source_id": ensure_record_id(source.id or source_id)} | |
| ) | |
| notebook_ids = [str(nb_id) for nb_id in notebooks_query] if notebooks_query else [] | |
| return SourceResponse( | |
| id=source.id or "", | |
| title=source.title, | |
| topics=source.topics or [], | |
| asset=AssetModel( | |
| file_path=source.asset.file_path if source.asset else None, | |
| url=source.asset.url if source.asset else None, | |
| ) | |
| if source.asset | |
| else None, | |
| full_text=source.full_text, | |
| embedded=embedded_chunks > 0, | |
| embedded_chunks=embedded_chunks, | |
| file_available=_is_source_file_available(source), | |
| created=str(source.created), | |
| updated=str(source.updated), | |
| # Status fields | |
| command_id=str(source.command) if source.command else None, | |
| status=status, | |
| processing_info=processing_info, | |
| # Notebook associations | |
| notebooks=notebook_ids, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error fetching source {source_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching source: {str(e)}") | |
| async def check_source_file(source_id: str): | |
| """Check if a source has a downloadable file.""" | |
| try: | |
| await _resolve_source_file(source_id) | |
| return Response(status_code=200) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error checking file for source {source_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Failed to verify file") | |
| async def download_source_file(source_id: str): | |
| """Download the original file associated with an uploaded source.""" | |
| try: | |
| resolved_path, filename = await _resolve_source_file(source_id) | |
| return FileResponse( | |
| path=resolved_path, | |
| filename=filename, | |
| media_type="application/octet-stream", | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error downloading file for source {source_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Failed to download source file") | |
| async def get_source_status(source_id: str): | |
| """Get processing status for a source.""" | |
| try: | |
| # First, verify source exists | |
| source = await Source.get(source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Check if this is a legacy source (no command) | |
| if not source.command: | |
| return SourceStatusResponse( | |
| status=None, | |
| message="Legacy source (completed before async processing)", | |
| processing_info=None, | |
| command_id=None, | |
| ) | |
| # Get command status and processing info | |
| try: | |
| status = await source.get_status() | |
| processing_info = await source.get_processing_progress() | |
| # Generate descriptive message based on status | |
| if status == "completed": | |
| message = "Source processing completed successfully" | |
| elif status == "failed": | |
| message = "Source processing failed" | |
| elif status == "running": | |
| message = "Source processing in progress" | |
| elif status == "queued": | |
| message = "Source processing queued" | |
| elif status == "unknown": | |
| message = "Source processing status unknown" | |
| else: | |
| message = f"Source processing status: {status}" | |
| return SourceStatusResponse( | |
| status=status, | |
| message=message, | |
| processing_info=processing_info, | |
| command_id=str(source.command) if source.command else None, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to get status for source {source_id}: {e}") | |
| return SourceStatusResponse( | |
| status="unknown", | |
| message="Failed to retrieve processing status", | |
| processing_info=None, | |
| command_id=str(source.command) if source.command else None, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error fetching status for source {source_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error fetching source status: {str(e)}" | |
| ) | |
| async def update_source(source_id: str, source_update: SourceUpdate): | |
| """Update a source.""" | |
| try: | |
| source = await Source.get(source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Update only provided fields | |
| if source_update.title is not None: | |
| source.title = source_update.title | |
| if source_update.topics is not None: | |
| source.topics = source_update.topics | |
| await source.save() | |
| embedded_chunks = await source.get_embedded_chunks() | |
| return SourceResponse( | |
| id=source.id or "", | |
| title=source.title, | |
| topics=source.topics or [], | |
| asset=AssetModel( | |
| file_path=source.asset.file_path if source.asset else None, | |
| url=source.asset.url if source.asset else None, | |
| ) | |
| if source.asset | |
| else None, | |
| full_text=source.full_text, | |
| embedded=embedded_chunks > 0, | |
| embedded_chunks=embedded_chunks, | |
| created=str(source.created), | |
| updated=str(source.updated), | |
| ) | |
| except HTTPException: | |
| raise | |
| except InvalidInputError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Error updating source {source_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error updating source: {str(e)}") | |
| async def retry_source_processing(source_id: str): | |
| """Retry processing for a failed or stuck source.""" | |
| try: | |
| # First, verify source exists | |
| source = await Source.get(source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Check if source already has a running command | |
| if source.command: | |
| try: | |
| status = await source.get_status() | |
| if status in ["running", "queued"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Source is already processing. Cannot retry while processing is active.", | |
| ) | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to check current status for source {source_id}: {e}" | |
| ) | |
| # Continue with retry if we can't check status | |
| # Get notebooks that this source belongs to | |
| query = "SELECT notebook FROM reference WHERE source = $source_id" | |
| references = await repo_query(query, {"source_id": source_id}) | |
| notebook_ids = [str(ref["notebook"]) for ref in references] | |
| if not notebook_ids: | |
| raise HTTPException( | |
| status_code=400, detail="Source is not associated with any notebooks" | |
| ) | |
| # Prepare content_state based on source asset | |
| content_state = {} | |
| if source.asset: | |
| if source.asset.file_path: | |
| content_state = { | |
| "file_path": source.asset.file_path, | |
| "delete_source": False, # Don't delete on retry | |
| } | |
| elif source.asset.url: | |
| content_state = {"url": source.asset.url} | |
| else: | |
| raise HTTPException( | |
| status_code=400, detail="Source asset has no file_path or url" | |
| ) | |
| else: | |
| # Check if it's a text source by trying to get full_text | |
| if source.full_text: | |
| content_state = {"content": source.full_text} | |
| else: | |
| raise HTTPException( | |
| status_code=400, detail="Cannot determine source content for retry" | |
| ) | |
| try: | |
| # Import command modules to ensure they're registered | |
| import commands.source_commands # noqa: F401 | |
| # Submit new command for background processing | |
| command_input = SourceProcessingInput( | |
| source_id=str(source.id), | |
| content_state=content_state, | |
| notebook_ids=notebook_ids, | |
| transformations=[], # Use default transformations on retry | |
| embed=True, # Always embed on retry | |
| ) | |
| command_id = await CommandService.submit_command_job( | |
| "open_notebook", # app name | |
| "process_source", # command name | |
| command_input.model_dump(), | |
| ) | |
| logger.info( | |
| f"Submitted retry processing command: {command_id} for source {source_id}" | |
| ) | |
| # Update source with new command ID | |
| source.command = ensure_record_id(f"command:{command_id}") | |
| await source.save() | |
| # Get current embedded chunks count | |
| embedded_chunks = await source.get_embedded_chunks() | |
| # Return updated source response | |
| return SourceResponse( | |
| id=source.id or "", | |
| title=source.title, | |
| topics=source.topics or [], | |
| asset=AssetModel( | |
| file_path=source.asset.file_path if source.asset else None, | |
| url=source.asset.url if source.asset else None, | |
| ) | |
| if source.asset | |
| else None, | |
| full_text=source.full_text, | |
| embedded=embedded_chunks > 0, | |
| embedded_chunks=embedded_chunks, | |
| created=str(source.created), | |
| updated=str(source.updated), | |
| command_id=command_id, | |
| status="queued", | |
| processing_info={"retry": True, "queued": True}, | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"Failed to submit retry processing command for source {source_id}: {e}" | |
| ) | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to queue retry processing: {str(e)}" | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error retrying source processing for {source_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error retrying source processing: {str(e)}" | |
| ) | |
| async def delete_source(source_id: str): | |
| """Delete a source.""" | |
| try: | |
| source = await Source.get(source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| await source.delete() | |
| return {"message": "Source deleted successfully"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting source {source_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error deleting source: {str(e)}") | |
| async def get_source_insights(source_id: str): | |
| """Get all insights for a specific source.""" | |
| try: | |
| source = await Source.get(source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| insights = await source.get_insights() | |
| return [ | |
| SourceInsightResponse( | |
| id=insight.id or "", | |
| source_id=source_id, | |
| insight_type=insight.insight_type, | |
| content=insight.content, | |
| created=str(insight.created), | |
| updated=str(insight.updated), | |
| ) | |
| for insight in insights | |
| ] | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error fetching insights for source {source_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Error fetching insights: {str(e)}" | |
| ) | |
| async def create_source_insight(source_id: str, request: CreateSourceInsightRequest): | |
| """Create a new insight for a source by running a transformation.""" | |
| try: | |
| # Get source | |
| source = await Source.get(source_id) | |
| if not source: | |
| raise HTTPException(status_code=404, detail="Source not found") | |
| # Get transformation | |
| transformation = await Transformation.get(request.transformation_id) | |
| if not transformation: | |
| raise HTTPException(status_code=404, detail="Transformation not found") | |
| # Run transformation graph | |
| from open_notebook.graphs.transformation import graph as transform_graph | |
| await transform_graph.ainvoke( | |
| input=dict(source=source, transformation=transformation) # type: ignore[arg-type] | |
| ) | |
| # Get the newly created insight (last one) | |
| insights = await source.get_insights() | |
| if insights: | |
| newest = insights[-1] | |
| return SourceInsightResponse( | |
| id=newest.id or "", | |
| source_id=source_id, | |
| insight_type=newest.insight_type, | |
| content=newest.content, | |
| created=str(newest.created), | |
| updated=str(newest.updated), | |
| ) | |
| else: | |
| raise HTTPException(status_code=500, detail="Failed to create insight") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error creating insight for source {source_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error creating insight: {str(e)}") | |