uyen13 commited on
Commit
c1244e1
·
verified ·
1 Parent(s): dc7d59a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -25
app.py CHANGED
@@ -1,57 +1,60 @@
1
  # app.py
2
- from langchain.document_loaders import PyPDFLoader
 
 
 
3
  from langchain.text_splitter import CharacterTextSplitter
 
4
  from langchain.embeddings import SentenceTransformerEmbeddings
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
6
- from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
- from langchain.llms import HuggingFacePipeline
9
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
 
11
- import streamlit as st
12
- import tempfile
13
 
14
- # Load FLAN-T5 model
15
  model_name = "google/flan-t5-base"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
 
19
- pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
20
- llm = HuggingFacePipeline(pipeline=pipe)
21
 
22
- # Streamlit UI
23
- st.title("Chat with PDF (FLAN-T5, no OpenAI)")
24
 
25
- uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
26
  if uploaded_file:
27
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
28
  tmp_file.write(uploaded_file.read())
29
  pdf_path = tmp_file.name
30
 
31
- # Load PDF
32
  loader = PyPDFLoader(pdf_path)
33
  documents = loader.load()
34
 
35
- # Split text
36
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
37
  docs = splitter.split_documents(documents)
38
 
39
- # Embed & Store
40
- embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
41
- db = FAISS.from_documents(docs, embedding)
42
- retriever = db.as_retriever()
43
 
44
- # RetrievalQA chain
45
- qa_chain = RetrievalQA.from_chain_type(
46
  llm=llm,
47
  chain_type="stuff",
48
  retriever=retriever,
49
  return_source_documents=True
50
  )
51
 
52
- # Chat input
53
- query = st.text_input("Ask a question about the PDF:")
54
  if query:
55
- result = qa_chain(query)
56
- st.write("### Answer:")
57
  st.write(result["result"])
 
 
 
 
 
1
  # app.py
2
+ import streamlit as st
3
+ import tempfile
4
+
5
+ from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import CharacterTextSplitter
7
+ from langchain_community.vectorstores import FAISS
8
  from langchain.embeddings import SentenceTransformerEmbeddings
 
 
9
  from langchain.chains import RetrievalQA
10
+ from langchain_huggingface import HuggingFacePipeline
 
11
 
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
13
 
14
+ # Khai báo model HuggingFace LLM
15
  model_name = "google/flan-t5-base"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
 
19
+ text2text_gen = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
20
+ llm = HuggingFacePipeline(pipeline=text2text_gen)
21
 
22
+ st.title("Chat với PDF (LangChain + HuggingFace + FAISS)")
 
23
 
24
+ uploaded_file = st.file_uploader("Tải lên file PDF", type="pdf")
25
  if uploaded_file:
26
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
27
  tmp_file.write(uploaded_file.read())
28
  pdf_path = tmp_file.name
29
 
30
+ # Load văn bản từ PDF
31
  loader = PyPDFLoader(pdf_path)
32
  documents = loader.load()
33
 
34
+ # Chia nhỏ văn bản
35
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
36
  docs = splitter.split_documents(documents)
37
 
38
+ # Embedding FAISS index
39
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
40
+ vectorstore = FAISS.from_documents(docs, embeddings)
41
+ retriever = vectorstore.as_retriever()
42
 
43
+ # Tạo RetrievalQA
44
+ qa = RetrievalQA.from_chain_type(
45
  llm=llm,
46
  chain_type="stuff",
47
  retriever=retriever,
48
  return_source_documents=True
49
  )
50
 
51
+ # Hỏi đáp
52
+ query = st.text_input("Nhập câu hỏi về PDF:")
53
  if query:
54
+ result = qa.invoke({"query": query})
55
+ st.markdown("### Câu trả lời:")
56
  st.write(result["result"])
57
+
58
+ with st.expander("📄 Nguồn tham chiếu"):
59
+ for doc in result["source_documents"]:
60
+ st.markdown(doc.page_content[:1000] + ("..." if len(doc.page_content) > 1000 else ""))