uyen13 commited on
Commit
c3f97cb
·
verified ·
1 Parent(s): c647597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -52
app.py CHANGED
@@ -1,62 +1,87 @@
1
  import streamlit as st
2
- from langchain_community.document_loaders import PyPDFLoader
 
3
  from langchain.text_splitter import CharacterTextSplitter
4
- from langchain_community.embeddings import HuggingFaceEmbeddings
5
- from langchain_community.vectorstores import FAISS
6
  from langchain.chains import RetrievalQA
7
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
8
- from langchain_community.llms import HuggingFacePipeline
9
 
10
- # Khởi tạo mô hình và tokenizer
11
- model_name = "google/flan-t5-base"
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
14
-
15
- # Tạo pipeline cho HuggingFace
16
- pipe = pipeline(
17
- "text2text-generation",
18
- model=model,
19
- tokenizer=tokenizer,
20
- max_length=512,
21
- temperature=0,
22
- repetition_penalty=1.15
23
- )
24
-
25
- llm = HuggingFacePipeline(pipeline=pipe)
26
-
27
- # Cấu hình Streamlit
28
- st.title("PDF Chatbot with Flan-T5")
29
- uploaded_file = st.file_uploader("Upload PDF", type="pdf")
30
 
31
- if uploaded_file:
32
- # Lưu file tạm và load nội dung
33
- with open("temp.pdf", "wb") as f:
34
- f.write(uploaded_file.getbuffer())
35
-
36
- loader = PyPDFLoader("temp.pdf")
37
  documents = loader.load()
38
 
39
- # Chia nhỏ văn bản
40
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
41
  texts = text_splitter.split_documents(documents)
42
 
43
- # Tạo embeddings vector store
44
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
45
- db = FAISS.from_documents(texts, embeddings)
46
-
47
- # Tạo retrieval chain
48
- qa_chain = RetrievalQA.from_chain_type(
49
- llm=llm,
50
- chain_type="stuff",
51
- retriever=db.as_retriever(search_kwargs={"k": 3}),
52
- return_source_documents=True
53
- )
54
 
55
- # Xử chat
56
- question = st.text_input("Ask your question:")
57
- if question:
58
- result = qa_chain({"query": question})
59
- st.write("Answer:", result["result"])
60
- st.write("Sources:")
61
- for doc in result['source_documents']:
62
- st.write(f"- Page {doc.metadata['page']}: {doc.page_content[:200]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from langchain.llms import HuggingFacePipeline
3
+ from langchain.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import CharacterTextSplitter
5
+ from langchain.embeddings import SentenceTransformerEmbeddings
6
+ from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
+ import os
10
 
11
+ # Load FLAN-T5 model
12
+ @st.cache_resource
13
+ def load_llm():
14
+ model_name = "google/flan-t5-base" # Adjust model size if needed
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
17
+ pipe = pipeline(
18
+ "text2text-generation",
19
+ model=model,
20
+ tokenizer=tokenizer,
21
+ max_length=512,
22
+ temperature=0.7, # Adjust for creativity
23
+ top_p=0.95,
24
+ repetition_penalty=1.15
25
+ )
26
+ return HuggingFacePipeline(pipeline=pipe)
 
 
 
 
27
 
28
+ # Process PDF and create vectorstore
29
+ def process_pdf(pdf_path):
30
+ loader = PyPDFLoader(pdf_path)
 
 
 
31
  documents = loader.load()
32
 
33
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
 
34
  texts = text_splitter.split_documents(documents)
35
 
36
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
37
+ vectorstore = FAISS.from_documents(texts, embeddings)
38
+ return vectorstore
39
+
40
+ def main():
41
+ st.set_page_config(page_title="PDF Chatbot", page_icon="📄")
42
+ st.title("PDF Chatbot 📄")
43
+ st.markdown("Upload a PDF and ask questions about its content using FLAN-T5!")
44
+
45
+ uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
 
46
 
47
+ if uploaded_file is not None:
48
+ # Save uploaded file temporarily
49
+ with open("temp.pdf", "wb") as f:
50
+ f.write(uploaded_file.getbuffer())
51
+
52
+ # Process PDF
53
+ with st.spinner("Processing PDF..."):
54
+ vectorstore = process_pdf("temp.pdf")
55
+
56
+ # Load LLM
57
+ llm = load_llm()
58
+
59
+ # Create QA chain
60
+ qa_chain = RetrievalQA.from_chain_type(
61
+ llm=llm,
62
+ chain_type="stuff",
63
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
64
+ return_source_documents=True
65
+ )
66
+
67
+ # Query input
68
+ query = st.text_input("Ask a question about the PDF:")
69
+ if query:
70
+ with st.spinner("Generating answer..."):
71
+ result = qa_chain({"query": query})
72
+ answer = result["result"]
73
+ source_docs = result["source_documents"]
74
+
75
+ st.markdown("### Answer")
76
+ st.write(answer)
77
+
78
+ with st.expander("Show Source Documents"):
79
+ for i, doc in enumerate(source_docs):
80
+ st.markdown(f"**Source {i+1}:**")
81
+ st.write(doc.page_content)
82
+
83
+ else:
84
+ st.info("Please upload a PDF file to get started.")
85
+
86
+ if __name__ == "__main__":
87
+ main()