uyen13 commited on
Commit
694a2d1
·
verified ·
1 Parent(s): 75d3fac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -34
app.py CHANGED
@@ -5,19 +5,20 @@ from langchain.text_splitter import CharacterTextSplitter
5
  from langchain.embeddings import SentenceTransformerEmbeddings
6
  from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
 
8
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
  import os
10
  import torch
11
- # Load FLAN-T5 model
 
12
  @st.cache_resource
13
  def load_llm():
14
  model_name = "google/flan-t5-xl"
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
-
18
  model = AutoModelForSeq2SeqLM.from_pretrained(
19
  model_name,
20
- torch_dtype=torch.float32, # T5 thường dùng float32 hoặc bfloat16 nếu có GPU hỗ trợ
21
  device_map="auto"
22
  )
23
 
@@ -25,72 +26,119 @@ def load_llm():
25
  "text2text-generation",
26
  model=model,
27
  tokenizer=tokenizer,
28
- max_new_tokens=256,
29
- temperature=0.7,
30
- top_p=0.95,
31
- repetition_penalty=1.15,
 
 
 
32
  do_sample=True
33
  )
34
 
35
  return HuggingFacePipeline(pipeline=pipe)
36
 
37
- # Process PDF and create vectorstore
38
  def process_pdf(pdf_path):
39
  loader = PyPDFLoader(pdf_path)
40
  documents = loader.load()
41
 
42
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
 
 
 
 
43
  texts = text_splitter.split_documents(documents)
44
 
45
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
46
  vectorstore = FAISS.from_documents(texts, embeddings)
47
  return vectorstore
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def main():
50
- st.set_page_config(page_title="PDF Chatbot", page_icon="📄")
51
- st.title("PDF Chatbot 📄")
52
- st.markdown("Upload a PDF and ask questions about its content using FLAN-T5!")
53
 
54
- uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
55
 
56
  if uploaded_file is not None:
57
- # Save uploaded file temporarily
58
  with open("temp.pdf", "wb") as f:
59
  f.write(uploaded_file.getbuffer())
60
 
61
- # Process PDF
62
- with st.spinner("Processing PDF..."):
63
  vectorstore = process_pdf("temp.pdf")
64
 
65
- # Load LLM
66
  llm = load_llm()
67
 
68
- # Create QA chain
69
  qa_chain = RetrievalQA.from_chain_type(
70
  llm=llm,
71
  chain_type="stuff",
72
  retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
73
- return_source_documents=True
 
74
  )
75
 
76
- # Query input
77
- query = st.text_input("Ask a question about the PDF:")
78
  if query:
79
- with st.spinner("Generating answer..."):
80
- result = qa_chain({"query": query})
81
- answer = result["result"]
82
- source_docs = result["source_documents"]
83
-
84
- st.markdown("### Answer")
85
- st.write(answer)
86
-
87
- with st.expander("Show Source Documents"):
88
- for i, doc in enumerate(source_docs):
89
- st.markdown(f"**Source {i+1}:**")
90
- st.write(doc.page_content)
 
 
 
 
91
 
92
  else:
93
- st.info("Please upload a PDF file to get started.")
94
 
95
  if __name__ == "__main__":
96
  main()
 
5
  from langchain.embeddings import SentenceTransformerEmbeddings
6
  from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
+ from langchain.prompts import PromptTemplate
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
  import os
11
  import torch
12
+
13
+ # Load FLAN-T5 model với các tham số tối ưu
14
  @st.cache_resource
15
  def load_llm():
16
  model_name = "google/flan-t5-xl"
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
19
  model = AutoModelForSeq2SeqLM.from_pretrained(
20
  model_name,
21
+ torch_dtype=torch.float32,
22
  device_map="auto"
23
  )
24
 
 
26
  "text2text-generation",
27
  model=model,
28
  tokenizer=tokenizer,
29
+ max_new_tokens=512,
30
+ temperature=0.6,
31
+ top_k=50,
32
+ top_p=0.85,
33
+ repetition_penalty=1.2,
34
+ num_beams=3,
35
+ early_stopping=True,
36
  do_sample=True
37
  )
38
 
39
  return HuggingFacePipeline(pipeline=pipe)
40
 
41
+ # Xử PDF tạo vector store
42
  def process_pdf(pdf_path):
43
  loader = PyPDFLoader(pdf_path)
44
  documents = loader.load()
45
 
46
+ text_splitter = CharacterTextSplitter(
47
+ chunk_size=1000,
48
+ chunk_overlap=200,
49
+ separator="\n"
50
+ )
51
  texts = text_splitter.split_documents(documents)
52
 
53
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
54
  vectorstore = FAISS.from_documents(texts, embeddings)
55
  return vectorstore
56
 
57
+ # Xử lý hậu kỳ cho câu trả lời
58
+ def postprocess_answer(answer):
59
+ # Thay thế các cụm từ không tự nhiên
60
+ replacements = {
61
+ "the context": "tài liệu",
62
+ "according to the document": "theo nội dung tài liệu",
63
+ "it is stated that": "trong tài liệu có đề cập rằng",
64
+ "the answer is": "câu trả lời là",
65
+ "based on the information": "dựa trên thông tin được cung cấp"
66
+ }
67
+
68
+ for eng, vi in replacements.items():
69
+ answer = answer.replace(eng, vi)
70
+
71
+ # Chuẩn hóa định dạng
72
+ answer = answer.strip()
73
+ if answer and len(answer) > 0:
74
+ answer = answer[0].upper() + answer[1:]
75
+
76
+ # Kiểm tra câu trả lời ngắn
77
+ if len(answer.split()) < 4:
78
+ answer = "Thông tin này hiện chưa rõ ràng. " + answer
79
+
80
+ return answer
81
+
82
+ # Prompt template tiếng Việt
83
+ template = """Hãy trả lời câu hỏi một cách tự nhiên và mạch lạc như con người.
84
+ Sử dụng ngôn từ dễ hiểu, tránh các thuật ngữ kỹ thuật phức tạp.
85
+ Nếu không có thông tin trong tài liệu, hãy trả lời 'Tôi không tìm thấy thông tin liên quan trong tài liệu'.
86
+
87
+ Câu hỏi: {query}
88
+ Trả lời:"""
89
+
90
+ QA_PROMPT = PromptTemplate.from_template(template)
91
+
92
  def main():
93
+ st.set_page_config(page_title="Trợ lý PDF thông minh", page_icon="📘")
94
+ st.title("Trợ lý PDF thông minh 🤖")
95
+ st.markdown("Tải lên file PDF đặt câu hỏi về nội dung tài liệu!")
96
 
97
+ uploaded_file = st.file_uploader("Chọn file PDF", type="pdf")
98
 
99
  if uploaded_file is not None:
100
+ # Lưu file tạm
101
  with open("temp.pdf", "wb") as f:
102
  f.write(uploaded_file.getbuffer())
103
 
104
+ # Xử PDF
105
+ with st.spinner("Đang phân tích tài liệu..."):
106
  vectorstore = process_pdf("temp.pdf")
107
 
108
+ # Khởi tạo model
109
  llm = load_llm()
110
 
111
+ # Tạo QA chain với prompt template
112
  qa_chain = RetrievalQA.from_chain_type(
113
  llm=llm,
114
  chain_type="stuff",
115
  retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
116
+ return_source_documents=True,
117
+ chain_type_kwargs={"prompt": QA_PROMPT}
118
  )
119
 
120
+ # Giao diện hỏi đáp
121
+ query = st.text_input("Nhập câu hỏi của bạn về tài liệu:")
122
  if query:
123
+ with st.spinner("Đang tổng hợp câu trả lời..."):
124
+ try:
125
+ result = qa_chain({"query": query})
126
+ raw_answer = result["result"]
127
+ answer = postprocess_answer(raw_answer)
128
+
129
+ st.markdown("### Câu trả lời")
130
+ st.success(answer)
131
+
132
+ with st.expander("Xem chi tiết nguồn tham khảo"):
133
+ for i, doc in enumerate(result["source_documents"]):
134
+ st.markdown(f"**Trích dẫn {i+1}:**")
135
+ st.info(doc.page_content[:500] + "...")
136
+
137
+ except Exception as e:
138
+ st.error("Có lỗi xảy ra khi xử lý yêu cầu. Vui lòng thử lại với câu hỏi khác.")
139
 
140
  else:
141
+ st.info("Vui lòng tải lên file PDF để bắt đầu.")
142
 
143
  if __name__ == "__main__":
144
  main()