| import os |
| import torch |
| import re |
| import hydra |
| import logging |
| from langchain_community.vectorstores import Chroma |
| from BCEmbedding.tools.langchain import BCERerank |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from langchain.prompts import PromptTemplate |
| from langchain.chains import RetrievalQA |
| from langchain.llms.base import LLM |
| from typing import Any, List, Optional, Iterator |
| from langchain.callbacks.manager import CallbackManagerForLLMRun |
| from langchain.retrievers import ContextualCompressionRetriever |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, TextIteratorStreamer |
| from threading import Thread |
| from modelscope.hub.snapshot_download import snapshot_download |
|
|
| import sys |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from rag.chroma_db import get_chroma_db |
| from download_models import download_model |
|
|
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', |
| filename='ragchat.log', |
| filemode='w') |
|
|
| class InternLM(LLM): |
| tokenizer : AutoTokenizer = None |
| model: AutoModelForCausalLM = None |
| llm_system_prompt: str="" |
| def __init__(self, model_path: str, llm_system_prompt: str): |
| super().__init__() |
| logging.info(f"正在从本地:{model_path}加载模型...") |
| try: |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).cuda() |
| self.model.eval() |
| self.llm_system_prompt = llm_system_prompt |
| logging.info("完成本地模型的加载") |
| except Exception as e: |
| logging.error(f"加载模型时发生错误: {e}") |
| raise |
|
|
| def _call(self, prompt : str, stop: Optional[List[str]] = None, |
| run_manager: Optional[CallbackManagerForLLMRun] = None, |
| **kwargs: Any): |
| |
| |
| system_prompt = self.llm_system_prompt |
|
|
| messages = [("system", system_prompt)] |
| response, history = self.model.chat(self.tokenizer, prompt , history=messages) |
| return response |
| |
| def stream(self, prompt: str) -> Iterator[str]: |
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt").cuda() |
| streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) |
| |
| generation_kwargs = dict( |
| input_ids=input_ids, |
| max_new_tokens=2048, |
| |
| |
| |
| |
| |
| streamer=streamer, |
| ) |
| |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
| thread.start() |
| |
| for idx, new_text in enumerate(streamer): |
| |
| if idx > 0: |
| yield new_text |
|
|
| @property |
| def _llm_type(self) -> str: |
| return "InternLM2" |
|
|
| class WuleRAG(): |
| """ |
| 存储检索问答链的对象 |
| """ |
| def __init__(self, data_source_dir, db_persist_directory, base_mode, embeddings_model, reranker_model, rag_prompt_template): |
| |
| self.llm = base_mode |
|
|
| |
| |
| if not os.path.exists(embeddings_model): |
| if embeddings_model.endswith("bce-embedding-base_v1"): |
| |
| save_dir = os.path.dirname(os.path.dirname(embeddings_model)) |
| embeddings_model = snapshot_download("maidalun/bce-embedding-base_v1", cache_dir=save_dir, revision='master') |
| logging.info(f"bce-embedding model not exist, downloading from modelscope \n save to {embeddings_model}") |
| else: |
| raise ValueError(f"{embeddings_model} model not exist, please reset or re-download your model.") |
| |
| self.embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, |
| model_kwargs={"device": "cuda"}, |
| encode_kwargs={"batch_size": 1, "normalize_embeddings": True}) |
| |
| |
| if not os.path.exists(reranker_model): |
| if reranker_model.endswith("bce-reranker-base_v1"): |
| |
| save_dir = os.path.dirname(os.path.dirname(reranker_model)) |
| reranker_model = snapshot_download("maidalun/bce-reranker-base_v1", cache_dir=save_dir, revision='master') |
| logging.info(f"reranker_model model not exist, downloading from modelscope \n save to {reranker_model}") |
| else: |
| raise ValueError(f"{reranker_model} model not exist, please reset or re-download your model.") |
| reranker_args = {'model': reranker_model, 'top_n': 5, 'device': 'cuda', "use_fp16": True} |
| self.reranker = BCERerank(**reranker_args) |
| vectordb = get_chroma_db(data_source_dir, db_persist_directory, self.embeddings) |
|
|
| |
| |
| self.retriever = vectordb.as_retriever(search_kwargs={"k": 3, "score_threshold": 0.6}, search_type="similarity_score_threshold" ) |
|
|
| |
| self.compression_retriever = ContextualCompressionRetriever( |
| base_compressor=self.reranker, base_retriever=self.retriever |
| ) |
| |
| |
| self.PROMPT = PromptTemplate( |
| template=rag_prompt_template, input_variables=["context", "question"] |
| ) |
|
|
| |
| self.qa_chain = RetrievalQA.from_chain_type( |
| llm=self.llm, |
| chain_type="stuff", |
| retriever=self.compression_retriever, |
| return_source_documents=True, |
| chain_type_kwargs={"prompt": self.PROMPT} |
| ) |
| |
| def query_stream(self, query: str) -> Iterator[str]: |
| docs = self.compression_retriever.get_relevant_documents(query) |
| context = "\n\n".join(doc.page_content for doc in docs) |
| prompt = self.PROMPT.format(context=context, question=query) |
| return self.llm.stream(prompt) |
|
|
| def query(self, question): |
| """ |
| 调用问答链进行回答,如果没有找到相关文档,则使用模型自身的回答 |
| #使用示例 |
| question='黑神话悟空发售时间和团队?主要讲了什么故事?' |
| result = self.qa_chain({"query": question}) |
| print(result["result"]) |
| """ |
| if not question: |
| return "请提供个有用的问题。" |
|
|
| try: |
| |
| result = self.qa_chain.invoke({"query": question}) |
| |
| |
| if 'result' in result: |
| answer = result['result'] |
| final_answer = re.sub(r'^根据提供的信息,\s?', '', answer, flags=re.M).strip() |
| return final_answer |
| else: |
| logging.error("Error: 'result' field not found in the result.") |
| return "悟了悟了目前无法提供答案,请稍后再试。" |
| except Exception as e: |
| |
| import traceback |
| logging.error(f"An error occurred: {e}\n{traceback.format_exc()}") |
| return "悟了悟了遇到了一些技术问题,正在修复中。" |
|
|
|
|
| @hydra.main(version_base=None, config_path="../configs", config_name="rag_cfg") |
| def main(config): |
| data_source_dir = config.data_source_dir |
| db_persist_directory = config.db_persist_directory |
| llm_model = config.llm_model |
| embeddings_model = config.embeddings_model |
| reranker_model = config.reranker_model |
| llm_system_prompt = config.llm_system_prompt |
| rag_prompt_template = config.rag_prompt_template |
|
|
| |
| if not os.path.exists(llm_model): |
| download_model(llm_model_path = llm_model) |
|
|
| base_mode = InternLM(model_path=llm_model, llm_system_prompt=llm_system_prompt) |
| |
| |
| wulewule_rag = WuleRAG(data_source_dir, db_persist_directory, base_mode, embeddings_model, reranker_model, rag_prompt_template) |
| question="""黑神话悟空发售时间和团队?主要讲了什么故事?""" |
| |
| if config.stream_response: |
| logging.info("Streaming response:") |
| for chunk in wulewule_rag.query_stream(question): |
| print(chunk, end='', flush=True) |
| print("\n") |
| |
| else: |
| response = wulewule_rag.query(question) |
| logging.info(f"question: {question}\n wulewule answer:\n{response}") |
|
|
| if __name__ == "__main__": |
| main() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |