Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| @author:XuMing(xuming624@qq.com) | |
| @description: | |
| """ | |
| import argparse | |
| import hashlib | |
| import os | |
| import re | |
| from threading import Thread | |
| from typing import Union, List | |
| import jieba | |
| import torch | |
| from loguru import logger | |
| from peft import PeftModel | |
| from similarities import ( | |
| EnsembleSimilarity, | |
| BertSimilarity, | |
| BM25Similarity, | |
| TfidfSimilarity | |
| ) | |
| from similarities.similarity import SimilarityABC | |
| from transformers import ( | |
| AutoModel, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BloomForCausalLM, | |
| BloomTokenizerFast, | |
| LlamaTokenizer, | |
| LlamaForCausalLM, | |
| TextIteratorStreamer, | |
| GenerationConfig, | |
| AutoModelForSequenceClassification, | |
| ) | |
| jieba.setLogLevel("ERROR") | |
| MODEL_CLASSES = { | |
| "bloom": (BloomForCausalLM, BloomTokenizerFast), | |
| "chatglm": (AutoModel, AutoTokenizer), | |
| "llama": (LlamaForCausalLM, LlamaTokenizer), | |
| "baichuan": (AutoModelForCausalLM, AutoTokenizer), | |
| "auto": (AutoModelForCausalLM, AutoTokenizer), | |
| } | |
| PROMPT_TEMPLATE = """Basándose únicamente en la información proporcionada a continuación, responda a las preguntas del usuario de manera concisa y profesional. | |
| No se debe responder a preguntas relacionadas con sentimientos, emociones, temas personales o cualquier información que no esté explícitamente presente en el contenido proporcionado. | |
| Si la pregunta se refiere a un artículo específico y no se encuentra en el contenido proporcionado, diga: "No se puede encontrar el artículo solicitado en la información conocida". | |
| Contenido conocido: | |
| {context_str} | |
| Pregunta: | |
| {query_str} | |
| """ | |
| class SentenceSplitter: | |
| def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50): | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| def split_text(self, text: str) -> List[str]: | |
| if self._is_has_chinese(text): | |
| return self._split_chinese_text(text) | |
| else: | |
| return self._split_english_text(text) | |
| def _split_chinese_text(self, text: str) -> List[str]: | |
| sentence_endings = {'\n', '。', '!', '?', ';', '…'} # 句末标点符号 | |
| chunks, current_chunk = [], '' | |
| for word in jieba.cut(text): | |
| if len(current_chunk) + len(word) > self.chunk_size: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = word | |
| else: | |
| current_chunk += word | |
| if word[-1] in sentence_endings and len(current_chunk) > self.chunk_size - self.chunk_overlap: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = '' | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| if self.chunk_overlap > 0 and len(chunks) > 1: | |
| chunks = self._handle_overlap(chunks) | |
| return chunks | |
| def _split_english_text(self, text: str) -> List[str]: | |
| # 使用正则表达式按句子分割英文文本 | |
| sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' ')) | |
| chunks = [] | |
| current_chunk = '' | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) <= self.chunk_size: | |
| current_chunk += (' ' if current_chunk else '') + sentence | |
| else: | |
| if len(sentence) > self.chunk_size: | |
| for i in range(0, len(sentence), self.chunk_size): | |
| chunks.append(sentence[i:i + self.chunk_size]) | |
| current_chunk = '' | |
| else: | |
| chunks.append(current_chunk) | |
| current_chunk = sentence | |
| if current_chunk: # Add the last chunk | |
| chunks.append(current_chunk) | |
| if self.chunk_overlap > 0 and len(chunks) > 1: | |
| chunks = self._handle_overlap(chunks) | |
| return chunks | |
| def _is_has_chinese(self, text: str) -> bool: | |
| # check if contains chinese characters | |
| if any("\u4e00" <= ch <= "\u9fff" for ch in text): | |
| return True | |
| else: | |
| return False | |
| def _handle_overlap(self, chunks: List[str]) -> List[str]: | |
| # 处理块间重叠 | |
| overlapped_chunks = [] | |
| for i in range(len(chunks) - 1): | |
| chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap] | |
| overlapped_chunks.append(chunk.strip()) | |
| overlapped_chunks.append(chunks[-1]) | |
| return overlapped_chunks | |
| class Rag: | |
| def __init__( | |
| self, | |
| similarity_model: SimilarityABC = None, | |
| generate_model_type: str = "auto", | |
| generate_model_name_or_path: str = "Qwen/Qwen2-0.5B-Instruct", | |
| lora_model_name_or_path: str = None, | |
| corpus_files: Union[str, List[str]] = None, | |
| save_corpus_emb_dir: str = "./corpus_embs/", | |
| device: str = None, | |
| int8: bool = False, | |
| int4: bool = False, | |
| chunk_size: int = 250, | |
| chunk_overlap: int = 0, | |
| rerank_model_name_or_path: str = None, | |
| enable_history: bool = False, | |
| num_expand_context_chunk: int = 2, | |
| similarity_top_k: int = 10, | |
| rerank_top_k: int = 3, | |
| ): | |
| """ | |
| Init RAG model. | |
| :param similarity_model: similarity model, default None, if set, will use it instead of EnsembleSimilarity | |
| :param generate_model_type: generate model type | |
| :param generate_model_name_or_path: generate model name or path | |
| :param lora_model_name_or_path: lora model name or path | |
| :param corpus_files: corpus files | |
| :param save_corpus_emb_dir: save corpus embeddings dir, default ./corpus_embs/ | |
| :param device: device, default None, auto select gpu or cpu | |
| :param int8: use int8 quantization, default False | |
| :param int4: use int4 quantization, default False | |
| :param chunk_size: chunk size, default 250 | |
| :param chunk_overlap: chunk overlap, default 0, can not set to > 0 if num_expand_context_chunk > 0 | |
| :param rerank_model_name_or_path: rerank model name or path, default 'BAAI/bge-reranker-base' | |
| :param enable_history: enable history, default False | |
| :param num_expand_context_chunk: num expand context chunk, default 2, if set to 0, will not expand context chunk | |
| :param similarity_top_k: similarity_top_k, default 5, similarity model search k corpus chunks | |
| :param rerank_top_k: rerank_top_k, default 3, rerank model search k corpus chunks | |
| """ | |
| if torch.cuda.is_available(): | |
| default_device = torch.device(0) | |
| elif torch.backends.mps.is_available(): | |
| default_device = torch.device('cpu') | |
| else: | |
| default_device = torch.device('cpu') | |
| self.device = device or default_device | |
| if num_expand_context_chunk > 0 and chunk_overlap > 0: | |
| logger.warning(f" 'num_expand_context_chunk' and 'chunk_overlap' cannot both be greater than zero. " | |
| f" 'chunk_overlap' has been set to zero by default.") | |
| chunk_overlap = 0 | |
| self.text_splitter = SentenceSplitter(chunk_size, chunk_overlap) | |
| if similarity_model is not None: | |
| self.sim_model = similarity_model | |
| else: | |
| m1 = BertSimilarity(model_name_or_path="shibing624/text2vec-base-multilingual", device=self.device) | |
| m2 = BM25Similarity() | |
| m3 = TfidfSimilarity() | |
| default_sim_model = EnsembleSimilarity(similarities=[m1, m2, m3], weights=[0.5, 0.5, 0.5], c=2) # Ajuste los pesos según los resultados | |
| self.sim_model = default_sim_model | |
| self.gen_model, self.tokenizer = self._init_gen_model( | |
| generate_model_type, | |
| generate_model_name_or_path, | |
| peft_name=lora_model_name_or_path, | |
| int8=int8, | |
| int4=int4, | |
| ) | |
| self.history = [] | |
| self.corpus_files = corpus_files | |
| if corpus_files: | |
| self.add_corpus(corpus_files) | |
| self.save_corpus_emb_dir = save_corpus_emb_dir | |
| if rerank_model_name_or_path is None: | |
| rerank_model_name_or_path = "BAAI/bge-reranker-large" | |
| if rerank_model_name_or_path: | |
| self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path) | |
| self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path) | |
| self.rerank_model.to(self.device) | |
| self.rerank_model.eval() | |
| else: | |
| self.rerank_model = None | |
| self.rerank_tokenizer = None | |
| self.enable_history = enable_history | |
| self.similarity_top_k = similarity_top_k | |
| self.num_expand_context_chunk = num_expand_context_chunk | |
| self.rerank_top_k = rerank_top_k | |
| def __str__(self): | |
| return f"Similarity model: {self.sim_model}, Generate model: {self.gen_model}" | |
| def _init_gen_model( | |
| self, | |
| gen_model_type: str, | |
| gen_model_name_or_path: str, | |
| peft_name: str = None, | |
| int8: bool = False, | |
| int4: bool = False, | |
| ): | |
| """Init generate model.""" | |
| if int8 or int4: | |
| device_map = None | |
| else: | |
| device_map = "auto" | |
| model_class, tokenizer_class = MODEL_CLASSES[gen_model_type] | |
| tokenizer = tokenizer_class.from_pretrained(gen_model_name_or_path, trust_remote_code=True) | |
| model = model_class.from_pretrained( | |
| gen_model_name_or_path, | |
| load_in_8bit=int8 if gen_model_type not in ['baichuan', 'chatglm'] else False, | |
| load_in_4bit=int4 if gen_model_type not in ['baichuan', 'chatglm'] else False, | |
| torch_dtype="auto", | |
| device_map=device_map, | |
| trust_remote_code=True, | |
| ) | |
| if self.device == torch.device('cpu'): | |
| model.float() | |
| if gen_model_type in ['baichuan', 'chatglm']: | |
| if int4: | |
| model = model.quantize(4).cuda() | |
| elif int8: | |
| model = model.quantize(8).cuda() | |
| try: | |
| model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True) | |
| except Exception as e: | |
| logger.warning(f"Failed to load generation config from {gen_model_name_or_path}, {e}") | |
| if peft_name: | |
| model = PeftModel.from_pretrained( | |
| model, | |
| peft_name, | |
| torch_dtype="auto", | |
| ) | |
| logger.info(f"Loaded peft model from {peft_name}") | |
| model.eval() | |
| return model, tokenizer | |
| def _get_chat_input(self): | |
| messages = [] | |
| for conv in self.history: | |
| if conv and len(conv) > 0 and conv[0]: | |
| messages.append({'role': 'user', 'content': conv[0]}) | |
| if conv and len(conv) > 1 and conv[1]: | |
| messages.append({'role': 'assistant', 'content': conv[1]}) | |
| input_ids = self.tokenizer.apply_chat_template( | |
| conversation=messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors='pt' | |
| ) | |
| return input_ids.to(self.gen_model.device) | |
| def stream_generate_answer( | |
| self, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| repetition_penalty=1.0, | |
| context_len=2048 | |
| ): | |
| streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
| input_ids = self._get_chat_input() | |
| max_src_len = context_len - max_new_tokens - 8 | |
| input_ids = input_ids[-max_src_len:] | |
| generation_kwargs = dict( | |
| input_ids=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| repetition_penalty=repetition_penalty, | |
| streamer=streamer, | |
| ) | |
| thread = Thread(target=self.gen_model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| yield from streamer | |
| def add_corpus(self, files: Union[str, List[str]]): | |
| """Load document files.""" | |
| if isinstance(files, str): | |
| files = [files] | |
| for doc_file in files: | |
| if doc_file.endswith('.pdf'): | |
| corpus = self.extract_text_from_pdf(doc_file) | |
| elif doc_file.endswith('.docx'): | |
| corpus = self.extract_text_from_docx(doc_file) | |
| elif doc_file.endswith('.md'): | |
| corpus = self.extract_text_from_markdown(doc_file) | |
| else: | |
| corpus = self.extract_text_from_txt(doc_file) | |
| full_text = '\n'.join(corpus) | |
| chunks = self.text_splitter.split_text(full_text) | |
| self.sim_model.add_corpus(chunks) | |
| self.corpus_files = files | |
| logger.debug(f"files: {files}, corpus size: {len(self.sim_model.corpus)}, top3: " | |
| f"{list(self.sim_model.corpus.values())[:3]}") | |
| def get_file_hash(fpaths): | |
| hasher = hashlib.md5() | |
| target_file_data = bytes() | |
| if isinstance(fpaths, str): | |
| fpaths = [fpaths] | |
| for fpath in fpaths: | |
| with open(fpath, 'rb') as file: | |
| chunk = file.read(1024 * 1024) # read only first 1MB | |
| hasher.update(chunk) | |
| target_file_data += chunk | |
| hash_name = hasher.hexdigest()[:32] | |
| return hash_name | |
| def extract_text_from_pdf(file_path: str): | |
| """Extract text content from a PDF file.""" | |
| import PyPDF2 | |
| contents = [] | |
| with open(file_path, 'rb') as f: | |
| pdf_reader = PyPDF2.PdfReader(f) | |
| for page in pdf_reader.pages: | |
| page_text = page.extract_text().strip() | |
| raw_text = [text.strip() for text in page_text.splitlines() if text.strip()] | |
| new_text = '' | |
| for text in raw_text: | |
| if new_text: | |
| new_text += ' ' | |
| new_text += text | |
| if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」', | |
| '』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']: | |
| contents.append(new_text) | |
| new_text = '' | |
| if new_text: | |
| contents.append(new_text) | |
| return contents | |
| def extract_text_from_txt(file_path: str): | |
| """Extract text content from a TXT file.""" | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| contents = [text.strip() for text in f.readlines() if text.strip()] | |
| return contents | |
| def extract_text_from_docx(file_path: str): | |
| """Extract text content from a DOCX file.""" | |
| import docx | |
| document = docx.Document(file_path) | |
| contents = [paragraph.text.strip() for paragraph in document.paragraphs if paragraph.text.strip()] | |
| return contents | |
| def extract_text_from_markdown(file_path: str): | |
| """Extract text content from a Markdown file.""" | |
| import markdown | |
| from bs4 import BeautifulSoup | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| markdown_text = f.read() | |
| html = markdown.markdown(markdown_text) | |
| soup = BeautifulSoup(html, 'html.parser') | |
| contents = [text.strip() for text in soup.get_text().splitlines() if text.strip()] | |
| return contents | |
| def _add_source_numbers(lst): | |
| """Add source numbers to a list of strings.""" | |
| return [f'[{idx + 1}]\t "{item}"' for idx, item in enumerate(lst)] | |
| def _get_reranker_score(self, query: str, reference_results: List[str]): | |
| """Get reranker score.""" | |
| pairs = [] | |
| for reference in reference_results: | |
| pairs.append([query, reference]) | |
| with torch.no_grad(): | |
| inputs = self.rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) | |
| inputs_on_device = {k: v.to(self.rerank_model.device) for k, v in inputs.items()} | |
| scores = self.rerank_model(**inputs_on_device, return_dict=True).logits.view(-1, ).float() | |
| return scores | |
| def get_reference_results(self, query: str): | |
| # Verificar si la consulta incluye un "Artículo X" | |
| exact_match = None | |
| if re.search(r'Artículo\s*\d+', query, re.IGNORECASE): | |
| # Buscar el término específico "Artículo X" en el corpus de manera más precisa | |
| term = re.search(r'Artículo\s*\d+', query, re.IGNORECASE).group() | |
| # Buscar coincidencias exactas en el corpus | |
| for corpus_id, content in self.sim_model.corpus.items(): | |
| # Agregar espacio o signo de puntuación alrededor de "term" para evitar coincidencias parciales | |
| if re.search(r'\b' + re.escape(term) + r'\b', content, re.IGNORECASE): | |
| exact_match = content | |
| break | |
| if exact_match: | |
| # Si se encuentra una coincidencia exacta, devolverla como contexto | |
| return [exact_match] | |
| reference_results = [] | |
| sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k) | |
| # Get reference results from corpus | |
| hit_chunk_dict = dict() | |
| for c in sim_contents: | |
| for id_score_dict in c: | |
| corpus_id = id_score_dict['corpus_id'] | |
| hit_chunk = id_score_dict["corpus_doc"] | |
| reference_results.append(hit_chunk) | |
| hit_chunk_dict[corpus_id] = hit_chunk | |
| if reference_results: | |
| if self.rerank_model is not None: | |
| # Rerank reference results | |
| rerank_scores = self._get_reranker_score(query, reference_results) | |
| logger.debug(f"rerank_scores: {rerank_scores}") | |
| # Get rerank top k chunks | |
| reference_results = [reference for reference, score in sorted( | |
| zip(reference_results, rerank_scores), key=lambda x: x[1], reverse=True)][:self.rerank_top_k] | |
| hit_chunk_dict = {corpus_id: hit_chunk for corpus_id, hit_chunk in hit_chunk_dict.items() if | |
| hit_chunk in reference_results} | |
| # Expand reference context chunk | |
| if self.num_expand_context_chunk > 0: | |
| new_reference_results = [] | |
| for corpus_id, hit_chunk in hit_chunk_dict.items(): | |
| expanded_reference = self.sim_model.corpus.get(corpus_id - 1, '') + hit_chunk | |
| for i in range(self.num_expand_context_chunk): | |
| expanded_reference += self.sim_model.corpus.get(corpus_id + i + 1, '') | |
| new_reference_results.append(expanded_reference) | |
| reference_results = new_reference_results | |
| return reference_results | |
| def predict_stream( | |
| self, | |
| query: str, | |
| max_length: int = 512, | |
| context_len: int = 2048, | |
| temperature: float = 0.7, | |
| ): | |
| """Generate predictions stream.""" | |
| stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>" | |
| if not self.enable_history: | |
| self.history = [] | |
| if self.sim_model.corpus: | |
| reference_results = self.get_reference_results(query) | |
| if reference_results: | |
| reference_results = self._add_source_numbers(reference_results) | |
| context_str = '\n'.join(reference_results)[:] | |
| else: | |
| context_str = '' | |
| prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query) | |
| else: | |
| prompt = query | |
| logger.debug(f"prompt: {prompt}") | |
| self.history.append([prompt, '']) | |
| response = "" | |
| for new_text in self.stream_generate_answer( | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| context_len=context_len, | |
| ): | |
| if new_text != stop_str: | |
| response += new_text | |
| yield response | |
| def predict( | |
| self, | |
| query: str, | |
| max_length: int = 512, | |
| context_len: int = 2048, | |
| temperature: float = 0.7, | |
| ): | |
| """Query from corpus.""" | |
| reference_results = [] | |
| if not self.enable_history: | |
| self.history = [] | |
| if self.sim_model.corpus: | |
| reference_results = self.get_reference_results(query) | |
| if reference_results: | |
| reference_results = self._add_source_numbers(reference_results) | |
| context_str = '\n'.join(reference_results)[:] | |
| else: | |
| context_str = '' | |
| prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query) | |
| else: | |
| prompt = query | |
| logger.debug(f"prompt: {prompt}") | |
| self.history.append([prompt, '']) | |
| response = "" | |
| for new_text in self.stream_generate_answer( | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| context_len=context_len, | |
| ): | |
| response += new_text | |
| response = response.strip() | |
| self.history[-1][1] = response | |
| return response, reference_results | |
| def query(self, query: str, **kwargs): | |
| return self.predict(query, **kwargs) | |
| def save_corpus_emb(self): | |
| dir_name = self.get_file_hash(self.corpus_files) | |
| save_dir = os.path.join(self.save_corpus_emb_dir, dir_name) | |
| if hasattr(self.sim_model, 'save_corpus_embeddings'): | |
| self.sim_model.save_corpus_embeddings(save_dir) | |
| logger.debug(f"Saving corpus embeddings to {save_dir}") | |
| return save_dir | |
| def load_corpus_emb(self, emb_dir: str): | |
| if hasattr(self.sim_model, 'load_corpus_embeddings'): | |
| logger.debug(f"Loading corpus embeddings from {emb_dir}") | |
| self.sim_model.load_corpus_embeddings(emb_dir) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual") | |
| parser.add_argument("--gen_model_type", type=str, default="auto") | |
| parser.add_argument("--gen_model_name", type=str, default="Qwen/Qwen2-0.5B-Instruct") | |
| parser.add_argument("--lora_model", type=str, default=None) | |
| parser.add_argument("--rerank_model_name", type=str, default="") | |
| parser.add_argument("--corpus_files", type=str, default="data/sample.pdf") | |
| parser.add_argument("--device", type=str, default=None) | |
| parser.add_argument("--int4", action='store_true', help="use int4 quantization") | |
| parser.add_argument("--int8", action='store_true', help="use int8 quantization") | |
| parser.add_argument("--chunk_size", type=int, default=220) | |
| parser.add_argument("--chunk_overlap", type=int, default=0) | |
| parser.add_argument("--num_expand_context_chunk", type=int, default=1) | |
| args = parser.parse_args() | |
| print(args) | |
| sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device) | |
| m = Rag( | |
| similarity_model=sim_model, | |
| generate_model_type=args.gen_model_type, | |
| generate_model_name_or_path=args.gen_model_name, | |
| lora_model_name_or_path=args.lora_model, | |
| device=args.device, | |
| int4=args.int4, | |
| int8=args.int8, | |
| chunk_size=args.chunk_size, | |
| chunk_overlap=args.chunk_overlap, | |
| corpus_files=args.corpus_files.split(','), | |
| num_expand_context_chunk=args.num_expand_context_chunk, | |
| rerank_model_name_or_path=args.rerank_model_name, | |
| ) | |
| r, refs = m.predict('自然语言中的非平行迁移是指什么?') | |
| print(r, refs) |