Spaces:
Sleeping
Sleeping
| import json | |
| from typing import AsyncGenerator | |
| from fastapi import APIRouter, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from loguru import logger | |
| from api.models import AskRequest, AskResponse, DirectAskRequest, SearchRequest, SearchResponse | |
| from open_notebook.domain.models import Model, model_manager | |
| from open_notebook.domain.notebook import text_search, vector_search | |
| from open_notebook.exceptions import DatabaseOperationError, InvalidInputError | |
| from open_notebook.graphs.ask import graph as ask_graph | |
| router = APIRouter() | |
| async def search_knowledge_base(search_request: SearchRequest): | |
| """Search the knowledge base using text or vector search.""" | |
| try: | |
| if search_request.type == "vector": | |
| # Check if embedding model is available for vector search | |
| if not await model_manager.get_embedding_model(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Vector search requires an embedding model. Please configure one in the Models section.", | |
| ) | |
| results = await vector_search( | |
| keyword=search_request.query, | |
| results=search_request.limit, | |
| source=search_request.search_sources, | |
| note=search_request.search_notes, | |
| minimum_score=search_request.minimum_score, | |
| ) | |
| else: | |
| # Text search | |
| results = await text_search( | |
| keyword=search_request.query, | |
| results=search_request.limit, | |
| source=search_request.search_sources, | |
| note=search_request.search_notes, | |
| ) | |
| return SearchResponse( | |
| results=results or [], | |
| total_count=len(results) if results else 0, | |
| search_type=search_request.type, | |
| ) | |
| except InvalidInputError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except DatabaseOperationError as e: | |
| logger.error(f"Database error during search: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Unexpected error during search: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") | |
| async def stream_ask_response( | |
| question: str, strategy_model: Model, answer_model: Model, final_answer_model: Model | |
| ) -> AsyncGenerator[str, None]: | |
| """Stream the ask response as Server-Sent Events.""" | |
| try: | |
| final_answer = None | |
| async for chunk in ask_graph.astream( | |
| input=dict(question=question), # type: ignore[arg-type] | |
| config=dict( | |
| configurable=dict( | |
| strategy_model=strategy_model.id, | |
| answer_model=answer_model.id, | |
| final_answer_model=final_answer_model.id, | |
| ) | |
| ), | |
| stream_mode="updates", | |
| ): | |
| if "agent" in chunk: | |
| strategy_data = { | |
| "type": "strategy", | |
| "reasoning": chunk["agent"]["strategy"].reasoning, | |
| "searches": [ | |
| {"term": search.term, "instructions": search.instructions} | |
| for search in chunk["agent"]["strategy"].searches | |
| ], | |
| } | |
| yield f"data: {json.dumps(strategy_data)}\n\n" | |
| elif "provide_answer" in chunk: | |
| for answer in chunk["provide_answer"]["answers"]: | |
| answer_data = {"type": "answer", "content": answer} | |
| yield f"data: {json.dumps(answer_data)}\n\n" | |
| elif "write_final_answer" in chunk: | |
| final_answer = chunk["write_final_answer"]["final_answer"] | |
| final_data = {"type": "final_answer", "content": final_answer} | |
| yield f"data: {json.dumps(final_data)}\n\n" | |
| # Send completion signal | |
| completion_data = {"type": "complete", "final_answer": final_answer} | |
| yield f"data: {json.dumps(completion_data)}\n\n" | |
| except Exception as e: | |
| logger.error(f"Error in ask streaming: {str(e)}") | |
| error_data = {"type": "error", "message": str(e)} | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| async def ask_knowledge_base(ask_request: AskRequest): | |
| """Ask the knowledge base a question using AI models.""" | |
| try: | |
| # Validate models exist | |
| strategy_model = await Model.get(ask_request.strategy_model) | |
| answer_model = await Model.get(ask_request.answer_model) | |
| final_answer_model = await Model.get(ask_request.final_answer_model) | |
| if not strategy_model: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Strategy model {ask_request.strategy_model} not found", | |
| ) | |
| if not answer_model: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Answer model {ask_request.answer_model} not found", | |
| ) | |
| if not final_answer_model: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Final answer model {ask_request.final_answer_model} not found", | |
| ) | |
| # Check if embedding model is available | |
| if not await model_manager.get_embedding_model(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Ask feature requires an embedding model. Please configure one in the Models section.", | |
| ) | |
| # For streaming response | |
| return StreamingResponse( | |
| stream_ask_response( | |
| ask_request.question, strategy_model, answer_model, final_answer_model | |
| ), | |
| media_type="text/plain", | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in ask endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Ask operation failed: {str(e)}") | |
| async def ask_knowledge_base_simple(ask_request: AskRequest): | |
| """Ask the knowledge base a question and return a simple response (non-streaming).""" | |
| try: | |
| # Validate models exist | |
| strategy_model = await Model.get(ask_request.strategy_model) | |
| answer_model = await Model.get(ask_request.answer_model) | |
| final_answer_model = await Model.get(ask_request.final_answer_model) | |
| if not strategy_model: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Strategy model {ask_request.strategy_model} not found", | |
| ) | |
| if not answer_model: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Answer model {ask_request.answer_model} not found", | |
| ) | |
| if not final_answer_model: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Final answer model {ask_request.final_answer_model} not found", | |
| ) | |
| # Check if embedding model is available | |
| if not await model_manager.get_embedding_model(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Ask feature requires an embedding model. Please configure one in the Models section.", | |
| ) | |
| # Run the ask graph and get final result | |
| final_answer = None | |
| async for chunk in ask_graph.astream( | |
| input=dict(question=ask_request.question), # type: ignore[arg-type] | |
| config=dict( | |
| configurable=dict( | |
| strategy_model=strategy_model.id, | |
| answer_model=answer_model.id, | |
| final_answer_model=final_answer_model.id, | |
| ) | |
| ), | |
| stream_mode="updates", | |
| ): | |
| if "write_final_answer" in chunk: | |
| final_answer = chunk["write_final_answer"]["final_answer"] | |
| if not final_answer: | |
| raise HTTPException(status_code=500, detail="No answer generated") | |
| return AskResponse(answer=final_answer, question=ask_request.question) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in ask simple endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Ask operation failed: {str(e)}") | |
| async def ask_ai_direct(request: DirectAskRequest): | |
| """ | |
| Ask AI directly without RAG/knowledge retrieval. | |
| This uses the LLM's general knowledge to answer questions. | |
| """ | |
| try: | |
| from open_notebook.graphs.utils import provision_langchain_model | |
| # Use the default chat model for direct questions | |
| model = await provision_langchain_model( | |
| content=request.question, | |
| model_id=request.model_id, | |
| default_type="chat", | |
| max_tokens=2048 | |
| ) | |
| # Create prompt for direct AI response | |
| prompt = f"""You are a helpful AI assistant. Answer the following question directly using your knowledge. | |
| Be clear, concise, and informative. | |
| Question: {request.question} | |
| Answer:""" | |
| response = await model.ainvoke(prompt) | |
| answer = response.content if hasattr(response, 'content') else str(response) | |
| return AskResponse(answer=answer, question=request.question) | |
| except Exception as e: | |
| logger.error(f"Error in direct AI endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Direct AI query failed: {str(e)}") | |