| | from abc import ABCMeta, abstractmethod |
| | from typing import Union |
| |
|
| | from langchain_community.embeddings import HuggingFaceEmbeddings |
| | from transformers import AutoTokenizer |
| | from langchain.text_splitter import TokenTextSplitter |
| | from langchain_core.documents import Document |
| | from torch.cuda import is_available |
| |
|
| |
|
| | class BaseDB(metaclass=ABCMeta): |
| | def __init__(self, embedding_name: str = None, persist_dir=None) -> None: |
| | super().__init__() |
| |
|
| | self.client = None |
| |
|
| | if persist_dir: |
| | self.persist_dir = persist_dir |
| | else: |
| | self.persist_dir = "data" |
| |
|
| | if not embedding_name: |
| | embedding_name = "BAAI/bge-small-zh-v1.5" |
| |
|
| | if is_available(): |
| | model_kwargs = {"device": "cuda"} |
| | else: |
| | model_kwargs = {"device": "cpu"} |
| |
|
| | self.embedding = HuggingFaceEmbeddings(model_name=embedding_name,model_kwargs=model_kwargs) |
| | self.tokenizer = AutoTokenizer.from_pretrained(embedding_name) |
| |
|
| | self.init_db() |
| |
|
| | @abstractmethod |
| | def init_db(self): |
| | pass |
| |
|
| | def text_splitter( |
| | self, text: Union[str, Document], chunk_size=300, chunk_overlap=10 |
| | ): |
| | if isinstance(text, Document): |
| | return TokenTextSplitter.from_huggingface_tokenizer( |
| | self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap |
| | ).split_documents(text) |
| | elif isinstance(text, str): |
| | return TokenTextSplitter.from_huggingface_tokenizer( |
| | self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap |
| | ).split_text(text) |
| | else: |
| | raise ValueError("text must be a str or Document") |
| |
|
| | @abstractmethod |
| | def addStories(self, stories, metas=None): |
| | pass |
| |
|
| | @abstractmethod |
| | def deleteStoriesByMeta(self, metas): |
| | pass |
| |
|
| | @abstractmethod |
| | def searchBySim(self, query, n_results, metas, only_return_document=True): |
| | pass |
| |
|
| | @abstractmethod |
| | def searchByMeta(self, metas=None): |
| | pass |
| |
|