Spaces:
Runtime error
Runtime error
| import yaml | |
| import fitz | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import Chroma | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.document_loaders import PyPDFLoader | |
| from langchain.prompts import PromptTemplate | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import spaces | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType | |
| from datasets import Dataset, load_from_disk | |
| import faiss | |
| import numpy as np | |
| from pastebin_api import get_protected_content | |
| class RAGbot: | |
| def __init__(self, config_path="config.yaml"): | |
| self.processed = False | |
| self.page = 0 | |
| self.chat_history = [] | |
| self.prompt = None | |
| self.documents = None | |
| self.embeddings = None | |
| self.zilliz_vectordb = None | |
| self.hf_vectordb = None | |
| self.tokenizer = None | |
| self.model = None | |
| self.pipeline = None | |
| self.chain = None | |
| self.chunk_size = 512 | |
| self.overlap_percentage = 50 | |
| self.max_chunks_in_context = 2 | |
| self.current_context = None | |
| self.model_temperatue = 0.5 | |
| self.format_seperator = "\n\n--\n\n" | |
| self.pipe = None | |
| with open(config_path, "r") as file: | |
| config = yaml.safe_load(file) | |
| self.model_embeddings = config["modelEmbeddings"] | |
| self.auto_tokenizer = config["autoTokenizer"] | |
| self.auto_model_for_causal_lm = config["autoModelForCausalLM"] | |
| self.zilliz_config = config["zilliz"] | |
| self.persona_paste_key = config["personaPasteKey"] | |
| def connect_to_zilliz(self): | |
| connections.connect( | |
| host=self.zilliz_config["host"], | |
| port=self.zilliz_config["port"], | |
| user=self.zilliz_config["user"], | |
| password=self.zilliz_config["password"], | |
| secure=True | |
| ) | |
| self.zilliz_vectordb = Collection(self.zilliz_config["collection"]) | |
| def load_embeddings(self): | |
| self.embeddings = HuggingFaceEmbeddings(model_name=self.model_embeddings) | |
| def load_hf_vectordb(self, dataset_path, index_path): | |
| dataset = load_from_disk(dataset_path) | |
| index = faiss.read_index(index_path) | |
| self.hf_vectordb = (dataset, index) | |
| def load_tokenizer(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.auto_tokenizer) | |
| def create_organic_pipeline(self): | |
| self.pipe = pipeline( | |
| "text-generation", | |
| model=self.auto_model_for_causal_lm, | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| device="cuda", | |
| ) | |
| def get_organic_context(self, query, use_hf=False): | |
| if use_hf: | |
| dataset, index = self.hf_vectordb | |
| D, I = index.search(np.array([self.embeddings.embed_query(query)]), self.max_chunks_in_context) | |
| context = self.format_seperator.join([dataset[i] for i in I[0]]) | |
| else: | |
| result = self.zilliz_vectordb.search( | |
| data=[self.embeddings.embed_query(query)], | |
| anns_field="embeddings", | |
| param={"metric_type": "IP", "params": {"nprobe": 10}}, | |
| limit=self.max_chunks_in_context, | |
| expr=None, | |
| ) | |
| context = self.format_seperator.join([hit.entity.get('text') for hit in result[0]]) | |
| self.current_context = context | |
| def load_persona_data(self): | |
| persona_content = get_protected_content(self.persona_paste_key) | |
| persona_data = yaml.safe_load(persona_content) | |
| self.persona_text = persona_data["persona_text"] | |
| def create_organic_response(self, history, query, use_hf=False): | |
| self.get_organic_context(query, use_hf=use_hf) | |
| messages = [ | |
| {"role": "system", "content": f"Based on the given context, answer the user's question while maintaining the persona:\n{self.persona_text}\n\nContext:\n{self.current_context}"}, | |
| {"role": "user", "content": query}, | |
| ] | |
| prompt = self.pipe.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| temp = 0.1 | |
| outputs = self.pipe( | |
| prompt, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=temp, | |
| top_p=0.9, | |
| ) | |
| return outputs[0]["generated_text"][len(prompt):] | |
| def process_file(self, file): | |
| self.documents = PyPDFLoader(file.name).load() | |
| self.load_embeddings() | |
| self.connect_to_zilliz() | |
| def generate_response(self, history, query, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context, use_hf_index=False, hf_dataset_path=None, hf_index_path=None): | |
| self.chunk_size = chunk_size | |
| self.overlap_percentage = chunk_overlap_percentage | |
| self.model_temperatue = model_temperature | |
| self.max_chunks_in_context = max_chunks_in_context | |
| if not query: | |
| raise gr.Error(message='Submit a question') | |
| if use_hf_index: | |
| if not hf_dataset_path or not hf_index_path: | |
| raise gr.Error(message='Provide HuggingFace dataset and index paths') | |
| self.load_hf_vectordb(hf_dataset_path, hf_index_path) | |
| result = self.create_organic_response(history="", query=query, use_hf=True) | |
| else: | |
| if not file: | |
| raise gr.Error(message='Upload a PDF') | |
| if not self.processed: | |
| self.process_file(file) | |
| self.processed = True | |
| result = self.create_organic_response(history="", query=query) | |
| self.load_persona_data() | |
| result = f"{self.persona_text}\n\n{result}" | |
| for char in result: | |
| history[-1][-1] += char | |
| return history, "" | |
| def render_file(self, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context): | |
| doc = fitz.open(file.name) | |
| page = doc[self.page] | |
| self.chunk_size = chunk_size | |
| self.overlap_percentage = chunk_overlap_percentage | |
| self.model_temperatue = model_temperature | |
| self.max_chunks_in_context = max_chunks_in_context | |
| pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) | |
| image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples) | |
| return image | |
| def add_text(self, history, text): | |
| if not text: | |
| raise gr.Error('Enter text') | |
| history.append((text, '')) | |
| return history |