uyen13 commited on
Commit
f3d30d1
·
verified ·
1 Parent(s): 64148cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -45
app.py CHANGED
@@ -6,39 +6,44 @@ from langchain.embeddings import SentenceTransformerEmbeddings
6
  from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
9
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
  import torch
11
 
12
- # --- Load hình ngôn ngữ ---
13
  @st.cache_resource
14
  def load_llm():
15
- model_name = "google/flan-ul2" # thể thay bằng google/flan-ul2 hoặc mistralai/Mistral-7B-Instruct-v0.2 nếu có GPU
16
-
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- model = AutoModelForSeq2SeqLM.from_pretrained(
 
 
 
 
19
  model_name,
20
- torch_dtype=torch.float32,
21
  device_map="auto"
22
  )
23
-
24
  pipe = pipeline(
25
- "text2text-generation",
26
- model=model,
27
- tokenizer=tokenizer,
28
- max_new_tokens=512,
29
- temperature=0.75, # Tăng tính sáng tạo (tự nhiên hơn)
30
- top_p=0.92, # Kết hợp với temperature
31
- top_k=40,
32
- repetition_penalty=1.25, # Tránh lặp từ ngữ
33
- num_beams=4, # Giúp câu mượt hơn nếu không sampling
34
- early_stopping=True,
35
- do_sample=True # Bật chế độ sampling để tránh cứng nhắc
36
- )
 
37
 
38
  return HuggingFacePipeline(pipeline=pipe)
39
 
40
 
41
- # --- Xử lý file PDF ---
42
  def process_pdf(pdf_path):
43
  loader = PyPDFLoader(pdf_path)
44
  documents = loader.load()
@@ -55,8 +60,8 @@ def process_pdf(pdf_path):
55
  return vectorstore
56
 
57
 
58
- # --- Tiền xử prompt hậu xử lý câu trả lời ---
59
- template = """あなたは親しみやすく丁寧なアシスタントです。以下の文書情報をもとに、質問に自然で分かりやすい日本語で回答してください。
60
 
61
  - 回答はできるだけ口語的で柔らかい表現を使ってください。
62
  - 理由や例を交えて説明すると良いでしょう。
@@ -66,37 +71,31 @@ template = """あなたは親しみやすく丁寧なアシスタントです。
66
  {context}
67
 
68
  質問: {question}
69
- 回答:"""
70
 
71
- QA_PROMPT = PromptTemplate(
72
- template=template,
73
- input_variables=["context", "question"]
74
- )
75
 
76
 
 
77
  def postprocess_answer(answer):
78
  answer = answer.strip()
79
-
80
- # Loại bỏ các cụm không mong muốn
81
  for phrase in ["Answer:", "答え:", "回答:", "The answer is", "Based on the context"]:
82
  answer = answer.replace(phrase, "").strip()
83
 
84
- # Viết hoa chữ cái đầu tiên (nếu cần thiết trong tiếng Nhật)
85
- # if answer and len(answer) > 0:
86
- # answer = answer[0].upper() + answer[1:]
87
 
88
- # # Thêm dấu chấm cuối câu nếu thiếu
89
- # if answer and answer[-1] not in "。.?!":
90
- # answer += "。"
91
 
92
- # # Nếu câu quá ngắn hoặc vô nghĩa, phản hồi mềm mại
93
- # if len(answer.split()) < 4:
94
- # answer = "資料にはその件についての詳細な記載が見受けられませんが、以下のように推測されます:" + answer
95
 
96
  return answer
97
 
98
 
99
- # --- Giao diện chính của ứng dụng ---
100
  def main():
101
  st.set_page_config(page_title="PDFアシスタント", page_icon="📘")
102
  st.title("PDFアシスタント 🤖")
@@ -112,9 +111,6 @@ def main():
112
  vectorstore = process_pdf("temp.pdf")
113
 
114
  llm = load_llm()
115
- response = llm("東京の人口はどのくらいですか?")
116
- st.success(response)
117
- print("LLM response:", response)
118
 
119
  qa_chain = RetrievalQA.from_chain_type(
120
  llm=llm,
@@ -133,9 +129,8 @@ def main():
133
  with st.spinner("回答を生成中..."):
134
  try:
135
  result = qa_chain({"question": query})
136
- answer = result["result"]
137
- # answer = postprocess_answer(raw_answer)
138
-
139
 
140
  st.markdown("### 回答")
141
  st.success(answer)
 
6
  from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
  import torch
11
 
12
+ # --- 1. Load Hình TinyLlama hoặc Mistral ---
13
  @st.cache_resource
14
  def load_llm():
15
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Thay bằng "mistralai/Mistral-7B-Instruct-v0.2" nếu có GPU
16
+
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+
19
+ if tokenizer.pad_token is None:
20
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
21
+
22
+ model = AutoModelForCausalLM.from_pretrained(
23
  model_name,
24
+ torch_dtype=torch.float32, # Trên CPU nên dùng float32
25
  device_map="auto"
26
  )
27
+
28
  pipe = pipeline(
29
+ "text-generation",
30
+ model=model,
31
+ tokenizer=tokenizer,
32
+ max_new_tokens=512,
33
+ temperature=0.7,
34
+ top_p=0.9,
35
+ top_k=50,
36
+ repetition_penalty=1.2,
37
+ do_sample=True,
38
+ eos_token_id=tokenizer.eos_token_id,
39
+ truncation=True,
40
+ return_full_text=False
41
+ )
42
 
43
  return HuggingFacePipeline(pipeline=pipe)
44
 
45
 
46
+ # --- 2. Xử lý file PDF ---
47
  def process_pdf(pdf_path):
48
  loader = PyPDFLoader(pdf_path)
49
  documents = loader.load()
 
60
  return vectorstore
61
 
62
 
63
+ # --- 3. Prompt Template tiếng Nhật (tự nhiên) ---
64
+ template = """<s>[INST]あなたは親しみやすく丁寧なアシスタントです。以下の文書情報をもとに、質問に自然で分かりやすい日本語で回答してください。
65
 
66
  - 回答はできるだけ口語的で柔らかい表現を使ってください。
67
  - 理由や例を交えて説明すると良いでしょう。
 
71
  {context}
72
 
73
  質問: {question}
74
+ 回答: [/INST]"""
75
 
76
+ QA_PROMPT = PromptTemplate(template=template, input_variables=["context", "question"])
 
 
 
77
 
78
 
79
+ # --- 4. Hàm hậu xử lý câu trả lời ---
80
  def postprocess_answer(answer):
81
  answer = answer.strip()
82
+
 
83
  for phrase in ["Answer:", "答え:", "回答:", "The answer is", "Based on the context"]:
84
  answer = answer.replace(phrase, "").strip()
85
 
86
+ if answer and len(answer) > 0:
87
+ answer = answer[0].upper() + answer[1:]
 
88
 
89
+ if answer and answer[-1] not in "。.?!":
90
+ answer += ""
 
91
 
92
+ if len(answer.split()) < 4:
93
+ answer = "資料にはその件についての詳細な記載が見受けられませんが、以下のように推測されます:" + answer
 
94
 
95
  return answer
96
 
97
 
98
+ # --- 5. Giao diện chính của ứng dụng ---
99
  def main():
100
  st.set_page_config(page_title="PDFアシスタント", page_icon="📘")
101
  st.title("PDFアシスタント 🤖")
 
111
  vectorstore = process_pdf("temp.pdf")
112
 
113
  llm = load_llm()
 
 
 
114
 
115
  qa_chain = RetrievalQA.from_chain_type(
116
  llm=llm,
 
129
  with st.spinner("回答を生成中..."):
130
  try:
131
  result = qa_chain({"question": query})
132
+ raw_answer = result["result"]
133
+ answer = postprocess_answer(raw_answer)
 
134
 
135
  st.markdown("### 回答")
136
  st.success(answer)