theerasin commited on
Commit
3201029
·
verified ·
1 Parent(s): d55790e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -10,7 +10,8 @@ import io
10
  from langchain_text_splitters import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import FAISS
12
  from langchain_core.documents import Document as LCDocument
13
- from langchain.embeddings import HuggingFaceEmbeddings
 
14
  import time
15
 
16
  # === Load summarization model ===
@@ -20,9 +21,20 @@ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
20
  # === Load QA pipeline ===
21
  qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer)
22
 
23
- # === Setup BGE Embedding model ===
24
- embedding_model_name = "BAAI/bge-large-en-v1.5"
25
- embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name)
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # === Data models ===
28
  class KeyPoint(BaseModel):
@@ -82,7 +94,7 @@ def create_word_report(analysis):
82
  # === Streamlit UI ===
83
  st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄")
84
  st.title("📄 Chat With PDF")
85
- st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Embeddings")
86
 
87
  for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]:
88
  if key not in st.session_state:
@@ -154,6 +166,6 @@ if st.session_state.vectorstore is not None:
154
 
155
  if st.session_state.analysis_time is not None:
156
  st.markdown(
157
- f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Large v1.5</div>',
158
  unsafe_allow_html=True
159
  )
 
10
  from langchain_text_splitters import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import FAISS
12
  from langchain_core.documents import Document as LCDocument
13
+ from langchain_core.embeddings import Embeddings
14
+ from sentence_transformers import SentenceTransformer
15
  import time
16
 
17
  # === Load summarization model ===
 
21
  # === Load QA pipeline ===
22
  qa_pipeline = pipeline("question-answering", model="facebook/bart-large-cnn", tokenizer=tokenizer)
23
 
24
+ # === Load SentenceTransformer embedding model ===
25
+ embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5")
26
+
27
+ class CustomSentenceTransformer(Embeddings):
28
+ def __init__(self, model):
29
+ self.model = model
30
+
31
+ def embed_documents(self, texts):
32
+ return self.model.encode(texts, show_progress_bar=False).tolist()
33
+
34
+ def embed_query(self, text):
35
+ return self.model.encode(text, show_progress_bar=False).tolist()
36
+
37
+ embedding_function = CustomSentenceTransformer(embedding_model)
38
 
39
  # === Data models ===
40
  class KeyPoint(BaseModel):
 
94
  # === Streamlit UI ===
95
  st.set_page_config(page_title="Chat With PDF (BART + BGE)", page_icon="📄")
96
  st.title("📄 Chat With PDF")
97
+ st.caption("Summarize and Chat with Documents using facebook/bart-large-cnn + BGE Small Embeddings")
98
 
99
  for key in ["current_file", "pdf_summary", "analysis_time", "pdf_report", "word_report", "vectorstore", "messages"]:
100
  if key not in st.session_state:
 
166
 
167
  if st.session_state.analysis_time is not None:
168
  st.markdown(
169
+ f'<div style="text-align:center; margin-top:2rem; color:gray;">Analysis Time: {st.session_state.analysis_time:.1f}s | Embedding: BGE Small v1.5</div>',
170
  unsafe_allow_html=True
171
  )