ashishbangwal commited on
Commit
6ce472c
·
1 Parent(s): 9178a11
Files changed (8) hide show
  1. .gitignore +1 -0
  2. Dockerfile +15 -0
  3. components/Database.py +136 -0
  4. components/LLM.py +181 -0
  5. components/prompts.json +3 -0
  6. components/utils.py +173 -0
  7. main.py +235 -0
  8. requirements.txt +108 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .venv
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10.9
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the current directory contents into the container at /app
8
+ COPY requirements.txt /app
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ COPY . /app
14
+
15
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
components/Database.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contain Wrapper Class for ChormaDB client, that can process and store documents and retrive document chunks.
3
+ """
4
+
5
+ from io import BytesIO
6
+ from typing import List
7
+ from typing_extensions import Literal
8
+ import uuid
9
+ import warnings
10
+ import chromadb
11
+ import re
12
+ from .utils import (
13
+ generate_file_id,
14
+ chunk_document,
15
+ generate_embedding,
16
+ extract_content_from_docx,
17
+ extract_content_from_pdf,
18
+ )
19
+
20
+
21
+ class AdvancedClient:
22
+
23
+ def __init__(self, vector_database_path: str = "vectorDB") -> None:
24
+ self.client = chromadb.PersistentClient(path=vector_database_path)
25
+ self.exsisting_collections = [
26
+ collection.name for collection in self.client.list_collections()
27
+ ]
28
+ self.selected_collections = []
29
+
30
+ def create_or_get_collection(
31
+ self,
32
+ file_names: List[str],
33
+ file_types: List[Literal["pdf", "docx"]],
34
+ file_datas,
35
+ ):
36
+ collections = []
37
+ for data in zip(file_names, file_types, file_datas):
38
+
39
+ file_name, file_type, file_data = data
40
+ file_id = generate_file_id(file_bytes=file_data)
41
+ file_exisis = file_id in self.exsisting_collections
42
+
43
+ if file_exisis:
44
+ collection = file_id
45
+
46
+ else:
47
+ collection = self.client.create_collection(name=file_id)
48
+ file_buffer = BytesIO(file_data)
49
+
50
+ if file_type == "pdf":
51
+ document, pil_images = extract_content_from_pdf(file_buffer)
52
+ chunks = chunk_document(document)
53
+ ids = [f"{uuid.uuid4()}_id_{x}" for x in range(1, len(chunks) + 1)]
54
+ embeddings = generate_embedding(
55
+ chunks, embedding_model="znbang/bge:small-en-v1.5-q8_0"
56
+ )
57
+ metadatas = []
58
+
59
+ for chunk in chunks:
60
+ imgs_found = re.findall(
61
+ pattern=r"<img\s+src='([^']*)'>", string=chunk
62
+ )
63
+ chunk_imgs = []
64
+ if len(imgs_found) > 0:
65
+ for img in imgs_found:
66
+ chunk_imgs.append(pil_images[int(img)])
67
+ metadatas.append(
68
+ {"images": str(chunk_imgs), "file_name": file_name}
69
+ )
70
+
71
+ elif file_type == "docx":
72
+ document = extract_content_from_docx(file_buffer)
73
+ chunks = chunk_document(document)
74
+ ids = [f"{uuid.uuid4()}_id_{x}" for x in range(1, len(chunks) + 1)]
75
+
76
+ embeddings = generate_embedding(
77
+ chunks, embedding_model="znbang/bge:small-en-v1.5-q8_0"
78
+ )
79
+ metadatas = [{"file_name": file_name} for _ in chunks]
80
+
81
+ else:
82
+ raise Exception(
83
+ f"Given format '.{file_type}' is currently not supported."
84
+ )
85
+
86
+ collection.add(
87
+ ids=ids,
88
+ embeddings=embeddings, # type: ignore
89
+ documents=chunks,
90
+ metadatas=metadatas, # type: ignore
91
+ )
92
+ collection = file_id
93
+ collections.append(collection)
94
+
95
+ self.selected_collections = collections
96
+
97
+ def retrieve_chunks(self, query: str, number_of_chunks: int = 3):
98
+ if len(self.selected_collections) == 0:
99
+
100
+ warnings.warn(
101
+ message=f"No collection is selected using all the exsisting collections, total collections : {len(self.exsisting_collections)}"
102
+ )
103
+ collections = [self.client.get_collection("UNION")]
104
+ self.selected_collections = collections
105
+ else:
106
+ collections = [
107
+ self.client.get_collection(collection_name)
108
+ for collection_name in self.selected_collections
109
+ ]
110
+
111
+ query_emb = generate_embedding(
112
+ [query], embedding_model="znbang/bge:small-en-v1.5-q8_0"
113
+ )
114
+
115
+ retrieved_docs = []
116
+
117
+ for collection in collections:
118
+ results = collection.query(
119
+ query_embeddings=query_emb,
120
+ n_results=5,
121
+ include=["documents", "metadatas", "distances"],
122
+ )
123
+
124
+ for i in range(len(results["ids"][0])):
125
+ retrieved_docs.append(
126
+ {
127
+ "document": results["documents"][0][i],
128
+ "metadata": results["metadatas"][0][i],
129
+ "distance": results["distances"][0][i],
130
+ "collection": collection.name,
131
+ }
132
+ )
133
+
134
+ retrieved_docs = sorted(retrieved_docs, key=lambda x: x["distance"])
135
+
136
+ return retrieved_docs[:number_of_chunks]
components/LLM.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contain Classes for LLM inference for RAG pipeline.
3
+ """
4
+
5
+ ### ** Make input output tokens as class properties **
6
+
7
+ from openai import OpenAI
8
+ from .utils import count_tokens
9
+ import json
10
+
11
+
12
+ class rLLM:
13
+ def __init__(self, llm_name: str, api_key: str) -> None:
14
+ self.llm_name = llm_name
15
+ self.llm_client = OpenAI(
16
+ api_key=api_key, base_url="https://api.together.xyz/v1"
17
+ )
18
+ with open("components/prompts.json", "r") as file:
19
+ SysPrompt = json.load(file)["SysPrompt"]
20
+ self.sys_prompt = SysPrompt
21
+
22
+ def generate_rag_response(self, context: str, prompt: str, message_history):
23
+ """
24
+ Generates a natural language response for user query(prompt) based on provided
25
+ context and message history, in Q&A style.
26
+ """
27
+
28
+ system_prompt = self.sys_prompt
29
+
30
+ messages = [
31
+ {"role": "system", "content": system_prompt},
32
+ ]
33
+
34
+ for message in message_history[-6:-1]:
35
+ if message["role"] == "assistant":
36
+ messages.append({"role": "assistant", "content": message["content"]})
37
+ else:
38
+ messages.append(message)
39
+
40
+ messages.append(
41
+ {"role": "user", "content": f"CONTEXT:\n{context}QUERY:\n{prompt}"},
42
+ )
43
+
44
+ stream = self.llm_client.chat.completions.create(
45
+ model=self.llm_name,
46
+ messages=messages,
47
+ stream=True,
48
+ )
49
+
50
+ output = ""
51
+ for chunk in stream:
52
+ if chunk.choices[0].delta.content is not None:
53
+ content = chunk.choices[0].delta.content
54
+ output += content
55
+ yield 0, content
56
+
57
+ input_token_count = count_tokens(
58
+ string="\n".join([i["content"] for i in messages])
59
+ )
60
+
61
+ output_token_count = count_tokens(string=output)
62
+
63
+ yield 1, (output, input_token_count, output_token_count)
64
+
65
+ def HyDE(self, query: str, message_history):
66
+ """
67
+ Rephare/rewrite the user query to include more semantics, hence improving the
68
+ semantic search based retreival.
69
+ """
70
+
71
+ system_prompt = """You are an AI assistant specifically designed to generate hypothetical answers for semantic search. Your primary function is to create concise Maximum 100-150 words, informative, and relevant responses to user queries. Make sure to capture the original intent of the user query (by including keywords present in user query) as these responses will be used to generate embeddings for improved semantic search results.
72
+ """
73
+
74
+ messages = [
75
+ {"role": "system", "content": system_prompt},
76
+ ]
77
+
78
+ for message in message_history[-6:-1]:
79
+ if message["role"] == "assistant":
80
+ messages.append({"role": "assistant", "content": message["content"]})
81
+ else:
82
+ messages.append(message)
83
+
84
+ messages.append(
85
+ {"role": "user", "content": f"\n\nQUERY:\n{query}"},
86
+ )
87
+
88
+ response = self.llm_client.chat.completions.create(
89
+ model="meta-llama/Llama-3-8b-chat-hf",
90
+ messages=messages,
91
+ max_tokens=500,
92
+ )
93
+ response = response.choices[0].message.content
94
+ return response
95
+
96
+ ### NOT IN USE
97
+
98
+ def generate_rag_chat_response(self, context: str, prompt: str, message_history):
99
+ """
100
+ NOT IN USE CURRENTLY
101
+ Generates a natural language response for user query(prompt) based on provided
102
+ context and message history, in Q&A style.
103
+ """
104
+
105
+ system_prompt = """You are a helpful legal compliance CHAT assistant designed to answer and resolve user query in chat format hence quick and *small* responses.
106
+
107
+ Instructions:
108
+ 1. Use the provided CONTEXT to inform your responses, citing specific parts when relevant.
109
+ 2. If unable to answer the QUERY, politely inform the user and suggest what additional information might help.
110
+ 3. Give a small/chatty format response.
111
+ 4. Try to give decisive responses that can help user to make informed decision.
112
+ 5. Format responses for readability, include bold words to give weightage.
113
+
114
+ Don't add phrases like "According to provided context.." etcectra. """
115
+
116
+ messages = [
117
+ {"role": "system", "content": system_prompt},
118
+ ]
119
+
120
+ for message in message_history[-6:-1]:
121
+ if message["role"] == "assistant":
122
+ messages.append({"role": "assistant", "content": message["content"]})
123
+ else:
124
+ messages.append(message)
125
+
126
+ messages.append(
127
+ {"role": "user", "content": f"CONTEXT:\n{context}\n\nQUERY:\n{prompt}"},
128
+ )
129
+
130
+ stream = self.llm_client.chat.completions.create(
131
+ model=self.llm_name,
132
+ messages=messages,
133
+ stream=True,
134
+ )
135
+
136
+ output = ""
137
+ for chunk in stream:
138
+ if chunk.choices[0].delta.content is not None:
139
+ content = chunk.choices[0].delta.content
140
+ output += content
141
+ yield 0, content
142
+
143
+ input_token_count = count_tokens(
144
+ string="\n".join([i["content"] for i in messages])
145
+ )
146
+
147
+ output_token_count = count_tokens(string=output)
148
+
149
+ yield 1, (output, input_token_count, output_token_count)
150
+
151
+ def rephrase_query(self, query: str, message_history):
152
+ """
153
+ NOT IN USE CURRENTLY
154
+ Rephare/rewrite the user query to include more semantics, hence improving the
155
+ semantic search based retreival.
156
+ """
157
+
158
+ system_prompt = """You are an AI assistant specifically designed to rewrite the user QUERY for semantic search. Your primary function is to create more comprehensive, semantically rich query *while maintaining the original intent of the user*. These responses will be used to generate embeddings for improved semantic search results.
159
+ Do not include any other comments or text, other than the repharased/rewritten query.
160
+ """
161
+
162
+ messages = [
163
+ {"role": "system", "content": system_prompt},
164
+ ]
165
+
166
+ for message in message_history[-6:-1]:
167
+ if message["role"] == "assistant":
168
+ messages.append({"role": "assistant", "content": message["content"]})
169
+ else:
170
+ messages.append(message)
171
+
172
+ messages.append(
173
+ {"role": "user", "content": f"\n\nQUERY:\n{query}"},
174
+ )
175
+
176
+ response = self.llm_client.chat.completions.create(
177
+ model="meta-llama/Llama-3-8b-chat-hf",
178
+ messages=messages,
179
+ )
180
+ response = response.choices[0].message.content
181
+ return response
components/prompts.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "SysPrompt": "You are an advanced AI assistant specializing in Retrieval-Augmented Generation (RAG) for question answering. Your primary function is to provide accurate, relevant, and contextually appropriate answers to user queries by leveraging both your general knowledge and the CONTEXT retrieved from a given knowledge base.\n \n## Your Responsibilities:\n1. Carefully analyze user query to understand the intent and context.\n2. Synthesize provided CONTEXT with your general knowledge to formulate comprehensive answers.\n4. Clearly distinguish between information from the knowledge base and your general knowledge or inferences.\n5. When uncertain, acknowledge the limitations of your knowledge or the available information.\n6. Maintain a professional, helpful, and courteous demeanor in all interactions.\n\n## Guidelines for Responses:\n1. Provide concise answers for straightforward questions, and more detailed explanations for complex topics.\n2. Use bullet points or numbered lists for clarity when presenting multiple pieces of information.\n3. Include relevant examples or analogies to aid understanding when appropriate.\n4. If the query cannot be fully answered with the available information, explain what is known and what remains uncertain.\n\nRemember, your goal is to assist users by providing accurate, helpful, and contextually relevant information. Always strive to enhance the user's understanding of the topic at hand."
3
+ }
components/utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains Utility functions for LLM and Database module. Along with some other misllaneous functions.
3
+ """
4
+
5
+ from pymupdf import pymupdf
6
+ from docx import Document
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ import tiktoken
9
+ import base64
10
+ import hashlib
11
+ import ollama
12
+ from typing import List
13
+ from openai import OpenAI
14
+ import os
15
+
16
+ TOGETHER_API = str(os.getenv("TOGETHER_API_KEY"))
17
+
18
+
19
+ def get_preview_pdf(file_bytes: bytes):
20
+ """Returns first 3 pages of a PDF file."""
21
+
22
+ doc = pymupdf.open(stream=file_bytes, filetype="pdf")
23
+ sliced_doc = pymupdf.open()
24
+ sliced_doc.insert_pdf(doc, from_page=0, to_page=2)
25
+
26
+ return sliced_doc.tobytes()
27
+
28
+
29
+ def count_tokens(string: str) -> int:
30
+ """Returns number of tokens in inputted string."""
31
+
32
+ tokenizer = tiktoken.get_encoding("cl100k_base")
33
+ return len(tokenizer.encode(text=string))
34
+
35
+
36
+ def create_refrences(retrieved_docs):
37
+ """Create a refrences of chunks/pecies used in generating reponse, in markdown format"""
38
+
39
+ refrences = ""
40
+ for doc in retrieved_docs:
41
+ try:
42
+ chunk_imgs = eval(doc["metadata"]["images"])
43
+ except:
44
+ chunk_imgs = None
45
+ chunk = doc["document"]
46
+
47
+ if chunk_imgs:
48
+ chunk_split = chunk.split("<img src='")
49
+ chunk_with_img = ""
50
+
51
+ if len(chunk_split) > 1:
52
+ for i in range(0, len(chunk_split) - 1):
53
+ img_bytes = chunk_imgs[i]
54
+ base64_str = base64.b64encode(img_bytes).decode("utf-8")
55
+ chunk_with_img += (
56
+ chunk_split[i].strip()
57
+ + f"\n<img src='data:image/png;base64,{base64_str}'>\n"
58
+ + chunk_split[i + 1][3:]
59
+ )
60
+ else:
61
+ chunk_with_img = chunk
62
+
63
+ refrences += (
64
+ f"###### {doc['metadata']['file_name']}\n\n{chunk_with_img}\n\n"
65
+ )
66
+ else:
67
+ chunk = doc["document"]
68
+ refrences += f"###### {doc['metadata']['file_name']}\n\n{chunk}\n\n**Distance : {doc['distance']}**\n\n"
69
+
70
+ return refrences
71
+
72
+
73
+ def generate_file_id(file_bytes):
74
+ """Generate a Unique file ID for given file."""
75
+
76
+ hash_obj = hashlib.sha256()
77
+ hash_obj.update(file_bytes[:4096])
78
+ file_id = hash_obj.hexdigest()[:63]
79
+ return str(file_id)
80
+
81
+
82
+ def extract_content_from_docx(docx_content):
83
+ """Extract content (text) from DOCX file"""
84
+ doc = Document(docx_content)
85
+ full_text = []
86
+ for para in doc.paragraphs:
87
+ full_text.append(para.text)
88
+ content = "\n".join(full_text)
89
+ return content
90
+
91
+
92
+ def extract_content_from_pdf(pdf_content):
93
+ """Extereact content (Image + text) from PDF files."""
94
+
95
+ doc = pymupdf.open(stream=pdf_content, filetype="pdf")
96
+ DOCUMENT = ""
97
+ pil_images = []
98
+
99
+ for page in doc:
100
+
101
+ blocks = page.get_text_blocks() # type: ignore
102
+ images = page.get_images() # type: ignore
103
+
104
+ # Create a list of all elements (text blocks and images) with their positions
105
+ elements = [(block[:4], block[4], "text") for block in blocks]
106
+
107
+ img_list = []
108
+ for img in images:
109
+ try:
110
+ img_bbox = page.get_image_rects(img[0])[0] # type: ignore
111
+ if len(img_bbox) > 0:
112
+ img_data = (img_bbox, img[0], "image")
113
+ img_list.append(img_data)
114
+ else:
115
+ continue
116
+ except Exception as e:
117
+ print("Exception :", e)
118
+ pass
119
+
120
+ elements.extend(img_list)
121
+
122
+ # Sort elements by their vertical position (top coordinate)
123
+ elements.sort(key=lambda x: x[0][1])
124
+
125
+ for element in elements:
126
+ if element[2] == "text":
127
+ DOCUMENT += element[1]
128
+ else:
129
+ xref = element[1]
130
+ base_image = doc.extract_image(xref)
131
+ image_bytes = base_image["image"]
132
+
133
+ # Save the image
134
+ image = image_bytes
135
+ pil_images.append(image)
136
+ DOCUMENT += f"\n<img src='{len(pil_images)-1}'>\n\n"
137
+ return DOCUMENT, pil_images
138
+
139
+
140
+ def chunk_document(document, chunk_size=200, overlap=10, encoding_name="cl100k_base"):
141
+ """Split/Chunk Document with Recursive splitting strategy"""
142
+
143
+ splitter = RecursiveCharacterTextSplitter(
144
+ separators=["\n\n", "\n", " ", ""], keep_separator=True
145
+ ).from_tiktoken_encoder(
146
+ encoding_name=encoding_name, chunk_size=chunk_size, chunk_overlap=overlap
147
+ )
148
+ chunks = splitter.split_text(document)
149
+ return chunks
150
+
151
+
152
+ def generate_embedding_ollama(
153
+ texts: List[str], embedding_model: str
154
+ ) -> List[List[float]]:
155
+ """Generate Embeddings for the givien pieces of texts."""
156
+
157
+ embeddings = []
158
+ for text in texts:
159
+ embedding = ollama.embeddings(model=embedding_model, prompt=text)["embedding"]
160
+ embeddings.append(list(embedding))
161
+
162
+ return embeddings
163
+
164
+
165
+ def generate_embedding(texts: List[str], embedding_model: str) -> List[List[float]]:
166
+ """Generate Embeddings for the givien pieces of texts."""
167
+
168
+ client = OpenAI(api_key=TOGETHER_API, base_url="https://api.together.xyz/v1")
169
+ embeddings_response = client.embeddings.create(
170
+ input=texts, model="BAAI/bge-large-en-v1.5"
171
+ ).data
172
+ embeddings = [i.embedding for i in embeddings_response]
173
+ return embeddings
main.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.responses import StreamingResponse
2
+ from fastapi import FastAPI, HTTPException
3
+ import os
4
+ import base64
5
+
6
+ from pydantic import BaseModel, Field
7
+ from typing import List, Dict
8
+ from typing_extensions import Literal
9
+
10
+ import logging
11
+ import sqlite3
12
+ import time
13
+ import asyncio
14
+
15
+ from components.LLM import rLLM
16
+ from components.Database import AdvancedClient
17
+ from components.utils import create_refrences
18
+
19
+ # LLM API key
20
+ TOGETHER_API = str(os.getenv("TOGETHER_API_KEY"))
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.WARNING,
25
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
26
+ handlers=[logging.FileHandler("app.log"), logging.StreamHandler()],
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+ app = FastAPI()
31
+
32
+ # SQLite setup
33
+ DB_PATH = "data/conversations.db"
34
+
35
+ # In-memory storage for conversations
36
+ conversations: Dict[str, List[Dict[str, str]]] = {}
37
+ COLLECTIONS: Dict[str, List[str]] = {}
38
+ last_activity: Dict[str, float] = {}
39
+
40
+
41
+ # initialize SQLite database
42
+ def init_db():
43
+ logger.info("Initializing database")
44
+ os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
45
+ conn = sqlite3.connect(DB_PATH)
46
+ c = conn.cursor()
47
+ c.execute(
48
+ """CREATE TABLE IF NOT EXISTS conversations
49
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
50
+ conversation_id TEXT,
51
+ collections TEXT,
52
+ lastmessage TEXT
53
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)"""
54
+ )
55
+ conn.commit()
56
+ conn.close()
57
+ logger.info("Database initialized successfully")
58
+
59
+
60
+ init_db()
61
+
62
+
63
+ def update_db(conversation_id, collections, message):
64
+ logger.info(f"Updating database for conversation: {conversation_id}")
65
+ conn = sqlite3.connect(DB_PATH)
66
+ c = conn.cursor()
67
+ c.execute(
68
+ """INSERT INTO conversations (conversation_id, collections, lastmessage)
69
+ VALUES (?, ?, ?)""",
70
+ (conversation_id, collections, message),
71
+ )
72
+ conn.commit()
73
+ conn.close()
74
+ logger.info("Database updated successfully")
75
+
76
+
77
+ def get_collection_from_db(conversation_id):
78
+ conn = sqlite3.connect(DB_PATH)
79
+ try:
80
+ c = conn.cursor()
81
+ c.execute(
82
+ """SELECT collections FROM conversations WHERE conversation_id = ?""",
83
+ (conversation_id,),
84
+ )
85
+ collection = c.fetchone()
86
+ if collection:
87
+ return collection[0]
88
+ else:
89
+ return None
90
+ finally:
91
+ conn.close()
92
+
93
+
94
+ async def clear_inactive_conversations():
95
+ while True:
96
+ logger.info("Clearing inactive conversations")
97
+ current_time = time.time()
98
+ inactive_convos = [
99
+ conv_id
100
+ for conv_id, last_time in last_activity.items()
101
+ if current_time - last_time > 1800
102
+ ] # 30 minutes
103
+ for conv_id in inactive_convos:
104
+ if conv_id in conversations:
105
+ del conversations[conv_id]
106
+ if conv_id in last_activity:
107
+ del last_activity[conv_id]
108
+ if conv_id in collections:
109
+ del collections[conv_id]
110
+ logger.info(f"Cleared {len(inactive_convos)} inactive conversations")
111
+ await asyncio.sleep(60) # Check every minute
112
+
113
+
114
+ @app.on_event("startup")
115
+ async def startup_event():
116
+ logger.info("Starting up the application")
117
+ asyncio.create_task(clear_inactive_conversations())
118
+
119
+
120
+ class UploadedFiles(BaseModel):
121
+ ConversationID: str = Field(examples=["123e4567-e89b-12d3-a456-426614174000"])
122
+ FileNames: List[str] = Field(examples=[["file_1.pdf", "file_2.docx"]])
123
+ FileTypes: List[Literal["pdf", "docx"]] = Field(examples=[["pdf", "docx"]])
124
+ FileData: List[str]
125
+
126
+
127
+ class UserInput(BaseModel):
128
+ ConversationID: str = Field(examples=["123e4567-e89b-12d3-a456-426614174000"])
129
+ Query: str = Field(examples=["What is IT ACT 2000?"])
130
+
131
+
132
+ class ChunkResponse(BaseModel):
133
+ chunk: str = Field(examples=["This is", "streaming"])
134
+
135
+
136
+ class CompletedResponse(BaseModel):
137
+ FullResponse: str = Field(examples=["This is a complete response"])
138
+ InputToken: int = Field(examples=[1024, 2048])
139
+ OutputToken: int = Field(examples=[4096, 7000])
140
+
141
+
142
+ @app.post("/initiate_conversation")
143
+ async def get_conversation_id(files: UploadedFiles):
144
+ # Decoding bytes data
145
+ data = [base64.b64decode(b) for b in files.FileData]
146
+ vector_db = AdvancedClient()
147
+ vector_db.create_or_get_collection(
148
+ file_names=files.FileNames,
149
+ file_types=files.FileTypes,
150
+ file_datas=data,
151
+ )
152
+
153
+ file_ids = vector_db.selected_collections
154
+
155
+ # update in-memory data
156
+ COLLECTIONS[files.ConversationID] = file_ids
157
+ conversations[files.ConversationID] = []
158
+ last_activity[files.ConversationID] = time.time()
159
+
160
+ # update SQL data
161
+ update_db(
162
+ conversation_id=files.ConversationID,
163
+ collections="|".join(file_ids),
164
+ message="NONE",
165
+ )
166
+ return True
167
+
168
+
169
+ @app.post("/get_response")
170
+ async def get_response_streaming(user_query: UserInput):
171
+
172
+ llm = rLLM(llm_name="meta-llama/Llama-3-8b-chat-hf", api_key=TOGETHER_API)
173
+
174
+ conv_id = user_query.ConversationID
175
+ try:
176
+ print(COLLECTIONS)
177
+ if conv_id in COLLECTIONS:
178
+ collection_to_use = COLLECTIONS[conv_id]
179
+ last_activity[conv_id] = time.time()
180
+ else:
181
+ collections = get_collection_from_db(conv_id)
182
+ if collections:
183
+ collection_to_use = collections.split("|")
184
+
185
+ except:
186
+ return HTTPException(
187
+ status_code=404,
188
+ detail="Conversation ID does not exist, please register one with /initiate_conversation endpoint.",
189
+ )
190
+
191
+ vector_db = AdvancedClient()
192
+ # update database to user conversation's documents
193
+ vector_db.selected_collections = collection_to_use
194
+
195
+ try:
196
+ conversation_history = conversations[conv_id]
197
+ except:
198
+ conversations[conv_id] = []
199
+ conversation_history = []
200
+ rephrased_query = llm.HyDE(
201
+ query=user_query.Query, message_history=conversation_history
202
+ )
203
+
204
+ retrieved_docs = vector_db.retrieve_chunks(query=rephrased_query)
205
+
206
+ conversations[conv_id].append({"role": "user", "content": user_query.Query})
207
+
208
+ context = ""
209
+ for i, doc in enumerate(retrieved_docs, start=1):
210
+ context += f"Refrence {i}\n\n" + doc["document"] + "\n\n"
211
+
212
+ def streaming():
213
+ for data in llm.generate_rag_response(
214
+ context=context,
215
+ prompt=user_query.Query,
216
+ message_history=conversation_history,
217
+ ):
218
+ completed, chunk = data
219
+ if completed:
220
+ full_response, input_token, output_token = chunk
221
+
222
+ conversations[conv_id].append(
223
+ {"role": "assistant", "content": full_response}
224
+ )
225
+
226
+ logger.info(msg=f"Input:{input_token} \nOuptut:{output_token}")
227
+ yield "\n\n<REFRENCES>\n" + create_refrences(
228
+ retrieved_docs
229
+ ) + "\n</REFRENCES>"
230
+
231
+ else:
232
+ chunk = chunk
233
+ yield chunk
234
+
235
+ return StreamingResponse(streaming(), media_type="text/event-stream")
requirements.txt ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.7.0
2
+ anyio==4.4.0
3
+ asgiref==3.8.1
4
+ backoff==2.2.1
5
+ bcrypt==4.2.0
6
+ build==1.2.1
7
+ cachetools==5.4.0
8
+ certifi==2024.7.4
9
+ charset-normalizer==3.3.2
10
+ chroma-hnswlib==0.7.6
11
+ chromadb==0.5.5
12
+ click==8.1.7
13
+ coloredlogs==15.0.1
14
+ Deprecated==1.2.14
15
+ distro==1.9.0
16
+ dnspython==2.6.1
17
+ email_validator==2.2.0
18
+ exceptiongroup==1.2.2
19
+ fastapi==0.112.0
20
+ fastapi-cli==0.0.5
21
+ filelock==3.15.4
22
+ flatbuffers==24.3.25
23
+ fsspec==2024.6.1
24
+ google-auth==2.32.0
25
+ googleapis-common-protos==1.63.2
26
+ grpcio==1.65.4
27
+ h11==0.14.0
28
+ httpcore==1.0.5
29
+ httptools==0.6.1
30
+ httpx==0.27.0
31
+ huggingface-hub==0.24.5
32
+ humanfriendly==10.0
33
+ idna==3.7
34
+ importlib_metadata==8.0.0
35
+ importlib_resources==6.4.0
36
+ Jinja2==3.1.4
37
+ jsonpatch==1.33
38
+ jsonpointer==3.0.0
39
+ kubernetes==30.1.0
40
+ langchain-core==0.2.28
41
+ langchain-text-splitters==0.2.2
42
+ langsmith==0.1.96
43
+ lxml==5.2.2
44
+ markdown-it-py==3.0.0
45
+ MarkupSafe==2.1.5
46
+ mdurl==0.1.2
47
+ mmh3==4.1.0
48
+ monotonic==1.6
49
+ mpmath==1.3.0
50
+ numpy==1.26.4
51
+ oauthlib==3.2.2
52
+ ollama==0.3.1
53
+ onnxruntime==1.18.1
54
+ openai==1.38.0
55
+ opentelemetry-api==1.26.0
56
+ opentelemetry-exporter-otlp-proto-common==1.26.0
57
+ opentelemetry-exporter-otlp-proto-grpc==1.26.0
58
+ opentelemetry-instrumentation==0.47b0
59
+ opentelemetry-instrumentation-asgi==0.47b0
60
+ opentelemetry-instrumentation-fastapi==0.47b0
61
+ opentelemetry-proto==1.26.0
62
+ opentelemetry-sdk==1.26.0
63
+ opentelemetry-semantic-conventions==0.47b0
64
+ opentelemetry-util-http==0.47b0
65
+ orjson==3.10.6
66
+ overrides==7.7.0
67
+ packaging==24.1
68
+ posthog==3.5.0
69
+ protobuf==4.25.4
70
+ pyasn1==0.6.0
71
+ pyasn1_modules==0.4.0
72
+ pydantic==2.8.2
73
+ pydantic_core==2.20.1
74
+ Pygments==2.18.0
75
+ PyMuPDF==1.24.9
76
+ PyMuPDFb==1.24.9
77
+ PyPika==0.48.9
78
+ pyproject_hooks==1.1.0
79
+ python-dateutil==2.9.0.post0
80
+ python-docx==1.1.2
81
+ python-dotenv==1.0.1
82
+ python-multipart==0.0.9
83
+ PyYAML==6.0.1
84
+ regex==2024.7.24
85
+ requests==2.32.3
86
+ requests-oauthlib==2.0.0
87
+ rich==13.7.1
88
+ rsa==4.9
89
+ shellingham==1.5.4
90
+ six==1.16.0
91
+ sniffio==1.3.1
92
+ starlette==0.37.2
93
+ sympy==1.13.1
94
+ tenacity==8.5.0
95
+ tiktoken==0.7.0
96
+ tokenizers==0.19.1
97
+ tomli==2.0.1
98
+ tqdm==4.66.5
99
+ typer==0.12.3
100
+ typing_extensions==4.12.2
101
+ urllib3==2.2.2
102
+ uvicorn==0.30.5
103
+ uvloop==0.19.0
104
+ watchfiles==0.22.0
105
+ websocket-client==1.8.0
106
+ websockets==12.0
107
+ wrapt==1.16.0
108
+ zipp==3.19.2