import json from typing import AsyncGenerator from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from loguru import logger from api.models import AskRequest, AskResponse, DirectAskRequest, SearchRequest, SearchResponse from open_notebook.domain.models import Model, model_manager from open_notebook.domain.notebook import text_search, vector_search from open_notebook.exceptions import DatabaseOperationError, InvalidInputError from open_notebook.graphs.ask import graph as ask_graph router = APIRouter() @router.post("/search", response_model=SearchResponse) async def search_knowledge_base(search_request: SearchRequest): """Search the knowledge base using text or vector search.""" try: if search_request.type == "vector": # Check if embedding model is available for vector search if not await model_manager.get_embedding_model(): raise HTTPException( status_code=400, detail="Vector search requires an embedding model. Please configure one in the Models section.", ) results = await vector_search( keyword=search_request.query, results=search_request.limit, source=search_request.search_sources, note=search_request.search_notes, minimum_score=search_request.minimum_score, ) else: # Text search results = await text_search( keyword=search_request.query, results=search_request.limit, source=search_request.search_sources, note=search_request.search_notes, ) return SearchResponse( results=results or [], total_count=len(results) if results else 0, search_type=search_request.type, ) except InvalidInputError as e: raise HTTPException(status_code=400, detail=str(e)) except DatabaseOperationError as e: logger.error(f"Database error during search: {str(e)}") raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") except Exception as e: logger.error(f"Unexpected error during search: {str(e)}") raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") async def stream_ask_response( question: str, strategy_model: Model, answer_model: Model, final_answer_model: Model ) -> AsyncGenerator[str, None]: """Stream the ask response as Server-Sent Events.""" try: final_answer = None async for chunk in ask_graph.astream( input=dict(question=question), # type: ignore[arg-type] config=dict( configurable=dict( strategy_model=strategy_model.id, answer_model=answer_model.id, final_answer_model=final_answer_model.id, ) ), stream_mode="updates", ): if "agent" in chunk: strategy_data = { "type": "strategy", "reasoning": chunk["agent"]["strategy"].reasoning, "searches": [ {"term": search.term, "instructions": search.instructions} for search in chunk["agent"]["strategy"].searches ], } yield f"data: {json.dumps(strategy_data)}\n\n" elif "provide_answer" in chunk: for answer in chunk["provide_answer"]["answers"]: answer_data = {"type": "answer", "content": answer} yield f"data: {json.dumps(answer_data)}\n\n" elif "write_final_answer" in chunk: final_answer = chunk["write_final_answer"]["final_answer"] final_data = {"type": "final_answer", "content": final_answer} yield f"data: {json.dumps(final_data)}\n\n" # Send completion signal completion_data = {"type": "complete", "final_answer": final_answer} yield f"data: {json.dumps(completion_data)}\n\n" except Exception as e: logger.error(f"Error in ask streaming: {str(e)}") error_data = {"type": "error", "message": str(e)} yield f"data: {json.dumps(error_data)}\n\n" @router.post("/search/ask") async def ask_knowledge_base(ask_request: AskRequest): """Ask the knowledge base a question using AI models.""" try: # Validate models exist strategy_model = await Model.get(ask_request.strategy_model) answer_model = await Model.get(ask_request.answer_model) final_answer_model = await Model.get(ask_request.final_answer_model) if not strategy_model: raise HTTPException( status_code=400, detail=f"Strategy model {ask_request.strategy_model} not found", ) if not answer_model: raise HTTPException( status_code=400, detail=f"Answer model {ask_request.answer_model} not found", ) if not final_answer_model: raise HTTPException( status_code=400, detail=f"Final answer model {ask_request.final_answer_model} not found", ) # Check if embedding model is available if not await model_manager.get_embedding_model(): raise HTTPException( status_code=400, detail="Ask feature requires an embedding model. Please configure one in the Models section.", ) # For streaming response return StreamingResponse( stream_ask_response( ask_request.question, strategy_model, answer_model, final_answer_model ), media_type="text/plain", ) except HTTPException: raise except Exception as e: logger.error(f"Error in ask endpoint: {str(e)}") raise HTTPException(status_code=500, detail=f"Ask operation failed: {str(e)}") @router.post("/search/ask/simple", response_model=AskResponse) async def ask_knowledge_base_simple(ask_request: AskRequest): """Ask the knowledge base a question and return a simple response (non-streaming).""" try: # Validate models exist strategy_model = await Model.get(ask_request.strategy_model) answer_model = await Model.get(ask_request.answer_model) final_answer_model = await Model.get(ask_request.final_answer_model) if not strategy_model: raise HTTPException( status_code=400, detail=f"Strategy model {ask_request.strategy_model} not found", ) if not answer_model: raise HTTPException( status_code=400, detail=f"Answer model {ask_request.answer_model} not found", ) if not final_answer_model: raise HTTPException( status_code=400, detail=f"Final answer model {ask_request.final_answer_model} not found", ) # Check if embedding model is available if not await model_manager.get_embedding_model(): raise HTTPException( status_code=400, detail="Ask feature requires an embedding model. Please configure one in the Models section.", ) # Run the ask graph and get final result final_answer = None async for chunk in ask_graph.astream( input=dict(question=ask_request.question), # type: ignore[arg-type] config=dict( configurable=dict( strategy_model=strategy_model.id, answer_model=answer_model.id, final_answer_model=final_answer_model.id, ) ), stream_mode="updates", ): if "write_final_answer" in chunk: final_answer = chunk["write_final_answer"]["final_answer"] if not final_answer: raise HTTPException(status_code=500, detail="No answer generated") return AskResponse(answer=final_answer, question=ask_request.question) except HTTPException: raise except Exception as e: logger.error(f"Error in ask simple endpoint: {str(e)}") raise HTTPException(status_code=500, detail=f"Ask operation failed: {str(e)}") @router.post("/search/ask/direct", response_model=AskResponse) async def ask_ai_direct(request: DirectAskRequest): """ Ask AI directly without RAG/knowledge retrieval. This uses the LLM's general knowledge to answer questions. """ try: from open_notebook.graphs.utils import provision_langchain_model # Use the default chat model for direct questions model = await provision_langchain_model( content=request.question, model_id=request.model_id, default_type="chat", max_tokens=2048 ) # Create prompt for direct AI response prompt = f"""You are a helpful AI assistant. Answer the following question directly using your knowledge. Be clear, concise, and informative. Question: {request.question} Answer:""" response = await model.ainvoke(prompt) answer = response.content if hasattr(response, 'content') else str(response) return AskResponse(answer=answer, question=request.question) except Exception as e: logger.error(f"Error in direct AI endpoint: {str(e)}") raise HTTPException(status_code=500, detail=f"Direct AI query failed: {str(e)}")