maria355 commited on
Commit
2bf8221
·
verified ·
1 Parent(s): 6d6259e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +412 -412
app.py CHANGED
@@ -1,413 +1,413 @@
1
- import os
2
- import fitz # PyMuPDF
3
- import streamlit as st
4
- import tempfile
5
- from sentence_transformers import SentenceTransformer
6
- import faiss
7
- import numpy as np
8
- import tiktoken
9
- import requests
10
- from googletrans import Translator
11
- from gtts import gTTS
12
- import time
13
- # Set page config for better appearance
14
- st.set_page_config(
15
- page_title="RAG Document Assistant",
16
- page_icon="📄",
17
- layout="wide",
18
- initial_sidebar_state="expanded"
19
- )
20
- print('---------------------------------')
21
- # Sidebar profile function
22
- def sidebar_profiles():
23
- st.sidebar.markdown("""<hr>""", unsafe_allow_html=True) # Add line before author name
24
- st.sidebar.markdown("### 🎉Author: Maria Nadeem🌟")
25
- st.sidebar.markdown("### 🔗 Connect With Me")
26
- st.sidebar.markdown("""
27
- <hr>
28
- <div class="profile-links">
29
- <a href="https://github.com/marianadeem755" target="_blank">
30
- <img src="https://cdn-icons-png.flaticon.com/512/25/25231.png" width="20px"> GitHub
31
- </a><br><br>
32
- <a href="https://www.kaggle.com/marianadeem755" target="_blank">
33
- <img src="https://cdn4.iconfinder.com/data/icons/logos-and-brands/512/189_Kaggle_logo_logos-512.png" width="20px"> Kaggle
34
- </a><br><br>
35
- <a href="mailto:marianadeem755@gmail.com">
36
- <img src="https://cdn-icons-png.flaticon.com/512/561/561127.png" width="20px"> Email
37
- </a><br><br>
38
- <a href="https://huggingface.co/maria355" target="_blank">
39
- <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="20px"> Hugging Face
40
- </a>
41
- </div>
42
- <hr>
43
- """, unsafe_allow_html=True)
44
-
45
-
46
- # API Key Management with better error handling
47
- def get_api_key():
48
- # First try to get from environment
49
- api_key = os.getenv("GROQ_API_KEY")
50
-
51
- # If not in environment, try to get from session state or let user input it
52
- if not api_key:
53
- if "GROQ_API_KEY" in st.session_state:
54
- api_key = st.session_state["GROQ_API_KEY"]
55
-
56
- return api_key
57
-
58
- # Initialize session state variables if they don't exist
59
- if "chunks" not in st.session_state:
60
- st.session_state.chunks = []
61
- if "chunk_sources" not in st.session_state:
62
- st.session_state.chunk_sources = []
63
- if "debug_mode" not in st.session_state:
64
- st.session_state.debug_mode = False
65
- if "last_query_time" not in st.session_state:
66
- st.session_state.last_query_time = None
67
- if "last_response" not in st.session_state:
68
- st.session_state.last_response = None
69
-
70
- # Setup
71
- @st.cache_resource
72
- def load_embedder():
73
- return SentenceTransformer("all-MiniLM-L6-v2")
74
-
75
- embedder = load_embedder()
76
- embedding_dim = 384
77
- index = faiss.IndexFlatL2(embedding_dim)
78
- translator = Translator()
79
- tokenizer = tiktoken.get_encoding("cl100k_base")
80
-
81
- # Utilities
82
- def num_tokens_from_string(string: str) -> int:
83
- return len(tokenizer.encode(string))
84
-
85
- def chunk_text(text, max_tokens=250):
86
- sentences = text.split(". ")
87
- current_chunk = []
88
- total_tokens = 0
89
- result_chunks = []
90
- for sentence in sentences:
91
- if not sentence.strip(): # Skip empty sentences
92
- continue
93
- token_len = num_tokens_from_string(sentence)
94
- if total_tokens + token_len > max_tokens:
95
- if current_chunk: # Only add if there's content
96
- result_chunks.append(". ".join(current_chunk) + ("." if not current_chunk[-1].endswith(".") else ""))
97
- current_chunk = [sentence]
98
- total_tokens = token_len
99
- else:
100
- current_chunk.append(sentence)
101
- total_tokens += token_len
102
- if current_chunk:
103
- result_chunks.append(". ".join(current_chunk) + ("." if not current_chunk[-1].endswith(".") else ""))
104
- return result_chunks
105
-
106
- def extract_text_from_pdf(pdf_file):
107
- doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
108
- text = ""
109
- for page in doc:
110
- text += page.get_text()
111
- return text
112
-
113
- def index_uploaded_text(text):
114
- # Reset the index and chunks
115
- global index
116
- index = faiss.IndexFlatL2(embedding_dim)
117
- st.session_state.chunks = []
118
- st.session_state.chunk_sources = []
119
-
120
- # Process text into chunks
121
- chunks_list = chunk_text(text)
122
- st.session_state.chunks = chunks_list
123
-
124
- # Create source references and vectors
125
- for i, chunk in enumerate(chunks_list):
126
- st.session_state.chunk_sources.append(f"Chunk {i+1}: {chunk[:50]}...")
127
- vector = embedder.encode([chunk])[0]
128
- index.add(np.array([vector]).astype('float32'))
129
-
130
- return len(chunks_list)
131
-
132
- def retrieve_chunks(query, top_k=5):
133
- if index.ntotal == 0:
134
- return []
135
- q_vector = embedder.encode([query])
136
- D, I = index.search(np.array(q_vector).astype('float32'), k=min(top_k, index.ntotal))
137
- return [st.session_state.chunks[i] for i in I[0] if i < len(st.session_state.chunks)]
138
-
139
- def build_prompt(system_prompt, context_chunks, question):
140
- context = "\n\n".join(context_chunks)
141
- return f"""{system_prompt}
142
-
143
- Context:
144
- {context}
145
-
146
- Question:
147
- {question}
148
-
149
- Answer: Please provide a comprehensive answer based only on the context provided."""
150
-
151
- def generate_answer(prompt):
152
- api_key = get_api_key()
153
-
154
- if not api_key:
155
- return "API key is missing. Please set the GROQ_API_KEY environment variable or enter it in the sidebar."
156
-
157
- headers = {
158
- "Authorization": f"Bearer {api_key.strip()}", # Strip to remove any whitespace
159
- "Content-Type": "application/json"
160
- }
161
-
162
- # Use the model selected by the user, default to llama3-8b if none selected
163
- selected_model = st.session_state.get("MODEL_CHOICE", "llama3-8b-8192")
164
-
165
- payload = {
166
- "model": selected_model,
167
- "messages": [
168
- {"role": "system", "content": "You are a helpful document assistant that answers questions only using the provided context."},
169
- {"role": "user", "content": prompt}
170
- ],
171
- "temperature": 0.3,
172
- "max_tokens": 1024
173
- }
174
-
175
- try:
176
- start_time = time.time()
177
- with st.spinner("Sending request to Groq API..."):
178
- response = requests.post(
179
- "https://api.groq.com/openai/v1/chat/completions",
180
- json=payload,
181
- headers=headers,
182
- timeout=30
183
- )
184
-
185
- query_time = time.time() - start_time
186
- st.session_state.last_query_time = f"{query_time:.2f} seconds"
187
-
188
- # For debugging - show only status code when debug mode is enabled
189
- if st.session_state.debug_mode:
190
- st.write(f"API Response Status Code: {response.status_code}")
191
- st.write(f"Response time: {query_time:.2f} seconds")
192
-
193
- if response.status_code == 401:
194
- return "Authentication failed: The API key appears to be invalid or expired. Please check your API key."
195
-
196
- if response.status_code == 400:
197
- # Display the detailed error for 400 Bad Request
198
- error_info = response.json().get("error", {})
199
- error_message = error_info.get("message", "Unknown error")
200
- error_type = error_info.get("type", "Unknown type")
201
-
202
- # Try alternate model if model not found
203
- if "model not found" in error_message.lower() or "model_not_found" in error_type.lower():
204
- st.warning("Trying with an alternate model (llama3-8b-8192)...")
205
- payload["model"] = "llama3-8b-8192"
206
-
207
- response = requests.post(
208
- "https://api.groq.com/openai/v1/chat/completions",
209
- json=payload,
210
- headers=headers,
211
- timeout=30
212
- )
213
-
214
- if response.status_code != 200:
215
- return f"Both model attempts failed. Please check the available models for your Groq API key. Error: {error_message}"
216
- else:
217
- return f"API Error: {error_message}"
218
-
219
- response.raise_for_status() # Raises an HTTPError for other bad responses
220
-
221
- response_json = response.json()
222
-
223
- if "choices" not in response_json:
224
- error_msg = f"Unexpected API response format. Response: {response_json}"
225
- if "error" in response_json:
226
- error_msg = f"API Error: {response_json['error'].get('message', 'Unknown error')}"
227
- st.error(error_msg)
228
- return "Sorry, I couldn't retrieve an answer due to an API error."
229
-
230
- if not response_json["choices"]:
231
- return "No answer was generated."
232
-
233
- answer = response_json["choices"][0]["message"]["content"]
234
- st.session_state.last_response = answer
235
- return answer
236
-
237
- except requests.exceptions.RequestException as e:
238
- st.error(f"API request failed: {str(e)}")
239
- return f"Sorry, I couldn't connect to the API service. Error: {str(e)}"
240
- except Exception as e:
241
- st.error(f"Unexpected error: {str(e)}")
242
- return f"Sorry, something went wrong. Error: {str(e)}"
243
-
244
- def translate_text(text, target_language):
245
- try:
246
- with st.spinner(f"Translating to {target_language}..."):
247
- return translator.translate(text, dest=target_language).text
248
- except Exception as e:
249
- st.error(f"Translation failed: {str(e)}")
250
- return text # Return original text if translation fails
251
-
252
- def text_to_speech(text, lang_code):
253
- try:
254
- with st.spinner("Generating audio..."):
255
- tts = gTTS(text=text, lang=lang_code)
256
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
257
- tts.save(temp_file.name)
258
- return temp_file.name
259
- except Exception as e:
260
- st.error(f"Text-to-speech failed: {str(e)}")
261
- return None
262
-
263
- # Streamlit UI
264
- st.title("📄 Task-Specific RAG Assistant")
265
- st.markdown("Upload a document and ask questions to get AI-powered answers with translation capabilities.")
266
-
267
- # Add API key input in sidebar
268
- with st.sidebar:
269
- st.header("API Configuration")
270
- api_key_input = st.text_input(
271
- "Groq API Key",
272
- value=get_api_key() or "",
273
- type="password",
274
- help="Enter your Groq API key here if not set as environment variable"
275
- )
276
-
277
- if api_key_input:
278
- st.session_state["GROQ_API_KEY"] = api_key_input
279
- st.success("API key saved for this session!")
280
-
281
- # Add model selection
282
- st.subheader("Model Selection")
283
- model_choice = st.selectbox(
284
- "Select LLM Model",
285
- [
286
- "llama3-8b-8192", # Changed default to a model known to work
287
- "llama3-70b-8192"
288
- ],
289
- help="Choose the Groq model to use for answering questions"
290
- )
291
-
292
- st.session_state["MODEL_CHOICE"] = model_choice
293
-
294
- # Debug mode toggle
295
- st.subheader("Debug Settings")
296
- st.session_state.debug_mode = st.checkbox("Show Debug Information", value=st.session_state.debug_mode)
297
-
298
- if st.session_state.last_query_time:
299
- st.subheader("Performance")
300
- st.info(f"Last query time: {st.session_state.last_query_time}")
301
-
302
- st.subheader("About")
303
- st.markdown("""
304
- This app uses Retrieval-Augmented Generation (RAG) to answer questions about uploaded documents.
305
- 1. Upload a document
306
- 2. Ask a question
307
- 3. Optionally translate responses to other languages
308
- """)
309
-
310
- # Add the profile section
311
- sidebar_profiles()
312
-
313
- # Main content area
314
- col1, col2 = st.columns([2, 1])
315
-
316
- with col1:
317
- uploaded_file = st.file_uploader("Upload a PDF or TXT file", type=["pdf", "txt"])
318
- if uploaded_file:
319
- with st.spinner("Reading and indexing document..."):
320
- raw_text = ""
321
- if uploaded_file.type == "application/pdf":
322
- raw_text = extract_text_from_pdf(uploaded_file)
323
- elif uploaded_file.type == "text/plain":
324
- raw_text = uploaded_file.read().decode("utf-8")
325
-
326
- total_chunks = index_uploaded_text(raw_text)
327
- st.success(f"Document indexed successfully! Created {total_chunks} chunks.")
328
-
329
- # Display document preview
330
- with st.expander("Document Preview"):
331
- st.text_area("First 1000 characters of document", raw_text[:20000], height=200)
332
-
333
- with col2:
334
- if st.session_state.chunks:
335
- st.info(f"Document chunks: {len(st.session_state.chunks)}")
336
-
337
- # Query and answer section
338
- st.divider()
339
- query = st.text_input("Ask a question about the document")
340
-
341
- col1, col2 = st.columns([1, 1])
342
-
343
- with col1:
344
- enable_translation = st.checkbox("Translate answer", value=False)
345
- use_local = st.checkbox("Use local processing (no API call)", value=False,
346
- help="Use this if you're having API issues")
347
-
348
- with col2:
349
- language = st.selectbox("Language", ["English", "Urdu", "Hindi", "French", "Chinese", "Spanish", "German", "Arabic", "Russian"])
350
- language_codes = {
351
- "English": "en", "Urdu": "ur", "Hindi": "hi", "French": "fr", "Chinese": "zh-cn",
352
- "Spanish": "es", "German": "de", "Arabic": "ar", "Russian": "ru"
353
- }
354
- lang_code = language_codes[language]
355
-
356
- if query:
357
- if index.ntotal == 0:
358
- st.warning("Please upload and index a document first.")
359
- else:
360
- with st.spinner("Generating answer..."):
361
- top_chunks = retrieve_chunks(query)
362
- if not top_chunks:
363
- st.error("No relevant content found.")
364
- else:
365
- system_prompt = "You are a document assistant. Use only the context to answer accurately."
366
- prompt = build_prompt(system_prompt, top_chunks, query)
367
-
368
- # Check API key before making call
369
- if not get_api_key() and not use_local:
370
- st.error("API key is not set. Please add it in the sidebar.")
371
- else:
372
- if use_local:
373
- # Simple local processing that summarizes the chunks without API call
374
- st.warning("Using local processing - limited functionality!")
375
- answer = f"Local processing summary (no LLM used):\n\n"
376
- answer += f"Question: {query}\n\n"
377
- answer += "Here are the most relevant passages found:\n\n"
378
- for i, chunk in enumerate(top_chunks[:3], 1):
379
- answer += f"{i}. {chunk[:200]}...\n\n"
380
- else:
381
- answer = generate_answer(prompt)
382
-
383
- # Display query and context if debug mode is on
384
- if st.session_state.debug_mode:
385
- with st.expander("Query Context", expanded=False):
386
- st.write("Query:", query)
387
- st.write("Top chunks used:")
388
- for i, chunk in enumerate(top_chunks, 1):
389
- st.write(f"{i}. {chunk[:100]}...")
390
-
391
- # Create tabs for original and translated answers
392
- tab1, tab2 = st.tabs(["Original Answer", f"Translated ({language})" if enable_translation else "Translation (disabled)"])
393
-
394
- with tab1:
395
- st.markdown("### Answer:")
396
- st.write(answer)
397
-
398
- with tab2:
399
- if enable_translation and answer:
400
- translated = translate_text(answer, lang_code)
401
- st.markdown(f"### Answer ({language}):")
402
- st.write(translated)
403
-
404
- # Audio generation
405
- audio_path = text_to_speech(translated, lang_code)
406
- if audio_path:
407
- st.audio(audio_path, format="audio/mp3")
408
- else:
409
- st.info("Enable translation to see the answer in your selected language.")
410
-
411
- # Add footer
412
- st.divider()
413
  st.caption("RAG Document Assistant - Powered by Groq & Sentence Transformers")
 
1
+ import os
2
+ import fitz # PyMuPDF
3
+ import streamlit as st
4
+ import tempfile
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
+ import numpy as np
8
+ import tiktoken
9
+ import requests
10
+ from deep_translator import GoogleTranslator
11
+ from gtts import gTTS
12
+ import time
13
+ # Set page config for better appearance
14
+ st.set_page_config(
15
+ page_title="RAG Document Assistant",
16
+ page_icon="📄",
17
+ layout="wide",
18
+ initial_sidebar_state="expanded"
19
+ )
20
+ print('---------------------------------')
21
+ # Sidebar profile function
22
+ def sidebar_profiles():
23
+ st.sidebar.markdown("""<hr>""", unsafe_allow_html=True) # Add line before author name
24
+ st.sidebar.markdown("### 🎉Author: Maria Nadeem🌟")
25
+ st.sidebar.markdown("### 🔗 Connect With Me")
26
+ st.sidebar.markdown("""
27
+ <hr>
28
+ <div class="profile-links">
29
+ <a href="https://github.com/marianadeem755" target="_blank">
30
+ <img src="https://cdn-icons-png.flaticon.com/512/25/25231.png" width="20px"> GitHub
31
+ </a><br><br>
32
+ <a href="https://www.kaggle.com/marianadeem755" target="_blank">
33
+ <img src="https://cdn4.iconfinder.com/data/icons/logos-and-brands/512/189_Kaggle_logo_logos-512.png" width="20px"> Kaggle
34
+ </a><br><br>
35
+ <a href="mailto:marianadeem755@gmail.com">
36
+ <img src="https://cdn-icons-png.flaticon.com/512/561/561127.png" width="20px"> Email
37
+ </a><br><br>
38
+ <a href="https://huggingface.co/maria355" target="_blank">
39
+ <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="20px"> Hugging Face
40
+ </a>
41
+ </div>
42
+ <hr>
43
+ """, unsafe_allow_html=True)
44
+
45
+
46
+ # API Key Management with better error handling
47
+ def get_api_key():
48
+ # First try to get from environment
49
+ api_key = os.getenv("GROQ_API_KEY")
50
+
51
+ # If not in environment, try to get from session state or let user input it
52
+ if not api_key:
53
+ if "GROQ_API_KEY" in st.session_state:
54
+ api_key = st.session_state["GROQ_API_KEY"]
55
+
56
+ return api_key
57
+
58
+ # Initialize session state variables if they don't exist
59
+ if "chunks" not in st.session_state:
60
+ st.session_state.chunks = []
61
+ if "chunk_sources" not in st.session_state:
62
+ st.session_state.chunk_sources = []
63
+ if "debug_mode" not in st.session_state:
64
+ st.session_state.debug_mode = False
65
+ if "last_query_time" not in st.session_state:
66
+ st.session_state.last_query_time = None
67
+ if "last_response" not in st.session_state:
68
+ st.session_state.last_response = None
69
+
70
+ # Setup
71
+ @st.cache_resource
72
+ def load_embedder():
73
+ return SentenceTransformer("all-MiniLM-L6-v2")
74
+
75
+ embedder = load_embedder()
76
+ embedding_dim = 384
77
+ index = faiss.IndexFlatL2(embedding_dim)
78
+ translator = Translator()
79
+ tokenizer = tiktoken.get_encoding("cl100k_base")
80
+
81
+ # Utilities
82
+ def num_tokens_from_string(string: str) -> int:
83
+ return len(tokenizer.encode(string))
84
+
85
+ def chunk_text(text, max_tokens=250):
86
+ sentences = text.split(". ")
87
+ current_chunk = []
88
+ total_tokens = 0
89
+ result_chunks = []
90
+ for sentence in sentences:
91
+ if not sentence.strip(): # Skip empty sentences
92
+ continue
93
+ token_len = num_tokens_from_string(sentence)
94
+ if total_tokens + token_len > max_tokens:
95
+ if current_chunk: # Only add if there's content
96
+ result_chunks.append(". ".join(current_chunk) + ("." if not current_chunk[-1].endswith(".") else ""))
97
+ current_chunk = [sentence]
98
+ total_tokens = token_len
99
+ else:
100
+ current_chunk.append(sentence)
101
+ total_tokens += token_len
102
+ if current_chunk:
103
+ result_chunks.append(". ".join(current_chunk) + ("." if not current_chunk[-1].endswith(".") else ""))
104
+ return result_chunks
105
+
106
+ def extract_text_from_pdf(pdf_file):
107
+ doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
108
+ text = ""
109
+ for page in doc:
110
+ text += page.get_text()
111
+ return text
112
+
113
+ def index_uploaded_text(text):
114
+ # Reset the index and chunks
115
+ global index
116
+ index = faiss.IndexFlatL2(embedding_dim)
117
+ st.session_state.chunks = []
118
+ st.session_state.chunk_sources = []
119
+
120
+ # Process text into chunks
121
+ chunks_list = chunk_text(text)
122
+ st.session_state.chunks = chunks_list
123
+
124
+ # Create source references and vectors
125
+ for i, chunk in enumerate(chunks_list):
126
+ st.session_state.chunk_sources.append(f"Chunk {i+1}: {chunk[:50]}...")
127
+ vector = embedder.encode([chunk])[0]
128
+ index.add(np.array([vector]).astype('float32'))
129
+
130
+ return len(chunks_list)
131
+
132
+ def retrieve_chunks(query, top_k=5):
133
+ if index.ntotal == 0:
134
+ return []
135
+ q_vector = embedder.encode([query])
136
+ D, I = index.search(np.array(q_vector).astype('float32'), k=min(top_k, index.ntotal))
137
+ return [st.session_state.chunks[i] for i in I[0] if i < len(st.session_state.chunks)]
138
+
139
+ def build_prompt(system_prompt, context_chunks, question):
140
+ context = "\n\n".join(context_chunks)
141
+ return f"""{system_prompt}
142
+
143
+ Context:
144
+ {context}
145
+
146
+ Question:
147
+ {question}
148
+
149
+ Answer: Please provide a comprehensive answer based only on the context provided."""
150
+
151
+ def generate_answer(prompt):
152
+ api_key = get_api_key()
153
+
154
+ if not api_key:
155
+ return "API key is missing. Please set the GROQ_API_KEY environment variable or enter it in the sidebar."
156
+
157
+ headers = {
158
+ "Authorization": f"Bearer {api_key.strip()}", # Strip to remove any whitespace
159
+ "Content-Type": "application/json"
160
+ }
161
+
162
+ # Use the model selected by the user, default to llama3-8b if none selected
163
+ selected_model = st.session_state.get("MODEL_CHOICE", "llama3-8b-8192")
164
+
165
+ payload = {
166
+ "model": selected_model,
167
+ "messages": [
168
+ {"role": "system", "content": "You are a helpful document assistant that answers questions only using the provided context."},
169
+ {"role": "user", "content": prompt}
170
+ ],
171
+ "temperature": 0.3,
172
+ "max_tokens": 1024
173
+ }
174
+
175
+ try:
176
+ start_time = time.time()
177
+ with st.spinner("Sending request to Groq API..."):
178
+ response = requests.post(
179
+ "https://api.groq.com/openai/v1/chat/completions",
180
+ json=payload,
181
+ headers=headers,
182
+ timeout=30
183
+ )
184
+
185
+ query_time = time.time() - start_time
186
+ st.session_state.last_query_time = f"{query_time:.2f} seconds"
187
+
188
+ # For debugging - show only status code when debug mode is enabled
189
+ if st.session_state.debug_mode:
190
+ st.write(f"API Response Status Code: {response.status_code}")
191
+ st.write(f"Response time: {query_time:.2f} seconds")
192
+
193
+ if response.status_code == 401:
194
+ return "Authentication failed: The API key appears to be invalid or expired. Please check your API key."
195
+
196
+ if response.status_code == 400:
197
+ # Display the detailed error for 400 Bad Request
198
+ error_info = response.json().get("error", {})
199
+ error_message = error_info.get("message", "Unknown error")
200
+ error_type = error_info.get("type", "Unknown type")
201
+
202
+ # Try alternate model if model not found
203
+ if "model not found" in error_message.lower() or "model_not_found" in error_type.lower():
204
+ st.warning("Trying with an alternate model (llama3-8b-8192)...")
205
+ payload["model"] = "llama3-8b-8192"
206
+
207
+ response = requests.post(
208
+ "https://api.groq.com/openai/v1/chat/completions",
209
+ json=payload,
210
+ headers=headers,
211
+ timeout=30
212
+ )
213
+
214
+ if response.status_code != 200:
215
+ return f"Both model attempts failed. Please check the available models for your Groq API key. Error: {error_message}"
216
+ else:
217
+ return f"API Error: {error_message}"
218
+
219
+ response.raise_for_status() # Raises an HTTPError for other bad responses
220
+
221
+ response_json = response.json()
222
+
223
+ if "choices" not in response_json:
224
+ error_msg = f"Unexpected API response format. Response: {response_json}"
225
+ if "error" in response_json:
226
+ error_msg = f"API Error: {response_json['error'].get('message', 'Unknown error')}"
227
+ st.error(error_msg)
228
+ return "Sorry, I couldn't retrieve an answer due to an API error."
229
+
230
+ if not response_json["choices"]:
231
+ return "No answer was generated."
232
+
233
+ answer = response_json["choices"][0]["message"]["content"]
234
+ st.session_state.last_response = answer
235
+ return answer
236
+
237
+ except requests.exceptions.RequestException as e:
238
+ st.error(f"API request failed: {str(e)}")
239
+ return f"Sorry, I couldn't connect to the API service. Error: {str(e)}"
240
+ except Exception as e:
241
+ st.error(f"Unexpected error: {str(e)}")
242
+ return f"Sorry, something went wrong. Error: {str(e)}"
243
+
244
+ def translate_text(text, target_language):
245
+ try:
246
+ with st.spinner(f"Translating to {target_language}..."):
247
+ return translator.translate(text, dest=target_language).text
248
+ except Exception as e:
249
+ st.error(f"Translation failed: {str(e)}")
250
+ return text # Return original text if translation fails
251
+
252
+ def text_to_speech(text, lang_code):
253
+ try:
254
+ with st.spinner("Generating audio..."):
255
+ tts = gTTS(text=text, lang=lang_code)
256
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
257
+ tts.save(temp_file.name)
258
+ return temp_file.name
259
+ except Exception as e:
260
+ st.error(f"Text-to-speech failed: {str(e)}")
261
+ return None
262
+
263
+ # Streamlit UI
264
+ st.title("📄 Task-Specific RAG Assistant")
265
+ st.markdown("Upload a document and ask questions to get AI-powered answers with translation capabilities.")
266
+
267
+ # Add API key input in sidebar
268
+ with st.sidebar:
269
+ st.header("API Configuration")
270
+ api_key_input = st.text_input(
271
+ "Groq API Key",
272
+ value=get_api_key() or "",
273
+ type="password",
274
+ help="Enter your Groq API key here if not set as environment variable"
275
+ )
276
+
277
+ if api_key_input:
278
+ st.session_state["GROQ_API_KEY"] = api_key_input
279
+ st.success("API key saved for this session!")
280
+
281
+ # Add model selection
282
+ st.subheader("Model Selection")
283
+ model_choice = st.selectbox(
284
+ "Select LLM Model",
285
+ [
286
+ "llama3-8b-8192", # Changed default to a model known to work
287
+ "llama3-70b-8192"
288
+ ],
289
+ help="Choose the Groq model to use for answering questions"
290
+ )
291
+
292
+ st.session_state["MODEL_CHOICE"] = model_choice
293
+
294
+ # Debug mode toggle
295
+ st.subheader("Debug Settings")
296
+ st.session_state.debug_mode = st.checkbox("Show Debug Information", value=st.session_state.debug_mode)
297
+
298
+ if st.session_state.last_query_time:
299
+ st.subheader("Performance")
300
+ st.info(f"Last query time: {st.session_state.last_query_time}")
301
+
302
+ st.subheader("About")
303
+ st.markdown("""
304
+ This app uses Retrieval-Augmented Generation (RAG) to answer questions about uploaded documents.
305
+ 1. Upload a document
306
+ 2. Ask a question
307
+ 3. Optionally translate responses to other languages
308
+ """)
309
+
310
+ # Add the profile section
311
+ sidebar_profiles()
312
+
313
+ # Main content area
314
+ col1, col2 = st.columns([2, 1])
315
+
316
+ with col1:
317
+ uploaded_file = st.file_uploader("Upload a PDF or TXT file", type=["pdf", "txt"])
318
+ if uploaded_file:
319
+ with st.spinner("Reading and indexing document..."):
320
+ raw_text = ""
321
+ if uploaded_file.type == "application/pdf":
322
+ raw_text = extract_text_from_pdf(uploaded_file)
323
+ elif uploaded_file.type == "text/plain":
324
+ raw_text = uploaded_file.read().decode("utf-8")
325
+
326
+ total_chunks = index_uploaded_text(raw_text)
327
+ st.success(f"Document indexed successfully! Created {total_chunks} chunks.")
328
+
329
+ # Display document preview
330
+ with st.expander("Document Preview"):
331
+ st.text_area("First 1000 characters of document", raw_text[:20000], height=200)
332
+
333
+ with col2:
334
+ if st.session_state.chunks:
335
+ st.info(f"Document chunks: {len(st.session_state.chunks)}")
336
+
337
+ # Query and answer section
338
+ st.divider()
339
+ query = st.text_input("Ask a question about the document")
340
+
341
+ col1, col2 = st.columns([1, 1])
342
+
343
+ with col1:
344
+ enable_translation = st.checkbox("Translate answer", value=False)
345
+ use_local = st.checkbox("Use local processing (no API call)", value=False,
346
+ help="Use this if you're having API issues")
347
+
348
+ with col2:
349
+ language = st.selectbox("Language", ["English", "Urdu", "Hindi", "French", "Chinese", "Spanish", "German", "Arabic", "Russian"])
350
+ language_codes = {
351
+ "English": "en", "Urdu": "ur", "Hindi": "hi", "French": "fr", "Chinese": "zh-cn",
352
+ "Spanish": "es", "German": "de", "Arabic": "ar", "Russian": "ru"
353
+ }
354
+ lang_code = language_codes[language]
355
+
356
+ if query:
357
+ if index.ntotal == 0:
358
+ st.warning("Please upload and index a document first.")
359
+ else:
360
+ with st.spinner("Generating answer..."):
361
+ top_chunks = retrieve_chunks(query)
362
+ if not top_chunks:
363
+ st.error("No relevant content found.")
364
+ else:
365
+ system_prompt = "You are a document assistant. Use only the context to answer accurately."
366
+ prompt = build_prompt(system_prompt, top_chunks, query)
367
+
368
+ # Check API key before making call
369
+ if not get_api_key() and not use_local:
370
+ st.error("API key is not set. Please add it in the sidebar.")
371
+ else:
372
+ if use_local:
373
+ # Simple local processing that summarizes the chunks without API call
374
+ st.warning("Using local processing - limited functionality!")
375
+ answer = f"Local processing summary (no LLM used):\n\n"
376
+ answer += f"Question: {query}\n\n"
377
+ answer += "Here are the most relevant passages found:\n\n"
378
+ for i, chunk in enumerate(top_chunks[:3], 1):
379
+ answer += f"{i}. {chunk[:200]}...\n\n"
380
+ else:
381
+ answer = generate_answer(prompt)
382
+
383
+ # Display query and context if debug mode is on
384
+ if st.session_state.debug_mode:
385
+ with st.expander("Query Context", expanded=False):
386
+ st.write("Query:", query)
387
+ st.write("Top chunks used:")
388
+ for i, chunk in enumerate(top_chunks, 1):
389
+ st.write(f"{i}. {chunk[:100]}...")
390
+
391
+ # Create tabs for original and translated answers
392
+ tab1, tab2 = st.tabs(["Original Answer", f"Translated ({language})" if enable_translation else "Translation (disabled)"])
393
+
394
+ with tab1:
395
+ st.markdown("### Answer:")
396
+ st.write(answer)
397
+
398
+ with tab2:
399
+ if enable_translation and answer:
400
+ translated = translate_text(answer, lang_code)
401
+ st.markdown(f"### Answer ({language}):")
402
+ st.write(translated)
403
+
404
+ # Audio generation
405
+ audio_path = text_to_speech(translated, lang_code)
406
+ if audio_path:
407
+ st.audio(audio_path, format="audio/mp3")
408
+ else:
409
+ st.info("Enable translation to see the answer in your selected language.")
410
+
411
+ # Add footer
412
+ st.divider()
413
  st.caption("RAG Document Assistant - Powered by Groq & Sentence Transformers")