Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| from scripts.llm.runner import run_groq | |
| from scripts.scrapper.page import page_to_docs | |
| from scripts.routers.services import hash_url , clean_redis | |
| from scripts.llm.services import save_history , load_history | |
| from scripts.scrapper.pdf import pdf_to_docs , pdf_file_to_docs | |
| # ! --------------------------------------Typing Annotations-------------------------------------- | |
| from logging import Logger | |
| from sentence_transformers import SentenceTransformer | |
| from pymilvus import MilvusClient | |
| from redis import Redis | |
| from groq import Groq | |
| async def add_to_milvus( | |
| documents : list , | |
| milvus_client : MilvusClient , | |
| embedding_model : SentenceTransformer , | |
| url_prefix : int , | |
| url : str , | |
| api_key : str | |
| ) -> list : | |
| milvus_client.create_collection(collection_name = api_key , dimension = 384) | |
| texts = [document['text'] for document in documents] | |
| embeddings : np.ndarray = embedding_model.encode(texts[ : 100_000] , show_progress_bar = True) | |
| ids = [] | |
| chunk_counter = 1 | |
| for document , embedding in zip(documents , embeddings) : | |
| id_ = url_prefix * 100_000 + chunk_counter | |
| chunk_counter += 1 | |
| ids.append(id_) | |
| data = { | |
| 'id' : id_ , | |
| 'vector' : embedding | |
| } | |
| for key , value in zip(document.keys() , document.values()) : data[key] = value | |
| milvus_client.insert( | |
| collection_name = api_key , | |
| data = [data] | |
| ) | |
| chunk_counter += 1 | |
| if len(documents) > 100_000 : print(f'Warning: Document from {url} had {len(documents)} chunks, but only processed 100_000 ') | |
| return ids | |
| async def scrape_page_route( | |
| url : str , | |
| api_key : str , | |
| logger : Logger , | |
| embedding_model : SentenceTransformer , | |
| milvus_client : MilvusClient , | |
| image_model , | |
| url_redis_client : Redis , | |
| scrape_images : bool = False | |
| ) -> None : | |
| url_prefix : int = await hash_url(url) | |
| _ : None = await clean_redis( | |
| url , | |
| url_redis_client , | |
| milvus_client , | |
| api_key | |
| ) | |
| documents : list = await page_to_docs(url , scrape_images , image_model) | |
| logger.info(f'Added {len(documents)} for {api_key} for {url}') | |
| ids : list = await add_to_milvus(documents , milvus_client , embedding_model , url_prefix , url , api_key) | |
| url_redis_client.set(url , json.dumps(ids)) | |
| async def scrape_pdf_route( | |
| url : str , | |
| api_key : str , | |
| logger : Logger , | |
| embedding_model : SentenceTransformer , | |
| milvus_client : MilvusClient , | |
| image_model , | |
| url_redis_client : Redis , | |
| scrape_images : bool = False | |
| ) -> None : | |
| url_prefix : int = await hash_url(url) | |
| _ : None = await clean_redis( | |
| url , | |
| url_redis_client , | |
| milvus_client , | |
| api_key | |
| ) | |
| documents : list = await pdf_to_docs(url , scrape_images , image_model) | |
| logger.info(f'Added {len(documents)} for {api_key} for {url}') | |
| ids : list = await add_to_milvus(documents , milvus_client , embedding_model , url_prefix , url , api_key) | |
| url_redis_client.set(url , json.dumps(ids)) | |
| async def scrape_pdf__file_route( | |
| filename : str , | |
| api_key : str , | |
| logger : Logger , | |
| contents : bytes , | |
| embedding_model : SentenceTransformer , | |
| milvus_client : MilvusClient , | |
| url_redis_client : Redis , | |
| ) -> None : | |
| url_prefix : int = await hash_url(filename) | |
| _ : None = await clean_redis( | |
| filename , | |
| url_redis_client , | |
| milvus_client | |
| ) | |
| filename = f'assets/pdfs/{filename}' | |
| with open(filename , 'wb') as pdf_file : pdf_file.write(contents) | |
| documents : list = await pdf_file_to_docs(filename) | |
| logger.info(f'Added {len(documents)} for {api_key} for {filename}') | |
| ids : list = await add_to_milvus(documents , milvus_client , embedding_model , url_prefix , filename , api_key) | |
| url_redis_client.set(filename , json.dumps(ids)) | |
| async def ask_route( | |
| query : str , | |
| session_id : str , | |
| api_key : str , | |
| logger : Logger , | |
| embedding_model : SentenceTransformer , | |
| milvus_client : MilvusClient , | |
| chat_redis_client : Redis , | |
| groq_client : Groq , | |
| ) -> str : | |
| query_embeddings = embedding_model.encode(query) | |
| results : list = milvus_client.search( | |
| collection_name = api_key , | |
| data = [query_embeddings] , | |
| limit = 2 , | |
| output_fields = ['text' , 'source'] | |
| )[0] | |
| context = '\n'.join([f'''Content : {row['entity']['text']} + {row['entity']['source']}''' for row in results]) | |
| with open('assets/database/prompt/rag.md') as rag_prompt_file : prompt = rag_prompt_file.read() | |
| history : list = await load_history(chat_redis_client , session_id) | |
| if history == [] : history = [ | |
| { | |
| 'role' : 'system' , | |
| 'content' : prompt | |
| } | |
| ] | |
| history.append({ | |
| 'role' : 'user' , | |
| 'content' : f''' | |
| Context : {context} | |
| Query : {query} | |
| ''' | |
| }) | |
| response : str = await run_groq(history , groq_client) | |
| history.append({ | |
| 'role' : 'assistant' , | |
| 'content' : response | |
| }) | |
| logger.info(f'Answerd : {response} : for {api_key} for {query}') | |
| await save_history(chat_redis_client , history , session_id) | |
| return response |