wayne0603 commited on
Commit
101fcf4
·
verified ·
1 Parent(s): 176a539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -25
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModel, pipeline
3
- import faiss
4
- import numpy as np
5
- import torch
6
  import os
 
 
 
 
7
  from PyPDF2 import PdfReader
 
8
 
9
  # ===== 嵌入模型 =====
10
  embed_model = AutoModel.from_pretrained(
@@ -20,26 +20,23 @@ def embed_text(text):
20
  embeddings = embed_model(**inputs).last_hidden_state[:, 0, :]
21
  return embeddings[0].numpy()
22
 
23
- # ===== 生成模型(轻量LLM) =====
24
  generator = pipeline(
25
  "text-generation",
26
- model="Qwen/Qwen1.5-1.8B-Chat", # 改成可用的公开模型
27
  device=-1
28
  )
29
 
30
- # ===== 全局变量存储索引和文档 =====
31
  index = None
32
  docs = []
33
 
34
- # ===== 文件解析函数 =====
35
-
36
-
37
  def load_file(file_obj):
38
  global index, docs
39
  docs = []
40
  text_data = ""
41
 
42
- # 获取文件路径
43
  file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
44
  ext = os.path.splitext(file_path)[1].lower()
45
 
@@ -61,38 +58,59 @@ def load_file(file_obj):
61
  if not text_data.strip():
62
  return "未能从文件中提取到文本", None
63
 
64
- # 切块
65
- chunks = [text_data[i:i+500] for i in range(0, len(text_data), 500)]
 
 
 
 
 
 
 
 
66
  docs = [{"text": chunk, "source": f"chunk_{i}"} for i, chunk in enumerate(chunks)]
67
 
68
- # 向量化并建索引
69
  doc_embeddings = np.array([embed_text(d["text"]) for d in docs])
70
  index = faiss.IndexFlatL2(doc_embeddings.shape[1])
71
  index.add(doc_embeddings)
72
 
73
  return f"已加载 {len(docs)} 个文本块", None
74
 
75
- # ===== RAG 查询函数 =====
76
  def rag_query(query):
77
- if index is None:
78
- return "请先上传文件构建知识库"
79
  q_emb = embed_text(query).reshape(1, -1)
80
- D, I = index.search(q_emb, k=3)
81
  retrieved = [docs[i]["text"] for i in I[0]]
82
- context = "\n".join(retrieved)
83
- prompt = f"已知信息:\n{context}\n\n问题:{query}\n请基于已知信息回答,并引用来源。"
84
- result = generator(prompt, max_length=200, do_sample=False)
85
- return result[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # ===== Gradio 界面 =====
88
  with gr.Blocks() as demo:
89
- gr.Markdown("## 📚 轻量 RAG 原型(上传 PDF/TXT)")
90
  with gr.Row():
91
  file_input = gr.File(label="上传 PDF 或 TXT 文件")
92
  load_btn = gr.Button("构建知识库")
93
  status = gr.Textbox(label="状态")
94
  query_input = gr.Textbox(label="输入你的问题")
95
- answer_output = gr.Textbox(label="回答")
96
  load_btn.click(load_file, inputs=file_input, outputs=status)
97
  query_input.submit(rag_query, inputs=query_input, outputs=answer_output)
98
 
 
 
 
 
 
 
1
  import os
2
+ import torch
3
+ import numpy as np
4
+ import faiss
5
+ import gradio as gr
6
  from PyPDF2 import PdfReader
7
+ from transformers import AutoTokenizer, AutoModel, pipeline
8
 
9
  # ===== 嵌入模型 =====
10
  embed_model = AutoModel.from_pretrained(
 
20
  embeddings = embed_model(**inputs).last_hidden_state[:, 0, :]
21
  return embeddings[0].numpy()
22
 
23
+ # ===== 生成模型(Qwen 1.8B) =====
24
  generator = pipeline(
25
  "text-generation",
26
+ model="Qwen/Qwen1.5-1.8B-Chat",
27
  device=-1
28
  )
29
 
30
+ # ===== 全局变量 =====
31
  index = None
32
  docs = []
33
 
34
+ # ===== 文件解析 =====
 
 
35
  def load_file(file_obj):
36
  global index, docs
37
  docs = []
38
  text_data = ""
39
 
 
40
  file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
41
  ext = os.path.splitext(file_path)[1].lower()
42
 
 
58
  if not text_data.strip():
59
  return "未能从文件中提取到文本", None
60
 
61
+ # 分块(350字 + 100字重叠)
62
+ chunk_size = 350
63
+ overlap = 100
64
+ start = 0
65
+ chunks = []
66
+ while start < len(text_data):
67
+ end = min(start + chunk_size, len(text_data))
68
+ chunks.append(text_data[start:end])
69
+ start += chunk_size - overlap
70
+
71
  docs = [{"text": chunk, "source": f"chunk_{i}"} for i, chunk in enumerate(chunks)]
72
 
73
+ # 向量化 & 建索引
74
  doc_embeddings = np.array([embed_text(d["text"]) for d in docs])
75
  index = faiss.IndexFlatL2(doc_embeddings.shape[1])
76
  index.add(doc_embeddings)
77
 
78
  return f"已加载 {len(docs)} 个文本块", None
79
 
80
+ # ===== RAG 查询 =====
81
  def rag_query(query):
82
+ if index is None or not docs:
83
+ return "请先上传文件并构建知识库"
84
  q_emb = embed_text(query).reshape(1, -1)
85
+ D, I = index.search(q_emb, k=5) # Top-K=5
86
  retrieved = [docs[i]["text"] for i in I[0]]
87
+ context = "\n".join([f"[{idx+1}] {txt}" for idx, txt in enumerate(retrieved)])
88
+
89
+ prompt = f"""已知信息:
90
+ {context}
91
+
92
+ 问题:{query}
93
+
94
+ 要求:
95
+ 1. 仅依据已知信息回答
96
+ 2. 无法回答时直接说“我不知道”
97
+ 3. 在回答中标注引用的片段编号
98
+ """
99
+
100
+ result = generator(prompt, max_length=300, do_sample=False)
101
+ answer = result[0]["generated_text"]
102
+
103
+ return f"回答:\n{answer}\n\n参考片段:\n{context}"
104
 
105
  # ===== Gradio 界面 =====
106
  with gr.Blocks() as demo:
107
+ gr.Markdown("## 📚 加强版 RAG(Qwen 1.8B + 引用显示)")
108
  with gr.Row():
109
  file_input = gr.File(label="上传 PDF 或 TXT 文件")
110
  load_btn = gr.Button("构建知识库")
111
  status = gr.Textbox(label="状态")
112
  query_input = gr.Textbox(label="输入你的问题")
113
+ answer_output = gr.Textbox(label="回答", lines=10)
114
  load_btn.click(load_file, inputs=file_input, outputs=status)
115
  query_input.submit(rag_query, inputs=query_input, outputs=answer_output)
116