Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files- Dockerfile +25 -0
- chat_management.py +94 -0
- document_loaders.py +193 -0
- embedding.py +11 -0
- llm_initialization.py +17 -0
- main.py +355 -0
- prompt_templates.py +242 -0
- requirements.txt +23 -0
- retrieval_chain.py +110 -0
- text_splitter.py +30 -0
- vector_store.py +234 -0
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Base image using Python 3.9
|
| 2 |
+
FROM python:3.9
|
| 3 |
+
|
| 4 |
+
# Create a new user to run the app
|
| 5 |
+
RUN useradd -m -u 1000 user
|
| 6 |
+
USER user
|
| 7 |
+
|
| 8 |
+
# Set environment variables
|
| 9 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 10 |
+
|
| 11 |
+
# Set the working directory
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
# Copy the requirements and install dependencies
|
| 15 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 16 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy the rest of the application
|
| 19 |
+
COPY --chown=user . /app
|
| 20 |
+
|
| 21 |
+
# Expose port 7860 for the application
|
| 22 |
+
EXPOSE 7860
|
| 23 |
+
|
| 24 |
+
# Command to run the FastAPI app using uvicorn
|
| 25 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
chat_management.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from pymongo import MongoClient
|
| 3 |
+
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ChatManagement:
|
| 7 |
+
def __init__(self, cluster_url, database_name, collection_name):
|
| 8 |
+
"""
|
| 9 |
+
Initializes the ChatManagement class with MongoDB connection details.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
cluster_url (str): MongoDB cluster URL.
|
| 13 |
+
database_name (str): Name of the database.
|
| 14 |
+
collection_name (str): Name of the collection.
|
| 15 |
+
"""
|
| 16 |
+
self.connection_string = cluster_url
|
| 17 |
+
self.database_name = database_name
|
| 18 |
+
self.collection_name = collection_name
|
| 19 |
+
self.chat_sessions = {} # Dictionary to store chat history objects for each session
|
| 20 |
+
|
| 21 |
+
def create_new_chat(self):
|
| 22 |
+
"""
|
| 23 |
+
Creates a new chat session by initializing a MongoDBChatMessageHistory object.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
str: The unique chat ID.
|
| 27 |
+
"""
|
| 28 |
+
# Generate a unique chat ID
|
| 29 |
+
chat_id = str(uuid.uuid4())
|
| 30 |
+
|
| 31 |
+
# Initialize MongoDBChatMessageHistory for the chat session
|
| 32 |
+
chat_message_history = MongoDBChatMessageHistory(
|
| 33 |
+
session_id=chat_id,
|
| 34 |
+
connection_string=self.connection_string,
|
| 35 |
+
database_name=self.database_name,
|
| 36 |
+
collection_name=self.collection_name,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Store the chat_message_history object in the session dictionary
|
| 40 |
+
self.chat_sessions[chat_id] = chat_message_history
|
| 41 |
+
return chat_id
|
| 42 |
+
|
| 43 |
+
def get_chat_history(self, chat_id):
|
| 44 |
+
"""
|
| 45 |
+
Retrieves the MongoDBChatMessageHistory object for a given chat session by its chat ID.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
chat_id (str): The unique ID of the chat session.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
MongoDBChatMessageHistory or None: The chat history object of the chat session, or None if not found.
|
| 52 |
+
"""
|
| 53 |
+
# Check if the chat session is already in memory
|
| 54 |
+
if chat_id in self.chat_sessions:
|
| 55 |
+
return self.chat_sessions[chat_id]
|
| 56 |
+
|
| 57 |
+
# If not in memory, try to fetch from the database
|
| 58 |
+
chat_message_history = MongoDBChatMessageHistory(
|
| 59 |
+
session_id=chat_id,
|
| 60 |
+
connection_string=self.connection_string,
|
| 61 |
+
database_name=self.database_name,
|
| 62 |
+
collection_name=self.collection_name,
|
| 63 |
+
)
|
| 64 |
+
if chat_message_history.messages: # Check if the session exists in the database
|
| 65 |
+
self.chat_sessions[chat_id] = chat_message_history
|
| 66 |
+
return chat_message_history
|
| 67 |
+
|
| 68 |
+
return None # Chat session not found
|
| 69 |
+
|
| 70 |
+
def initialize_chat_history(self, chat_id):
|
| 71 |
+
"""
|
| 72 |
+
Initializes a new chat history for the given chat ID if it does not already exist.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
chat_id (str): The unique ID of the chat session.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
MongoDBChatMessageHistory: The initialized chat history object.
|
| 79 |
+
"""
|
| 80 |
+
# If the chat history already exists, return it
|
| 81 |
+
if chat_id in self.chat_sessions:
|
| 82 |
+
return self.chat_sessions[chat_id]
|
| 83 |
+
|
| 84 |
+
# Otherwise, create a new chat history
|
| 85 |
+
chat_message_history = MongoDBChatMessageHistory(
|
| 86 |
+
session_id=chat_id,
|
| 87 |
+
connection_string=self.connection_string,
|
| 88 |
+
database_name=self.database_name,
|
| 89 |
+
collection_name=self.collection_name,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Save the new chat session to the session dictionary
|
| 93 |
+
self.chat_sessions[chat_id] = chat_message_history
|
| 94 |
+
return chat_message_history
|
document_loaders.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.document_loaders import (CSVLoader, WikipediaLoader, UnstructuredURLLoader,
|
| 2 |
+
YoutubeLoader, PyPDFLoader, BSHTMLLoader,
|
| 3 |
+
Docx2txtLoader, UnstructuredMarkdownLoader)
|
| 4 |
+
|
| 5 |
+
from langchain_unstructured import UnstructuredLoader
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DocumentLoader:
|
| 9 |
+
def load_unstructured(self, path):
|
| 10 |
+
"""
|
| 11 |
+
Load data from a file at the specified path:
|
| 12 |
+
|
| 13 |
+
supported files:
|
| 14 |
+
"csv", "doc", "docx", "epub", "image", "md", "msg", "odt", "org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
path (str): The file paths
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
The loaded data.
|
| 22 |
+
|
| 23 |
+
Exceptions:
|
| 24 |
+
Prints an error message if the loading fails.
|
| 25 |
+
"""
|
| 26 |
+
try:
|
| 27 |
+
loader = UnstructuredLoader(path)
|
| 28 |
+
data = loader.load()
|
| 29 |
+
return data
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Error loading Unstructured: {e}")
|
| 32 |
+
|
| 33 |
+
def load_csv(self, path):
|
| 34 |
+
"""
|
| 35 |
+
Load data from a CSV file at the specified path.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
path (str): The file path to the CSV file.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
The loaded CSV data.
|
| 42 |
+
|
| 43 |
+
Exceptions:
|
| 44 |
+
Prints an error message if the CSV loading fails.
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
loader = CSVLoader(file_path=path)
|
| 48 |
+
data = loader.load()
|
| 49 |
+
return data
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error loading CSV: {e}")
|
| 52 |
+
|
| 53 |
+
def wikipedia_query(self, search_query):
|
| 54 |
+
"""
|
| 55 |
+
Query Wikipedia using a given search term and return the results.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
search_query (str): The search term to query on Wikipedia.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
The query results.
|
| 62 |
+
|
| 63 |
+
Exceptions:
|
| 64 |
+
Prints an error message if the Wikipedia query fails.
|
| 65 |
+
"""
|
| 66 |
+
try:
|
| 67 |
+
data = WikipediaLoader(query=search_query, load_max_docs=2).load()
|
| 68 |
+
return data
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error querying Wikipedia: {e}")
|
| 71 |
+
|
| 72 |
+
def load_urls(self, urls):
|
| 73 |
+
"""
|
| 74 |
+
Load and parse content from a list of URLs.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
urls (list): A list of URLs to load.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
The loaded data from the URLs.
|
| 81 |
+
|
| 82 |
+
Exceptions:
|
| 83 |
+
Prints an error message if loading URLs fails.
|
| 84 |
+
"""
|
| 85 |
+
try:
|
| 86 |
+
loader = UnstructuredURLLoader(urls=urls)
|
| 87 |
+
data = loader.load()
|
| 88 |
+
return data
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"Error loading URLs: {e}")
|
| 91 |
+
|
| 92 |
+
def load_YouTubeVideo(self, urls):
|
| 93 |
+
"""
|
| 94 |
+
Load YouTube video information from provided URLs.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
urls (list): A list of YouTube video URLs.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
The loaded documents from the YouTube URLs.
|
| 101 |
+
|
| 102 |
+
Exceptions:
|
| 103 |
+
Prints an error message if loading YouTube videos fails.
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
loader = YoutubeLoader.from_youtube_url(
|
| 107 |
+
urls, add_video_info=True, language=["en", "pt", "zh-Hans", "es", "ur", "hi"],
|
| 108 |
+
translation="en")
|
| 109 |
+
documents = loader.load()
|
| 110 |
+
return documents
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Error loading YouTube video: {e}")
|
| 113 |
+
|
| 114 |
+
def load_pdf(self, path):
|
| 115 |
+
"""
|
| 116 |
+
Load data from a PDF file at the specified path.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
path (str): The file path to the PDF file.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
The loaded and split PDF pages.
|
| 123 |
+
|
| 124 |
+
Exceptions:
|
| 125 |
+
Prints an error message if the PDF loading fails.
|
| 126 |
+
"""
|
| 127 |
+
try:
|
| 128 |
+
loader = PyPDFLoader(path)
|
| 129 |
+
pages = loader.load_and_split()
|
| 130 |
+
return pages
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"Error loading PDF: {e}")
|
| 133 |
+
|
| 134 |
+
def load_text_from_html(self, path):
|
| 135 |
+
"""
|
| 136 |
+
Load and parse text content from an HTML file at the specified path.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
path (str): The file path to the HTML file.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
The loaded HTML data.
|
| 143 |
+
|
| 144 |
+
Exceptions:
|
| 145 |
+
Prints an error message if loading text from HTML fails.
|
| 146 |
+
"""
|
| 147 |
+
try:
|
| 148 |
+
loader = BSHTMLLoader(path)
|
| 149 |
+
data = loader.load()
|
| 150 |
+
return data
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"Error loading text from HTML: {e}")
|
| 153 |
+
|
| 154 |
+
def load_markdown(self, path):
|
| 155 |
+
"""
|
| 156 |
+
Load data from a Markdown file at the specified path.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
path (str): The file path to the Markdown file.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
The loaded Markdown data.
|
| 163 |
+
|
| 164 |
+
Exceptions:
|
| 165 |
+
Prints an error message if loading Markdown fails.
|
| 166 |
+
"""
|
| 167 |
+
try:
|
| 168 |
+
loader = UnstructuredMarkdownLoader(path)
|
| 169 |
+
data = loader.load()
|
| 170 |
+
return data
|
| 171 |
+
except Exception as e:
|
| 172 |
+
print(f"Error loading Markdown: {e}")
|
| 173 |
+
|
| 174 |
+
def load_doc(self, path):
|
| 175 |
+
"""
|
| 176 |
+
Load data from a DOCX file at the specified path.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
path (str): The file path to the DOCX file.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
The loaded DOCX data.
|
| 183 |
+
|
| 184 |
+
Exceptions:
|
| 185 |
+
Prints an error message if loading DOCX fails.
|
| 186 |
+
"""
|
| 187 |
+
try:
|
| 188 |
+
loader = Docx2txtLoader(path)
|
| 189 |
+
data = loader.load()
|
| 190 |
+
return data
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(f"Error loading DOCX: {e}")
|
| 193 |
+
|
embedding.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
|
| 3 |
+
def get_embeddings():
|
| 4 |
+
# Initialize HuggingFace embeddings
|
| 5 |
+
model_name = "BAAI/bge-small-en"
|
| 6 |
+
model_kwargs = {"device": "cpu"}
|
| 7 |
+
encode_kwargs = {"normalize_embeddings": True}
|
| 8 |
+
embeddings = HuggingFaceEmbeddings(
|
| 9 |
+
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
|
| 10 |
+
)
|
| 11 |
+
return embeddings
|
llm_initialization.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_groq import ChatGroq
|
| 2 |
+
|
| 3 |
+
def get_llm():
|
| 4 |
+
"""
|
| 5 |
+
Returns the language model instance (LLM) using ChatGroq API.
|
| 6 |
+
The LLM used is Llama 3.1 with a versatile 70 billion parameters model.
|
| 7 |
+
|
| 8 |
+
Returns:
|
| 9 |
+
llm (ChatGroq): An instance of the ChatGroq LLM.
|
| 10 |
+
"""
|
| 11 |
+
llm = ChatGroq(
|
| 12 |
+
model="llama-3.3-70b-versatile",
|
| 13 |
+
temperature=0,
|
| 14 |
+
max_tokens=1024,
|
| 15 |
+
api_key='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'
|
| 16 |
+
)
|
| 17 |
+
return llm
|
main.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import zipfile
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
| 7 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
| 8 |
+
|
| 9 |
+
from llm_initialization import get_llm
|
| 10 |
+
from embedding import get_embeddings
|
| 11 |
+
from document_loaders import DocumentLoader
|
| 12 |
+
from text_splitter import TextSplitter
|
| 13 |
+
from vector_store import VectorStoreManager
|
| 14 |
+
from prompt_templates import PromptTemplates
|
| 15 |
+
from chat_management import ChatManagement
|
| 16 |
+
from retrieval_chain import RetrievalChain
|
| 17 |
+
from urllib.parse import quote_plus
|
| 18 |
+
from dotenv import load_dotenv
|
| 19 |
+
from pymongo import MongoClient
|
| 20 |
+
|
| 21 |
+
# Load environment variables
|
| 22 |
+
load_dotenv()
|
| 23 |
+
MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD"))
|
| 24 |
+
MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME")
|
| 25 |
+
MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME")
|
| 26 |
+
MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING")
|
| 27 |
+
|
| 28 |
+
app = FastAPI(title="VectorStore & Document Management API")
|
| 29 |
+
|
| 30 |
+
# Global variables (initialized on startup)
|
| 31 |
+
llm = None
|
| 32 |
+
embeddings = None
|
| 33 |
+
chat_manager = None
|
| 34 |
+
document_loader = None
|
| 35 |
+
text_splitter = None
|
| 36 |
+
vector_store_manager = None
|
| 37 |
+
vector_store = None
|
| 38 |
+
k = 3 # Number of documents to retrieve per query
|
| 39 |
+
|
| 40 |
+
# Global MongoDB collection to store retrieval chain configuration per chat session.
|
| 41 |
+
chat_chains_collection = None
|
| 42 |
+
|
| 43 |
+
# ----------------------- Startup Event -----------------------
|
| 44 |
+
@app.on_event("startup")
|
| 45 |
+
async def startup_event():
|
| 46 |
+
global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store, chat_chains_collection
|
| 47 |
+
|
| 48 |
+
print("Starting up: Initializing components...")
|
| 49 |
+
|
| 50 |
+
# Initialize LLM and embeddings
|
| 51 |
+
llm = get_llm()
|
| 52 |
+
print("LLM initialized.")
|
| 53 |
+
embeddings = get_embeddings()
|
| 54 |
+
print("Embeddings initialized.")
|
| 55 |
+
|
| 56 |
+
# Setup chat management
|
| 57 |
+
chat_manager = ChatManagement(
|
| 58 |
+
cluster_url=MONGO_CLUSTER_URL,
|
| 59 |
+
database_name=MONGO_DATABASE_NAME,
|
| 60 |
+
collection_name=MONGO_COLLECTION_NAME,
|
| 61 |
+
)
|
| 62 |
+
print("Chat management initialized.")
|
| 63 |
+
|
| 64 |
+
# Initialize document loader and text splitter
|
| 65 |
+
document_loader = DocumentLoader()
|
| 66 |
+
text_splitter = TextSplitter()
|
| 67 |
+
print("Document loader and text splitter initialized.")
|
| 68 |
+
|
| 69 |
+
# Initialize vector store manager and ensure vectorstore is set
|
| 70 |
+
vector_store_manager = VectorStoreManager(embeddings)
|
| 71 |
+
vector_store = vector_store_manager.vectorstore # Now properly initialized
|
| 72 |
+
print("Vector store initialized.")
|
| 73 |
+
|
| 74 |
+
# Connect to MongoDB and get the collection.
|
| 75 |
+
client = MongoClient(MONGO_CLUSTER_URL)
|
| 76 |
+
db = client[MONGO_DATABASE_NAME]
|
| 77 |
+
chat_chains_collection = db["chat_chains"]
|
| 78 |
+
print("Chat chains collection initialized in MongoDB.")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ----------------------- Root Endpoint -----------------------
|
| 82 |
+
@app.get("/")
|
| 83 |
+
def root():
|
| 84 |
+
"""
|
| 85 |
+
Root endpoint that returns a welcome message.
|
| 86 |
+
"""
|
| 87 |
+
return {"message": "Welcome to the VectorStore & Document Management API!"}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ----------------------- New Chat Endpoint -----------------------
|
| 91 |
+
@app.post("/new_chat")
|
| 92 |
+
def new_chat():
|
| 93 |
+
"""
|
| 94 |
+
Create a new chat session.
|
| 95 |
+
"""
|
| 96 |
+
new_chat_id = chat_manager.create_new_chat()
|
| 97 |
+
return {"chat_id": new_chat_id}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ----------------------- Create Chain Endpoint -----------------------
|
| 101 |
+
@app.post("/create_chain")
|
| 102 |
+
def create_chain(
|
| 103 |
+
chat_id: str = Query(..., description="Existing chat session ID"),
|
| 104 |
+
template: str = Query(
|
| 105 |
+
"quiz_solving",
|
| 106 |
+
description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation",
|
| 107 |
+
),
|
| 108 |
+
):
|
| 109 |
+
global chat_chains_collection # Ensure we reference the global variable
|
| 110 |
+
|
| 111 |
+
valid_templates = [
|
| 112 |
+
"quiz_solving",
|
| 113 |
+
"assignment_solving",
|
| 114 |
+
"paper_solving",
|
| 115 |
+
"quiz_creation",
|
| 116 |
+
"assignment_creation",
|
| 117 |
+
"paper_creation",
|
| 118 |
+
]
|
| 119 |
+
if template not in valid_templates:
|
| 120 |
+
raise HTTPException(status_code=400, detail="Invalid template selection.")
|
| 121 |
+
|
| 122 |
+
# Upsert the configuration document for this chat session.
|
| 123 |
+
chat_chains_collection.update_one(
|
| 124 |
+
{"chat_id": chat_id}, {"$set": {"template": template}}, upsert=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ----------------------- Chat Endpoint -----------------------
|
| 131 |
+
@app.get("/chat")
|
| 132 |
+
def chat(query: str, chat_id: str = Query(..., description="Chat session ID created via /new_chat and configured via /create_chain")):
|
| 133 |
+
"""
|
| 134 |
+
Process a chat query using the retrieval chain associated with the given chat_id.
|
| 135 |
+
|
| 136 |
+
This endpoint uses the following code:
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
stream_generator = retrieval_chain.stream_chat_response(
|
| 140 |
+
query=query,
|
| 141 |
+
chat_id=chat_id,
|
| 142 |
+
get_chat_history=chat_manager.get_chat_history,
|
| 143 |
+
initialize_chat_history=chat_manager.initialize_chat_history,
|
| 144 |
+
)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
|
| 147 |
+
|
| 148 |
+
return StreamingResponse(stream_generator, media_type="text/event-stream")
|
| 149 |
+
|
| 150 |
+
It first retrieves the configuration from MongoDB, re-creates the chain, and then streams the response.
|
| 151 |
+
"""
|
| 152 |
+
# Retrieve the chat configuration from MongoDB.
|
| 153 |
+
config = chat_chains_collection.find_one({"chat_id": chat_id})
|
| 154 |
+
if not config:
|
| 155 |
+
raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.")
|
| 156 |
+
|
| 157 |
+
template = config.get("template", "quiz_solving")
|
| 158 |
+
if template == "quiz_solving":
|
| 159 |
+
prompt = PromptTemplates.get_quiz_solving_prompt()
|
| 160 |
+
elif template == "assignment_solving":
|
| 161 |
+
prompt = PromptTemplates.get_assignment_solving_prompt()
|
| 162 |
+
elif template == "paper_solving":
|
| 163 |
+
prompt = PromptTemplates.get_paper_solving_prompt()
|
| 164 |
+
elif template == "quiz_creation":
|
| 165 |
+
prompt = PromptTemplates.get_quiz_creation_prompt()
|
| 166 |
+
elif template == "assignment_creation":
|
| 167 |
+
prompt = PromptTemplates.get_assignment_creation_prompt()
|
| 168 |
+
elif template == "paper_creation":
|
| 169 |
+
prompt = PromptTemplates.get_paper_creation_prompt()
|
| 170 |
+
else:
|
| 171 |
+
raise HTTPException(status_code=400, detail="Invalid chat configuration.")
|
| 172 |
+
|
| 173 |
+
# Re-create the retrieval chain for this chat session.
|
| 174 |
+
retrieval_chain = RetrievalChain(
|
| 175 |
+
llm,
|
| 176 |
+
vector_store.as_retriever(search_kwargs={"k": k}),
|
| 177 |
+
prompt,
|
| 178 |
+
verbose=True,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
stream_generator = retrieval_chain.stream_chat_response(
|
| 183 |
+
query=query,
|
| 184 |
+
chat_id=chat_id,
|
| 185 |
+
get_chat_history=chat_manager.get_chat_history,
|
| 186 |
+
initialize_chat_history=chat_manager.initialize_chat_history,
|
| 187 |
+
)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
|
| 190 |
+
|
| 191 |
+
return StreamingResponse(stream_generator, media_type="text/event-stream")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ----------------------- Add Document Endpoint -----------------------
|
| 195 |
+
from typing import Any, Optional
|
| 196 |
+
|
| 197 |
+
@app.post("/add_document")
|
| 198 |
+
async def add_document(
|
| 199 |
+
file: Optional[Any] = File(None),
|
| 200 |
+
wiki_query: Optional[str] = Query(None),
|
| 201 |
+
wiki_url: Optional[str] = Query(None)
|
| 202 |
+
):
|
| 203 |
+
"""
|
| 204 |
+
Upload a document OR load data from a Wikipedia query or URL.
|
| 205 |
+
|
| 206 |
+
- If a file is provided, the document is loaded from the file.
|
| 207 |
+
- If 'wiki_query' is provided, the Wikipedia page(s) are loaded using document_loader.wikipedia_query.
|
| 208 |
+
- If 'wiki_url' is provided, the URL is loaded using document_loader.load_urls.
|
| 209 |
+
|
| 210 |
+
The loaded document(s) are then split into chunks and added to the vector store.
|
| 211 |
+
"""
|
| 212 |
+
# If file is provided but not as an UploadFile (e.g. an empty string), set it to None.
|
| 213 |
+
if not isinstance(file, UploadFile):
|
| 214 |
+
file = None
|
| 215 |
+
|
| 216 |
+
# Ensure at least one input is provided.
|
| 217 |
+
if file is None and wiki_query is None and wiki_url is None:
|
| 218 |
+
raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).")
|
| 219 |
+
|
| 220 |
+
# Load document(s) based on input priority: file > wiki_query > wiki_url.
|
| 221 |
+
if file is not None:
|
| 222 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
| 223 |
+
contents = await file.read()
|
| 224 |
+
tmp.write(contents)
|
| 225 |
+
tmp_filename = tmp.name
|
| 226 |
+
|
| 227 |
+
ext = file.filename.split(".")[-1].lower()
|
| 228 |
+
try:
|
| 229 |
+
if ext == "pdf":
|
| 230 |
+
documents = document_loader.load_pdf(tmp_filename)
|
| 231 |
+
elif ext == "csv":
|
| 232 |
+
documents = document_loader.load_csv(tmp_filename)
|
| 233 |
+
elif ext in ["doc", "docx"]:
|
| 234 |
+
documents = document_loader.load_doc(tmp_filename)
|
| 235 |
+
elif ext in ["html", "htm"]:
|
| 236 |
+
documents = document_loader.load_text_from_html(tmp_filename)
|
| 237 |
+
elif ext in ["md", "markdown"]:
|
| 238 |
+
documents = document_loader.load_markdown(tmp_filename)
|
| 239 |
+
else:
|
| 240 |
+
documents = document_loader.load_unstructured(tmp_filename)
|
| 241 |
+
except Exception as e:
|
| 242 |
+
os.remove(tmp_filename)
|
| 243 |
+
raise HTTPException(status_code=400, detail=f"Error loading document from file: {str(e)}")
|
| 244 |
+
os.remove(tmp_filename)
|
| 245 |
+
elif wiki_query is not None:
|
| 246 |
+
try:
|
| 247 |
+
documents = document_loader.wikipedia_query(wiki_query)
|
| 248 |
+
except Exception as e:
|
| 249 |
+
raise HTTPException(status_code=400, detail=f"Error loading Wikipedia query: {str(e)}")
|
| 250 |
+
elif wiki_url is not None:
|
| 251 |
+
try:
|
| 252 |
+
documents = document_loader.load_urls([wiki_url])
|
| 253 |
+
except Exception as e:
|
| 254 |
+
raise HTTPException(status_code=400, detail=f"Error loading URL: {str(e)}")
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
chunks = text_splitter.split_documents(documents)
|
| 258 |
+
except Exception as e:
|
| 259 |
+
raise HTTPException(status_code=500, detail=f"Error splitting document: {str(e)}")
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
ids = vector_store_manager.add_documents(chunks)
|
| 263 |
+
except Exception as e:
|
| 264 |
+
raise HTTPException(status_code=500, detail=f"Error indexing document chunks: {str(e)}")
|
| 265 |
+
|
| 266 |
+
return {"message": f"Added {len(chunks)} document chunks.", "ids": ids}
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ----------------------- Delete Document Endpoint -----------------------
|
| 270 |
+
@app.post("/delete_document")
|
| 271 |
+
def delete_document(ids: List[str]):
|
| 272 |
+
"""
|
| 273 |
+
Delete document(s) from the vector store using their IDs.
|
| 274 |
+
"""
|
| 275 |
+
try:
|
| 276 |
+
success = vector_store_manager.delete_documents(ids)
|
| 277 |
+
except Exception as e:
|
| 278 |
+
raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}")
|
| 279 |
+
if not success:
|
| 280 |
+
raise HTTPException(status_code=400, detail="Failed to delete documents.")
|
| 281 |
+
return {"message": f"Deleted documents with IDs: {ids}"}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# ----------------------- Save Vectorstore Endpoint -----------------------
|
| 285 |
+
@app.get("/save_vectorstore")
|
| 286 |
+
def save_vectorstore():
|
| 287 |
+
"""
|
| 288 |
+
Save the current vector store locally.
|
| 289 |
+
If it is a directory, it will be zipped.
|
| 290 |
+
Returns the file as a downloadable response.
|
| 291 |
+
"""
|
| 292 |
+
try:
|
| 293 |
+
save_result = vector_store_manager.save("faiss_index")
|
| 294 |
+
except Exception as e:
|
| 295 |
+
raise HTTPException(status_code=500, detail=f"Error saving vectorstore: {str(e)}")
|
| 296 |
+
return FileResponse(
|
| 297 |
+
path=save_result["file_path"],
|
| 298 |
+
media_type=save_result["media_type"],
|
| 299 |
+
filename=save_result["serve_filename"],
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# ----------------------- Load Vectorstore Endpoint -----------------------
|
| 304 |
+
@app.post("/load_vectorstore")
|
| 305 |
+
async def load_vectorstore(file: UploadFile = File(...)):
|
| 306 |
+
"""
|
| 307 |
+
Load a vector store from an uploaded file (raw or zipped).
|
| 308 |
+
This will replace the current vector store.
|
| 309 |
+
"""
|
| 310 |
+
tmp_filename = None
|
| 311 |
+
try:
|
| 312 |
+
# Save the uploaded file content to a temporary file.
|
| 313 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
| 314 |
+
file_bytes = await file.read() # await to get bytes
|
| 315 |
+
tmp.write(file_bytes)
|
| 316 |
+
tmp_filename = tmp.name
|
| 317 |
+
|
| 318 |
+
instance, message = VectorStoreManager.load(tmp_filename, embeddings)
|
| 319 |
+
except Exception as e:
|
| 320 |
+
raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}")
|
| 321 |
+
finally:
|
| 322 |
+
if tmp_filename and os.path.exists(tmp_filename):
|
| 323 |
+
os.remove(tmp_filename)
|
| 324 |
+
global vector_store_manager
|
| 325 |
+
vector_store_manager = instance
|
| 326 |
+
return {"message": message}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# ----------------------- Merge Vectorstore Endpoint -----------------------
|
| 330 |
+
@app.post("/merge_vectorstore")
|
| 331 |
+
async def merge_vectorstore(file: UploadFile = File(...)):
|
| 332 |
+
"""
|
| 333 |
+
Merge an uploaded vector store (raw or zipped) into the current vector store.
|
| 334 |
+
"""
|
| 335 |
+
tmp_filename = None
|
| 336 |
+
try:
|
| 337 |
+
# Save the uploaded file content to a temporary file.
|
| 338 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
| 339 |
+
file_bytes = await file.read() # Await the file.read() coroutine!
|
| 340 |
+
tmp.write(file_bytes)
|
| 341 |
+
tmp_filename = tmp.name
|
| 342 |
+
|
| 343 |
+
# Pass the filename (a string) to the merge method.
|
| 344 |
+
result = vector_store_manager.merge(tmp_filename, embeddings)
|
| 345 |
+
except Exception as e:
|
| 346 |
+
raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}")
|
| 347 |
+
finally:
|
| 348 |
+
if tmp_filename and os.path.exists(tmp_filename):
|
| 349 |
+
os.remove(tmp_filename)
|
| 350 |
+
return result
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
import uvicorn
|
| 355 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
prompt_templates.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.prompts import ChatPromptTemplate
|
| 2 |
+
|
| 3 |
+
class PromptTemplates:
|
| 4 |
+
"""
|
| 5 |
+
A class to encapsulate various prompt templates for solving assignments, papers, creating quizzes, and assignments.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def get_quiz_solving_prompt():
|
| 10 |
+
|
| 11 |
+
quiz_solving_prompt = '''
|
| 12 |
+
You are an assistant specialized in solving quizzes. Your goal is to provide accurate, concise, and contextually relevant answers.
|
| 13 |
+
Use the following retrieved context to answer the user's question.
|
| 14 |
+
If the context lacks sufficient information, respond with "I don't know." Do not make up answers or provide unverified information.
|
| 15 |
+
|
| 16 |
+
Guidelines:
|
| 17 |
+
1. Extract key information from the context to form a coherent response.
|
| 18 |
+
2. Maintain a clear and professional tone.
|
| 19 |
+
3. If the question requires clarification, specify it politely.
|
| 20 |
+
|
| 21 |
+
Retrieved context:
|
| 22 |
+
{context}
|
| 23 |
+
|
| 24 |
+
User's question:
|
| 25 |
+
{question}
|
| 26 |
+
|
| 27 |
+
Your response:
|
| 28 |
+
'''
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Create a prompt template to pass the context and user input to the chain
|
| 32 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 33 |
+
[
|
| 34 |
+
("system", quiz_solving_prompt),
|
| 35 |
+
("human", "{question}"),
|
| 36 |
+
]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return prompt
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def get_assignment_solving_prompt():
|
| 43 |
+
# Prompt template for solving assignments
|
| 44 |
+
assignment_solving_prompt = '''
|
| 45 |
+
You are an expert assistant specializing in solving academic assignments with clarity and precision.
|
| 46 |
+
Your task is to provide step-by-step solutions and detailed explanations that align with the given requirements.
|
| 47 |
+
|
| 48 |
+
Retrieved context:
|
| 49 |
+
{context}
|
| 50 |
+
|
| 51 |
+
Assignment Details:
|
| 52 |
+
{question}
|
| 53 |
+
|
| 54 |
+
Guidelines:
|
| 55 |
+
1. **Understand the Problem:** Carefully analyze the assignment details to identify the objective and requirements.
|
| 56 |
+
2. **Provide a Step-by-Step Solution:** Break down the solution into clear, logical steps. Use examples where appropriate.
|
| 57 |
+
3. **Explain Your Reasoning:** Include concise explanations for each step to enhance understanding.
|
| 58 |
+
4. **Follow Formatting Rules:** Ensure the response matches any specified formatting or citation guidelines.
|
| 59 |
+
5. **Maintain Academic Integrity:** Do not fabricate information, copy content verbatim without attribution, or complete the task in a way that breaches academic honesty policies.
|
| 60 |
+
|
| 61 |
+
Deliverable:
|
| 62 |
+
Provide the final answer in the format outlined in the assignment description. Where relevant, include:
|
| 63 |
+
- A brief introduction summarizing the approach.
|
| 64 |
+
- Calculations or code (if applicable).
|
| 65 |
+
- Any necessary diagrams, tables, or figures (use textual descriptions for diagrams if unavailable).
|
| 66 |
+
- A conclusion summarizing the findings.
|
| 67 |
+
|
| 68 |
+
If the assignment details are incomplete or ambiguous, specify what additional information is required to proceed.
|
| 69 |
+
|
| 70 |
+
Assignment Response:
|
| 71 |
+
'''
|
| 72 |
+
|
| 73 |
+
# Create a prompt template to pass the context and user input to the chain
|
| 74 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 75 |
+
[
|
| 76 |
+
("system", assignment_solving_prompt),
|
| 77 |
+
("human", "{question}"),
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return prompt
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def get_paper_solving_prompt():
|
| 86 |
+
# Prompt template for solving papers
|
| 87 |
+
paper_solving_prompt = '''
|
| 88 |
+
You are an expert assistant specialized in solving academic papers with precision and clarity.
|
| 89 |
+
Your task is to provide well-structured answers to the questions in the paper, ensuring accuracy, depth, and adherence to any specified instructions.
|
| 90 |
+
|
| 91 |
+
Retrieved context:
|
| 92 |
+
{context}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
Paper Information:
|
| 96 |
+
{question}
|
| 97 |
+
|
| 98 |
+
Instructions:
|
| 99 |
+
1. **Understand Each Question:** Read each question carefully and identify its requirements, keywords, and scope.
|
| 100 |
+
2. **Structured Responses:** Provide answers in a clear, logical structure (e.g., Introduction, Body, Conclusion).
|
| 101 |
+
3. **Depth and Accuracy:** Support answers with explanations, examples, calculations, or references where applicable.
|
| 102 |
+
4. **Formatting Guidelines:** Adhere to any specified format or style (e.g., bullet points, paragraphs, equations).
|
| 103 |
+
5. **Time Efficiency:** If the paper is timed, prioritize accuracy and completeness over excessive detail.
|
| 104 |
+
6. **Clarify Ambiguities:** If any question is unclear, mention the assumptions made while answering.
|
| 105 |
+
7. **Ethical Guidelines:** Ensure the answers are original and aligned with academic integrity standards.
|
| 106 |
+
|
| 107 |
+
Deliverables:
|
| 108 |
+
- Answer all questions to the best of your ability.
|
| 109 |
+
- Include relevant diagrams, tables, or code (describe diagrams in text if unavailable).
|
| 110 |
+
- Summarize key points in a conclusion where applicable.
|
| 111 |
+
- Clearly number and label answers to match the questions in the paper.
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
If the paper includes multiple sections, label each section and solve sequentially.
|
| 115 |
+
|
| 116 |
+
Paper Solution:
|
| 117 |
+
'''
|
| 118 |
+
|
| 119 |
+
# Create a prompt template to pass the context and user input to the chain
|
| 120 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 121 |
+
[
|
| 122 |
+
("system", paper_solving_prompt),
|
| 123 |
+
("human", "{question}"),
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return prompt
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def get_quiz_creation_prompt():
|
| 131 |
+
# Prompt template for creating a quiz
|
| 132 |
+
quiz_creation_prompt = '''
|
| 133 |
+
You are an expert assistant specializing in creating engaging and educational quizzes for students.
|
| 134 |
+
Your task is to design a quiz based on the topic, difficulty level, and format specified by the teacher.
|
| 135 |
+
|
| 136 |
+
Retrieved context:
|
| 137 |
+
{context}
|
| 138 |
+
|
| 139 |
+
Quiz Details:
|
| 140 |
+
Topic: {question}
|
| 141 |
+
|
| 142 |
+
Guidelines for Quiz Creation:
|
| 143 |
+
1. **Relevance to Topic:** Ensure all questions are directly related to the specified topic.
|
| 144 |
+
2. **Clear and Concise Wording:** Write questions clearly and concisely to avoid ambiguity.
|
| 145 |
+
3. **Diverse Question Types:** Incorporate a variety of question types if specified.
|
| 146 |
+
4. **Appropriate Difficulty:** Tailor the complexity of the questions to match the target audience and difficulty level.
|
| 147 |
+
5. **Answer Key:** Provide correct answers or explanations for each question.
|
| 148 |
+
|
| 149 |
+
Deliverables:
|
| 150 |
+
- A complete quiz with numbered questions.
|
| 151 |
+
- An answer key with correct answers and explanations where relevant.
|
| 152 |
+
|
| 153 |
+
Quiz:
|
| 154 |
+
'''
|
| 155 |
+
|
| 156 |
+
# Create a prompt template to pass the context and user input to the chain
|
| 157 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 158 |
+
[
|
| 159 |
+
("system", quiz_creation_prompt),
|
| 160 |
+
("human", "{question}"),
|
| 161 |
+
]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return prompt
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@staticmethod
|
| 168 |
+
def get_assignment_creation_prompt():
|
| 169 |
+
# Prompt template for creating an assignment
|
| 170 |
+
assignment_creation_prompt = '''
|
| 171 |
+
You are an expert assistant specializing in designing assignments that align with the educational goals and requirements of teachers.
|
| 172 |
+
Your task is to create a comprehensive assignment based on the provided topic, target audience, and desired outcomes.
|
| 173 |
+
|
| 174 |
+
Retrieved context:
|
| 175 |
+
{context}
|
| 176 |
+
|
| 177 |
+
Assignment Details:
|
| 178 |
+
Topic: {question}
|
| 179 |
+
|
| 180 |
+
Guidelines for Assignment Creation:
|
| 181 |
+
1. **Alignment with Topic:** Ensure all tasks/questions are closely related to the specified topic and designed to achieve the teacher’s learning objectives.
|
| 182 |
+
2. **Clear Instructions:** Provide detailed and clear instructions for each question or task.
|
| 183 |
+
3. **Encourage Critical Thinking:** Include questions or tasks that require analysis, creativity, and application of knowledge where appropriate.
|
| 184 |
+
4. **Variety of Tasks:** Incorporate diverse question types (e.g., short answers, essays, practical tasks) as per the specified format.
|
| 185 |
+
5. **Grading Rubric (Optional):** Include a grading rubric or evaluation criteria if specified in the instructions.
|
| 186 |
+
|
| 187 |
+
Deliverables:
|
| 188 |
+
- A detailed assignment with numbered tasks/questions.
|
| 189 |
+
- Any required supporting materials (e.g., diagrams, data tables, references).
|
| 190 |
+
- (Optional) A grading rubric or expected outcomes for each task.
|
| 191 |
+
|
| 192 |
+
Assignment:
|
| 193 |
+
'''
|
| 194 |
+
|
| 195 |
+
# Create a prompt template to pass the context and user input to the chain
|
| 196 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 197 |
+
[
|
| 198 |
+
("system", assignment_creation_prompt),
|
| 199 |
+
("human", "{question}"),
|
| 200 |
+
]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
return prompt
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def get_paper_creation_prompt():
|
| 208 |
+
# Prompt template for creating an academic paper
|
| 209 |
+
paper_creation_prompt = '''
|
| 210 |
+
You are an expert assistant specializing in designing comprehensive academic papers tailored to the educational goals and requirements of teachers.
|
| 211 |
+
Your task is to create a complete paper based on the specified topic, audience, format, and difficulty level.
|
| 212 |
+
|
| 213 |
+
Retrieved context:
|
| 214 |
+
{context}
|
| 215 |
+
|
| 216 |
+
Paper Details:
|
| 217 |
+
Subject/Topic: {question}
|
| 218 |
+
|
| 219 |
+
Guidelines for Paper Creation:
|
| 220 |
+
1. **Relevance and Alignment:** Ensure all questions align with the specified subject/topic and are tailored to the target audience’s curriculum or learning objectives.
|
| 221 |
+
2. **Clear Wording:** Write questions in clear, concise language to avoid ambiguity or confusion.
|
| 222 |
+
3. **Diverse Question Types:** Incorporate a variety of question formats as specified (e.g., multiple-choice, fill-in-the-blank, long-form essays).
|
| 223 |
+
4. **Grading and Marks Allocation:** Provide a suggested mark allocation for each question, ensuring it reflects the question's complexity and time required.
|
| 224 |
+
5. **Answer Key:** Include correct answers or model responses for objective and descriptive questions (optional).
|
| 225 |
+
|
| 226 |
+
Deliverables:
|
| 227 |
+
- A complete paper with numbered questions, organized by sections if required.
|
| 228 |
+
- An answer key or marking scheme (if requested).
|
| 229 |
+
- Any supporting materials (e.g., diagrams, charts, or data tables) if applicable.
|
| 230 |
+
|
| 231 |
+
Paper:
|
| 232 |
+
'''
|
| 233 |
+
|
| 234 |
+
# Create a prompt template to pass the context and user input to the chain
|
| 235 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 236 |
+
[
|
| 237 |
+
("system", paper_creation_prompt),
|
| 238 |
+
("human", "{question}"),
|
| 239 |
+
]
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
return prompt
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
# python-jose
|
| 3 |
+
python-dotenv
|
| 4 |
+
# bcrypt
|
| 5 |
+
# passlib
|
| 6 |
+
uvicorn
|
| 7 |
+
# pyjwt
|
| 8 |
+
python-multipart
|
| 9 |
+
# pydantic[email]
|
| 10 |
+
pymongo
|
| 11 |
+
faiss-cpu
|
| 12 |
+
sentence_transformers
|
| 13 |
+
langchain_groq
|
| 14 |
+
langchain-community
|
| 15 |
+
langchain_unstructured
|
| 16 |
+
unstructured[all-docs]
|
| 17 |
+
unstructured[docx]
|
| 18 |
+
unstructured
|
| 19 |
+
unstructured[pdf]
|
| 20 |
+
langchain-mongodb
|
| 21 |
+
langchain_huggingface
|
| 22 |
+
wikipedia
|
| 23 |
+
docx2txt
|
retrieval_chain.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 2 |
+
from langchain.prompts import ChatPromptTemplate
|
| 3 |
+
|
| 4 |
+
class RetrievalChain:
|
| 5 |
+
def __init__(self, llm, retriever, user_prompt, verbose=False):
|
| 6 |
+
"""
|
| 7 |
+
Initializes the RetrievalChain with an LLM and retriever.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
llm: Language model to use for the conversational chain.
|
| 11 |
+
retriever: Retriever object to fetch relevant documents.
|
| 12 |
+
user_prompt: Custom prompt to guide the chain.
|
| 13 |
+
verbose (bool): Whether to print verbose chain outputs.
|
| 14 |
+
"""
|
| 15 |
+
self.llm = llm
|
| 16 |
+
|
| 17 |
+
self.chain = ConversationalRetrievalChain.from_llm(
|
| 18 |
+
llm=llm,
|
| 19 |
+
retriever=retriever,
|
| 20 |
+
return_source_documents=True,
|
| 21 |
+
chain_type='stuff',
|
| 22 |
+
combine_docs_chain_kwargs={"prompt": user_prompt},
|
| 23 |
+
verbose=verbose,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def summarize_messages(self, chat_history):
|
| 27 |
+
"""
|
| 28 |
+
Summarizes the chat history into a concise message.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
chat_history: The chat history object for the session.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
bool: True if summarization is successful, False otherwise.
|
| 35 |
+
"""
|
| 36 |
+
stored_messages = chat_history.messages
|
| 37 |
+
if len(stored_messages) == 0:
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
summarization_prompt = ChatPromptTemplate.from_messages(
|
| 41 |
+
[
|
| 42 |
+
("placeholder", "{chat_history}"),
|
| 43 |
+
(
|
| 44 |
+
"human",
|
| 45 |
+
"Summarize the above chat messages into a single concise message. Include only the important specific details.",
|
| 46 |
+
),
|
| 47 |
+
]
|
| 48 |
+
)
|
| 49 |
+
# Create a chain for summarization by piping the prompt into the language model.
|
| 50 |
+
summarization_chain = summarization_prompt | self.llm
|
| 51 |
+
summary_message = summarization_chain.invoke({"chat_history": stored_messages})
|
| 52 |
+
|
| 53 |
+
chat_history.clear() # Clear the existing chat history
|
| 54 |
+
chat_history.add_ai_message(summary_message.content) # Add the summary message as the first entry
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
def stream_chat_response(self, query, chat_id, get_chat_history, initialize_chat_history):
|
| 58 |
+
"""
|
| 59 |
+
Streams the response to a query in real-time for a given chat session using SSE formatting.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
query (str): The user's query.
|
| 63 |
+
chat_id (str): The unique ID of the chat session.
|
| 64 |
+
get_chat_history (function): Function to retrieve chat history by chat ID.
|
| 65 |
+
initialize_chat_history (function): Function to initialize a new chat history.
|
| 66 |
+
|
| 67 |
+
Yields:
|
| 68 |
+
str: Server-Sent Event (SSE) formatted string for each chunk of the response.
|
| 69 |
+
"""
|
| 70 |
+
# Retrieve the chat history for the session.
|
| 71 |
+
chat_message_history = get_chat_history(chat_id)
|
| 72 |
+
if not chat_message_history:
|
| 73 |
+
# If no chat history exists, initialize one.
|
| 74 |
+
chat_message_history = initialize_chat_history(chat_id)
|
| 75 |
+
|
| 76 |
+
# Optionally summarize previous messages.
|
| 77 |
+
self.summarize_messages(chat_message_history)
|
| 78 |
+
chat_history = chat_message_history.messages
|
| 79 |
+
|
| 80 |
+
# Prepare input data for the conversational retrieval chain.
|
| 81 |
+
input_data_for_chain = {
|
| 82 |
+
"question": query,
|
| 83 |
+
"chat_history": chat_history
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Add the user query to the chat history.
|
| 87 |
+
chat_message_history.add_user_message(query)
|
| 88 |
+
|
| 89 |
+
# Execute the chain in streaming mode (this assumes the chain supports a `stream` method).
|
| 90 |
+
response_stream = self.chain.stream(input_data_for_chain)
|
| 91 |
+
|
| 92 |
+
accumulated_response = ""
|
| 93 |
+
# Process the response stream and yield SSE events.
|
| 94 |
+
for chunk in response_stream:
|
| 95 |
+
if 'answer' in chunk:
|
| 96 |
+
accumulated_response += chunk['answer']
|
| 97 |
+
# Format the SSE event.
|
| 98 |
+
sse_event = f"data: {chunk['answer']}\n\n"
|
| 99 |
+
yield sse_event
|
| 100 |
+
else:
|
| 101 |
+
# Yield an SSE event with debug info if the chunk structure is unexpected.
|
| 102 |
+
debug_msg = f"Unexpected chunk structure: {chunk}"
|
| 103 |
+
yield f"data: {debug_msg}\n\n"
|
| 104 |
+
|
| 105 |
+
# Once streaming is complete, update chat history with the final response.
|
| 106 |
+
if accumulated_response:
|
| 107 |
+
chat_message_history.add_ai_message(accumulated_response)
|
| 108 |
+
else:
|
| 109 |
+
yield "data: No valid response content was generated.\n\n"
|
| 110 |
+
|
text_splitter.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 2 |
+
|
| 3 |
+
class TextSplitter:
|
| 4 |
+
def __init__(self, chunk_size=1024, chunk_overlap=100):
|
| 5 |
+
"""
|
| 6 |
+
Initialize the TextSplitter with a specific chunk size and overlap.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
chunk_size (int): The size of each text chunk.
|
| 10 |
+
chunk_overlap (int): The overlap size between chunks.
|
| 11 |
+
"""
|
| 12 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 13 |
+
|
| 14 |
+
def split_documents(self, documents):
|
| 15 |
+
"""
|
| 16 |
+
Split the provided documents into chunks based on the chunk size and overlap.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
documents (list): A list of documents to be split.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
A list of split documents.
|
| 23 |
+
|
| 24 |
+
Exceptions:
|
| 25 |
+
Prints an error message if splitting documents fails.
|
| 26 |
+
"""
|
| 27 |
+
try:
|
| 28 |
+
return self.text_splitter.split_documents(documents)
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Error splitting documents: {e}")
|
vector_store.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import shutil
|
| 4 |
+
import tempfile
|
| 5 |
+
import zipfile
|
| 6 |
+
|
| 7 |
+
from faiss import IndexFlatL2
|
| 8 |
+
|
| 9 |
+
from langchain_community.vectorstores import FAISS
|
| 10 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VectorStoreManager:
|
| 14 |
+
def __init__(self, embeddings=None):
|
| 15 |
+
"""
|
| 16 |
+
Initializes the VectorStoreManager with a FAISS vector store.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
embeddings (Embeddings, optional): Embeddings model used for the vector store.
|
| 20 |
+
"""
|
| 21 |
+
self.vectorstore = None
|
| 22 |
+
if embeddings:
|
| 23 |
+
self.vectorstore = self.create_vectorstore(embeddings)
|
| 24 |
+
|
| 25 |
+
def create_vectorstore(self, embeddings):
|
| 26 |
+
"""
|
| 27 |
+
Creates and initializes a FAISS vector store.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
embeddings (Embeddings): Embeddings model used for the vector store.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
FAISS: Initialized vector store.
|
| 34 |
+
"""
|
| 35 |
+
# Define vector store dimensions based on embeddings
|
| 36 |
+
dimensions = len(embeddings.embed_query("dummy"))
|
| 37 |
+
|
| 38 |
+
# Initialize FAISS vector store
|
| 39 |
+
vectorstore = FAISS(
|
| 40 |
+
embedding_function=embeddings,
|
| 41 |
+
index=IndexFlatL2(dimensions),
|
| 42 |
+
docstore=InMemoryDocstore(),
|
| 43 |
+
index_to_docstore_id={},
|
| 44 |
+
normalize_L2=False
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
print("Created a new FAISS vector store.")
|
| 48 |
+
return vectorstore
|
| 49 |
+
|
| 50 |
+
def add_documents(self, documents):
|
| 51 |
+
"""
|
| 52 |
+
Adds new documents to the FAISS vector store, each document with a unique UUID.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
documents (list): List of Document objects to be added to the vector store.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
list: List of UUIDs corresponding to the added documents.
|
| 59 |
+
"""
|
| 60 |
+
if not self.vectorstore:
|
| 61 |
+
raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
|
| 62 |
+
|
| 63 |
+
uuids = [str(uuid.uuid4()) for _ in range(len(documents))]
|
| 64 |
+
self.vectorstore.add_documents(documents=documents, ids=uuids)
|
| 65 |
+
|
| 66 |
+
print(f"Added {len(documents)} documents to the vector store with IDs: {uuids}")
|
| 67 |
+
return uuids
|
| 68 |
+
|
| 69 |
+
def delete_documents(self, ids):
|
| 70 |
+
"""
|
| 71 |
+
Deletes documents from the FAISS vector store using their unique IDs.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
ids (list): List of UUIDs corresponding to the documents to be deleted.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
bool: True if the documents were successfully deleted, False otherwise.
|
| 78 |
+
"""
|
| 79 |
+
if not self.vectorstore:
|
| 80 |
+
raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
|
| 81 |
+
|
| 82 |
+
if not ids:
|
| 83 |
+
print("No document IDs provided for deletion.")
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
success = self.vectorstore.delete(ids=ids)
|
| 87 |
+
if success:
|
| 88 |
+
print(f"Successfully deleted documents with IDs: {ids}")
|
| 89 |
+
else:
|
| 90 |
+
print(f"Failed to delete documents with IDs: {ids}")
|
| 91 |
+
return success
|
| 92 |
+
|
| 93 |
+
def save(self, filename="faiss_index"):
|
| 94 |
+
"""
|
| 95 |
+
Saves the current FAISS vector store locally. If the saved store is a directory,
|
| 96 |
+
it compresses it into a ZIP archive.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
filename (str): The filename or directory name where the vector store will be saved.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
dict: A dictionary with details about the saved file including file path and media type.
|
| 103 |
+
"""
|
| 104 |
+
if not self.vectorstore:
|
| 105 |
+
raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
|
| 106 |
+
|
| 107 |
+
# Save the vectorstore locally
|
| 108 |
+
self.vectorstore.save_local(filename)
|
| 109 |
+
print(f"Vector store saved to {filename}")
|
| 110 |
+
|
| 111 |
+
if not os.path.exists(filename):
|
| 112 |
+
raise FileNotFoundError("Saved vectorstore not found.")
|
| 113 |
+
|
| 114 |
+
# If the saved vectorstore is a directory, compress it into a zip file.
|
| 115 |
+
if os.path.isdir(filename):
|
| 116 |
+
zip_filename = filename + ".zip"
|
| 117 |
+
shutil.make_archive(filename, 'zip', filename)
|
| 118 |
+
return {
|
| 119 |
+
"file_path": zip_filename,
|
| 120 |
+
"media_type": "application/zip",
|
| 121 |
+
"serve_filename": os.path.basename(zip_filename),
|
| 122 |
+
"original": filename,
|
| 123 |
+
}
|
| 124 |
+
else:
|
| 125 |
+
return {
|
| 126 |
+
"file_path": filename,
|
| 127 |
+
"media_type": "application/octet-stream",
|
| 128 |
+
"serve_filename": os.path.basename(filename),
|
| 129 |
+
"original": filename,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def load(file_input, embeddings):
|
| 134 |
+
"""
|
| 135 |
+
Loads a FAISS vector store from an uploaded file or a filename.
|
| 136 |
+
If file_input is a file-like object, it is saved to a temporary file.
|
| 137 |
+
If it's a string (filename), it is used directly.
|
| 138 |
+
"""
|
| 139 |
+
# Check if file_input is a string (filename) or a file-like object.
|
| 140 |
+
if isinstance(file_input, str):
|
| 141 |
+
tmp_filename = file_input
|
| 142 |
+
else:
|
| 143 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
| 144 |
+
tmp.write(file_input.read())
|
| 145 |
+
tmp_filename = tmp.name
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
if zipfile.is_zipfile(tmp_filename):
|
| 149 |
+
with tempfile.TemporaryDirectory() as extract_dir:
|
| 150 |
+
with zipfile.ZipFile(tmp_filename, 'r') as zip_ref:
|
| 151 |
+
zip_ref.extractall(extract_dir)
|
| 152 |
+
extracted_items = os.listdir(extract_dir)
|
| 153 |
+
if len(extracted_items) == 1:
|
| 154 |
+
potential_dir = os.path.join(extract_dir, extracted_items[0])
|
| 155 |
+
if os.path.isdir(potential_dir):
|
| 156 |
+
vectorstore_dir = potential_dir
|
| 157 |
+
else:
|
| 158 |
+
vectorstore_dir = extract_dir
|
| 159 |
+
else:
|
| 160 |
+
vectorstore_dir = extract_dir
|
| 161 |
+
|
| 162 |
+
new_vectorstore = FAISS.load_local(vectorstore_dir, embeddings, allow_dangerous_deserialization=True)
|
| 163 |
+
message = "Vector store loaded successfully from ZIP."
|
| 164 |
+
else:
|
| 165 |
+
new_vectorstore = FAISS.load_local(tmp_filename, embeddings, allow_dangerous_deserialization=True)
|
| 166 |
+
message = "Vector store loaded successfully."
|
| 167 |
+
except Exception as e:
|
| 168 |
+
raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}")
|
| 169 |
+
finally:
|
| 170 |
+
# Only remove the temp file if we created it here (i.e. file_input was not a filename)
|
| 171 |
+
if not isinstance(file_input, str) and os.path.exists(tmp_filename):
|
| 172 |
+
os.remove(tmp_filename)
|
| 173 |
+
|
| 174 |
+
instance = VectorStoreManager()
|
| 175 |
+
instance.vectorstore = new_vectorstore
|
| 176 |
+
print(message)
|
| 177 |
+
return instance, message
|
| 178 |
+
|
| 179 |
+
def merge(self, file_input, embeddings):
|
| 180 |
+
"""
|
| 181 |
+
Merges an uploaded vector store file into the current FAISS vector store.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
file_input (Union[file-like object, str]): An object with a .read() method or a filename (str).
|
| 185 |
+
embeddings (Embeddings): Embeddings model used for loading the vector store.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
dict: A dictionary containing a message indicating successful merging.
|
| 189 |
+
"""
|
| 190 |
+
# Determine if file_input is a filename (str) or a file-like object.
|
| 191 |
+
if isinstance(file_input, str):
|
| 192 |
+
tmp_filename = file_input
|
| 193 |
+
temp_created = False
|
| 194 |
+
else:
|
| 195 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
| 196 |
+
tmp.write(file_input.read())
|
| 197 |
+
tmp_filename = tmp.name
|
| 198 |
+
temp_created = True
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
# Check if the file is a ZIP archive.
|
| 202 |
+
if zipfile.is_zipfile(tmp_filename):
|
| 203 |
+
with tempfile.TemporaryDirectory() as extract_dir:
|
| 204 |
+
with zipfile.ZipFile(tmp_filename, 'r') as zip_ref:
|
| 205 |
+
zip_ref.extractall(extract_dir)
|
| 206 |
+
extracted_items = os.listdir(extract_dir)
|
| 207 |
+
if len(extracted_items) == 1:
|
| 208 |
+
potential_dir = os.path.join(extract_dir, extracted_items[0])
|
| 209 |
+
if os.path.isdir(potential_dir):
|
| 210 |
+
vectorstore_dir = potential_dir
|
| 211 |
+
else:
|
| 212 |
+
vectorstore_dir = extract_dir
|
| 213 |
+
else:
|
| 214 |
+
vectorstore_dir = extract_dir
|
| 215 |
+
|
| 216 |
+
source_store = FAISS.load_local(
|
| 217 |
+
vectorstore_dir, embeddings, allow_dangerous_deserialization=True
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
source_store = FAISS.load_local(
|
| 221 |
+
tmp_filename, embeddings, allow_dangerous_deserialization=True
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if not self.vectorstore:
|
| 225 |
+
raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
|
| 226 |
+
|
| 227 |
+
self.vectorstore.merge_from(source_store)
|
| 228 |
+
print("Successfully merged the source vector store into the current vector store.")
|
| 229 |
+
except Exception as e:
|
| 230 |
+
raise Exception(f"Error merging vectorstore: {str(e)}")
|
| 231 |
+
finally:
|
| 232 |
+
if temp_created and os.path.exists(tmp_filename):
|
| 233 |
+
os.remove(tmp_filename)
|
| 234 |
+
return {"message": "Vector stores merged successfully"}
|