wayne0603 commited on
Commit
59ea9db
·
verified ·
1 Parent(s): d4748d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel, pipeline
3
+ import faiss
4
+ import numpy as np
5
+ import torch
6
+ import os
7
+ from PyPDF2 import PdfReader
8
+
9
+ # ===== 嵌入模型 =====
10
+ embed_model = AutoModel.from_pretrained(
11
+ "BAAI/bge-small-zh", trust_remote_code=True
12
+ )
13
+ embed_tokenizer = AutoTokenizer.from_pretrained(
14
+ "BAAI/bge-small-zh", trust_remote_code=True
15
+ )
16
+
17
+ def embed_text(text):
18
+ inputs = embed_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
19
+ with torch.no_grad():
20
+ embeddings = embed_model(**inputs).last_hidden_state[:, 0, :]
21
+ return embeddings[0].numpy()
22
+
23
+ # ===== 生成模型(轻量LLM) =====
24
+ generator = pipeline(
25
+ "text-generation",
26
+ model="Qwen/Qwen1.5-1.8B-Chat-GGML",
27
+ device=-1
28
+ )
29
+
30
+ # ===== 全局变量存储索引和文档 =====
31
+ index = None
32
+ docs = []
33
+
34
+ # ===== 文件解析函数 =====
35
+ def load_file(file_obj):
36
+ global index, docs
37
+ docs = []
38
+ text_data = ""
39
+
40
+ if file_obj.name.endswith(".pdf"):
41
+ reader = PdfReader(file_obj.name)
42
+ for page in reader.pages:
43
+ text_data += page.extract_text() + "\n"
44
+ elif file_obj.name.endswith(".txt"):
45
+ text_data = file_obj.read().decode("utf-8")
46
+ else:
47
+ return "仅支持 PDF 或 TXT 文件", None
48
+
49
+ # 切块
50
+ chunks = [text_data[i:i+500] for i in range(0, len(text_data), 500)]
51
+ docs = [{"text": chunk, "source": f"chunk_{i}"} for i, chunk in enumerate(chunks)]
52
+
53
+ # 向量化并建索引
54
+ doc_embeddings = np.array([embed_text(d["text"]) for d in docs])
55
+ index = faiss.IndexFlatL2(doc_embeddings.shape[1])
56
+ index.add(doc_embeddings)
57
+
58
+ return f"已加载 {len(docs)} 个文本块", None
59
+
60
+ # ===== RAG 查询函数 =====
61
+ def rag_query(query):
62
+ if index is None:
63
+ return "请先上传文件构建知识库"
64
+ q_emb = embed_text(query).reshape(1, -1)
65
+ D, I = index.search(q_emb, k=3)
66
+ retrieved = [docs[i]["text"] for i in I[0]]
67
+ context = "\n".join(retrieved)
68
+ prompt = f"已知信息:\n{context}\n\n问题:{query}\n请基于已知信息回答,并引用来源。"
69
+ result = generator(prompt, max_length=200, do_sample=False)
70
+ return result[0]["generated_text"]
71
+
72
+ # ===== Gradio 界面 =====
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown("## 📚 轻量 RAG 原型(上传 PDF/TXT)")
75
+ with gr.Row():
76
+ file_input = gr.File(label="上传 PDF 或 TXT 文件")
77
+ load_btn = gr.Button("构建知识库")
78
+ status = gr.Textbox(label="状态")
79
+ query_input = gr.Textbox(label="输入你的问题")
80
+ answer_output = gr.Textbox(label="回答")
81
+ load_btn.click(load_file, inputs=file_input, outputs=status)
82
+ query_input.submit(rag_query, inputs=query_input, outputs=answer_output)
83
+
84
+ if __name__ == "__main__":
85
+ demo.launch()