Spaces:
Sleeping
Sleeping
| """ | |
| Knowledge Base API Router | |
| ========================= | |
| Handles knowledge base CRUD operations, file uploads, and initialization. | |
| """ | |
| import asyncio | |
| from datetime import datetime | |
| import os | |
| from pathlib import Path | |
| import traceback | |
| from uuid import uuid4 | |
| from fastapi import ( | |
| APIRouter, | |
| BackgroundTasks, | |
| File, | |
| Form, | |
| HTTPException, | |
| UploadFile, | |
| WebSocket, | |
| WebSocketDisconnect, | |
| ) | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from deeptutor.api.utils.progress_broadcaster import ProgressBroadcaster | |
| from deeptutor.api.utils.task_id_manager import TaskIDManager | |
| from deeptutor.api.utils.task_log_stream import capture_task_logs, get_task_stream_manager | |
| from deeptutor.knowledge.add_documents import DocumentAdder | |
| from deeptutor.knowledge.initializer import KnowledgeBaseInitializer | |
| from deeptutor.knowledge.manager import KnowledgeBaseManager | |
| from deeptutor.knowledge.progress_tracker import ProgressStage, ProgressTracker | |
| from deeptutor.services.rag.components.routing import FileTypeRouter | |
| from deeptutor.services.rag.factory import DEFAULT_PROVIDER, has_pipeline, normalize_provider_name | |
| from deeptutor.utils.document_validator import DocumentValidator | |
| from deeptutor.utils.error_utils import format_exception_message | |
| from deeptutor.logging import get_logger | |
| from deeptutor.services.config import PROJECT_ROOT, load_config_with_main | |
| # Initialize logger with config | |
| config = load_config_with_main("main.yaml", PROJECT_ROOT) | |
| log_dir = config.get("paths", {}).get("user_log_dir") or config.get("logging", {}).get("log_dir") | |
| logger = get_logger("Knowledge", level="INFO", log_dir=log_dir) | |
| router = APIRouter() | |
| # Constants for byte conversions | |
| BYTES_PER_GB = 1024**3 | |
| BYTES_PER_MB = 1024**2 | |
| def format_bytes_human_readable(size_bytes: int) -> str: | |
| """Format bytes into human-readable string (GB, MB, or bytes).""" | |
| if size_bytes >= BYTES_PER_GB: | |
| return f"{size_bytes / BYTES_PER_GB:.1f} GB" | |
| elif size_bytes >= BYTES_PER_MB: | |
| return f"{size_bytes / BYTES_PER_MB:.1f} MB" | |
| else: | |
| return f"{size_bytes} bytes" | |
| _kb_base_dir = PROJECT_ROOT / "data" / "knowledge_bases" | |
| # Lazy initialization | |
| kb_manager = None | |
| def get_kb_manager(): | |
| """Get KnowledgeBaseManager instance (lazy init)""" | |
| global kb_manager | |
| if kb_manager is None: | |
| kb_manager = KnowledgeBaseManager(base_dir=str(_kb_base_dir)) | |
| return kb_manager | |
| class KnowledgeBaseInfo(BaseModel): | |
| name: str | |
| is_default: bool | |
| statistics: dict | |
| status: str | None = None | |
| progress: dict | None = None | |
| class LinkFolderRequest(BaseModel): | |
| """Request model for linking a local folder to a KB.""" | |
| folder_path: str | |
| class LinkedFolderInfo(BaseModel): | |
| """Response model for linked folder information.""" | |
| id: str | |
| path: str | |
| added_at: str | |
| file_count: int | |
| def _build_unique_task_id(task_type: str, task_key_prefix: str) -> str: | |
| task_manager = TaskIDManager.get_instance() | |
| task_key = f"{task_key_prefix}_{datetime.now().isoformat()}_{uuid4().hex[:8]}" | |
| return task_manager.generate_task_id(task_type, task_key) | |
| def _save_uploaded_files( | |
| files: list[UploadFile], | |
| target_dir: Path, | |
| allowed_extensions: set[str] | None = None, | |
| ) -> tuple[list[str], list[str]]: | |
| uploaded_files: list[str] = [] | |
| uploaded_file_paths: list[str] = [] | |
| for file in files: | |
| file_path = None | |
| original_filename = file.filename or "upload" | |
| try: | |
| sanitized_filename = DocumentValidator.validate_upload_safety( | |
| original_filename, None, allowed_extensions=allowed_extensions | |
| ) | |
| file.filename = sanitized_filename | |
| file_path = target_dir / sanitized_filename | |
| max_size = DocumentValidator.MAX_FILE_SIZE | |
| written_bytes = 0 | |
| with open(file_path, "wb") as buffer: | |
| for chunk in iter(lambda: file.file.read(8192), b""): | |
| written_bytes += len(chunk) | |
| if written_bytes > max_size: | |
| size_str = format_bytes_human_readable(max_size) | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File '{sanitized_filename}' exceeds maximum size limit of {size_str}", | |
| ) | |
| buffer.write(chunk) | |
| DocumentValidator.validate_upload_safety( | |
| sanitized_filename, written_bytes, allowed_extensions=allowed_extensions | |
| ) | |
| uploaded_files.append(sanitized_filename) | |
| uploaded_file_paths.append(str(file_path)) | |
| except Exception as e: | |
| if file_path and file_path.exists(): | |
| try: | |
| os.unlink(file_path) | |
| except OSError: | |
| pass | |
| error_message = ( | |
| f"Validation failed for file '{original_filename}': {format_exception_message(e)}" | |
| ) | |
| logger.error(error_message, exc_info=True) | |
| raise HTTPException(status_code=400, detail=error_message) from e | |
| return uploaded_files, uploaded_file_paths | |
| def _task_log(task_id: str, message: str, level: str = "info") -> None: | |
| manager = get_task_stream_manager() | |
| manager.ensure_task(task_id) | |
| manager.emit_log(task_id, message) | |
| log_method = getattr(logger, level, None) | |
| if callable(log_method): | |
| log_method(f"[{task_id}] {message}") | |
| else: | |
| logger.info(f"[{task_id}] {message}") | |
| def _validate_registered_provider(raw_provider: str | None) -> str: | |
| candidate = (raw_provider or DEFAULT_PROVIDER).strip().lower() | |
| if not candidate: | |
| candidate = DEFAULT_PROVIDER | |
| if not has_pipeline(candidate): | |
| from deeptutor.services.rag.service import RAGService | |
| available_providers = [item["id"] for item in RAGService.list_providers()] | |
| raise HTTPException( | |
| status_code=400, | |
| detail=( | |
| f"Unsupported RAG provider '{candidate}'. " | |
| f"Available providers: {available_providers}" | |
| ), | |
| ) | |
| return normalize_provider_name(candidate) | |
| def _load_kb_entry_or_404(manager: KnowledgeBaseManager, kb_name: str) -> dict: | |
| manager.config = manager._load_config() | |
| kb_entry = manager.config.get("knowledge_bases", {}).get(kb_name) | |
| if kb_entry is None: | |
| raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") | |
| return kb_entry | |
| def _assert_kb_writable_or_409(kb_name: str, kb_entry: dict) -> None: | |
| if bool(kb_entry.get("needs_reindex", False)): | |
| raise HTTPException( | |
| status_code=409, | |
| detail=( | |
| f"Knowledge base '{kb_name}' uses legacy index format and needs reindex " | |
| "before accepting incremental uploads." | |
| ), | |
| ) | |
| async def run_initialization_task(initializer: KnowledgeBaseInitializer, task_id: str): | |
| """Background task for knowledge base initialization""" | |
| task_manager = TaskIDManager.get_instance() | |
| task_stream_manager = get_task_stream_manager() | |
| task_stream_manager.ensure_task(task_id) | |
| with capture_task_logs(task_id): | |
| try: | |
| if not initializer.progress_tracker: | |
| initializer.progress_tracker = ProgressTracker( | |
| initializer.kb_name, initializer.base_dir | |
| ) | |
| initializer.progress_tracker.task_id = task_id | |
| _task_log(task_id, f"Initializing knowledge base '{initializer.kb_name}'") | |
| await initializer.process_documents() | |
| _task_log(task_id, "Document processing complete") | |
| initializer.extract_numbered_items() | |
| _task_log(task_id, "Finalizing initialization") | |
| initializer.progress_tracker.update( | |
| ProgressStage.COMPLETED, "Knowledge base initialization complete!", current=1, total=1 | |
| ) | |
| manager = get_kb_manager() | |
| manager.update_kb_status( | |
| name=initializer.kb_name, | |
| status="ready", | |
| progress={ | |
| "stage": "completed", | |
| "message": "Knowledge base initialization complete!", | |
| "percent": 100, | |
| "current": 1, | |
| "total": 1, | |
| "task_id": task_id, | |
| "timestamp": datetime.now().isoformat(), | |
| }, | |
| ) | |
| _task_log(task_id, f"Knowledge base '{initializer.kb_name}' initialized", level="success") | |
| task_manager.update_task_status(task_id, "completed") | |
| task_stream_manager.emit_complete( | |
| task_id, f"Knowledge base '{initializer.kb_name}' initialization complete" | |
| ) | |
| except Exception as e: | |
| error_msg = str(e) | |
| _task_log(task_id, f"Initialization failed: {error_msg}", level="error") | |
| task_manager.update_task_status(task_id, "error", error=error_msg) | |
| manager = get_kb_manager() | |
| manager.update_kb_status( | |
| name=initializer.kb_name, | |
| status="error", | |
| progress={ | |
| "stage": "error", | |
| "message": f"Initialization failed: {error_msg}", | |
| "percent": 0, | |
| "error": error_msg, | |
| "task_id": task_id, | |
| "timestamp": datetime.now().isoformat(), | |
| }, | |
| ) | |
| if initializer.progress_tracker: | |
| initializer.progress_tracker.update( | |
| ProgressStage.ERROR, f"Initialization failed: {error_msg}", error=error_msg | |
| ) | |
| task_stream_manager.emit_failed(task_id, error_msg) | |
| async def run_upload_processing_task( | |
| kb_name: str, | |
| base_dir: str, | |
| uploaded_file_paths: list[str], | |
| task_id: str, | |
| rag_provider: str = None, | |
| folder_id: str = None, | |
| ): | |
| """Background task for processing uploaded files. | |
| Args: | |
| kb_name: Knowledge base name | |
| base_dir: Base directory for knowledge bases | |
| uploaded_file_paths: List of file paths to process | |
| rag_provider: RAG provider (ignored - we use the one from KB metadata) | |
| folder_id: Optional folder ID for sync state update | |
| """ | |
| task_manager = TaskIDManager.get_instance() | |
| task_stream_manager = get_task_stream_manager() | |
| task_stream_manager.ensure_task(task_id) | |
| progress_tracker = ProgressTracker(kb_name, Path(base_dir)) | |
| progress_tracker.task_id = task_id | |
| with capture_task_logs(task_id): | |
| try: | |
| _task_log(task_id, f"Processing {len(uploaded_file_paths)} file(s) for KB '{kb_name}'") | |
| progress_tracker.update( | |
| ProgressStage.PROCESSING_DOCUMENTS, | |
| f"Processing {len(uploaded_file_paths)} files...", | |
| current=0, | |
| total=len(uploaded_file_paths), | |
| ) | |
| adder = DocumentAdder( | |
| kb_name=kb_name, | |
| base_dir=base_dir, | |
| progress_tracker=progress_tracker, | |
| rag_provider=rag_provider, | |
| ) | |
| staged_files = adder.add_documents(uploaded_file_paths, allow_duplicates=False) | |
| _task_log(task_id, f"Staged {len(staged_files)} new file(s)") | |
| if not staged_files: | |
| _task_log(task_id, "No new files to process (all duplicates or invalid)") | |
| progress_tracker.update( | |
| ProgressStage.COMPLETED, | |
| "No new files to process (all duplicates or invalid)", | |
| current=0, | |
| total=0, | |
| ) | |
| task_manager.update_task_status(task_id, "completed") | |
| task_stream_manager.emit_complete( | |
| task_id, "No new files to process (all duplicates or invalid)" | |
| ) | |
| return | |
| processed_files = await adder.process_new_documents(staged_files) | |
| _task_log(task_id, f"Indexed {len(processed_files)} file(s)") | |
| if processed_files: | |
| progress_tracker.update( | |
| ProgressStage.EXTRACTING_ITEMS, | |
| "Extracting numbered items...", | |
| current=0, | |
| total=len(processed_files), | |
| ) | |
| adder.extract_numbered_items_for_new_docs(processed_files, batch_size=20) | |
| adder.update_metadata(len(processed_files) if processed_files else 0) | |
| if folder_id and processed_files: | |
| try: | |
| manager = get_kb_manager() | |
| manager.update_folder_sync_state( | |
| kb_name, folder_id, [str(f) for f in processed_files] | |
| ) | |
| _task_log(task_id, f"Updated folder sync state: {folder_id}") | |
| except Exception as sync_err: | |
| _task_log(task_id, f"Folder sync state update failed: {sync_err}", level="warning") | |
| num_processed = len(processed_files) if processed_files else 0 | |
| progress_tracker.update( | |
| ProgressStage.COMPLETED, | |
| f"Successfully processed {num_processed} files!", | |
| current=num_processed, | |
| total=num_processed, | |
| ) | |
| _task_log(task_id, f"Processed {num_processed} file(s) for '{kb_name}'", level="success") | |
| task_manager.update_task_status(task_id, "completed") | |
| task_stream_manager.emit_complete( | |
| task_id, f"Successfully processed {num_processed} files for '{kb_name}'" | |
| ) | |
| except Exception as e: | |
| error_msg = f"Upload processing failed (KB '{kb_name}'): {e}" | |
| _task_log(task_id, error_msg, level="error") | |
| task_manager.update_task_status(task_id, "error", error=error_msg) | |
| progress_tracker.update( | |
| ProgressStage.ERROR, f"Processing failed: {error_msg}", error=error_msg | |
| ) | |
| task_stream_manager.emit_failed(task_id, error_msg) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| manager = get_kb_manager() | |
| config_exists = manager.config_file.exists() | |
| kb_count = len(manager.list_knowledge_bases()) | |
| return { | |
| "status": "ok", | |
| "config_file": str(manager.config_file), | |
| "config_exists": config_exists, | |
| "base_dir": str(manager.base_dir), | |
| "base_dir_exists": manager.base_dir.exists(), | |
| "knowledge_bases_count": kb_count, | |
| } | |
| except Exception as e: | |
| return {"status": "error", "error": str(e), "traceback": traceback.format_exc()} | |
| async def get_rag_providers(): | |
| """Get list of available RAG providers.""" | |
| try: | |
| from deeptutor.services.rag.service import RAGService | |
| providers = RAGService.list_providers() | |
| return {"providers": providers} | |
| except Exception as e: | |
| logger.error(f"Error getting RAG providers: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_all_kb_configs(): | |
| """Get all knowledge base configurations from centralized config file.""" | |
| try: | |
| from deeptutor.services.config import get_kb_config_service | |
| service = get_kb_config_service() | |
| return service.get_all_configs() | |
| except Exception as e: | |
| logger.error(f"Error getting KB configs: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_kb_config(kb_name: str): | |
| """Get configuration for a specific knowledge base.""" | |
| try: | |
| from deeptutor.services.config import get_kb_config_service | |
| service = get_kb_config_service() | |
| config = service.get_kb_config(kb_name) | |
| return {"kb_name": kb_name, "config": config} | |
| except Exception as e: | |
| logger.error(f"Error getting config for KB '{kb_name}': {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def update_kb_config(kb_name: str, config: dict): | |
| """Update configuration for a specific knowledge base.""" | |
| try: | |
| from deeptutor.services.config import get_kb_config_service | |
| if "rag_provider" in config: | |
| config["rag_provider"] = _validate_registered_provider(config.get("rag_provider")) | |
| service = get_kb_config_service() | |
| service.set_kb_config(kb_name, config) | |
| return {"status": "success", "kb_name": kb_name, "config": service.get_kb_config(kb_name)} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error updating config for KB '{kb_name}': {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def sync_configs_from_metadata(): | |
| """Sync all KB configurations from their metadata.json files to centralized config.""" | |
| try: | |
| from deeptutor.services.config import get_kb_config_service | |
| service = get_kb_config_service() | |
| service.sync_all_from_metadata(_kb_base_dir) | |
| return {"status": "success", "message": "Configurations synced from metadata files"} | |
| except Exception as e: | |
| logger.error(f"Error syncing configs: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_default_kb(): | |
| """Get the default knowledge base.""" | |
| try: | |
| manager = get_kb_manager() | |
| default_kb = manager.get_default() | |
| return {"default_kb": default_kb} | |
| except Exception as e: | |
| logger.error(f"Error getting default KB: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def set_default_kb(kb_name: str): | |
| """Set the default knowledge base.""" | |
| try: | |
| manager = get_kb_manager() | |
| # Verify KB exists | |
| if kb_name not in manager.list_knowledge_bases(): | |
| raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") | |
| manager.set_default(kb_name) | |
| return {"status": "success", "default_kb": kb_name} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error setting default KB: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_knowledge_bases(): | |
| """List all available knowledge bases with their details.""" | |
| try: | |
| manager = get_kb_manager() | |
| kb_names = manager.list_knowledge_bases() | |
| logger.debug(f"Found {len(kb_names)} knowledge bases: {kb_names}") | |
| if not kb_names: | |
| logger.debug("No knowledge bases found, returning empty list") | |
| return [] | |
| result = [] | |
| errors = [] | |
| for name in kb_names: | |
| try: | |
| info = manager.get_info(name) | |
| logger.debug(f"Successfully got info for KB '{name}': {info.get('statistics', {})}") | |
| result.append( | |
| KnowledgeBaseInfo( | |
| name=info["name"], | |
| is_default=info["is_default"], | |
| statistics=info.get("statistics", {}), | |
| status=info.get("status"), | |
| progress=info.get("progress"), | |
| ) | |
| ) | |
| except Exception as e: | |
| error_msg = f"Error getting info for KB '{name}': {e}" | |
| errors.append(error_msg) | |
| logger.warning(f"{error_msg}\n{traceback.format_exc()}") | |
| try: | |
| kb_dir = manager.base_dir / name | |
| if kb_dir.exists(): | |
| logger.debug(f"KB '{name}' directory exists, creating fallback info") | |
| result.append( | |
| KnowledgeBaseInfo( | |
| name=name, | |
| is_default=name == manager.get_default(), | |
| statistics={ | |
| "raw_documents": 0, | |
| "images": 0, | |
| "content_lists": 0, | |
| "rag_initialized": False, | |
| }, | |
| status="unknown", | |
| progress=None, | |
| ) | |
| ) | |
| except Exception as fallback_err: | |
| logger.error(f"Fallback also failed for KB '{name}': {fallback_err}") | |
| if errors and not result: | |
| error_detail = f"Failed to load knowledge bases. Errors: {'; '.join(errors)}" | |
| logger.error(error_detail) | |
| raise HTTPException(status_code=500, detail=error_detail) | |
| if errors: | |
| logger.warning( | |
| f"Some KBs had errors, returning {len(result)} results. Errors: {errors}" | |
| ) | |
| logger.debug(f"Returning {len(result)} knowledge bases") | |
| return result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| error_msg = f"Error listing knowledge bases: {e}" | |
| logger.error(f"{error_msg}\n{traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Failed to list knowledge bases: {e!s}") | |
| async def get_knowledge_base_details(kb_name: str): | |
| """Get detailed info for a specific KB.""" | |
| try: | |
| manager = get_kb_manager() | |
| return manager.get_info(kb_name) | |
| except ValueError: | |
| raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_knowledge_base(kb_name: str): | |
| """Delete a knowledge base.""" | |
| try: | |
| manager = get_kb_manager() | |
| success = manager.delete_knowledge_base(kb_name, confirm=True) | |
| if not success: | |
| raise HTTPException(status_code=400, detail="Failed to delete knowledge base") | |
| logger.info(f"KB '{kb_name}' deleted") | |
| return {"message": f"Knowledge base '{kb_name}' deleted successfully"} | |
| except ValueError: | |
| raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def stream_task_logs(task_id: str): | |
| """Stream task-specific logs for knowledge-base operations.""" | |
| manager = get_task_stream_manager() | |
| manager.ensure_task(task_id) | |
| return StreamingResponse( | |
| manager.stream(task_id), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| async def upload_files( | |
| kb_name: str, | |
| background_tasks: BackgroundTasks, | |
| files: list[UploadFile] = File(...), | |
| rag_provider: str = Form(None), | |
| ): | |
| """Upload files to a knowledge base and process them in background.""" | |
| try: | |
| manager = get_kb_manager() | |
| kb_path = manager.get_knowledge_base_path(kb_name) | |
| raw_dir = kb_path / "raw" | |
| raw_dir.mkdir(parents=True, exist_ok=True) | |
| requested_provider = None | |
| if rag_provider is not None and str(rag_provider).strip(): | |
| requested_provider = _validate_registered_provider(rag_provider) | |
| kb_entry = _load_kb_entry_or_404(manager, kb_name) | |
| _assert_kb_writable_or_409(kb_name, kb_entry) | |
| kb_provider = _validate_registered_provider(kb_entry.get("rag_provider") or DEFAULT_PROVIDER) | |
| if requested_provider and requested_provider != kb_provider: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=( | |
| f"Requested provider '{requested_provider}' does not match KB provider '{kb_provider}'. " | |
| "Update KB config first." | |
| ), | |
| ) | |
| allowed_extensions = FileTypeRouter.get_extensions_for_provider(kb_provider) | |
| uploaded_files, uploaded_file_paths = _save_uploaded_files( | |
| files, raw_dir, allowed_extensions=allowed_extensions | |
| ) | |
| task_id = _build_unique_task_id("kb_upload", kb_name) | |
| get_task_stream_manager().ensure_task(task_id) | |
| logger.info(f"Uploading {len(uploaded_files)} files to KB '{kb_name}'") | |
| background_tasks.add_task( | |
| run_upload_processing_task, | |
| kb_name=kb_name, | |
| base_dir=str(_kb_base_dir), | |
| uploaded_file_paths=uploaded_file_paths, | |
| task_id=task_id, | |
| rag_provider=kb_provider, | |
| ) | |
| return { | |
| "message": f"Uploaded {len(uploaded_files)} files. Processing in background.", | |
| "files": uploaded_files, | |
| "task_id": task_id, | |
| } | |
| except HTTPException: | |
| raise | |
| except ValueError: | |
| raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") | |
| except Exception as e: | |
| # Unexpected failure (Server error) | |
| formatted_error = format_exception_message(e) | |
| raise HTTPException(status_code=500, detail=formatted_error) from e | |
| async def create_knowledge_base( | |
| background_tasks: BackgroundTasks, | |
| name: str = Form(...), | |
| files: list[UploadFile] = File(...), | |
| rag_provider: str = Form(DEFAULT_PROVIDER), | |
| ): | |
| """Create a new knowledge base and initialize it with files.""" | |
| try: | |
| manager = get_kb_manager() | |
| if name in manager.list_knowledge_bases(): | |
| raise HTTPException(status_code=400, detail=f"Knowledge base '{name}' already exists") | |
| rag_provider = _validate_registered_provider(rag_provider) | |
| logger.info(f"Creating KB: {name}") | |
| task_id = _build_unique_task_id("kb_init", name) | |
| get_task_stream_manager().ensure_task(task_id) | |
| # Register KB to kb_config.json immediately with "initializing" status | |
| # This ensures the KB appears in the list right away | |
| manager.update_kb_status( | |
| name=name, | |
| status="initializing", | |
| progress={ | |
| "stage": "initializing", | |
| "message": "Initializing knowledge base...", | |
| "percent": 0, | |
| "current": 0, | |
| "total": len(files), | |
| "task_id": task_id, | |
| }, | |
| ) | |
| # Also store rag_provider in config (reload and update) | |
| manager.config = manager._load_config() | |
| if name in manager.config.get("knowledge_bases", {}): | |
| manager.config["knowledge_bases"][name]["rag_provider"] = rag_provider | |
| manager.config["knowledge_bases"][name]["needs_reindex"] = False | |
| manager._save_config() | |
| progress_tracker = ProgressTracker(name, _kb_base_dir) | |
| initializer = KnowledgeBaseInitializer( | |
| kb_name=name, | |
| base_dir=str(_kb_base_dir), | |
| progress_tracker=progress_tracker, | |
| rag_provider=rag_provider, | |
| ) | |
| initializer.create_directory_structure() | |
| progress_tracker.task_id = task_id | |
| manager = get_kb_manager() | |
| if name not in manager.list_knowledge_bases(): | |
| logger.warning(f"KB {name} not found in config, registering manually") | |
| initializer._register_to_config() | |
| allowed_extensions = FileTypeRouter.get_extensions_for_provider(rag_provider) | |
| uploaded_files, _ = _save_uploaded_files( | |
| files, initializer.raw_dir, allowed_extensions=allowed_extensions | |
| ) | |
| progress_tracker.update( | |
| ProgressStage.PROCESSING_DOCUMENTS, | |
| f"Saved {len(uploaded_files)} files, preparing to process...", | |
| current=0, | |
| total=len(uploaded_files), | |
| ) | |
| background_tasks.add_task(run_initialization_task, initializer, task_id) | |
| logger.success(f"KB '{name}' created, processing {len(uploaded_files)} files in background") | |
| return { | |
| "message": f"Knowledge base '{name}' created. Processing {len(uploaded_files)} files in background.", | |
| "name": name, | |
| "files": uploaded_files, | |
| "task_id": task_id, | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Failed to create KB: {e}") | |
| logger.debug(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_progress(kb_name: str): | |
| """Get initialization progress for a knowledge base""" | |
| try: | |
| progress_tracker = ProgressTracker(kb_name, _kb_base_dir) | |
| progress = progress_tracker.get_progress() | |
| if progress is None: | |
| return {"status": "not_started", "message": "Initialization not started"} | |
| return progress | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def clear_progress(kb_name: str): | |
| """Clear progress file for a knowledge base (useful for stuck states)""" | |
| try: | |
| progress_tracker = ProgressTracker(kb_name, _kb_base_dir) | |
| progress_tracker.clear() | |
| return {"status": "success", "message": f"Progress cleared for {kb_name}"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_progress(websocket: WebSocket, kb_name: str): | |
| """WebSocket endpoint for real-time progress updates""" | |
| await websocket.accept() | |
| broadcaster = ProgressBroadcaster.get_instance() | |
| try: | |
| await broadcaster.connect(kb_name, websocket) | |
| progress_tracker = ProgressTracker(kb_name, _kb_base_dir) | |
| initial_progress = progress_tracker.get_progress() | |
| expected_task_id = websocket.query_params.get("task_id") | |
| kb_dir = _kb_base_dir / kb_name | |
| llamaindex_storage_dir = kb_dir / "llamaindex_storage" | |
| kb_is_ready = llamaindex_storage_dir.exists() and llamaindex_storage_dir.is_dir() | |
| # Fast path: no active task — send current state and close immediately | |
| # This prevents infinite polling loops for ready or legacy KBs. | |
| has_active_task = False | |
| if initial_progress: | |
| stage = initial_progress.get("stage") | |
| if stage not in ("completed", "error", None): | |
| ts = initial_progress.get("timestamp") | |
| if ts: | |
| try: | |
| age = (datetime.now() - datetime.fromisoformat(ts)).total_seconds() | |
| has_active_task = age < 120 | |
| except Exception: | |
| pass | |
| if not has_active_task and not expected_task_id: | |
| if kb_is_ready: | |
| await websocket.send_json({ | |
| "type": "progress", | |
| "data": { | |
| "stage": "completed", | |
| "message": "Knowledge base is ready.", | |
| "percent": 100, | |
| "current": 1, | |
| "total": 1, | |
| }, | |
| }) | |
| else: | |
| await websocket.send_json({ | |
| "type": "progress", | |
| "data": initial_progress or { | |
| "stage": "error", | |
| "message": "Knowledge base needs reindex or initialization.", | |
| }, | |
| }) | |
| return | |
| if initial_progress: | |
| stage = initial_progress.get("stage") | |
| timestamp = initial_progress.get("timestamp") | |
| progress_task_id = initial_progress.get("task_id") | |
| should_send = False | |
| if expected_task_id and progress_task_id and progress_task_id != expected_task_id: | |
| should_send = False | |
| elif stage == "error" or not kb_is_ready: | |
| should_send = True | |
| elif stage != "completed" and timestamp: | |
| try: | |
| progress_time = datetime.fromisoformat(timestamp) | |
| now = datetime.now() | |
| age_seconds = (now - progress_time).total_seconds() | |
| if age_seconds < 300: | |
| should_send = True | |
| except Exception: | |
| pass | |
| if should_send: | |
| await websocket.send_json({"type": "progress", "data": initial_progress}) | |
| last_progress = initial_progress | |
| last_timestamp = initial_progress.get("timestamp") if initial_progress else None | |
| while True: | |
| try: | |
| try: | |
| await asyncio.wait_for(websocket.receive_text(), timeout=1.0) | |
| except asyncio.TimeoutError: | |
| current_progress = progress_tracker.get_progress() | |
| if current_progress: | |
| progress_task_id = current_progress.get("task_id") | |
| if expected_task_id and progress_task_id and progress_task_id != expected_task_id: | |
| continue | |
| current_timestamp = current_progress.get("timestamp") | |
| if current_timestamp != last_timestamp: | |
| await websocket.send_json( | |
| {"type": "progress", "data": current_progress} | |
| ) | |
| last_progress = current_progress | |
| last_timestamp = current_timestamp | |
| if current_progress.get("stage") in ["completed", "error"]: | |
| await asyncio.sleep(3) | |
| break | |
| continue | |
| except WebSocketDisconnect: | |
| break | |
| except Exception: | |
| break | |
| except Exception as e: | |
| logger.debug(f"Progress WS error: {e}") | |
| try: | |
| await websocket.send_json({"type": "error", "message": str(e)}) | |
| except Exception: | |
| pass | |
| finally: | |
| await broadcaster.disconnect(kb_name, websocket) | |
| try: | |
| await websocket.close() | |
| except Exception: | |
| pass | |
| async def link_folder(kb_name: str, request: LinkFolderRequest): | |
| """ | |
| Link a local folder to a knowledge base. | |
| This allows syncing documents from a local folder (which can be | |
| synced with SharePoint, Google Drive, OneLake, etc.) to the KB. | |
| The folder path supports: | |
| - Absolute paths: /Users/name/Documents or C:\\Users\\name\\Documents | |
| - Home directory: ~/Documents | |
| - Relative paths (resolved from server working directory) | |
| """ | |
| try: | |
| manager = get_kb_manager() | |
| folder_info = manager.link_folder(kb_name, request.folder_path) | |
| logger.info(f"Linked folder '{request.folder_path}' to KB '{kb_name}'") | |
| return LinkedFolderInfo(**folder_info) | |
| except ValueError as e: | |
| error_msg = str(e) | |
| if "not found" in error_msg.lower(): | |
| raise HTTPException(status_code=404, detail=error_msg) | |
| raise HTTPException(status_code=400, detail=error_msg) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_linked_folders(kb_name: str): | |
| """Get list of linked folders for a knowledge base.""" | |
| try: | |
| manager = get_kb_manager() | |
| folders = manager.get_linked_folders(kb_name) | |
| return [LinkedFolderInfo(**f) for f in folders] | |
| except ValueError: | |
| raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def unlink_folder(kb_name: str, folder_id: str): | |
| """Unlink a folder from a knowledge base.""" | |
| try: | |
| manager = get_kb_manager() | |
| success = manager.unlink_folder(kb_name, folder_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail=f"Folder '{folder_id}' not found") | |
| logger.info(f"Unlinked folder '{folder_id}' from KB '{kb_name}'") | |
| return {"message": "Folder unlinked successfully", "folder_id": folder_id} | |
| except ValueError: | |
| raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def sync_folder(kb_name: str, folder_id: str, background_tasks: BackgroundTasks): | |
| """ | |
| Sync files from a linked folder to the knowledge base. | |
| This scans the linked folder for supported documents and processes | |
| any new files that haven't been added yet. | |
| """ | |
| try: | |
| manager = get_kb_manager() | |
| kb_entry = _load_kb_entry_or_404(manager, kb_name) | |
| _assert_kb_writable_or_409(kb_name, kb_entry) | |
| kb_provider = _validate_registered_provider(kb_entry.get("rag_provider") or DEFAULT_PROVIDER) | |
| # Get linked folders and find the one with matching ID | |
| folders = manager.get_linked_folders(kb_name) | |
| folder_info = next((f for f in folders if f["id"] == folder_id), None) | |
| if not folder_info: | |
| raise HTTPException(status_code=404, detail=f"Linked folder '{folder_id}' not found") | |
| folder_path = folder_info["path"] | |
| # Check for changes (new or modified files) | |
| changes = manager.detect_folder_changes(kb_name, folder_id) | |
| files_to_process = changes["new_files"] + changes["modified_files"] | |
| if not files_to_process: | |
| return {"message": "No new or modified files to sync", "files": [], "file_count": 0} | |
| logger.info( | |
| f"Syncing {len(files_to_process)} files from folder '{folder_path}' to KB '{kb_name}'" | |
| ) | |
| task_id = _build_unique_task_id("kb_upload", f"{kb_name}_folder_{folder_id}") | |
| get_task_stream_manager().ensure_task(task_id) | |
| # NOTE: We DO NOT update sync state here anymore. | |
| # It is updated in run_upload_processing_task only after successful processing. | |
| # This prevents marking files as synced if processing fails (race condition fix). | |
| # Add background task to process files | |
| background_tasks.add_task( | |
| run_upload_processing_task, | |
| kb_name=kb_name, | |
| base_dir=str(_kb_base_dir), | |
| uploaded_file_paths=files_to_process, | |
| task_id=task_id, | |
| rag_provider=kb_provider, | |
| folder_id=folder_id, # Pass folder_id to update state on success | |
| ) | |
| return { | |
| "message": f"Syncing {len(files_to_process)} files from linked folder", | |
| "folder_path": folder_path, | |
| "new_files": changes["new_count"], | |
| "modified_files": changes["modified_count"], | |
| "file_count": len(files_to_process), | |
| "task_id": task_id, | |
| } | |
| except HTTPException: | |
| raise | |
| except ValueError: | |
| raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |