Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| from typing import List, Optional, Any | |
| from langchain.schema import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from loguru import logger | |
| from tqdm import tqdm | |
| from src.config import local_embedding, retrieve_proxy, chunk_overlap, chunk_size, hf_emb_model_name | |
| from src import shared | |
| from src.utils import excel_to_string, get_files_hash, load_pkl, save_pkl | |
| pwd_path = os.path.abspath(os.path.dirname(__file__)) | |
| class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): | |
| """Recursive text splitter for Chinese text. | |
| copy from: https://github.com/chatchat-space/Langchain-Chatchat/tree/master | |
| """ | |
| def __init__( | |
| self, | |
| separators: Optional[List[str]] = None, | |
| keep_separator: bool = True, | |
| is_separator_regex: bool = True, | |
| **kwargs: Any, | |
| ) -> None: | |
| """Create a new TextSplitter.""" | |
| super().__init__(keep_separator=keep_separator, **kwargs) | |
| self._separators = separators or [ | |
| "\n\n", | |
| "\n", | |
| "。|!|?", | |
| "\.\s|\!\s|\?\s", | |
| ";|;\s", | |
| ",|,\s" | |
| ] | |
| self._is_separator_regex = is_separator_regex | |
| def _split_text_with_regex_from_end( | |
| text: str, separator: str, keep_separator: bool | |
| ) -> List[str]: | |
| # Now that we have the separator, split the text | |
| if separator: | |
| if keep_separator: | |
| # The parentheses in the pattern keep the delimiters in the result. | |
| _splits = re.split(f"({separator})", text) | |
| splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] | |
| if len(_splits) % 2 == 1: | |
| splits += _splits[-1:] | |
| else: | |
| splits = re.split(separator, text) | |
| else: | |
| splits = list(text) | |
| return [s for s in splits if s != ""] | |
| def _split_text(self, text: str, separators: List[str]) -> List[str]: | |
| """Split incoming text and return chunks.""" | |
| final_chunks = [] | |
| # Get appropriate separator to use | |
| separator = separators[-1] | |
| new_separators = [] | |
| for i, _s in enumerate(separators): | |
| _separator = _s if self._is_separator_regex else re.escape(_s) | |
| if _s == "": | |
| separator = _s | |
| break | |
| if re.search(_separator, text): | |
| separator = _s | |
| new_separators = separators[i + 1:] | |
| break | |
| _separator = separator if self._is_separator_regex else re.escape(separator) | |
| splits = self._split_text_with_regex_from_end(text, _separator, self._keep_separator) | |
| # Now go merging things, recursively splitting longer texts. | |
| _good_splits = [] | |
| _separator = "" if self._keep_separator else separator | |
| for s in splits: | |
| if self._length_function(s) < self._chunk_size: | |
| _good_splits.append(s) | |
| else: | |
| if _good_splits: | |
| merged_text = self._merge_splits(_good_splits, _separator) | |
| final_chunks.extend(merged_text) | |
| _good_splits = [] | |
| if not new_separators: | |
| final_chunks.append(s) | |
| else: | |
| other_info = self._split_text(s, new_separators) | |
| final_chunks.extend(other_info) | |
| if _good_splits: | |
| merged_text = self._merge_splits(_good_splits, _separator) | |
| final_chunks.extend(merged_text) | |
| return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip() != ""] | |
| def get_documents(file_paths): | |
| text_splitter = ChineseRecursiveTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| documents = [] | |
| logger.debug("Loading documents...") | |
| logger.debug(f"file_paths: {file_paths}") | |
| for file in file_paths: | |
| filepath = file.name | |
| filename = os.path.basename(filepath) | |
| file_type = os.path.splitext(filename)[1] | |
| logger.info(f"loading file: {filename}") | |
| texts = None | |
| try: | |
| if file_type == ".pdf": | |
| import PyPDF2 | |
| logger.debug("Loading PDF...") | |
| try: | |
| from src.pdf_func import parse_pdf | |
| from src.config import advance_docs | |
| two_column = advance_docs["pdf"].get("two_column", False) | |
| pdftext = parse_pdf(filepath, two_column).text | |
| except: | |
| pdftext = "" | |
| with open(filepath, "rb") as pdfFileObj: | |
| pdfReader = PyPDF2.PdfReader(pdfFileObj) | |
| for page in tqdm(pdfReader.pages): | |
| pdftext += page.extract_text() | |
| texts = [Document(page_content=pdftext, | |
| metadata={"source": filepath})] | |
| elif file_type == ".docx": | |
| logger.debug("Loading Word...") | |
| from langchain.document_loaders import UnstructuredWordDocumentLoader | |
| loader = UnstructuredWordDocumentLoader(filepath) | |
| texts = loader.load() | |
| elif file_type == ".pptx": | |
| logger.debug("Loading PowerPoint...") | |
| from langchain.document_loaders import UnstructuredPowerPointLoader | |
| loader = UnstructuredPowerPointLoader(filepath) | |
| texts = loader.load() | |
| elif file_type == ".epub": | |
| logger.debug("Loading EPUB...") | |
| from langchain.document_loaders import UnstructuredEPubLoader | |
| loader = UnstructuredEPubLoader(filepath) | |
| texts = loader.load() | |
| elif file_type == ".xlsx": | |
| logger.debug("Loading Excel...") | |
| text_list = excel_to_string(filepath) | |
| texts = [] | |
| for elem in text_list: | |
| texts.append(Document(page_content=elem, | |
| metadata={"source": filepath})) | |
| else: | |
| logger.debug("Loading text file...") | |
| from langchain_community.document_loaders import TextLoader | |
| loader = TextLoader(filepath, "utf8") | |
| texts = loader.load() | |
| logger.debug(f"text size: {len(texts)}, text top3: {texts[:3]}") | |
| except Exception as e: | |
| logger.error(f"Error loading file: {filename}, {e}") | |
| if texts is not None: | |
| texts = text_splitter.split_documents(texts) | |
| documents.extend(texts) | |
| logger.debug(f"Documents loaded. documents size: {len(documents)}, top3: {documents[:3]}") | |
| return documents | |
| def construct_index(api_key, files, load_from_cache_if_possible=True): | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
| if api_key: | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| else: | |
| os.environ["OPENAI_API_KEY"] = "sk-xxxxxxx" | |
| index_name = get_files_hash(files) | |
| index_dir = os.path.join(pwd_path, 'index') | |
| index_path = os.path.join(index_dir, index_name) | |
| doc_file = os.path.join(index_path, 'docs.pkl') | |
| if local_embedding: | |
| embeddings = HuggingFaceEmbeddings(model_name=hf_emb_model_name) | |
| else: | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| if os.environ.get("OPENAI_API_TYPE", "openai") == "openai": | |
| embeddings = OpenAIEmbeddings( | |
| openai_api_base=shared.state.openai_api_base, | |
| openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key) | |
| ) | |
| else: | |
| embeddings = OpenAIEmbeddings( | |
| deployment=os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"], | |
| openai_api_key=os.environ["AZURE_OPENAI_API_KEY"], | |
| model=os.environ["AZURE_EMBEDDING_MODEL_NAME"], | |
| openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"], | |
| openai_api_type="azure" | |
| ) | |
| # 确保索引路径存在 | |
| os.makedirs(index_dir, exist_ok=True) | |
| if os.path.exists(index_path) and load_from_cache_if_possible: | |
| try: | |
| logger.info("找到了缓存的索引文件,加载中……") | |
| index = FAISS.load_local(index_path, embeddings) | |
| documents = load_pkl(doc_file) | |
| return index, documents | |
| except (FileNotFoundError, RuntimeError) as e: | |
| logger.error(f"加载缓存的索引文件失败,重新构建索引…… 错误: {e}") | |
| try: | |
| documents = get_documents(files) | |
| logger.info("构建索引中……") | |
| with retrieve_proxy(): | |
| index = FAISS.from_documents(documents, embeddings) | |
| logger.debug("索引构建完成!") | |
| os.makedirs(index_path, exist_ok=True) | |
| index.save_local(index_path) | |
| logger.debug("索引已保存至本地!") | |
| save_pkl(documents, doc_file) | |
| logger.debug("索引文档已保存至本地!") | |
| return index, documents | |
| except Exception as e: | |
| logger.error(f"索引构建失败!错误: {e}") | |
| return None |