Spaces:
Runtime error
Runtime error
| import os | |
| import requests | |
| from huggingface_hub import InferenceClient | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.llms import CTransformers | |
| from langchain_core.vectorstores import VectorStoreRetriever | |
| class LLMModel: | |
| base_model = "TheBloke/Llama-2-7B-GGUF" | |
| specific_model = "llama-2-7b.Q4_K_M.gguf" | |
| token_model = "meta-llama/Llama-2-7b-hf" | |
| llm_config = {'context_length': 2048, 'max_new_tokens': 1024, 'temperature': 0.3, 'top_p': 1.0} | |
| question_answer_system_prompt = """You are a helpful question answer assistant. Given the following context and a question, provide a set of potential questions and answers. | |
| Keep answers brief and well-structured. Do not give one word answers.""" | |
| final_assistant_system_prompt = """You are a helpful assistant. Given the following list of relevant questions and answers, generate an answer based on this list only. | |
| Keep answers brief and well-structured. Do not give one word answers. | |
| If the answer is not found in the list, kindly state "I don't know.". Don't try to make up an answer.""" | |
| template = """<s>[INST] <<SYS>> | |
| You are a question answer assistant. Given the following context and a question, generate an answer based on this context only. | |
| Keep answers brief and well-structured. Do not give one word answers. | |
| If the answer is not found in the context, kindly state "I don't know.". Don't try to make up an answer. | |
| <</SYS>> | |
| Context: {context} | |
| Question: Give me a step by step explanation of {question}[/INST] | |
| Answer:""" | |
| qa_chain_prompt = PromptTemplate.from_template(template) | |
| retriever = None | |
| hf_token = os.getenv('HF_TOKEN') | |
| api_url = os.getenv('API_URL') | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| client = InferenceClient(api_url) | |
| # llm = CTransformers(model=base_model, model_file=specific_model, config=llm_config, hf=True) | |
| llm = None | |
| def __init__(self, retriever: VectorStoreRetriever): | |
| self.retriever = retriever | |
| def create_qa_chain(self): | |
| return RetrievalQA.from_chain_type( | |
| llm=self.llm, | |
| chain_type="stuff", | |
| retriever=self.retriever, | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": self.qa_chain_prompt}, | |
| ) | |
| def format_retrieved_docs(self, docs): | |
| all_docs = [] | |
| for doc in docs: | |
| if "source" in doc.metadata: | |
| all_docs.append(f"""Document: {doc.metadata['source']}\nContent: {doc.page_content}\n\n""") | |
| return all_docs | |
| def format_query(self, question, context, system_prompt): | |
| prompt = f"""[INST] {system_prompt} | |
| Context: {context} | |
| Question: Give me a step by step explanation of {question}[/INST]""" | |
| return prompt | |
| def format_question(self, question): | |
| relevant_docs = self.retriever.get_relevant_documents(question) | |
| formatted_docs = self.format_retrieved_docs(relevant_docs) | |
| return self.format_query(question, formatted_docs, self.final_assistant_system_prompt) | |
| def get_potential_question_answer(self, document_chunk: str): | |
| prompt = self.format_query("potential questions and answers.", document_chunk, self.question_answer_system_prompt) | |
| return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4) | |
| def answer_question_inference_text_gen(self, question): | |
| prompt = self.format_question(question) | |
| return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4) | |
| def answer_question_inference(self, question): | |
| relevant_docs = self.retriever.get_relevant_documents(question) | |
| formatted_docs = "".join(self.format_retrieved_docs(relevant_docs)) | |
| if not formatted_docs: | |
| return "No uploaded documents. Please try upload a document on the left side." | |
| else: | |
| print(formatted_docs) | |
| return self.client.question_answering(question=question, context=formatted_docs) | |
| def answer_question_api(self, question): | |
| formatted_prompt = self.format_question(question) | |
| resp = requests.post(self.api_url, headers=self.headers, json={"inputs": formatted_prompt}, stream=True) | |
| for c in resp.iter_content(): | |
| yield c | |