ayushsinghal1510's picture
Init COmmit
cc65c1f
import json
import numpy as np
from scripts.llm.runner import run_groq
from scripts.scrapper.page import page_to_docs
from scripts.routers.services import hash_url , clean_redis
from scripts.llm.services import save_history , load_history
from scripts.scrapper.pdf import pdf_to_docs , pdf_file_to_docs
# ! --------------------------------------Typing Annotations--------------------------------------
from logging import Logger
from sentence_transformers import SentenceTransformer
from pymilvus import MilvusClient
from redis import Redis
from groq import Groq
async def add_to_milvus(
documents : list ,
milvus_client : MilvusClient ,
embedding_model : SentenceTransformer ,
url_prefix : int ,
url : str ,
api_key : str
) -> list :
milvus_client.create_collection(collection_name = api_key , dimension = 384)
texts = [document['text'] for document in documents]
embeddings : np.ndarray = embedding_model.encode(texts[ : 100_000] , show_progress_bar = True)
ids = []
chunk_counter = 1
for document , embedding in zip(documents , embeddings) :
id_ = url_prefix * 100_000 + chunk_counter
chunk_counter += 1
ids.append(id_)
data = {
'id' : id_ ,
'vector' : embedding
}
for key , value in zip(document.keys() , document.values()) : data[key] = value
milvus_client.insert(
collection_name = api_key ,
data = [data]
)
chunk_counter += 1
if len(documents) > 100_000 : print(f'Warning: Document from {url} had {len(documents)} chunks, but only processed 100_000 ')
return ids
async def scrape_page_route(
url : str ,
api_key : str ,
logger : Logger ,
embedding_model : SentenceTransformer ,
milvus_client : MilvusClient ,
image_model ,
url_redis_client : Redis ,
scrape_images : bool = False
) -> None :
url_prefix : int = await hash_url(url)
_ : None = await clean_redis(
url ,
url_redis_client ,
milvus_client ,
api_key
)
documents : list = await page_to_docs(url , scrape_images , image_model)
logger.info(f'Added {len(documents)} for {api_key} for {url}')
ids : list = await add_to_milvus(documents , milvus_client , embedding_model , url_prefix , url , api_key)
url_redis_client.set(url , json.dumps(ids))
async def scrape_pdf_route(
url : str ,
api_key : str ,
logger : Logger ,
embedding_model : SentenceTransformer ,
milvus_client : MilvusClient ,
image_model ,
url_redis_client : Redis ,
scrape_images : bool = False
) -> None :
url_prefix : int = await hash_url(url)
_ : None = await clean_redis(
url ,
url_redis_client ,
milvus_client ,
api_key
)
documents : list = await pdf_to_docs(url , scrape_images , image_model)
logger.info(f'Added {len(documents)} for {api_key} for {url}')
ids : list = await add_to_milvus(documents , milvus_client , embedding_model , url_prefix , url , api_key)
url_redis_client.set(url , json.dumps(ids))
async def scrape_pdf__file_route(
filename : str ,
api_key : str ,
logger : Logger ,
contents : bytes ,
embedding_model : SentenceTransformer ,
milvus_client : MilvusClient ,
url_redis_client : Redis ,
) -> None :
url_prefix : int = await hash_url(filename)
_ : None = await clean_redis(
filename ,
url_redis_client ,
milvus_client
)
filename = f'assets/pdfs/{filename}'
with open(filename , 'wb') as pdf_file : pdf_file.write(contents)
documents : list = await pdf_file_to_docs(filename)
logger.info(f'Added {len(documents)} for {api_key} for {filename}')
ids : list = await add_to_milvus(documents , milvus_client , embedding_model , url_prefix , filename , api_key)
url_redis_client.set(filename , json.dumps(ids))
async def ask_route(
query : str ,
session_id : str ,
api_key : str ,
logger : Logger ,
embedding_model : SentenceTransformer ,
milvus_client : MilvusClient ,
chat_redis_client : Redis ,
groq_client : Groq ,
) -> str :
query_embeddings = embedding_model.encode(query)
results : list = milvus_client.search(
collection_name = api_key ,
data = [query_embeddings] ,
limit = 2 ,
output_fields = ['text' , 'source']
)[0]
context = '\n'.join([f'''Content : {row['entity']['text']} + {row['entity']['source']}''' for row in results])
with open('assets/database/prompt/rag.md') as rag_prompt_file : prompt = rag_prompt_file.read()
history : list = await load_history(chat_redis_client , session_id)
if history == [] : history = [
{
'role' : 'system' ,
'content' : prompt
}
]
history.append({
'role' : 'user' ,
'content' : f'''
Context : {context}
Query : {query}
'''
})
response : str = await run_groq(history , groq_client)
history.append({
'role' : 'assistant' ,
'content' : response
})
logger.info(f'Answerd : {response} : for {api_key} for {query}')
await save_history(chat_redis_client , history , session_id)
return response