Spaces:
Sleeping
Sleeping
| import uuid | |
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| HumanMessage, | |
| trim_messages, | |
| ) | |
| from langchain_core.tools import tool | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.prebuilt import create_react_agent | |
| from pydantic import BaseModel | |
| import json | |
| from typing import Optional, Annotated | |
| from langchain_core.runnables import RunnableConfig | |
| from langgraph.prebuilt import InjectedState | |
| from document_rag_router import router as document_rag_router | |
| from document_rag_router import QueryInput, query_collection, SearchResult | |
| from fastapi import HTTPException | |
| import requests | |
| from sse_starlette.sse import EventSourceResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import re | |
| app = FastAPI() | |
| app.include_router(document_rag_router) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_user_age(name: str) -> str: | |
| """Use this tool to find the user's age.""" | |
| if "bob" in name.lower(): | |
| return "42 years old" | |
| return "41 years old" | |
| async def query_documents( | |
| query: str, | |
| config: RunnableConfig, | |
| #state: Annotated[dict, InjectedState] | |
| ) -> str: | |
| """Use this tool to retrieve relevant data from the collection. | |
| Args: | |
| query: The search query to find relevant document passages | |
| """ | |
| # Get collection_id and user_id from config | |
| thread_config = config.get("configurable", {}) | |
| collection_id = thread_config.get("collection_id") | |
| user_id = thread_config.get("user_id") | |
| if not collection_id or not user_id: | |
| return "Error: collection_id and user_id are required in the config" | |
| try: | |
| # Create query input | |
| input_data = QueryInput( | |
| collection_id=collection_id, | |
| query=query, | |
| user_id=user_id, | |
| top_k=6 | |
| ) | |
| response = await query_collection(input_data) | |
| results = [] | |
| # Access response directly since it's a Pydantic model | |
| for r in response.results: | |
| result_dict = { | |
| "text": r.text, | |
| "distance": r.distance, | |
| "metadata": { | |
| "document_id": r.metadata.get("document_id"), | |
| "chunk_index": r.metadata.get("location", {}).get("chunk_index") | |
| } | |
| } | |
| results.append(result_dict) | |
| return str(results) | |
| except Exception as e: | |
| print(e) | |
| return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" | |
| async def query_documents_raw( | |
| query: str, | |
| config: RunnableConfig, | |
| #state: Annotated[dict, InjectedState] | |
| ) -> SearchResult: | |
| """Use this tool to retrieve relevant data from the collection. | |
| Args: | |
| query: The search query to find relevant document passages | |
| """ | |
| # Get collection_id and user_id from config | |
| thread_config = config.get("configurable", {}) | |
| collection_id = thread_config.get("collection_id") | |
| user_id = thread_config.get("user_id") | |
| if not collection_id or not user_id: | |
| return "Error: collection_id and user_id are required in the config" | |
| try: | |
| # Create query input | |
| input_data = QueryInput( | |
| collection_id=collection_id, | |
| query=query, | |
| user_id=user_id, | |
| top_k=6 | |
| ) | |
| response = await query_collection(input_data) | |
| return response.results | |
| except Exception as e: | |
| print(e) | |
| return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" | |
| memory = MemorySaver() | |
| model = ChatOpenAI(model="gpt-4o-mini", streaming=True) | |
| def state_modifier(state) -> list[BaseMessage]: | |
| return trim_messages( | |
| state["messages"], | |
| token_counter=len, | |
| max_tokens=16000, | |
| strategy="last", | |
| start_on="human", | |
| include_system=True, | |
| allow_partial=False, | |
| ) | |
| agent = create_react_agent( | |
| model, | |
| tools=[query_documents], | |
| checkpointer=memory, | |
| state_modifier=state_modifier, | |
| ) | |
| class ChatInput(BaseModel): | |
| message: str | |
| thread_id: Optional[str] = None | |
| collection_id: Optional[str] = None | |
| user_id: Optional[str] = None | |
| async def chat(input_data: ChatInput): | |
| thread_id = input_data.thread_id or str(uuid.uuid4()) | |
| config = { | |
| "configurable": { | |
| "thread_id": thread_id, | |
| "collection_id": input_data.collection_id, | |
| "user_id": input_data.user_id | |
| } | |
| } | |
| input_message = HumanMessage(content=input_data.message) | |
| async def generate(): | |
| async for event in agent.astream_events( | |
| {"messages": [input_message]}, | |
| config, | |
| version="v2" | |
| ): | |
| kind = event["event"] | |
| if kind == "on_chat_model_stream": | |
| content = event["data"]["chunk"].content | |
| if content: | |
| yield f"{json.dumps({'type': 'token', 'content': content})}" | |
| elif kind == "on_tool_start": | |
| tool_input = str(event['data'].get('input', '')) | |
| yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}" | |
| elif kind == "on_tool_end": | |
| tool_output = str(event['data'].get('output', '')) | |
| yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}" | |
| return EventSourceResponse( | |
| generate(), | |
| media_type="text/event-stream" | |
| ) | |
| async def clean_tool_input(tool_input: str): | |
| # Use regex to parse the first key and value | |
| pattern = r"{\s*'([^']+)':\s*'([^']+)'" | |
| match = re.search(pattern, tool_input) | |
| if match: | |
| key, value = match.groups() | |
| return {key: value} | |
| return [tool_input] | |
| async def clean_tool_response(tool_output: str): | |
| """Clean and extract relevant information from tool response if it contains query_documents.""" | |
| if "query_documents" in tool_output: | |
| try: | |
| # First safely evaluate the string as a Python literal | |
| import ast | |
| print(tool_output) | |
| # Extract the list string from the content | |
| start = tool_output.find("[{") | |
| end = tool_output.rfind("}]") + 2 | |
| if start >= 0 and end > 0: | |
| list_str = tool_output[start:end] | |
| # Convert string to Python object using ast.literal_eval | |
| results = ast.literal_eval(list_str) | |
| # Return only relevant fields | |
| return [{"text": r["text"], "document_id": r["metadata"]["document_id"]} | |
| for r in results] | |
| except SyntaxError as e: | |
| print(f"Syntax error in parsing: {e}") | |
| return f"Error parsing document results: {str(e)}" | |
| except Exception as e: | |
| print(f"General error: {e}") | |
| return f"Error processing results: {str(e)}" | |
| return tool_output | |
| async def chat2(input_data: ChatInput): | |
| thread_id = input_data.thread_id or str(uuid.uuid4()) | |
| config = { | |
| "configurable": { | |
| "thread_id": thread_id, | |
| "collection_id": input_data.collection_id, | |
| "user_id": input_data.user_id | |
| } | |
| } | |
| input_message = HumanMessage(content=input_data.message) | |
| async def generate(): | |
| async for event in agent.astream_events( | |
| {"messages": [input_message]}, | |
| config, | |
| version="v2" | |
| ): | |
| kind = event["event"] | |
| if kind == "on_chat_model_stream": | |
| content = event["data"]["chunk"].content | |
| if content: | |
| yield f"{json.dumps({'type': 'token', 'content': content})}" | |
| elif kind == "on_tool_start": | |
| tool_name = event['name'] | |
| tool_input = event['data'].get('input', '') | |
| clean_input = await clean_tool_input(str(tool_input)) | |
| yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}" | |
| elif kind == "on_tool_end": | |
| if "query_documents" in event['name']: | |
| print(event) | |
| raw_output = await query_documents_raw(str(event['data'].get('input', '')), config) | |
| try: | |
| serializable_output = [ | |
| { | |
| "text": result.text, | |
| "distance": result.distance, | |
| "metadata": result.metadata | |
| } | |
| for result in raw_output | |
| ] | |
| yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}" | |
| except Exception as e: | |
| print(e) | |
| yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}" | |
| else: | |
| tool_name = event['name'] | |
| raw_output = str(event['data'].get('output', '')) | |
| clean_output = await clean_tool_response(raw_output) | |
| yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}" | |
| return EventSourceResponse( | |
| generate(), | |
| media_type="text/event-stream" | |
| ) | |
| async def health_check(): | |
| return {"status": "healthy"} |