NavyDevilDoc commited on
Commit
8ba0387
·
verified ·
1 Parent(s): 1fd5385

Update src/app.py

Browse files

updated to add a rudimentary chat component

Files changed (1) hide show
  1. src/app.py +139 -70
src/app.py CHANGED
@@ -133,13 +133,32 @@ def update_sidebar_metrics():
133
  # Call metrics once on load
134
  update_sidebar_metrics()
135
 
136
- def query_local_model(user_prompt, system_persona, max_tokens, model_name):
137
  if not API_URL_ROOT:
138
  return "Error: API_URL not set.", None
139
 
140
  url = API_URL_ROOT + "/generate"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  payload = {
142
- "text": user_prompt,
143
  "persona": system_persona,
144
  "max_tokens": max_tokens,
145
  "model": model_name
@@ -159,7 +178,7 @@ def query_local_model(user_prompt, system_persona, max_tokens, model_name):
159
  except Exception as e:
160
  return f"Connection Error: {e}", None
161
 
162
- def query_gpt4o(prompt, persona, max_tokens):
163
  if not OPENAI_KEY:
164
  return "Error: OPENAI_API_KEY not set.", None
165
 
@@ -169,10 +188,7 @@ def query_gpt4o(prompt, persona, max_tokens):
169
  response = client.chat.completions.create(
170
  model="gpt-4o",
171
  max_tokens=max_tokens,
172
- messages=[
173
- {"role": "system", "content": persona},
174
- {"role": "user", "content": prompt}
175
- ],
176
  temperature=0.3
177
  )
178
  usage_obj = response.usage
@@ -192,7 +208,7 @@ def clean_text(text):
192
 
193
  def ask_ai(user_prompt, system_persona, max_tokens):
194
  if "GPT-4o" in model_choice:
195
- return query_gpt4o(user_prompt, system_persona, max_tokens)
196
  else:
197
  technical_name = model_map[model_choice]
198
  return query_local_model(user_prompt, system_persona, max_tokens, technical_name)
@@ -268,72 +284,125 @@ with tab1:
268
  with tab2:
269
  st.header("Choose Your Model and Start a Discussion")
270
 
271
- if "chat_response" not in st.session_state:
272
- st.session_state.chat_response = ""
273
-
274
- user_input = st.text_input("Ask a question:")
275
-
276
- c1, c2 = st.columns([1,1])
 
277
  with c1:
278
- use_rag = st.toggle("🔌 Enable Knowledge Base", value=False)
 
279
  with c2:
280
- est_tokens = len(user_input) / 4
281
- st.progress(min(est_tokens / 2000, 1.0), text=f"Input: {int(est_tokens)} tokens")
282
-
283
- if st.button("Send Query"):
284
- if not user_input:
285
- st.warning("Please enter a question.")
286
- else:
287
- final_prompt = user_input
288
- system_persona = "You are a helpful assistant."
289
-
290
- # --- RAG LOGIC ---
291
- if use_rag:
292
- with st.spinner("🧠 Searching Knowledge Base..."):
293
- # 1. Retrieve & Rerank (Now using the fixed function)
294
- retrieved_docs = rag_engine.search_knowledge_base(
295
- user_input,
296
- st.session_state.username
297
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
- if retrieved_docs:
300
- # 2. Format Context
301
- context_text = ""
302
- for i, doc in enumerate(retrieved_docs):
303
- # Add metadata relevance score if available
304
- score = doc.metadata.get('relevance_score', 'N/A')
305
- src = os.path.basename(doc.metadata.get('source', 'Unknown'))
306
- context_text += f"---\nSOURCE: {src} (Rel: {score})\nTEXT: {doc.page_content}\n"
307
-
308
- # 3. Update Prompt
309
- system_persona = (
310
- "You are a Navy Document Analyst. "
311
- "Your task is to answer the user's question using ONLY the Context provided below. "
312
- "Follow these rules strictly:\n"
313
- "1. If the answer is present in the Context, provide it clearly. Do NOT add any disclaimers about missing information if you found the answer.\n"
314
- "2. If the answer is NOT present in the Context, return ONLY this exact phrase: 'I cannot find that information in the provided documents.'\n\n"
315
- f"### CONTEXT:\n{context_text}"
316
- )
317
- st.success(f"Found {len(retrieved_docs)} relevant documents.")
318
- with st.expander("View Context Used"):
319
- st.text(context_text)
320
- else:
321
- st.warning("No relevant documents found. Using general knowledge.")
322
-
323
- # --- GENERATION ---
324
- with st.spinner(f"Thinking with {model_choice}..."):
325
- reply, usage = ask_ai(final_prompt, system_persona, max_len)
326
- st.session_state.chat_response = reply
327
 
328
- if usage:
329
- m_name = "Granite" if "Granite" in model_choice else "GPT-4o"
330
- tracker.log_usage(m_name, usage["input"], usage["output"])
331
- update_sidebar_metrics()
332
-
333
- if st.session_state.chat_response:
334
- st.divider()
335
- st.markdown("**AI Response:**")
336
- st.write(st.session_state.chat_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  # --- TAB 3: PROMPT ARCHITECT ---
339
  with tab3:
 
133
  # Call metrics once on load
134
  update_sidebar_metrics()
135
 
136
+ def query_local_model(messages, max_tokens, model_name):
137
  if not API_URL_ROOT:
138
  return "Error: API_URL not set.", None
139
 
140
  url = API_URL_ROOT + "/generate"
141
+
142
+ # --- FLATTEN MESSAGE HISTORY ---
143
+ # Since the backend expects a single string ("text"), we format the history here.
144
+ # We extract the system persona separately to pass to the 'persona' field.
145
+
146
+ formatted_history = ""
147
+ system_persona = "You are a helpful assistant." # Default
148
+
149
+ for msg in messages:
150
+ if msg['role'] == 'system':
151
+ system_persona = msg['content']
152
+ elif msg['role'] == 'user':
153
+ formatted_history += f"User: {msg['content']}\n"
154
+ elif msg['role'] == 'assistant':
155
+ formatted_history += f"Assistant: {msg['content']}\n"
156
+
157
+ # Append the "Assistant:" prompt at the end to cue the model
158
+ formatted_history += "Assistant: "
159
+
160
  payload = {
161
+ "text": formatted_history, # <--- History goes here
162
  "persona": system_persona,
163
  "max_tokens": max_tokens,
164
  "model": model_name
 
178
  except Exception as e:
179
  return f"Connection Error: {e}", None
180
 
181
+ def query_openai_model(messages, max_tokens):
182
  if not OPENAI_KEY:
183
  return "Error: OPENAI_API_KEY not set.", None
184
 
 
188
  response = client.chat.completions.create(
189
  model="gpt-4o",
190
  max_tokens=max_tokens,
191
+ messages=messages,
 
 
 
192
  temperature=0.3
193
  )
194
  usage_obj = response.usage
 
208
 
209
  def ask_ai(user_prompt, system_persona, max_tokens):
210
  if "GPT-4o" in model_choice:
211
+ return query_local_model(user_prompt, system_persona, max_tokens)
212
  else:
213
  technical_name = model_map[model_choice]
214
  return query_local_model(user_prompt, system_persona, max_tokens, technical_name)
 
284
  with tab2:
285
  st.header("Choose Your Model and Start a Discussion")
286
 
287
+ # --- INITIALIZE CHAT MEMORY (MUST BE DONE FIRST) ---
288
+ if "messages" not in st.session_state:
289
+ st.session_state.messages = []
290
+
291
+ # --- CONTROLS AND METRICS ---
292
+ # The controls are kept outside the chat loop.
293
+ c1, c2, c3 = st.columns([1, 1, 1])
294
  with c1:
295
+ # Use the global model_choice from the sidebar/tab1 initialization
296
+ selected_model_name = st.session_state.get('model_choice', 'Granite 4 (IBM)')
297
  with c2:
298
+ use_rag = st.toggle("🔌 Enable Knowledge Base", value=False)
299
+ # The token progress bar will be handled inside the prompt logic based on input length
300
+
301
+ # --- DISPLAY CONVERSATION HISTORY ---
302
+ for message in st.session_state.messages:
303
+ with st.chat_message(message["role"]):
304
+ st.markdown(message["content"])
305
+
306
+ # --- CHAT INPUT HANDLING (Replaces st.text_input and st.button) ---
307
+ if prompt := st.chat_input("Ask about Naval Systems..."):
308
+ # 1. Display User Message and save to history
309
+ st.session_state.messages.append({"role": "user", "content": prompt})
310
+ with st.chat_message("user"):
311
+ st.markdown(prompt)
312
+
313
+ # 2. Initialize the Payload with System Persona
314
+ system_persona = "You are a Navy Document Analyst. Your task is to answer the user's question using ONLY the Context provided below. If the answer is not present in the Context, return ONLY this exact phrase: 'I cannot find that information in the provided documents.' If no context is provided, answer generally."
315
+
316
+ # Start the message payload with the system persona
317
+ messages_payload = [{"role": "system", "content": system_persona}]
318
+
319
+ # --- MEMORY LOGIC: SLIDING WINDOW ---
320
+ # Get the last N messages (e.g., 6 total: 3 user + 3 assistant) for memory.
321
+ # We start from -7 because we need to exclude the current prompt (already added)
322
+ # and we want pairs of messages (user/assistant).
323
+ history_depth = 8 # 4 full exchanges (8 messages) + current
324
+ recent_history = st.session_state.messages[-(history_depth+1):-1]
325
+
326
+ # Add history to payload
327
+ messages_payload.extend(recent_history)
328
+
329
+ # 3. Handle RAG & Current Prompt Augmentation
330
+ final_user_content = prompt
331
+ retrieved_docs = [] # Initialize for the context display later
332
+
333
+ if use_rag:
334
+ with st.spinner("🧠 Searching Knowledge Base..."):
335
+ # Retrieve Docs
336
+ retrieved_docs = rag_engine.search_knowledge_base(
337
+ prompt,
338
+ st.session_state.username
339
+ )
340
+
341
+ # Format Context
342
+ context_text = ""
343
+ if retrieved_docs:
344
+ for doc in retrieved_docs:
345
+ score = doc.metadata.get('relevance_score', 'N/A')
346
+ src = os.path.basename(doc.metadata.get('source', 'Unknown'))
347
+ context_text += f"---\nSOURCE: {src} (Rel: {score})\nTEXT: {doc.page_content}\n"
348
 
349
+ # Augment the FINAL prompt with RAG context
350
+ final_user_content = (
351
+ f"User Question: {prompt}\n\n"
352
+ f"Relevant Context:\n{context_text}\n\n"
353
+ "Answer the question using the context provided."
354
+ )
355
+
356
+ # 4. Add the final (potentially augmented) user message to payload
357
+ messages_payload.append({"role": "user", "content": final_user_content})
358
+
359
+ # 5. Generate Response and Display
360
+ with st.chat_message("assistant"):
361
+ with st.spinner(f"Thinking with {selected_model_name}..."):
362
+ # Determine model ID and max_len (assuming these are defined globally)
363
+ max_len = 2000 # Example max length
364
+ model_id = "" # To be mapped
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
+ # --- MODEL MAPPING LOGIC (Use your existing global logic) ---
367
+ ollama_map = {
368
+ "Granite 4 (IBM)": "granite4:latest",
369
+ "Llama 3.2 (Meta)": "llama3.2:latest",
370
+ "Gemma 3 (Google)": "gemma3:latest"
371
+ }
372
+ for key, val in ollama_map.items():
373
+ if key in selected_model_name:
374
+ model_id = val
375
+ break
376
+
377
+ if not model_id and "gpt" in selected_model_name.lower():
378
+ # If it's the GPT model choice
379
+ response, usage = query_openai_model(messages_payload, max_len)
380
+ elif model_id:
381
+ # If it's the local Ollama model
382
+ response, usage = query_local_model(messages_payload, max_len, model_id)
383
+ else:
384
+ response, usage = "Error: Could not determine model to use.", None
385
+
386
+ st.markdown(response)
387
+
388
+ # 6. Final Steps: Save Assistant Response and Update Metrics
389
+ st.session_state.messages.append({"role": "assistant", "content": response})
390
+
391
+ if usage:
392
+ m_name = "Granite" if "Granite" in selected_model_name else "GPT-4o"
393
+ tracker.log_usage(m_name, usage["input"], usage["output"])
394
+ # Assuming update_sidebar_metrics() is defined globally
395
+ update_sidebar_metrics()
396
+
397
+ # 7. Display Context Used (if RAG was enabled)
398
+ if use_rag and retrieved_docs:
399
+ with st.expander("📚 View Context Used"):
400
+ for i, doc in enumerate(retrieved_docs):
401
+ score = doc.metadata.get('relevance_score', 'N/A')
402
+ src = os.path.basename(doc.metadata.get('source', 'Unknown'))
403
+ st.caption(f"Rank {i+1} (Source: {src}, Rel: {score})")
404
+ st.text(doc.page_content)
405
+ st.divider()
406
 
407
  # --- TAB 3: PROMPT ARCHITECT ---
408
  with tab3: