Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -17,7 +17,7 @@ from typing import Optional, Annotated
|
|
| 17 |
from langchain_core.runnables import RunnableConfig
|
| 18 |
from langgraph.prebuilt import InjectedState
|
| 19 |
from document_rag_router import router as document_rag_router
|
| 20 |
-
from document_rag_router import QueryInput, query_collection, SearchResult
|
| 21 |
from fastapi import HTTPException
|
| 22 |
import requests
|
| 23 |
from sse_starlette.sse import EventSourceResponse
|
|
@@ -136,13 +136,40 @@ model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
|
|
| 136 |
|
| 137 |
# Create a prompt template for formatting
|
| 138 |
prompt = ChatPromptTemplate.from_messages([
|
| 139 |
-
("system", "You are a helpful AI assistant.
|
| 140 |
("placeholder", "{messages}"),
|
| 141 |
])
|
| 142 |
|
| 143 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
return prompt.invoke({
|
| 145 |
-
"
|
| 146 |
"messages": state["messages"]
|
| 147 |
})
|
| 148 |
|
|
|
|
| 17 |
from langchain_core.runnables import RunnableConfig
|
| 18 |
from langgraph.prebuilt import InjectedState
|
| 19 |
from document_rag_router import router as document_rag_router
|
| 20 |
+
from document_rag_router import QueryInput, query_collection, SearchResult,db
|
| 21 |
from fastapi import HTTPException
|
| 22 |
import requests
|
| 23 |
from sse_starlette.sse import EventSourceResponse
|
|
|
|
| 136 |
|
| 137 |
# Create a prompt template for formatting
|
| 138 |
prompt = ChatPromptTemplate.from_messages([
|
| 139 |
+
("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}"),
|
| 140 |
("placeholder", "{messages}"),
|
| 141 |
])
|
| 142 |
|
| 143 |
+
async def get_collection_files(collection_id: str, user_id: str) -> str:
|
| 144 |
+
"""Get list of files in the specified collection"""
|
| 145 |
+
try:
|
| 146 |
+
# Get the full collection name
|
| 147 |
+
collection_name = f"{user_id}_{collection_id}"
|
| 148 |
+
|
| 149 |
+
# Open the table and convert to pandas
|
| 150 |
+
table = db.open_table(collection_name)
|
| 151 |
+
df = table.to_pandas()
|
| 152 |
+
|
| 153 |
+
# Get unique file names
|
| 154 |
+
unique_files = df['file_name'].unique()
|
| 155 |
+
|
| 156 |
+
# Join the file names into a string
|
| 157 |
+
return ", ".join(unique_files)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logging.error(f"Error getting collection files: {str(e)}")
|
| 160 |
+
return f"Error getting files: {str(e)}"
|
| 161 |
+
|
| 162 |
+
async def format_for_model(state):
|
| 163 |
+
# Get collection_id and user_id from the state's configurable
|
| 164 |
+
config = state.get("configurable", {})
|
| 165 |
+
collection_id = config.get("collection_id")
|
| 166 |
+
user_id = config.get("user_id")
|
| 167 |
+
|
| 168 |
+
# Get files in the collection
|
| 169 |
+
collection_files = await get_collection_files(collection_id, user_id) if collection_id and user_id else "No files available"
|
| 170 |
+
|
| 171 |
return prompt.invoke({
|
| 172 |
+
"collection_files": collection_files,
|
| 173 |
"messages": state["messages"]
|
| 174 |
})
|
| 175 |
|