Spaces:
Sleeping
Sleeping
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.prompts import PromptTemplate | |
| import pickle | |
| import config | |
| from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever | |
| from memory import memory3 | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.retrievers.document_compressors import EmbeddingsFilter | |
| from langchain.document_transformers import EmbeddingsRedundantFilter | |
| from langchain.retrievers.document_compressors import DocumentCompressorPipeline | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from pydantic import BaseModel, Field | |
| from typing import Any, Optional, Dict, List | |
| from huggingface_hub import InferenceClient | |
| from langchain.llms.base import LLM | |
| import os | |
| chat_model_name = "HuggingFaceH4/zephyr-7b-alpha" | |
| reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1" | |
| hf_token = os.getenv("apiToken") | |
| kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True} | |
| reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True} | |
| class KwArgsModel(BaseModel): | |
| kwargs: Dict[str, Any] = Field(default_factory=dict) | |
| class CustomInferenceClient(LLM, KwArgsModel): | |
| model_name: str | |
| inference_client: InferenceClient | |
| def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): | |
| inference_client = InferenceClient(model=model_name, token=hf_token) | |
| super().__init__( | |
| model_name=model_name, | |
| hf_token=hf_token, | |
| kwargs=kwargs, | |
| inference_client=inference_client | |
| ) | |
| def _call( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None | |
| ) -> str: | |
| if stop is not None: | |
| raise ValueError("stop kwargs are not permitted.") | |
| response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False) | |
| response = ''.join(response_gen) | |
| return response | |
| def _llm_type(self) -> str: | |
| return "custom" | |
| def _identifying_params(self) -> dict: | |
| return {"model_name": self.model_name} | |
| chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs) | |
| reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs) | |
| prompt_template = config.DEFAULT_CHAT_TEMPLATE | |
| PROMPT = PromptTemplate( | |
| template=prompt_template, input_variables=["context", "question", "chat_history"] | |
| ) | |
| chain_type_kwargs = {"prompt": PROMPT} | |
| embeddings = HuggingFaceEmbeddings() | |
| vectorstore = FAISS.load_local("cima_faiss_index", embeddings) | |
| retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5}) | |
| splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ") | |
| redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) | |
| relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.5) | |
| pipeline_compressor = DocumentCompressorPipeline( | |
| transformers=[splitter, redundant_filter, relevant_filter] | |
| ) | |
| compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever) | |
| with open("docs_data.pkl", "rb") as file: | |
| docs = pickle.load(file) | |
| bm25_retriever = BM25Retriever.from_texts(docs) | |
| bm25_retriever.k = 2 | |
| bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever) | |
| ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5]) | |
| custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST]. | |
| Chat History: | |
| {chat_history} | |
| Follow-up user message: {question} | |
| Rewritten user message:""" | |
| CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template) | |
| chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm, | |
| chain_type="stuff", | |
| retriever=ensemble_retriever, | |
| combine_docs_chain_kwargs=chain_type_kwargs, | |
| return_source_documents=True, | |
| get_chat_history=lambda h : h, | |
| condense_question_prompt=CUSTOM_QUESTION_PROMPT, | |
| memory=memory3, | |
| condense_question_llm = reform_llm | |
| ) | |