Spaces:
Runtime error
Runtime error
| #DOCS | |
| # https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent | |
| import uuid | |
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| HumanMessage, | |
| SystemMessage, | |
| trim_messages, | |
| ) | |
| from langchain_core.tools import tool | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.prebuilt import create_react_agent | |
| from pydantic import BaseModel | |
| import json | |
| from typing import Optional, Annotated | |
| from langchain_core.runnables import RunnableConfig | |
| from langgraph.prebuilt import InjectedState | |
| from document_rag_router import router as document_rag_router | |
| from document_rag_router import QueryInput, query_collection, SearchResult,db | |
| from fastapi import HTTPException | |
| import requests | |
| from sse_starlette.sse import EventSourceResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import re | |
| import os | |
| from langchain_core.prompts import ChatPromptTemplate | |
| import logging.config | |
| # Configure logging at application startup | |
| logging.config.dictConfig({ | |
| "version": 1, | |
| "disable_existing_loggers": False, | |
| "formatters": { | |
| "default": { | |
| "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| "datefmt": "%Y-%m-%d %H:%M:%S", | |
| } | |
| }, | |
| "handlers": { | |
| "console": { | |
| "class": "logging.StreamHandler", | |
| "stream": "ext://sys.stdout", | |
| "formatter": "default", | |
| "level": "DEBUG", | |
| } | |
| }, | |
| "root": { | |
| "level": "DEBUG", | |
| "handlers": ["console"] | |
| }, | |
| "loggers": { | |
| "uvicorn": {"handlers": ["console"], "level": "DEBUG"}, | |
| "fastapi": {"handlers": ["console"], "level": "DEBUG"} | |
| } | |
| }) | |
| # Create logger instance | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| app.include_router(document_rag_router) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_current_files(): | |
| """Get list of files in current directory""" | |
| try: | |
| files = os.listdir('.') | |
| return ", ".join(files) | |
| except Exception as e: | |
| return f"Error getting files: {str(e)}" | |
| def get_user_age(name: str) -> str: | |
| """Use this tool to find the user's age.""" | |
| if "bob" in name.lower(): | |
| return "42 years old" | |
| return "41 years old" | |
| async def query_documents( | |
| query: str, | |
| config: RunnableConfig, | |
| ) -> str: | |
| """Use this tool to retrieve relevant data from the collection. | |
| Args: | |
| query: The search query to find relevant document passages | |
| """ | |
| # Get collection_id and user_id from config | |
| thread_config = config.get("configurable", {}) | |
| collection_id = thread_config.get("collection_id") | |
| user_id = thread_config.get("user_id") | |
| if not collection_id or not user_id: | |
| return "Error: collection_id and user_id are required in the config" | |
| try: | |
| # Create query input | |
| input_data = QueryInput( | |
| collection_id=collection_id, | |
| query=query, | |
| user_id=user_id, | |
| top_k=6 | |
| ) | |
| response = await query_collection(input_data) | |
| results = [] | |
| # Access response directly since it's a Pydantic model | |
| for r in response.results: | |
| result_dict = { | |
| "text": r.text, | |
| "distance": r.distance, | |
| "metadata": { | |
| "document_id": r.metadata.get("document_id"), | |
| "chunk_index": r.metadata.get("location", {}).get("chunk_index") | |
| } | |
| } | |
| results.append(result_dict) | |
| return str(results) | |
| except Exception as e: | |
| print(e) | |
| return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" | |
| async def query_documents_raw( | |
| query: str, | |
| config: RunnableConfig, | |
| ) -> SearchResult: | |
| """Use this tool to retrieve relevant data from the collection. | |
| Args: | |
| query: The search query to find relevant document passages | |
| """ | |
| # Get collection_id and user_id from config | |
| thread_config = config.get("configurable", {}) | |
| collection_id = thread_config.get("collection_id") | |
| user_id = thread_config.get("user_id") | |
| if not collection_id or not user_id: | |
| return "Error: collection_id and user_id are required in the config" | |
| try: | |
| # Create query input | |
| input_data = QueryInput( | |
| collection_id=collection_id, | |
| query=query, | |
| user_id=user_id, | |
| top_k=6 | |
| ) | |
| response = await query_collection(input_data) | |
| return response.results | |
| except Exception as e: | |
| print(e) | |
| return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" | |
| memory = MemorySaver() | |
| model = ChatOpenAI(model="gpt-4o-mini", streaming=True) | |
| # Create a prompt template for formatting | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}, use query_documents tool to answer user queries from the document. In case a summary is requested, create multiple queries for different plausible sections of the document"), | |
| ("placeholder", "{messages}"), | |
| ]) | |
| import requests | |
| from requests.exceptions import RequestException, Timeout | |
| import logging | |
| from typing import Optional | |
| # def get_collection_files(collection_id: str, user_id: str) -> str: | |
| # """ | |
| # Synchronously get list of files in the specified collection using the external API | |
| # with proper timeout and error handling. | |
| # """ | |
| # try: | |
| # url = "https://pvanand-documind-api-v2.hf.space/rag/get_collection_files" | |
| # params = { | |
| # "collection_id": collection_id, | |
| # "user_id": user_id | |
| # } | |
| # headers = { | |
| # 'accept': 'application/json' | |
| # } | |
| # logger.debug(f"Requesting collection files for user {user_id}, collection {collection_id}") | |
| # # Set timeout to 5 seconds | |
| # response = requests.post(url, params=params, headers=headers, data='', timeout=5) | |
| # if response.status_code == 200: | |
| # logger.info(f"Successfully retrieved collection files: {response.text[:100]}...") | |
| # return response.text | |
| # else: | |
| # logger.error(f"API error (status {response.status_code}): {response.text}") | |
| # return f"Error fetching files (status {response.status_code})" | |
| # except Timeout: | |
| # logger.error("Timeout while fetching collection files") | |
| # return "Error: Request timed out" | |
| # except RequestException as e: | |
| # logger.error(f"Network error fetching collection files: {str(e)}") | |
| # return f"Error: Network issue - {str(e)}" | |
| # except Exception as e: | |
| # logger.error(f"Error fetching collection files: {str(e)}", exc_info=True) | |
| # return f"Error fetching files: {str(e)}" | |
| def get_collection_files(collection_id: str, user_id: str) -> str: | |
| """Get list of files in the specified collection""" | |
| try: | |
| # Get the full collection name | |
| collection_name = f"{user_id}_{collection_id}" | |
| # Open the table and convert to pandas | |
| table = db.open_table(collection_name) | |
| df = table.to_pandas() | |
| print(df.head()) | |
| # Get unique file names | |
| unique_files = df['file_name'].unique() | |
| # Join the file names into a string | |
| return ", ".join(unique_files) | |
| except Exception as e: | |
| logging.error(f"Error getting collection files: {str(e)}") | |
| return f"Error getting files: {str(e)}" | |
| def format_for_model(state: dict, config: Optional[RunnableConfig] = None) -> list[BaseMessage]: | |
| """ | |
| Format the input state and config for the model. | |
| Args: | |
| state: The current state dictionary containing messages | |
| config: Optional RunnableConfig containing thread configuration | |
| Returns: | |
| Formatted messages for the model | |
| """ | |
| # Get collection_id and user_id from config instead of state | |
| thread_config = config.get("configurable", {}) if config else {} | |
| collection_id = thread_config.get("collection_id") | |
| user_id = thread_config.get("user_id") | |
| try: | |
| # Get files in the collection with timeout protection | |
| if collection_id and user_id: | |
| collection_files = get_collection_files(collection_id, user_id) | |
| else: | |
| collection_files = "No files available" | |
| logger.info(f"Fetching collection for userid {user_id} and collection_id {collection_id} || Results: {collection_files[:100]}...") | |
| # Format using the prompt template | |
| return prompt.invoke({ | |
| "collection_files": collection_files, | |
| "messages": state.get("messages", []) | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error in format_for_model: {str(e)}", exc_info=True) | |
| # Return a basic format if there's an error | |
| return prompt.invoke({ | |
| "collection_files": "Error fetching files", | |
| "messages": state.get("messages", []) | |
| }) | |
| async def clean_tool_input(tool_input: str): | |
| # Use regex to parse the first key and value | |
| pattern = r"{\s*'([^']+)':\s*'([^']+)'" | |
| match = re.search(pattern, tool_input) | |
| if match: | |
| key, value = match.groups() | |
| return {key: value} | |
| return [tool_input] | |
| async def clean_tool_response(tool_output: str): | |
| """Clean and extract relevant information from tool response if it contains query_documents.""" | |
| if "query_documents" in tool_output: | |
| try: | |
| # First safely evaluate the string as a Python literal | |
| import ast | |
| print(tool_output) | |
| # Extract the list string from the content | |
| start = tool_output.find("[{") | |
| end = tool_output.rfind("}]") + 2 | |
| if start >= 0 and end > 0: | |
| list_str = tool_output[start:end] | |
| # Convert string to Python object using ast.literal_eval | |
| results = ast.literal_eval(list_str) | |
| # Return only relevant fields | |
| return [{"text": r["text"], "document_id": r["metadata"]["document_id"]} | |
| for r in results] | |
| except SyntaxError as e: | |
| print(f"Syntax error in parsing: {e}") | |
| return f"Error parsing document results: {str(e)}" | |
| except Exception as e: | |
| print(f"General error: {e}") | |
| return f"Error processing results: {str(e)}" | |
| return tool_output | |
| agent = create_react_agent( | |
| model, | |
| tools=[query_documents], | |
| checkpointer=memory, | |
| state_modifier=format_for_model, | |
| ) | |
| class ChatInput(BaseModel): | |
| message: str | |
| thread_id: Optional[str] = None | |
| collection_id: Optional[str] = None | |
| user_id: Optional[str] = None | |
| async def chat(input_data: ChatInput): | |
| thread_id = input_data.thread_id or str(uuid.uuid4()) | |
| config = { | |
| "configurable": { | |
| "thread_id": thread_id, | |
| "collection_id": input_data.collection_id, | |
| "user_id": input_data.user_id | |
| } | |
| } | |
| input_message = HumanMessage(content=input_data.message) | |
| async def generate(): | |
| async for event in agent.astream_events( | |
| {"messages": [input_message]}, | |
| config, | |
| version="v2" | |
| ): | |
| kind = event["event"] | |
| if kind == "on_chat_model_stream": | |
| content = event["data"]["chunk"].content | |
| if content: | |
| yield f"{json.dumps({'type': 'token', 'content': content})}" | |
| elif kind == "on_tool_start": | |
| tool_input = str(event['data'].get('input', '')) | |
| yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}" | |
| elif kind == "on_tool_end": | |
| tool_output = str(event['data'].get('output', '')) | |
| yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}" | |
| return EventSourceResponse( | |
| generate(), | |
| media_type="text/event-stream" | |
| ) | |
| async def chat2(input_data: ChatInput): | |
| thread_id = input_data.thread_id or str(uuid.uuid4()) | |
| config = { | |
| "configurable": { | |
| "thread_id": thread_id, | |
| "collection_id": input_data.collection_id, | |
| "user_id": input_data.user_id | |
| } | |
| } | |
| input_message = HumanMessage(content=input_data.message) | |
| async def generate(): | |
| async for event in agent.astream_events( | |
| {"messages": [input_message]}, | |
| config, | |
| version="v2" | |
| ): | |
| kind = event["event"] | |
| if kind == "on_chat_model_stream": | |
| content = event["data"]["chunk"].content | |
| if content: | |
| yield f"{json.dumps({'type': 'token', 'content': content})}" | |
| elif kind == "on_tool_start": | |
| tool_name = event['name'] | |
| tool_input = event['data'].get('input', '') | |
| clean_input = await clean_tool_input(str(tool_input)) | |
| yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}" | |
| elif kind == "on_tool_end": | |
| if "query_documents" in event['name']: | |
| print(event) | |
| raw_output = await query_documents_raw(str(event['data'].get('input', '')), config) | |
| try: | |
| serializable_output = [ | |
| { | |
| "text": result.text, | |
| "distance": result.distance, | |
| "metadata": result.metadata | |
| } | |
| for result in raw_output | |
| ] | |
| yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}" | |
| except Exception as e: | |
| print(e) | |
| yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}" | |
| else: | |
| tool_name = event['name'] | |
| raw_output = str(event['data'].get('output', '')) | |
| clean_output = await clean_tool_response(raw_output) | |
| yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}" | |
| return EventSourceResponse( | |
| generate(), | |
| media_type="text/event-stream" | |
| ) | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |