sourize commited on
Commit
61d7892
Β·
1 Parent(s): 2a82939

Updated main.py

Browse files
Files changed (1) hide show
  1. app.py +46 -38
app.py CHANGED
@@ -10,9 +10,14 @@ from transformers import pipeline
10
  def load_models():
11
  # Embedding model (lightweight)
12
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
13
- # QA model (distilled SQuAD)
14
- qa = pipeline('question-answering', model='distilbert-base-uncased-distilled-squad')
15
- return embedder, qa
 
 
 
 
 
16
 
17
  # Extract text from uploaded file
18
  def extract_text_from_file(uploaded_file):
@@ -40,8 +45,8 @@ def chunk_text(text, chunk_size=500, overlap=50):
40
 
41
  # Build FAISS index from chunks
42
  @st.cache_resource
43
- def build_faiss_index(chunks, _embedder): # underscore to avoid hashing issues
44
- embeddings = _embedder.encode(chunks)
45
  dim = embeddings.shape[1]
46
  index = faiss.IndexFlatL2(dim)
47
  index.add(embeddings)
@@ -52,48 +57,51 @@ def main():
52
  st.set_page_config(page_title='πŸ“„ RAGbot', layout='wide')
53
  st.title('πŸ€– RagBot')
54
  st.sidebar.header('Upload Documents')
55
-
56
- # Initialize chat history in session state
57
  if 'history' not in st.session_state:
58
  st.session_state.history = []
59
 
60
  uploaded = st.sidebar.file_uploader('Upload PDF/DOCX/TXT', type=['pdf', 'docx', 'txt'])
61
- if uploaded:
62
- text = extract_text_from_file(uploaded)
63
- chunks = chunk_text(text)
64
- embedder, qa = load_models()
65
- index = build_faiss_index(chunks, embedder)
66
-
67
- # Display existing chat history
68
- for chat in st.session_state.history:
69
- with st.chat_message('user'):
70
- st.markdown(f"**You:** {chat['question']}")
71
- with st.chat_message('assistant'):
72
- st.markdown(f"**RagBot:** {chat['answer']}")
73
 
74
- # Chat input
75
- question = st.chat_input('Ask a question about the document...')
76
- if question:
77
- # Retrieve top-k relevant chunks
78
- q_emb = embedder.encode([question])
79
- D, I = index.search(q_emb, k=3)
80
- context = '\n\n'.join(chunks[i] for i in I[0])
 
81
 
82
- # Get answer
83
- result = qa({'question': question, 'context': context})
84
- answer = result.get('answer', 'Sorry, could not find an answer.')
 
 
 
85
 
86
- # Save to history
87
- st.session_state.history.append({'question': question, 'answer': answer})
 
 
 
 
 
88
 
89
- # Display new messages
90
- with st.chat_message('user'):
91
- st.markdown(f"**You:** {question}")
92
- with st.chat_message('assistant'):
93
- st.markdown(f"**RagBot:** {answer}")
94
 
95
- else:
96
- st.info('Please upload a document in the sidebar to begin.')
 
 
 
 
97
 
98
  if __name__ == '__main__':
99
  main()
 
10
  def load_models():
11
  # Embedding model (lightweight)
12
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
13
+ # Generative QA model
14
+ qa_gen = pipeline(
15
+ 'text2text-generation',
16
+ model='google/flan-t5-base',
17
+ tokenizer='google/flan-t5-base',
18
+ device=-1 # CPU
19
+ )
20
+ return embedder, qa_gen
21
 
22
  # Extract text from uploaded file
23
  def extract_text_from_file(uploaded_file):
 
45
 
46
  # Build FAISS index from chunks
47
  @st.cache_resource
48
+ def build_faiss_index(chunks, _embedder): # underscore avoids hashing
49
+ embeddings = _embedder.encode(chunks, convert_to_numpy=True)
50
  dim = embeddings.shape[1]
51
  index = faiss.IndexFlatL2(dim)
52
  index.add(embeddings)
 
57
  st.set_page_config(page_title='πŸ“„ RAGbot', layout='wide')
58
  st.title('πŸ€– RagBot')
59
  st.sidebar.header('Upload Documents')
60
+
61
+ # Initialize chat history
62
  if 'history' not in st.session_state:
63
  st.session_state.history = []
64
 
65
  uploaded = st.sidebar.file_uploader('Upload PDF/DOCX/TXT', type=['pdf', 'docx', 'txt'])
66
+ if not uploaded:
67
+ st.info('Please upload a document in the sidebar to begin.')
68
+ return
 
 
 
 
 
 
 
 
 
69
 
70
+ # On first load of a doc, process and index
71
+ if 'chunks' not in st.session_state or st.session_state.uploaded_name != uploaded.name:
72
+ text = extract_text_from_file(uploaded)
73
+ st.session_state.chunks = chunk_text(text)
74
+ st.session_state.embedder, st.session_state.qa_gen = load_models()
75
+ st.session_state.index = build_faiss_index(st.session_state.chunks, st.session_state.embedder)
76
+ st.session_state.uploaded_name = uploaded.name
77
+ st.session_state.history = [] # reset history on new doc
78
 
79
+ # Display existing chat history
80
+ for chat in st.session_state.history:
81
+ with st.chat_message('user'):
82
+ st.markdown(f"**You:** {chat['question']}")
83
+ with st.chat_message('assistant'):
84
+ st.markdown(f"**RagBot:** {chat['answer']}")
85
 
86
+ # Chat input
87
+ question = st.chat_input('Ask a question about the document...')
88
+ if question:
89
+ # Retrieve top-k relevant chunks
90
+ q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
91
+ _, I = st.session_state.index.search(q_emb, k=3)
92
+ context = '\n\n'.join(st.session_state.chunks[i] for i in I[0])
93
 
94
+ # Generate answer
95
+ prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer in detail:"
96
+ response = st.session_state.qa_gen(prompt, max_new_tokens=200, do_sample=False)
97
+ answer = response[0]['generated_text'].strip()
 
98
 
99
+ # Save & display new messages
100
+ st.session_state.history.append({'question': question, 'answer': answer})
101
+ with st.chat_message('user'):
102
+ st.markdown(f"**You:** {question}")
103
+ with st.chat_message('assistant'):
104
+ st.markdown(f"**RagBot:** {answer}")
105
 
106
  if __name__ == '__main__':
107
  main()