router / deeptutor /api /routers /knowledge.py
Piyush1225's picture
push
5df8a73
"""
Knowledge Base API Router
=========================
Handles knowledge base CRUD operations, file uploads, and initialization.
"""
import asyncio
from datetime import datetime
import os
from pathlib import Path
import traceback
from uuid import uuid4
from fastapi import (
APIRouter,
BackgroundTasks,
File,
Form,
HTTPException,
UploadFile,
WebSocket,
WebSocketDisconnect,
)
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from deeptutor.api.utils.progress_broadcaster import ProgressBroadcaster
from deeptutor.api.utils.task_id_manager import TaskIDManager
from deeptutor.api.utils.task_log_stream import capture_task_logs, get_task_stream_manager
from deeptutor.knowledge.add_documents import DocumentAdder
from deeptutor.knowledge.initializer import KnowledgeBaseInitializer
from deeptutor.knowledge.manager import KnowledgeBaseManager
from deeptutor.knowledge.progress_tracker import ProgressStage, ProgressTracker
from deeptutor.services.rag.components.routing import FileTypeRouter
from deeptutor.services.rag.factory import DEFAULT_PROVIDER, has_pipeline, normalize_provider_name
from deeptutor.utils.document_validator import DocumentValidator
from deeptutor.utils.error_utils import format_exception_message
from deeptutor.logging import get_logger
from deeptutor.services.config import PROJECT_ROOT, load_config_with_main
# Initialize logger with config
config = load_config_with_main("main.yaml", PROJECT_ROOT)
log_dir = config.get("paths", {}).get("user_log_dir") or config.get("logging", {}).get("log_dir")
logger = get_logger("Knowledge", level="INFO", log_dir=log_dir)
router = APIRouter()
# Constants for byte conversions
BYTES_PER_GB = 1024**3
BYTES_PER_MB = 1024**2
def format_bytes_human_readable(size_bytes: int) -> str:
"""Format bytes into human-readable string (GB, MB, or bytes)."""
if size_bytes >= BYTES_PER_GB:
return f"{size_bytes / BYTES_PER_GB:.1f} GB"
elif size_bytes >= BYTES_PER_MB:
return f"{size_bytes / BYTES_PER_MB:.1f} MB"
else:
return f"{size_bytes} bytes"
_kb_base_dir = PROJECT_ROOT / "data" / "knowledge_bases"
# Lazy initialization
kb_manager = None
def get_kb_manager():
"""Get KnowledgeBaseManager instance (lazy init)"""
global kb_manager
if kb_manager is None:
kb_manager = KnowledgeBaseManager(base_dir=str(_kb_base_dir))
return kb_manager
class KnowledgeBaseInfo(BaseModel):
name: str
is_default: bool
statistics: dict
status: str | None = None
progress: dict | None = None
class LinkFolderRequest(BaseModel):
"""Request model for linking a local folder to a KB."""
folder_path: str
class LinkedFolderInfo(BaseModel):
"""Response model for linked folder information."""
id: str
path: str
added_at: str
file_count: int
def _build_unique_task_id(task_type: str, task_key_prefix: str) -> str:
task_manager = TaskIDManager.get_instance()
task_key = f"{task_key_prefix}_{datetime.now().isoformat()}_{uuid4().hex[:8]}"
return task_manager.generate_task_id(task_type, task_key)
def _save_uploaded_files(
files: list[UploadFile],
target_dir: Path,
allowed_extensions: set[str] | None = None,
) -> tuple[list[str], list[str]]:
uploaded_files: list[str] = []
uploaded_file_paths: list[str] = []
for file in files:
file_path = None
original_filename = file.filename or "upload"
try:
sanitized_filename = DocumentValidator.validate_upload_safety(
original_filename, None, allowed_extensions=allowed_extensions
)
file.filename = sanitized_filename
file_path = target_dir / sanitized_filename
max_size = DocumentValidator.MAX_FILE_SIZE
written_bytes = 0
with open(file_path, "wb") as buffer:
for chunk in iter(lambda: file.file.read(8192), b""):
written_bytes += len(chunk)
if written_bytes > max_size:
size_str = format_bytes_human_readable(max_size)
raise HTTPException(
status_code=400,
detail=f"File '{sanitized_filename}' exceeds maximum size limit of {size_str}",
)
buffer.write(chunk)
DocumentValidator.validate_upload_safety(
sanitized_filename, written_bytes, allowed_extensions=allowed_extensions
)
uploaded_files.append(sanitized_filename)
uploaded_file_paths.append(str(file_path))
except Exception as e:
if file_path and file_path.exists():
try:
os.unlink(file_path)
except OSError:
pass
error_message = (
f"Validation failed for file '{original_filename}': {format_exception_message(e)}"
)
logger.error(error_message, exc_info=True)
raise HTTPException(status_code=400, detail=error_message) from e
return uploaded_files, uploaded_file_paths
def _task_log(task_id: str, message: str, level: str = "info") -> None:
manager = get_task_stream_manager()
manager.ensure_task(task_id)
manager.emit_log(task_id, message)
log_method = getattr(logger, level, None)
if callable(log_method):
log_method(f"[{task_id}] {message}")
else:
logger.info(f"[{task_id}] {message}")
def _validate_registered_provider(raw_provider: str | None) -> str:
candidate = (raw_provider or DEFAULT_PROVIDER).strip().lower()
if not candidate:
candidate = DEFAULT_PROVIDER
if not has_pipeline(candidate):
from deeptutor.services.rag.service import RAGService
available_providers = [item["id"] for item in RAGService.list_providers()]
raise HTTPException(
status_code=400,
detail=(
f"Unsupported RAG provider '{candidate}'. "
f"Available providers: {available_providers}"
),
)
return normalize_provider_name(candidate)
def _load_kb_entry_or_404(manager: KnowledgeBaseManager, kb_name: str) -> dict:
manager.config = manager._load_config()
kb_entry = manager.config.get("knowledge_bases", {}).get(kb_name)
if kb_entry is None:
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
return kb_entry
def _assert_kb_writable_or_409(kb_name: str, kb_entry: dict) -> None:
if bool(kb_entry.get("needs_reindex", False)):
raise HTTPException(
status_code=409,
detail=(
f"Knowledge base '{kb_name}' uses legacy index format and needs reindex "
"before accepting incremental uploads."
),
)
async def run_initialization_task(initializer: KnowledgeBaseInitializer, task_id: str):
"""Background task for knowledge base initialization"""
task_manager = TaskIDManager.get_instance()
task_stream_manager = get_task_stream_manager()
task_stream_manager.ensure_task(task_id)
with capture_task_logs(task_id):
try:
if not initializer.progress_tracker:
initializer.progress_tracker = ProgressTracker(
initializer.kb_name, initializer.base_dir
)
initializer.progress_tracker.task_id = task_id
_task_log(task_id, f"Initializing knowledge base '{initializer.kb_name}'")
await initializer.process_documents()
_task_log(task_id, "Document processing complete")
initializer.extract_numbered_items()
_task_log(task_id, "Finalizing initialization")
initializer.progress_tracker.update(
ProgressStage.COMPLETED, "Knowledge base initialization complete!", current=1, total=1
)
manager = get_kb_manager()
manager.update_kb_status(
name=initializer.kb_name,
status="ready",
progress={
"stage": "completed",
"message": "Knowledge base initialization complete!",
"percent": 100,
"current": 1,
"total": 1,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
},
)
_task_log(task_id, f"Knowledge base '{initializer.kb_name}' initialized", level="success")
task_manager.update_task_status(task_id, "completed")
task_stream_manager.emit_complete(
task_id, f"Knowledge base '{initializer.kb_name}' initialization complete"
)
except Exception as e:
error_msg = str(e)
_task_log(task_id, f"Initialization failed: {error_msg}", level="error")
task_manager.update_task_status(task_id, "error", error=error_msg)
manager = get_kb_manager()
manager.update_kb_status(
name=initializer.kb_name,
status="error",
progress={
"stage": "error",
"message": f"Initialization failed: {error_msg}",
"percent": 0,
"error": error_msg,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
},
)
if initializer.progress_tracker:
initializer.progress_tracker.update(
ProgressStage.ERROR, f"Initialization failed: {error_msg}", error=error_msg
)
task_stream_manager.emit_failed(task_id, error_msg)
async def run_upload_processing_task(
kb_name: str,
base_dir: str,
uploaded_file_paths: list[str],
task_id: str,
rag_provider: str = None,
folder_id: str = None,
):
"""Background task for processing uploaded files.
Args:
kb_name: Knowledge base name
base_dir: Base directory for knowledge bases
uploaded_file_paths: List of file paths to process
rag_provider: RAG provider (ignored - we use the one from KB metadata)
folder_id: Optional folder ID for sync state update
"""
task_manager = TaskIDManager.get_instance()
task_stream_manager = get_task_stream_manager()
task_stream_manager.ensure_task(task_id)
progress_tracker = ProgressTracker(kb_name, Path(base_dir))
progress_tracker.task_id = task_id
with capture_task_logs(task_id):
try:
_task_log(task_id, f"Processing {len(uploaded_file_paths)} file(s) for KB '{kb_name}'")
progress_tracker.update(
ProgressStage.PROCESSING_DOCUMENTS,
f"Processing {len(uploaded_file_paths)} files...",
current=0,
total=len(uploaded_file_paths),
)
adder = DocumentAdder(
kb_name=kb_name,
base_dir=base_dir,
progress_tracker=progress_tracker,
rag_provider=rag_provider,
)
staged_files = adder.add_documents(uploaded_file_paths, allow_duplicates=False)
_task_log(task_id, f"Staged {len(staged_files)} new file(s)")
if not staged_files:
_task_log(task_id, "No new files to process (all duplicates or invalid)")
progress_tracker.update(
ProgressStage.COMPLETED,
"No new files to process (all duplicates or invalid)",
current=0,
total=0,
)
task_manager.update_task_status(task_id, "completed")
task_stream_manager.emit_complete(
task_id, "No new files to process (all duplicates or invalid)"
)
return
processed_files = await adder.process_new_documents(staged_files)
_task_log(task_id, f"Indexed {len(processed_files)} file(s)")
if processed_files:
progress_tracker.update(
ProgressStage.EXTRACTING_ITEMS,
"Extracting numbered items...",
current=0,
total=len(processed_files),
)
adder.extract_numbered_items_for_new_docs(processed_files, batch_size=20)
adder.update_metadata(len(processed_files) if processed_files else 0)
if folder_id and processed_files:
try:
manager = get_kb_manager()
manager.update_folder_sync_state(
kb_name, folder_id, [str(f) for f in processed_files]
)
_task_log(task_id, f"Updated folder sync state: {folder_id}")
except Exception as sync_err:
_task_log(task_id, f"Folder sync state update failed: {sync_err}", level="warning")
num_processed = len(processed_files) if processed_files else 0
progress_tracker.update(
ProgressStage.COMPLETED,
f"Successfully processed {num_processed} files!",
current=num_processed,
total=num_processed,
)
_task_log(task_id, f"Processed {num_processed} file(s) for '{kb_name}'", level="success")
task_manager.update_task_status(task_id, "completed")
task_stream_manager.emit_complete(
task_id, f"Successfully processed {num_processed} files for '{kb_name}'"
)
except Exception as e:
error_msg = f"Upload processing failed (KB '{kb_name}'): {e}"
_task_log(task_id, error_msg, level="error")
task_manager.update_task_status(task_id, "error", error=error_msg)
progress_tracker.update(
ProgressStage.ERROR, f"Processing failed: {error_msg}", error=error_msg
)
task_stream_manager.emit_failed(task_id, error_msg)
@router.get("/health")
async def health_check():
"""Health check endpoint"""
try:
manager = get_kb_manager()
config_exists = manager.config_file.exists()
kb_count = len(manager.list_knowledge_bases())
return {
"status": "ok",
"config_file": str(manager.config_file),
"config_exists": config_exists,
"base_dir": str(manager.base_dir),
"base_dir_exists": manager.base_dir.exists(),
"knowledge_bases_count": kb_count,
}
except Exception as e:
return {"status": "error", "error": str(e), "traceback": traceback.format_exc()}
@router.get("/rag-providers")
async def get_rag_providers():
"""Get list of available RAG providers."""
try:
from deeptutor.services.rag.service import RAGService
providers = RAGService.list_providers()
return {"providers": providers}
except Exception as e:
logger.error(f"Error getting RAG providers: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/configs")
async def get_all_kb_configs():
"""Get all knowledge base configurations from centralized config file."""
try:
from deeptutor.services.config import get_kb_config_service
service = get_kb_config_service()
return service.get_all_configs()
except Exception as e:
logger.error(f"Error getting KB configs: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{kb_name}/config")
async def get_kb_config(kb_name: str):
"""Get configuration for a specific knowledge base."""
try:
from deeptutor.services.config import get_kb_config_service
service = get_kb_config_service()
config = service.get_kb_config(kb_name)
return {"kb_name": kb_name, "config": config}
except Exception as e:
logger.error(f"Error getting config for KB '{kb_name}': {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put("/{kb_name}/config")
async def update_kb_config(kb_name: str, config: dict):
"""Update configuration for a specific knowledge base."""
try:
from deeptutor.services.config import get_kb_config_service
if "rag_provider" in config:
config["rag_provider"] = _validate_registered_provider(config.get("rag_provider"))
service = get_kb_config_service()
service.set_kb_config(kb_name, config)
return {"status": "success", "kb_name": kb_name, "config": service.get_kb_config(kb_name)}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error updating config for KB '{kb_name}': {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/configs/sync")
async def sync_configs_from_metadata():
"""Sync all KB configurations from their metadata.json files to centralized config."""
try:
from deeptutor.services.config import get_kb_config_service
service = get_kb_config_service()
service.sync_all_from_metadata(_kb_base_dir)
return {"status": "success", "message": "Configurations synced from metadata files"}
except Exception as e:
logger.error(f"Error syncing configs: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/default")
async def get_default_kb():
"""Get the default knowledge base."""
try:
manager = get_kb_manager()
default_kb = manager.get_default()
return {"default_kb": default_kb}
except Exception as e:
logger.error(f"Error getting default KB: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put("/default/{kb_name}")
async def set_default_kb(kb_name: str):
"""Set the default knowledge base."""
try:
manager = get_kb_manager()
# Verify KB exists
if kb_name not in manager.list_knowledge_bases():
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
manager.set_default(kb_name)
return {"status": "success", "default_kb": kb_name}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error setting default KB: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/list", response_model=list[KnowledgeBaseInfo])
async def list_knowledge_bases():
"""List all available knowledge bases with their details."""
try:
manager = get_kb_manager()
kb_names = manager.list_knowledge_bases()
logger.debug(f"Found {len(kb_names)} knowledge bases: {kb_names}")
if not kb_names:
logger.debug("No knowledge bases found, returning empty list")
return []
result = []
errors = []
for name in kb_names:
try:
info = manager.get_info(name)
logger.debug(f"Successfully got info for KB '{name}': {info.get('statistics', {})}")
result.append(
KnowledgeBaseInfo(
name=info["name"],
is_default=info["is_default"],
statistics=info.get("statistics", {}),
status=info.get("status"),
progress=info.get("progress"),
)
)
except Exception as e:
error_msg = f"Error getting info for KB '{name}': {e}"
errors.append(error_msg)
logger.warning(f"{error_msg}\n{traceback.format_exc()}")
try:
kb_dir = manager.base_dir / name
if kb_dir.exists():
logger.debug(f"KB '{name}' directory exists, creating fallback info")
result.append(
KnowledgeBaseInfo(
name=name,
is_default=name == manager.get_default(),
statistics={
"raw_documents": 0,
"images": 0,
"content_lists": 0,
"rag_initialized": False,
},
status="unknown",
progress=None,
)
)
except Exception as fallback_err:
logger.error(f"Fallback also failed for KB '{name}': {fallback_err}")
if errors and not result:
error_detail = f"Failed to load knowledge bases. Errors: {'; '.join(errors)}"
logger.error(error_detail)
raise HTTPException(status_code=500, detail=error_detail)
if errors:
logger.warning(
f"Some KBs had errors, returning {len(result)} results. Errors: {errors}"
)
logger.debug(f"Returning {len(result)} knowledge bases")
return result
except HTTPException:
raise
except Exception as e:
error_msg = f"Error listing knowledge bases: {e}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Failed to list knowledge bases: {e!s}")
@router.get("/{kb_name}")
async def get_knowledge_base_details(kb_name: str):
"""Get detailed info for a specific KB."""
try:
manager = get_kb_manager()
return manager.get_info(kb_name)
except ValueError:
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{kb_name}")
async def delete_knowledge_base(kb_name: str):
"""Delete a knowledge base."""
try:
manager = get_kb_manager()
success = manager.delete_knowledge_base(kb_name, confirm=True)
if not success:
raise HTTPException(status_code=400, detail="Failed to delete knowledge base")
logger.info(f"KB '{kb_name}' deleted")
return {"message": f"Knowledge base '{kb_name}' deleted successfully"}
except ValueError:
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/tasks/{task_id}/stream")
async def stream_task_logs(task_id: str):
"""Stream task-specific logs for knowledge-base operations."""
manager = get_task_stream_manager()
manager.ensure_task(task_id)
return StreamingResponse(
manager.stream(task_id),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.post("/{kb_name}/upload")
async def upload_files(
kb_name: str,
background_tasks: BackgroundTasks,
files: list[UploadFile] = File(...),
rag_provider: str = Form(None),
):
"""Upload files to a knowledge base and process them in background."""
try:
manager = get_kb_manager()
kb_path = manager.get_knowledge_base_path(kb_name)
raw_dir = kb_path / "raw"
raw_dir.mkdir(parents=True, exist_ok=True)
requested_provider = None
if rag_provider is not None and str(rag_provider).strip():
requested_provider = _validate_registered_provider(rag_provider)
kb_entry = _load_kb_entry_or_404(manager, kb_name)
_assert_kb_writable_or_409(kb_name, kb_entry)
kb_provider = _validate_registered_provider(kb_entry.get("rag_provider") or DEFAULT_PROVIDER)
if requested_provider and requested_provider != kb_provider:
raise HTTPException(
status_code=400,
detail=(
f"Requested provider '{requested_provider}' does not match KB provider '{kb_provider}'. "
"Update KB config first."
),
)
allowed_extensions = FileTypeRouter.get_extensions_for_provider(kb_provider)
uploaded_files, uploaded_file_paths = _save_uploaded_files(
files, raw_dir, allowed_extensions=allowed_extensions
)
task_id = _build_unique_task_id("kb_upload", kb_name)
get_task_stream_manager().ensure_task(task_id)
logger.info(f"Uploading {len(uploaded_files)} files to KB '{kb_name}'")
background_tasks.add_task(
run_upload_processing_task,
kb_name=kb_name,
base_dir=str(_kb_base_dir),
uploaded_file_paths=uploaded_file_paths,
task_id=task_id,
rag_provider=kb_provider,
)
return {
"message": f"Uploaded {len(uploaded_files)} files. Processing in background.",
"files": uploaded_files,
"task_id": task_id,
}
except HTTPException:
raise
except ValueError:
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
except Exception as e:
# Unexpected failure (Server error)
formatted_error = format_exception_message(e)
raise HTTPException(status_code=500, detail=formatted_error) from e
@router.post("/create")
async def create_knowledge_base(
background_tasks: BackgroundTasks,
name: str = Form(...),
files: list[UploadFile] = File(...),
rag_provider: str = Form(DEFAULT_PROVIDER),
):
"""Create a new knowledge base and initialize it with files."""
try:
manager = get_kb_manager()
if name in manager.list_knowledge_bases():
raise HTTPException(status_code=400, detail=f"Knowledge base '{name}' already exists")
rag_provider = _validate_registered_provider(rag_provider)
logger.info(f"Creating KB: {name}")
task_id = _build_unique_task_id("kb_init", name)
get_task_stream_manager().ensure_task(task_id)
# Register KB to kb_config.json immediately with "initializing" status
# This ensures the KB appears in the list right away
manager.update_kb_status(
name=name,
status="initializing",
progress={
"stage": "initializing",
"message": "Initializing knowledge base...",
"percent": 0,
"current": 0,
"total": len(files),
"task_id": task_id,
},
)
# Also store rag_provider in config (reload and update)
manager.config = manager._load_config()
if name in manager.config.get("knowledge_bases", {}):
manager.config["knowledge_bases"][name]["rag_provider"] = rag_provider
manager.config["knowledge_bases"][name]["needs_reindex"] = False
manager._save_config()
progress_tracker = ProgressTracker(name, _kb_base_dir)
initializer = KnowledgeBaseInitializer(
kb_name=name,
base_dir=str(_kb_base_dir),
progress_tracker=progress_tracker,
rag_provider=rag_provider,
)
initializer.create_directory_structure()
progress_tracker.task_id = task_id
manager = get_kb_manager()
if name not in manager.list_knowledge_bases():
logger.warning(f"KB {name} not found in config, registering manually")
initializer._register_to_config()
allowed_extensions = FileTypeRouter.get_extensions_for_provider(rag_provider)
uploaded_files, _ = _save_uploaded_files(
files, initializer.raw_dir, allowed_extensions=allowed_extensions
)
progress_tracker.update(
ProgressStage.PROCESSING_DOCUMENTS,
f"Saved {len(uploaded_files)} files, preparing to process...",
current=0,
total=len(uploaded_files),
)
background_tasks.add_task(run_initialization_task, initializer, task_id)
logger.success(f"KB '{name}' created, processing {len(uploaded_files)} files in background")
return {
"message": f"Knowledge base '{name}' created. Processing {len(uploaded_files)} files in background.",
"name": name,
"files": uploaded_files,
"task_id": task_id,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to create KB: {e}")
logger.debug(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{kb_name}/progress")
async def get_progress(kb_name: str):
"""Get initialization progress for a knowledge base"""
try:
progress_tracker = ProgressTracker(kb_name, _kb_base_dir)
progress = progress_tracker.get_progress()
if progress is None:
return {"status": "not_started", "message": "Initialization not started"}
return progress
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{kb_name}/progress/clear")
async def clear_progress(kb_name: str):
"""Clear progress file for a knowledge base (useful for stuck states)"""
try:
progress_tracker = ProgressTracker(kb_name, _kb_base_dir)
progress_tracker.clear()
return {"status": "success", "message": f"Progress cleared for {kb_name}"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.websocket("/{kb_name}/progress/ws")
async def websocket_progress(websocket: WebSocket, kb_name: str):
"""WebSocket endpoint for real-time progress updates"""
await websocket.accept()
broadcaster = ProgressBroadcaster.get_instance()
try:
await broadcaster.connect(kb_name, websocket)
progress_tracker = ProgressTracker(kb_name, _kb_base_dir)
initial_progress = progress_tracker.get_progress()
expected_task_id = websocket.query_params.get("task_id")
kb_dir = _kb_base_dir / kb_name
llamaindex_storage_dir = kb_dir / "llamaindex_storage"
kb_is_ready = llamaindex_storage_dir.exists() and llamaindex_storage_dir.is_dir()
# Fast path: no active task — send current state and close immediately
# This prevents infinite polling loops for ready or legacy KBs.
has_active_task = False
if initial_progress:
stage = initial_progress.get("stage")
if stage not in ("completed", "error", None):
ts = initial_progress.get("timestamp")
if ts:
try:
age = (datetime.now() - datetime.fromisoformat(ts)).total_seconds()
has_active_task = age < 120
except Exception:
pass
if not has_active_task and not expected_task_id:
if kb_is_ready:
await websocket.send_json({
"type": "progress",
"data": {
"stage": "completed",
"message": "Knowledge base is ready.",
"percent": 100,
"current": 1,
"total": 1,
},
})
else:
await websocket.send_json({
"type": "progress",
"data": initial_progress or {
"stage": "error",
"message": "Knowledge base needs reindex or initialization.",
},
})
return
if initial_progress:
stage = initial_progress.get("stage")
timestamp = initial_progress.get("timestamp")
progress_task_id = initial_progress.get("task_id")
should_send = False
if expected_task_id and progress_task_id and progress_task_id != expected_task_id:
should_send = False
elif stage == "error" or not kb_is_ready:
should_send = True
elif stage != "completed" and timestamp:
try:
progress_time = datetime.fromisoformat(timestamp)
now = datetime.now()
age_seconds = (now - progress_time).total_seconds()
if age_seconds < 300:
should_send = True
except Exception:
pass
if should_send:
await websocket.send_json({"type": "progress", "data": initial_progress})
last_progress = initial_progress
last_timestamp = initial_progress.get("timestamp") if initial_progress else None
while True:
try:
try:
await asyncio.wait_for(websocket.receive_text(), timeout=1.0)
except asyncio.TimeoutError:
current_progress = progress_tracker.get_progress()
if current_progress:
progress_task_id = current_progress.get("task_id")
if expected_task_id and progress_task_id and progress_task_id != expected_task_id:
continue
current_timestamp = current_progress.get("timestamp")
if current_timestamp != last_timestamp:
await websocket.send_json(
{"type": "progress", "data": current_progress}
)
last_progress = current_progress
last_timestamp = current_timestamp
if current_progress.get("stage") in ["completed", "error"]:
await asyncio.sleep(3)
break
continue
except WebSocketDisconnect:
break
except Exception:
break
except Exception as e:
logger.debug(f"Progress WS error: {e}")
try:
await websocket.send_json({"type": "error", "message": str(e)})
except Exception:
pass
finally:
await broadcaster.disconnect(kb_name, websocket)
try:
await websocket.close()
except Exception:
pass
@router.post("/{kb_name}/link-folder", response_model=LinkedFolderInfo)
async def link_folder(kb_name: str, request: LinkFolderRequest):
"""
Link a local folder to a knowledge base.
This allows syncing documents from a local folder (which can be
synced with SharePoint, Google Drive, OneLake, etc.) to the KB.
The folder path supports:
- Absolute paths: /Users/name/Documents or C:\\Users\\name\\Documents
- Home directory: ~/Documents
- Relative paths (resolved from server working directory)
"""
try:
manager = get_kb_manager()
folder_info = manager.link_folder(kb_name, request.folder_path)
logger.info(f"Linked folder '{request.folder_path}' to KB '{kb_name}'")
return LinkedFolderInfo(**folder_info)
except ValueError as e:
error_msg = str(e)
if "not found" in error_msg.lower():
raise HTTPException(status_code=404, detail=error_msg)
raise HTTPException(status_code=400, detail=error_msg)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{kb_name}/linked-folders", response_model=list[LinkedFolderInfo])
async def get_linked_folders(kb_name: str):
"""Get list of linked folders for a knowledge base."""
try:
manager = get_kb_manager()
folders = manager.get_linked_folders(kb_name)
return [LinkedFolderInfo(**f) for f in folders]
except ValueError:
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{kb_name}/linked-folders/{folder_id}")
async def unlink_folder(kb_name: str, folder_id: str):
"""Unlink a folder from a knowledge base."""
try:
manager = get_kb_manager()
success = manager.unlink_folder(kb_name, folder_id)
if not success:
raise HTTPException(status_code=404, detail=f"Folder '{folder_id}' not found")
logger.info(f"Unlinked folder '{folder_id}' from KB '{kb_name}'")
return {"message": "Folder unlinked successfully", "folder_id": folder_id}
except ValueError:
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{kb_name}/sync-folder/{folder_id}")
async def sync_folder(kb_name: str, folder_id: str, background_tasks: BackgroundTasks):
"""
Sync files from a linked folder to the knowledge base.
This scans the linked folder for supported documents and processes
any new files that haven't been added yet.
"""
try:
manager = get_kb_manager()
kb_entry = _load_kb_entry_or_404(manager, kb_name)
_assert_kb_writable_or_409(kb_name, kb_entry)
kb_provider = _validate_registered_provider(kb_entry.get("rag_provider") or DEFAULT_PROVIDER)
# Get linked folders and find the one with matching ID
folders = manager.get_linked_folders(kb_name)
folder_info = next((f for f in folders if f["id"] == folder_id), None)
if not folder_info:
raise HTTPException(status_code=404, detail=f"Linked folder '{folder_id}' not found")
folder_path = folder_info["path"]
# Check for changes (new or modified files)
changes = manager.detect_folder_changes(kb_name, folder_id)
files_to_process = changes["new_files"] + changes["modified_files"]
if not files_to_process:
return {"message": "No new or modified files to sync", "files": [], "file_count": 0}
logger.info(
f"Syncing {len(files_to_process)} files from folder '{folder_path}' to KB '{kb_name}'"
)
task_id = _build_unique_task_id("kb_upload", f"{kb_name}_folder_{folder_id}")
get_task_stream_manager().ensure_task(task_id)
# NOTE: We DO NOT update sync state here anymore.
# It is updated in run_upload_processing_task only after successful processing.
# This prevents marking files as synced if processing fails (race condition fix).
# Add background task to process files
background_tasks.add_task(
run_upload_processing_task,
kb_name=kb_name,
base_dir=str(_kb_base_dir),
uploaded_file_paths=files_to_process,
task_id=task_id,
rag_provider=kb_provider,
folder_id=folder_id, # Pass folder_id to update state on success
)
return {
"message": f"Syncing {len(files_to_process)} files from linked folder",
"folder_path": folder_path,
"new_files": changes["new_count"],
"modified_files": changes["modified_count"],
"file_count": len(files_to_process),
"task_id": task_id,
}
except HTTPException:
raise
except ValueError:
raise HTTPException(status_code=404, detail=f"Knowledge base '{kb_name}' not found")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))