Spaces:
Sleeping
Sleeping
| from langchain_community.llms import HuggingFaceEndpoint | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.chains import ConversationalRetrievalChain | |
| import gradio as gr | |
| import os | |
| from llm.CustomRetriever import CustomRetriever | |
| from langchain.schema.retriever import BaseRetriever | |
| from langchain_core.documents import Document | |
| from typing import List | |
| from langchain.callbacks.manager import CallbackManagerForRetrieverRun | |
| from langchain_core.documents import Document | |
| from langchain_core.runnables import chain | |
| API_TOKEN=os.getenv("TOKEN") | |
| # Initialize langchain LLM chain | |
| def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb, | |
| thold=0.8, progress=gr.Progress()): | |
| llm = HuggingFaceEndpoint( | |
| huggingfacehub_api_token = API_TOKEN, | |
| repo_id=llm_model, | |
| temperature = temperature, | |
| max_new_tokens = max_tokens, | |
| top_k = top_k, | |
| ) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| output_key='answer', | |
| return_messages=True | |
| ) | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm, | |
| retriever=CustomRetriever(vectorstore=vdb, thold=thold), | |
| chain_type="stuff", | |
| memory=memory, | |
| return_source_documents=True, | |
| verbose=False, | |
| ) | |
| return qa_chain | |
| # Initialize LLM | |
| def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, thold, progress=gr.Progress()): | |
| llm_name = "meta-llama/Meta-Llama-3-8B-Instruct" #"mistralai/Mistral-7B-Instruct-v0.2" | |
| qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, thold) | |
| return qa_chain | |
| def format_chat_history(chat_history): | |
| formatted_chat_history = [] | |
| for user_message, bot_message in chat_history: | |
| formatted_chat_history.append(f"User: {user_message}") | |
| formatted_chat_history.append(f"Assistant: {bot_message}") | |
| return formatted_chat_history | |
| def postprocess(response): | |
| try: | |
| result=response["answer"] | |
| #Here should be a binary classification model. | |
| if not "I don't know" in result: | |
| for doc in response['source_documents']: | |
| file_doc="\n\nFile: " + doc.metadata["source"].split('/')[-1] | |
| page="\nPage: " + str(doc.metadata["page"]) | |
| content="\nFragment: " + doc.page_content.strip() | |
| result+=file_doc+page+content | |
| return result | |
| except: | |
| return "I don't know." | |