maria355 commited on
Commit
f4bbb80
·
verified ·
1 Parent(s): 82e7c99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -85
app.py CHANGED
@@ -10,17 +10,16 @@ import requests
10
  from deep_translator import GoogleTranslator
11
  from gtts import gTTS
12
  import time
13
- # Set page config for better appearance
14
  st.set_page_config(
15
  page_title="RAG Document Assistant",
16
  page_icon="📄",
17
  layout="wide",
18
  initial_sidebar_state="expanded"
19
  )
20
- print('---------------------------------')
21
- # Sidebar profile function
22
  def sidebar_profiles():
23
- st.sidebar.markdown("""<hr>""", unsafe_allow_html=True) # Add line before author name
24
  st.sidebar.markdown("### 🎉Author: Maria Nadeem🌟")
25
  st.sidebar.markdown("### 🔗 Connect With Me")
26
  st.sidebar.markdown("""
@@ -42,32 +41,24 @@ def sidebar_profiles():
42
  <hr>
43
  """, unsafe_allow_html=True)
44
 
45
-
46
- # API Key Management with better error handling
47
  def get_api_key():
48
- # First try to get from environment
49
  api_key = os.getenv("GROQ_API_KEY")
50
-
51
- # If not in environment, try to get from session state or let user input it
52
  if not api_key:
53
  if "GROQ_API_KEY" in st.session_state:
54
  api_key = st.session_state["GROQ_API_KEY"]
55
-
56
  return api_key
57
 
58
- # Initialize session state variables if they don't exist
59
- if "chunks" not in st.session_state:
60
- st.session_state.chunks = []
61
- if "chunk_sources" not in st.session_state:
62
- st.session_state.chunk_sources = []
63
- if "debug_mode" not in st.session_state:
64
- st.session_state.debug_mode = False
65
- if "last_query_time" not in st.session_state:
66
- st.session_state.last_query_time = None
67
- if "last_response" not in st.session_state:
68
- st.session_state.last_response = None
69
-
70
- # Setup
71
  @st.cache_resource
72
  def load_embedder():
73
  return SentenceTransformer("all-MiniLM-L6-v2")
@@ -75,11 +66,8 @@ def load_embedder():
75
  embedder = load_embedder()
76
  embedding_dim = 384
77
  index = faiss.IndexFlatL2(embedding_dim)
78
- translated_text = GoogleTranslator(source='auto', target='fr').translate(text)
79
- result = GoogleTranslator(source='auto', target='fr').translate(text)
80
  tokenizer = tiktoken.get_encoding("cl100k_base")
81
 
82
- # Utilities
83
  def num_tokens_from_string(string: str) -> int:
84
  return len(tokenizer.encode(string))
85
 
@@ -89,11 +77,11 @@ def chunk_text(text, max_tokens=250):
89
  total_tokens = 0
90
  result_chunks = []
91
  for sentence in sentences:
92
- if not sentence.strip(): # Skip empty sentences
93
  continue
94
  token_len = num_tokens_from_string(sentence)
95
  if total_tokens + token_len > max_tokens:
96
- if current_chunk: # Only add if there's content
97
  result_chunks.append(". ".join(current_chunk) + ("." if not current_chunk[-1].endswith(".") else ""))
98
  current_chunk = [sentence]
99
  total_tokens = token_len
@@ -112,22 +100,19 @@ def extract_text_from_pdf(pdf_file):
112
  return text
113
 
114
  def index_uploaded_text(text):
115
- # Reset the index and chunks
116
  global index
117
  index = faiss.IndexFlatL2(embedding_dim)
118
  st.session_state.chunks = []
119
  st.session_state.chunk_sources = []
120
-
121
- # Process text into chunks
122
  chunks_list = chunk_text(text)
123
  st.session_state.chunks = chunks_list
124
-
125
- # Create source references and vectors
126
  for i, chunk in enumerate(chunks_list):
127
  st.session_state.chunk_sources.append(f"Chunk {i+1}: {chunk[:50]}...")
128
  vector = embedder.encode([chunk])[0]
129
  index.add(np.array([vector]).astype('float32'))
130
-
131
  return len(chunks_list)
132
 
133
  def retrieve_chunks(query, top_k=5):
@@ -151,18 +136,13 @@ Answer: Please provide a comprehensive answer based only on the context provided
151
 
152
  def generate_answer(prompt):
153
  api_key = get_api_key()
154
-
155
  if not api_key:
156
  return "API key is missing. Please set the GROQ_API_KEY environment variable or enter it in the sidebar."
157
-
158
  headers = {
159
- "Authorization": f"Bearer {api_key.strip()}", # Strip to remove any whitespace
160
  "Content-Type": "application/json"
161
  }
162
-
163
- # Use the model selected by the user, default to llama3-8b if none selected
164
  selected_model = st.session_state.get("MODEL_CHOICE", "llama3-8b-8192")
165
-
166
  payload = {
167
  "model": selected_model,
168
  "messages": [
@@ -172,7 +152,6 @@ def generate_answer(prompt):
172
  "temperature": 0.3,
173
  "max_tokens": 1024
174
  }
175
-
176
  try:
177
  start_time = time.time()
178
  with st.spinner("Sending request to Groq API..."):
@@ -182,65 +161,33 @@ def generate_answer(prompt):
182
  headers=headers,
183
  timeout=30
184
  )
185
-
186
  query_time = time.time() - start_time
187
  st.session_state.last_query_time = f"{query_time:.2f} seconds"
188
-
189
- # For debugging - show only status code when debug mode is enabled
190
- if st.session_state.debug_mode:
191
- st.write(f"API Response Status Code: {response.status_code}")
192
- st.write(f"Response time: {query_time:.2f} seconds")
193
-
194
  if response.status_code == 401:
195
- return "Authentication failed: The API key appears to be invalid or expired. Please check your API key."
196
-
197
  if response.status_code == 400:
198
- # Display the detailed error for 400 Bad Request
199
  error_info = response.json().get("error", {})
200
  error_message = error_info.get("message", "Unknown error")
201
- error_type = error_info.get("type", "Unknown type")
202
-
203
- # Try alternate model if model not found
204
- if "model not found" in error_message.lower() or "model_not_found" in error_type.lower():
205
- st.warning("Trying with an alternate model (llama3-8b-8192)...")
206
  payload["model"] = "llama3-8b-8192"
207
-
208
- response = requests.post(
209
- "https://api.groq.com/openai/v1/chat/completions",
210
- json=payload,
211
- headers=headers,
212
- timeout=30
213
- )
214
-
215
  if response.status_code != 200:
216
- return f"Both model attempts failed. Please check the available models for your Groq API key. Error: {error_message}"
217
  else:
218
  return f"API Error: {error_message}"
219
-
220
- response.raise_for_status() # Raises an HTTPError for other bad responses
221
-
222
  response_json = response.json()
223
-
224
- if "choices" not in response_json:
225
- error_msg = f"Unexpected API response format. Response: {response_json}"
226
- if "error" in response_json:
227
- error_msg = f"API Error: {response_json['error'].get('message', 'Unknown error')}"
228
- st.error(error_msg)
229
- return "Sorry, I couldn't retrieve an answer due to an API error."
230
-
231
- if not response_json["choices"]:
232
  return "No answer was generated."
233
-
234
  answer = response_json["choices"][0]["message"]["content"]
235
  st.session_state.last_response = answer
236
  return answer
237
-
238
  except requests.exceptions.RequestException as e:
239
- st.error(f"API request failed: {str(e)}")
240
- return f"Sorry, I couldn't connect to the API service. Error: {str(e)}"
241
  except Exception as e:
242
- st.error(f"Unexpected error: {str(e)}")
243
- return f"Sorry, something went wrong. Error: {str(e)}"
244
 
245
  def translate_text(text, target_language):
246
  try:
@@ -248,7 +195,7 @@ def translate_text(text, target_language):
248
  return GoogleTranslator(source='auto', target=target_language).translate(text)
249
  except Exception as e:
250
  st.error(f"Translation failed: {str(e)}")
251
- return text # Return original text if translation fails
252
 
253
  def text_to_speech(text, lang_code):
254
  try:
@@ -260,7 +207,6 @@ def text_to_speech(text, lang_code):
260
  except Exception as e:
261
  st.error(f"Text-to-speech failed: {str(e)}")
262
  return None
263
-
264
  # Streamlit UI
265
  st.title("📄 Task-Specific RAG Assistant")
266
  st.markdown("Upload a document and ask questions to get AI-powered answers with translation capabilities.")
 
10
  from deep_translator import GoogleTranslator
11
  from gtts import gTTS
12
  import time
13
+
14
  st.set_page_config(
15
  page_title="RAG Document Assistant",
16
  page_icon="📄",
17
  layout="wide",
18
  initial_sidebar_state="expanded"
19
  )
20
+
 
21
  def sidebar_profiles():
22
+ st.sidebar.markdown("""<hr>""", unsafe_allow_html=True)
23
  st.sidebar.markdown("### 🎉Author: Maria Nadeem🌟")
24
  st.sidebar.markdown("### 🔗 Connect With Me")
25
  st.sidebar.markdown("""
 
41
  <hr>
42
  """, unsafe_allow_html=True)
43
 
 
 
44
  def get_api_key():
 
45
  api_key = os.getenv("GROQ_API_KEY")
 
 
46
  if not api_key:
47
  if "GROQ_API_KEY" in st.session_state:
48
  api_key = st.session_state["GROQ_API_KEY"]
 
49
  return api_key
50
 
51
+ # Session state initialization
52
+ for key, default in {
53
+ "chunks": [],
54
+ "chunk_sources": [],
55
+ "debug_mode": False,
56
+ "last_query_time": None,
57
+ "last_response": None
58
+ }.items():
59
+ if key not in st.session_state:
60
+ st.session_state[key] = default
61
+
 
 
62
  @st.cache_resource
63
  def load_embedder():
64
  return SentenceTransformer("all-MiniLM-L6-v2")
 
66
  embedder = load_embedder()
67
  embedding_dim = 384
68
  index = faiss.IndexFlatL2(embedding_dim)
 
 
69
  tokenizer = tiktoken.get_encoding("cl100k_base")
70
 
 
71
  def num_tokens_from_string(string: str) -> int:
72
  return len(tokenizer.encode(string))
73
 
 
77
  total_tokens = 0
78
  result_chunks = []
79
  for sentence in sentences:
80
+ if not sentence.strip():
81
  continue
82
  token_len = num_tokens_from_string(sentence)
83
  if total_tokens + token_len > max_tokens:
84
+ if current_chunk:
85
  result_chunks.append(". ".join(current_chunk) + ("." if not current_chunk[-1].endswith(".") else ""))
86
  current_chunk = [sentence]
87
  total_tokens = token_len
 
100
  return text
101
 
102
  def index_uploaded_text(text):
 
103
  global index
104
  index = faiss.IndexFlatL2(embedding_dim)
105
  st.session_state.chunks = []
106
  st.session_state.chunk_sources = []
107
+
 
108
  chunks_list = chunk_text(text)
109
  st.session_state.chunks = chunks_list
110
+
 
111
  for i, chunk in enumerate(chunks_list):
112
  st.session_state.chunk_sources.append(f"Chunk {i+1}: {chunk[:50]}...")
113
  vector = embedder.encode([chunk])[0]
114
  index.add(np.array([vector]).astype('float32'))
115
+
116
  return len(chunks_list)
117
 
118
  def retrieve_chunks(query, top_k=5):
 
136
 
137
  def generate_answer(prompt):
138
  api_key = get_api_key()
 
139
  if not api_key:
140
  return "API key is missing. Please set the GROQ_API_KEY environment variable or enter it in the sidebar."
 
141
  headers = {
142
+ "Authorization": f"Bearer {api_key.strip()}",
143
  "Content-Type": "application/json"
144
  }
 
 
145
  selected_model = st.session_state.get("MODEL_CHOICE", "llama3-8b-8192")
 
146
  payload = {
147
  "model": selected_model,
148
  "messages": [
 
152
  "temperature": 0.3,
153
  "max_tokens": 1024
154
  }
 
155
  try:
156
  start_time = time.time()
157
  with st.spinner("Sending request to Groq API..."):
 
161
  headers=headers,
162
  timeout=30
163
  )
 
164
  query_time = time.time() - start_time
165
  st.session_state.last_query_time = f"{query_time:.2f} seconds"
166
+
 
 
 
 
 
167
  if response.status_code == 401:
168
+ return "Authentication failed: Invalid or expired API key."
 
169
  if response.status_code == 400:
 
170
  error_info = response.json().get("error", {})
171
  error_message = error_info.get("message", "Unknown error")
172
+ if "model not found" in error_message.lower():
173
+ st.warning("Trying with alternate model...")
 
 
 
174
  payload["model"] = "llama3-8b-8192"
175
+ response = requests.post("https://api.groq.com/openai/v1/chat/completions", json=payload, headers=headers)
 
 
 
 
 
 
 
176
  if response.status_code != 200:
177
+ return f"Both model attempts failed. Error: {error_message}"
178
  else:
179
  return f"API Error: {error_message}"
180
+ response.raise_for_status()
 
 
181
  response_json = response.json()
182
+ if "choices" not in response_json or not response_json["choices"]:
 
 
 
 
 
 
 
 
183
  return "No answer was generated."
 
184
  answer = response_json["choices"][0]["message"]["content"]
185
  st.session_state.last_response = answer
186
  return answer
 
187
  except requests.exceptions.RequestException as e:
188
+ return f"API request failed: {str(e)}"
 
189
  except Exception as e:
190
+ return f"Unexpected error: {str(e)}"
 
191
 
192
  def translate_text(text, target_language):
193
  try:
 
195
  return GoogleTranslator(source='auto', target=target_language).translate(text)
196
  except Exception as e:
197
  st.error(f"Translation failed: {str(e)}")
198
+ return text
199
 
200
  def text_to_speech(text, lang_code):
201
  try:
 
207
  except Exception as e:
208
  st.error(f"Text-to-speech failed: {str(e)}")
209
  return None
 
210
  # Streamlit UI
211
  st.title("📄 Task-Specific RAG Assistant")
212
  st.markdown("Upload a document and ask questions to get AI-powered answers with translation capabilities.")