sourize commited on
Commit
13f2322
·
1 Parent(s): c95539d

Updated main.py

Browse files
Files changed (1) hide show
  1. app.py +27 -25
app.py CHANGED
@@ -9,7 +9,6 @@ from transformers import pipeline
9
  @st.cache_resource
10
  def load_resources():
11
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
12
- # Generative chat model
13
  chat_gen = pipeline(
14
  'text2text-generation',
15
  model='google/flan-t5-base',
@@ -48,13 +47,14 @@ def build_index(chunks, _embedder): # underscore avoids hashing
48
  index.add(embs)
49
  return index
50
 
51
- # Compose prompt for chat+RAG
52
  def make_prompt(system_prompt, context, history, question):
53
- prompt = system_prompt + "\n\n" + "Document Context:\n" + context + "\n\n"
54
- # append conversation history
 
55
  for msg in history:
56
- role, text = msg['role'], msg['text']
57
- prompt += f"{role}: {text}\n"
58
  prompt += f"User: {question}\nAssistant:"
59
  return prompt
60
 
@@ -62,55 +62,57 @@ def make_prompt(system_prompt, context, history, question):
62
  def main():
63
  st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
64
  st.title('🤖 RagBot')
65
- st.sidebar.header('📂 Upload Document')
66
 
67
  # Initialize state
68
  if 'history' not in st.session_state:
69
- st.session_state.history = [] # list of {'role': 'User|Assistant', 'text': ...}
70
  if 'chunks' not in st.session_state:
71
  st.session_state.chunks = []
72
  if 'index' not in st.session_state:
73
  st.session_state.index = None
74
 
 
75
  uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
76
  if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
77
- # New document: extract, chunk, index, reset
78
  text = extract_text(uploaded)
79
  st.session_state.chunks = chunk_text(text)
80
  st.session_state.embedder, st.session_state.chat_gen = load_resources()
81
  st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
82
  st.session_state.uploaded_name = uploaded.name
83
- st.session_state.history = []
84
-
85
- # If no doc yet, ask to upload
86
- if st.session_state.index is None:
87
- st.info('Please upload a document in the sidebar to start.')
88
- return
89
 
90
  # Display chat history
91
  for msg in st.session_state.history:
92
  with st.chat_message('user' if msg['role']=='User' else 'assistant'):
93
  st.markdown(f"**{msg['role']}:** {msg['text']}")
94
 
95
- # User input
96
- question = st.chat_input('Ask anything—general or about the document...')
97
  if question:
98
- # Retrieve relevant context
99
- q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
100
- _, idxs = st.session_state.index.search(q_emb, k=3)
101
- context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
 
 
102
 
103
- # Build and run prompt
104
  system_prompt = (
105
  "You are RagBot, an AI assistant. "
106
- "Use the provided document context to answer specific questions, "
107
- "but also leverage your general knowledge for broader queries."
108
  )
109
  prompt = make_prompt(system_prompt, context, st.session_state.history, question)
 
 
110
  response = st.session_state.chat_gen(prompt, max_new_tokens=200, do_sample=False)
111
  answer = response[0]['generated_text'].strip()
112
 
113
- # Record and display
114
  st.session_state.history.append({'role':'User','text':question})
115
  st.session_state.history.append({'role':'Assistant','text':answer})
116
  with st.chat_message('user'):
 
9
  @st.cache_resource
10
  def load_resources():
11
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
12
  chat_gen = pipeline(
13
  'text2text-generation',
14
  model='google/flan-t5-base',
 
47
  index.add(embs)
48
  return index
49
 
50
+ # Compose prompt
51
  def make_prompt(system_prompt, context, history, question):
52
+ prompt = system_prompt + "\n\n"
53
+ if context:
54
+ prompt += f"Document Context:\n{context}\n\n"
55
  for msg in history:
56
+ role = 'User' if msg['role']=='User' else 'Assistant'
57
+ prompt += f"{role}: {msg['text']}\n"
58
  prompt += f"User: {question}\nAssistant:"
59
  return prompt
60
 
 
62
  def main():
63
  st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
64
  st.title('🤖 RagBot')
65
+ st.sidebar.header('📂 Optional: Upload Document')
66
 
67
  # Initialize state
68
  if 'history' not in st.session_state:
69
+ st.session_state.history = [] # list of {'role': 'User'|'Assistant', 'text': ...}
70
  if 'chunks' not in st.session_state:
71
  st.session_state.chunks = []
72
  if 'index' not in st.session_state:
73
  st.session_state.index = None
74
 
75
+ # Document upload
76
  uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
77
  if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
 
78
  text = extract_text(uploaded)
79
  st.session_state.chunks = chunk_text(text)
80
  st.session_state.embedder, st.session_state.chat_gen = load_resources()
81
  st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
82
  st.session_state.uploaded_name = uploaded.name
83
+ st.session_state.history = [] # reset conversation
84
+ # Load models if not loaded
85
+ if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
86
+ st.session_state.embedder, st.session_state.chat_gen = load_resources()
 
 
87
 
88
  # Display chat history
89
  for msg in st.session_state.history:
90
  with st.chat_message('user' if msg['role']=='User' else 'assistant'):
91
  st.markdown(f"**{msg['role']}:** {msg['text']}")
92
 
93
+ # Chat input always available
94
+ question = st.chat_input('Ask a question—general or document-specific...')
95
  if question:
96
+ # Retrieve context if index exists
97
+ context = ''
98
+ if st.session_state.index is not None:
99
+ q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
100
+ _, idxs = st.session_state.index.search(q_emb, k=3)
101
+ context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
102
 
103
+ # Build prompt
104
  system_prompt = (
105
  "You are RagBot, an AI assistant. "
106
+ "Use the provided document context for specific questions, "
107
+ "and your general knowledge for everything else."
108
  )
109
  prompt = make_prompt(system_prompt, context, st.session_state.history, question)
110
+
111
+ # Generate answer
112
  response = st.session_state.chat_gen(prompt, max_new_tokens=200, do_sample=False)
113
  answer = response[0]['generated_text'].strip()
114
 
115
+ # Save & display
116
  st.session_state.history.append({'role':'User','text':question})
117
  st.session_state.history.append({'role':'Assistant','text':answer})
118
  with st.chat_message('user'):