import os import faiss import pickle from PyPDF2 import PdfReader from transformers import AutoTokenizer, AutoModel, AutoConfig from torch.nn import functional as F import torch from config import INDEX_FILE, EMBEDDINGS_FILE, LLM_API_URL, EMBED_AX_MODEL, EMBED_HF_MODEL import numpy as np import requests import json import re import chardet # 用于检测文本编码 device = "cuda" if torch.cuda.is_available() else "cpu" # ========== Transformers 加载 embedding 模型 ========== tokenizer = AutoTokenizer.from_pretrained(EMBED_HF_MODEL, padding_side="left") """ axengine 相关 加载 embedding 模型 """ from ml_dtypes import bfloat16 from utils.infer_func import InferManager embeds = np.load(os.path.join(EMBED_AX_MODEL, "model.embed_tokens.weight.npy")) cfg = AutoConfig.from_pretrained(EMBED_HF_MODEL) imer = InferManager(cfg, EMBED_AX_MODEL, device_id=0) # 如果运行在 axcl 上, device_id 可以指定除 0 之外可访问的卡 id """ torch 加载 embedding 模型 model = AutoModel.from_pretrained(EMBED_HF_MODEL).to(device) model.eval() embedder = model """ def last_token_pool(last_hidden_states, attention_mask): left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def encode_texts(texts): task_desc = "Given a web search query, retrieve relevant passages that answer the query" inputs = [f"Instruct: {task_desc}\nQuery: {t}" for t in texts] inputs_tokenized = tokenizer( inputs, padding=True, truncation=True, max_length=8192, return_tensors="pt" ) inputs_tokenized = {k: v.to(device) for k, v in inputs_tokenized.items()} """ torch 相关 with torch.no_grad(): outputs = model(**inputs_tokenized) embeddings = last_token_pool(outputs.last_hidden_state, inputs_tokenized["attention_mask"]) embeddings = F.normalize(embeddings, p=2, dim=1) """ """ axengine 相关 """ input_ids = inputs_tokenized['input_ids'] inputs_embeds = np.take(embeds, input_ids.cpu().numpy(), axis=0) prefill_data = inputs_embeds prefill_data = prefill_data.astype(bfloat16) token_ids = input_ids[0].cpu().numpy().tolist() token_len = len(token_ids) batch_num, seq_len, seq_dim = inputs_embeds.shape last_hidden_state = np.zeros((batch_num, seq_len, seq_dim), dtype=bfloat16) for batch_idx in range(batch_num): last_hidden_state[batch_idx] = imer.prefill(tokenizer, token_ids, prefill_data[batch_idx], slice_len=128, return_last_hidden_state=True) embeddings = last_token_pool(torch.from_numpy(last_hidden_state.astype(np.float32)), inputs_tokenized['attention_mask']) # normalize embeddings embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings.cpu().numpy() # 读取 PDF 并分段 def load_pdf_chunks(pdf_path, chunk_size=500, chunk_overlap=100): reader = PdfReader(pdf_path) all_text = "" for page in reader.pages: all_text += page.extract_text() + "\n" # 按字符长度切分 chunks = [] start = 0 while start < len(all_text): end = min(start + chunk_size, len(all_text)) chunks.append(all_text[start:end]) start += chunk_size - chunk_overlap return chunks # 读取 TXT 文件并分段 def load_txt_chunks(txt_path, chunk_size=20, chunk_overlap=5): with open(txt_path, 'rb') as f: raw_data = f.read() result = chardet.detect(raw_data) encoding = result['encoding'] if result['encoding'] else 'utf-8' try: with open(txt_path, 'r', encoding=encoding) as f: all_text = f.read() except UnicodeDecodeError: try: with open(txt_path, 'r', encoding='gbk') as f: all_text = f.read() except: with open(txt_path, 'r', encoding='latin-1') as f: all_text = f.read() all_text = re.sub(r'\s+', ' ', all_text).strip() chunks = [] start = 0 while start < len(all_text): end = min(start + chunk_size, len(all_text)) chunks.append(all_text[start:end]) start += chunk_size - chunk_overlap return chunks # 构建并保存向量索引 def build_index(file_path): # 根据文件类型选择加载方法 if file_path.lower().endswith('.pdf'): chunks = load_pdf_chunks(file_path) elif file_path.lower().endswith('.txt'): chunks = load_txt_chunks(file_path) else: raise ValueError(f"不支持的文件类型: {file_path}") embeddings = encode_texts(chunks) # use transformers model faiss.normalize_L2(embeddings) dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(embeddings) # 保存 faiss.write_index(index, INDEX_FILE) with open(EMBEDDINGS_FILE, "wb") as f: pickle.dump(chunks, f) return f"✅ 成功构建索引: {len(chunks)}个片段" def index_exists(): return os.path.exists(INDEX_FILE) and os.path.exists(EMBEDDINGS_FILE) def get_top_k(query, k=3): if not index_exists(): return [] index = faiss.read_index(INDEX_FILE) with open(EMBEDDINGS_FILE, "rb") as f: texts = pickle.load(f) # query_vec = model.encode([query]) query_vec = encode_texts([query]) # use transformers model D, I = index.search(query_vec, k) return [texts[i] for i in I[0]] def ask_question(query): context = "\n".join(get_top_k(query)) prompt = f"""上下文内容是你可以参考的资料, 用户问题才是你需要回答的内容. [上下文内容]: - {context}\n [用户问题]: - {query}\n [简洁的输出回答]: """ print("DEBUG: prompt is \n", prompt) # 向本地 LLM API 发请求 response = requests.post(LLM_API_URL, json={"prompt": prompt, "max_tokens": 1024}) return response.json().get("text", "❌ LLM 接口未响应").strip() def stream_answer(query): context = "\n".join(get_top_k(query)) prompt = f"""上下文内容是你可以参考的资料, 用户问题才是你需要回答的内容. [上下文内容]: - {context}\n [用户问题]: - {query}\n [简洁的输出回答]: """ print("DEBUG: prompt is \n", prompt) """流式获取答案并逐个token生成的函数""" data = { "prompt": prompt, "max_tokens": 1024, "temperature": 0.6, "top_p": 0.9 } try: # 发送流式请求 with requests.post( LLM_API_URL, json=data, stream=True ) as response: # 检查响应状态 if response.status_code != 200: yield f"⚠️ 请求错误:{response.status_code}" return # 处理流式数据 for chunk in response.iter_lines(): # 过滤心跳和空行 if chunk and b'data:' in chunk: # 提取JSON数据 line = chunk.decode('utf-8').strip() json_data = line.replace('data:', '') try: # 解析JSON格式 event = json.loads(json_data) if 'token' in event: yield event['token'] elif event.get('end') or event.get('finish_reason'): return except json.JSONDecodeError: # 如果后端返回的是文本 yield json_data except Exception as e: yield f"⚠️ 连接错误:{str(e)}" # if __name__ == "__main__": # import argparse # parser = argparse.ArgumentParser(description="构建 PDF 索引并回答问题") # parser.add_argument("--pdf", type=str, required=True, help="PDF 文件路径") # args = parser.parse_args() # build_index(args.pdf) # print("🚗🚗🌲🌲 索引构建完成!")