Spaces:
Sleeping
Sleeping
rag-chat
#4
by
irashperera
- opened
- .gitignore +1 -0
- app.py +0 -56
- utils/create_vectordb.py +9 -21
.gitignore
CHANGED
|
@@ -3,3 +3,4 @@ venv
|
|
| 3 |
__pycache__
|
| 4 |
.vscode
|
| 5 |
corpus
|
|
|
|
|
|
| 3 |
__pycache__
|
| 4 |
.vscode
|
| 5 |
corpus
|
| 6 |
+
|
app.py
CHANGED
|
@@ -4,11 +4,6 @@ from langgraph.agents.rag_agent.graph import graph as rag_graph
|
|
| 4 |
from fastapi import Request
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
|
| 7 |
-
from langchain_core.documents import Document
|
| 8 |
-
|
| 9 |
-
from utils.create_vectordb import create_chroma_db_and_document,query_chroma_db
|
| 10 |
-
|
| 11 |
-
|
| 12 |
|
| 13 |
|
| 14 |
|
|
@@ -37,63 +32,12 @@ async def summarize(request: Request):
|
|
| 37 |
notes = data.get("notes")
|
| 38 |
return graph.invoke({"urls": urls, "codes": codes, "notes": notes})
|
| 39 |
|
| 40 |
-
|
| 41 |
-
@app.post("/save_summary")
|
| 42 |
-
async def save_summary(request: Request):
|
| 43 |
-
data = await request.json()
|
| 44 |
-
summary = data.get("summary", "")
|
| 45 |
-
post_id = data.get("post_id", None)
|
| 46 |
-
title = data.get("title", "")
|
| 47 |
-
category = data.get("category", "")
|
| 48 |
-
tags = data.get("tags", [])
|
| 49 |
-
references = data.get("references", [])
|
| 50 |
-
|
| 51 |
-
page_content = f"""
|
| 52 |
-
Title: {title}
|
| 53 |
-
Category: {category}
|
| 54 |
-
Tags: {', '.join(tags)}
|
| 55 |
-
Summary: {summary}
|
| 56 |
-
"""
|
| 57 |
-
|
| 58 |
-
document = Document(
|
| 59 |
-
page_content=page_content,
|
| 60 |
-
id = str(post_id)
|
| 61 |
-
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
is_added = create_chroma_db_and_document(document)
|
| 65 |
-
|
| 66 |
-
if not is_added:
|
| 67 |
-
return {"error": "Failed to save summary to the database." , "status": "error"}
|
| 68 |
-
|
| 69 |
-
return {"message": "Summary saved successfully." , "status": "success"}
|
| 70 |
-
|
| 71 |
-
@app.post("/summaries")
|
| 72 |
-
async def get_summaries(request: Request):
|
| 73 |
-
|
| 74 |
-
data = await request.json()
|
| 75 |
-
print(data)
|
| 76 |
-
query = data.get("query" , "")
|
| 77 |
-
|
| 78 |
-
print(f"Query received: {query}")
|
| 79 |
-
results = query_chroma_db(query=query)
|
| 80 |
-
return results
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
@app.post("/chat")
|
| 86 |
async def chat(request: Request):
|
| 87 |
data = await request.json()
|
| 88 |
-
|
| 89 |
-
print(f"Chat request data: {data}")
|
| 90 |
-
|
| 91 |
user_input = data.get("message", "")
|
| 92 |
chat_history = data.get("chat_history", [])
|
| 93 |
|
| 94 |
-
print(f"User input: {user_input}")
|
| 95 |
-
print(f"Chat history: {chat_history}")
|
| 96 |
-
|
| 97 |
# Invoke the RAG chatbot graph
|
| 98 |
result = rag_graph.invoke({
|
| 99 |
"user_input": user_input,
|
|
|
|
| 4 |
from fastapi import Request
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
|
|
|
|
| 32 |
notes = data.get("notes")
|
| 33 |
return graph.invoke({"urls": urls, "codes": codes, "notes": notes})
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
@app.post("/chat")
|
| 36 |
async def chat(request: Request):
|
| 37 |
data = await request.json()
|
|
|
|
|
|
|
|
|
|
| 38 |
user_input = data.get("message", "")
|
| 39 |
chat_history = data.get("chat_history", [])
|
| 40 |
|
|
|
|
|
|
|
|
|
|
| 41 |
# Invoke the RAG chatbot graph
|
| 42 |
result = rag_graph.invoke({
|
| 43 |
"user_input": user_input,
|
utils/create_vectordb.py
CHANGED
|
@@ -54,7 +54,7 @@ def split_documents(documents, chunk_size=1000, chunk_overlap=200):
|
|
| 54 |
|
| 55 |
return splits
|
| 56 |
|
| 57 |
-
def
|
| 58 |
"""Create a Chroma vector database from documents."""
|
| 59 |
# Initialize the Gemini embedding function
|
| 60 |
gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
|
|
@@ -75,25 +75,17 @@ def create_chroma_db_and_document(document, collection_name="corpus_collection",
|
|
| 75 |
embedding_function=gemini_ef
|
| 76 |
)
|
| 77 |
print(f"Created new collection: {collection_name}")
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
try:
|
| 81 |
|
|
|
|
|
|
|
| 82 |
collection.add(
|
| 83 |
-
documents
|
| 84 |
-
|
|
|
|
| 85 |
)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
except Exception as e:
|
| 91 |
-
print(f"Error adding document to collection: {e}")
|
| 92 |
-
|
| 93 |
-
return False
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
|
| 98 |
def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR):
|
| 99 |
"""Query the Chroma vector database."""
|
|
@@ -144,10 +136,6 @@ def main():
|
|
| 144 |
print(f"Source: {metadata.get('source', 'Unknown')}")
|
| 145 |
|
| 146 |
print("\nVector database creation and testing complete!")
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
|
| 152 |
if __name__ == "__main__":
|
| 153 |
main()
|
|
|
|
| 54 |
|
| 55 |
return splits
|
| 56 |
|
| 57 |
+
def create_chroma_db(documents, collection_name="corpus_collection", db_dir=DB_DIR):
|
| 58 |
"""Create a Chroma vector database from documents."""
|
| 59 |
# Initialize the Gemini embedding function
|
| 60 |
gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
|
|
|
|
| 75 |
embedding_function=gemini_ef
|
| 76 |
)
|
| 77 |
print(f"Created new collection: {collection_name}")
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# Add documents to collection
|
| 80 |
+
for i, doc in enumerate(documents):
|
| 81 |
collection.add(
|
| 82 |
+
documents=[doc.page_content],
|
| 83 |
+
metadatas=[doc.metadata],
|
| 84 |
+
ids=[f"doc_{i}"]
|
| 85 |
)
|
| 86 |
+
|
| 87 |
+
print(f"Added {len(documents)} documents to collection {collection_name}")
|
| 88 |
+
return collection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR):
|
| 91 |
"""Query the Chroma vector database."""
|
|
|
|
| 136 |
print(f"Source: {metadata.get('source', 'Unknown')}")
|
| 137 |
|
| 138 |
print("\nVector database creation and testing complete!")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
if __name__ == "__main__":
|
| 141 |
main()
|