sourize commited on
Commit
c95539d
·
1 Parent(s): 61d7892

Updated main.py

Browse files
Files changed (1) hide show
  1. app.py +75 -60
app.py CHANGED
@@ -5,99 +5,114 @@ from sentence_transformers import SentenceTransformer
5
  import faiss
6
  from transformers import pipeline
7
 
8
- # Caching heavy resources
9
  @st.cache_resource
10
- def load_models():
11
- # Embedding model (lightweight)
12
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
13
- # Generative QA model
14
- qa_gen = pipeline(
15
  'text2text-generation',
16
  model='google/flan-t5-base',
17
  tokenizer='google/flan-t5-base',
18
- device=-1 # CPU
19
  )
20
- return embedder, qa_gen
21
 
22
  # Extract text from uploaded file
23
- def extract_text_from_file(uploaded_file):
24
- name = uploaded_file.name.lower()
25
  if name.endswith('.pdf'):
26
- reader = PdfReader(uploaded_file)
27
- text = ''.join(page.extract_text() or '' for page in reader.pages)
28
- elif name.endswith('.docx'):
29
- doc = docx.Document(uploaded_file)
30
- text = '\n'.join(para.text for para in doc.paragraphs)
31
- else:
32
- text = uploaded_file.getvalue().decode('utf-8', errors='ignore')
33
- return text
34
 
35
- # Split text into chunks
36
- def chunk_text(text, chunk_size=500, overlap=50):
37
  words = text.split()
38
- chunks = []
39
- start = 0
40
  while start < len(words):
41
- end = min(start + chunk_size, len(words))
42
  chunks.append(' '.join(words[start:end]))
43
- start += chunk_size - overlap
44
  return chunks
45
 
46
- # Build FAISS index from chunks
47
  @st.cache_resource
48
- def build_faiss_index(chunks, _embedder): # underscore avoids hashing
49
- embeddings = _embedder.encode(chunks, convert_to_numpy=True)
50
- dim = embeddings.shape[1]
51
  index = faiss.IndexFlatL2(dim)
52
- index.add(embeddings)
53
  return index
54
 
55
- # Main Streamlit app
 
 
 
 
 
 
 
 
 
 
56
  def main():
57
- st.set_page_config(page_title='📄 RAGbot', layout='wide')
58
  st.title('🤖 RagBot')
59
- st.sidebar.header('Upload Documents')
60
 
61
- # Initialize chat history
62
  if 'history' not in st.session_state:
63
- st.session_state.history = []
 
 
 
 
64
 
65
- uploaded = st.sidebar.file_uploader('Upload PDF/DOCX/TXT', type=['pdf', 'docx', 'txt'])
66
- if not uploaded:
67
- st.info('Please upload a document in the sidebar to begin.')
68
- return
69
-
70
- # On first load of a doc, process and index
71
- if 'chunks' not in st.session_state or st.session_state.uploaded_name != uploaded.name:
72
- text = extract_text_from_file(uploaded)
73
  st.session_state.chunks = chunk_text(text)
74
- st.session_state.embedder, st.session_state.qa_gen = load_models()
75
- st.session_state.index = build_faiss_index(st.session_state.chunks, st.session_state.embedder)
76
  st.session_state.uploaded_name = uploaded.name
77
- st.session_state.history = [] # reset history on new doc
78
 
79
- # Display existing chat history
80
- for chat in st.session_state.history:
81
- with st.chat_message('user'):
82
- st.markdown(f"**You:** {chat['question']}")
83
- with st.chat_message('assistant'):
84
- st.markdown(f"**RagBot:** {chat['answer']}")
 
 
 
85
 
86
- # Chat input
87
- question = st.chat_input('Ask a question about the document...')
88
  if question:
89
- # Retrieve top-k relevant chunks
90
  q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
91
- _, I = st.session_state.index.search(q_emb, k=3)
92
- context = '\n\n'.join(st.session_state.chunks[i] for i in I[0])
93
 
94
- # Generate answer
95
- prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer in detail:"
96
- response = st.session_state.qa_gen(prompt, max_new_tokens=200, do_sample=False)
 
 
 
 
 
97
  answer = response[0]['generated_text'].strip()
98
 
99
- # Save & display new messages
100
- st.session_state.history.append({'question': question, 'answer': answer})
 
101
  with st.chat_message('user'):
102
  st.markdown(f"**You:** {question}")
103
  with st.chat_message('assistant'):
 
5
  import faiss
6
  from transformers import pipeline
7
 
8
+ # Load and cache models
9
  @st.cache_resource
10
+ def load_resources():
 
11
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
12
+ # Generative chat model
13
+ chat_gen = pipeline(
14
  'text2text-generation',
15
  model='google/flan-t5-base',
16
  tokenizer='google/flan-t5-base',
17
+ device=-1,
18
  )
19
+ return embedder, chat_gen
20
 
21
  # Extract text from uploaded file
22
+ def extract_text(uploaded):
23
+ name = uploaded.name.lower()
24
  if name.endswith('.pdf'):
25
+ reader = PdfReader(uploaded)
26
+ return ''.join(page.extract_text() or '' for page in reader.pages)
27
+ if name.endswith('.docx'):
28
+ doc = docx.Document(uploaded)
29
+ return '\n'.join(para.text for para in doc.paragraphs)
30
+ return uploaded.getvalue().decode('utf-8', errors='ignore')
 
 
31
 
32
+ # Chunking helper
33
+ def chunk_text(text, size=500, overlap=50):
34
  words = text.split()
35
+ chunks, start = [], 0
 
36
  while start < len(words):
37
+ end = min(start + size, len(words))
38
  chunks.append(' '.join(words[start:end]))
39
+ start += size - overlap
40
  return chunks
41
 
42
+ # Build FAISS index
43
  @st.cache_resource
44
+ def build_index(chunks, _embedder): # underscore avoids hashing
45
+ embs = _embedder.encode(chunks, convert_to_numpy=True)
46
+ dim = embs.shape[1]
47
  index = faiss.IndexFlatL2(dim)
48
+ index.add(embs)
49
  return index
50
 
51
+ # Compose prompt for chat+RAG
52
+ def make_prompt(system_prompt, context, history, question):
53
+ prompt = system_prompt + "\n\n" + "Document Context:\n" + context + "\n\n"
54
+ # append conversation history
55
+ for msg in history:
56
+ role, text = msg['role'], msg['text']
57
+ prompt += f"{role}: {text}\n"
58
+ prompt += f"User: {question}\nAssistant:"
59
+ return prompt
60
+
61
+ # Main app
62
  def main():
63
+ st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
64
  st.title('🤖 RagBot')
65
+ st.sidebar.header('📂 Upload Document')
66
 
67
+ # Initialize state
68
  if 'history' not in st.session_state:
69
+ st.session_state.history = [] # list of {'role': 'User|Assistant', 'text': ...}
70
+ if 'chunks' not in st.session_state:
71
+ st.session_state.chunks = []
72
+ if 'index' not in st.session_state:
73
+ st.session_state.index = None
74
 
75
+ uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
76
+ if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
77
+ # New document: extract, chunk, index, reset
78
+ text = extract_text(uploaded)
 
 
 
 
79
  st.session_state.chunks = chunk_text(text)
80
+ st.session_state.embedder, st.session_state.chat_gen = load_resources()
81
+ st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
82
  st.session_state.uploaded_name = uploaded.name
83
+ st.session_state.history = []
84
 
85
+ # If no doc yet, ask to upload
86
+ if st.session_state.index is None:
87
+ st.info('Please upload a document in the sidebar to start.')
88
+ return
89
+
90
+ # Display chat history
91
+ for msg in st.session_state.history:
92
+ with st.chat_message('user' if msg['role']=='User' else 'assistant'):
93
+ st.markdown(f"**{msg['role']}:** {msg['text']}")
94
 
95
+ # User input
96
+ question = st.chat_input('Ask anything—general or about the document...')
97
  if question:
98
+ # Retrieve relevant context
99
  q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
100
+ _, idxs = st.session_state.index.search(q_emb, k=3)
101
+ context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
102
 
103
+ # Build and run prompt
104
+ system_prompt = (
105
+ "You are RagBot, an AI assistant. "
106
+ "Use the provided document context to answer specific questions, "
107
+ "but also leverage your general knowledge for broader queries."
108
+ )
109
+ prompt = make_prompt(system_prompt, context, st.session_state.history, question)
110
+ response = st.session_state.chat_gen(prompt, max_new_tokens=200, do_sample=False)
111
  answer = response[0]['generated_text'].strip()
112
 
113
+ # Record and display
114
+ st.session_state.history.append({'role':'User','text':question})
115
+ st.session_state.history.append({'role':'Assistant','text':answer})
116
  with st.chat_message('user'):
117
  st.markdown(f"**You:** {question}")
118
  with st.chat_message('assistant'):