Spaces:
Sleeping
Sleeping
| import concurrent.futures | |
| import threading | |
| import torch | |
| from datetime import datetime | |
| import json | |
| import gradio as gr | |
| import re | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
| from langchain.document_loaders import DirectoryLoader, TextLoader # Import these from langchain | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter # Import the text splitter | |
| class DocumentRetrievalAndGeneration: | |
| def __init__(self, embedding_model_name, lm_model_id, data_folder): | |
| self.all_splits = self.load_documents(data_folder) | |
| self.embeddings = SentenceTransformer(embedding_model_name) | |
| self.gpu_index = self.create_faiss_index() | |
| self.llm = self.initialize_llm(lm_model_id) | |
| self.cancel_flag = threading.Event() | |
| def load_documents(self, folder_path): | |
| loader = DirectoryLoader(folder_path, loader_cls=TextLoader) | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250) | |
| all_splits = text_splitter.split_documents(documents) | |
| print('Length of documents:', len(documents)) | |
| print("LEN of all_splits", len(all_splits)) | |
| for i in range(5): | |
| print(all_splits[i].page_content) | |
| return all_splits | |
| def create_faiss_index(self): | |
| all_texts = [split.page_content for split in self.all_splits] | |
| embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy() | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings) | |
| gpu_resource = faiss.StandardGpuResources() | |
| gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index) | |
| return gpu_index | |
| def initialize_llm(self, model_id): | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| generate_text = pipeline( | |
| model=model, | |
| tokenizer=tokenizer, | |
| return_full_text=True, | |
| task='text-generation', | |
| temperature=0.6, | |
| max_new_tokens=256, | |
| ) | |
| return generate_text | |
| def generate_response_with_timeout(self, model_inputs): | |
| def target(future): | |
| if self.cancel_flag.is_set(): | |
| return | |
| generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True) | |
| if not self.cancel_flag.is_set(): | |
| future.set_result(generated_ids) | |
| else: | |
| future.set_exception(TimeoutError("Text generation process was canceled")) | |
| future = concurrent.futures.Future() | |
| thread = threading.Thread(target=target, args=(future,)) | |
| thread.start() | |
| try: | |
| generated_ids = future.result(timeout=60) # Timeout set to 60 seconds | |
| return generated_ids | |
| except concurrent.futures.TimeoutError: | |
| self.cancel_flag.set() | |
| raise TimeoutError("Text generation process timed out") | |
| def qa_infer_gradio(self, query): | |
| # Set the cancel flag to false for the new query | |
| self.cancel_flag.clear() | |
| try: | |
| query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() | |
| distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5) | |
| content = "" | |
| for idx in indices[0]: | |
| content += "-" * 50 + "\n" | |
| content += self.all_splits[idx].page_content + "\n" | |
| prompt = f"""<s> | |
| Here's my question: | |
| Query: {query} | |
| Solution: | |
| RETURN ONLY SOLUTION. IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE" | |
| </s> | |
| """ | |
| messages = [{"role": "user", "content": prompt}] | |
| encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
| model_inputs = encodeds.to(self.llm.device) | |
| start_time = datetime.now() | |
| generated_ids = self.generate_response_with_timeout(model_inputs) | |
| elapsed_time = datetime.now() - start_time | |
| decoded = self.llm.tokenizer.batch_decode(generated_ids) | |
| generated_response = decoded[0] | |
| match = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| solution_text = match.group(1).strip() | |
| else: | |
| solution_text = "NO SOLUTION AVAILABLE" | |
| print("Generated response:", generated_response) | |
| print("Time elapsed:", elapsed_time) | |
| print("Device in use:", self.llm.device) | |
| return solution_text, content | |
| except TimeoutError: | |
| return "timeout", content |