Spaces:
Sleeping
Sleeping
Update src/app.py
Browse filesupdated to add a rudimentary chat component
- src/app.py +139 -70
src/app.py
CHANGED
|
@@ -133,13 +133,32 @@ def update_sidebar_metrics():
|
|
| 133 |
# Call metrics once on load
|
| 134 |
update_sidebar_metrics()
|
| 135 |
|
| 136 |
-
def query_local_model(
|
| 137 |
if not API_URL_ROOT:
|
| 138 |
return "Error: API_URL not set.", None
|
| 139 |
|
| 140 |
url = API_URL_ROOT + "/generate"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
payload = {
|
| 142 |
-
"text":
|
| 143 |
"persona": system_persona,
|
| 144 |
"max_tokens": max_tokens,
|
| 145 |
"model": model_name
|
|
@@ -159,7 +178,7 @@ def query_local_model(user_prompt, system_persona, max_tokens, model_name):
|
|
| 159 |
except Exception as e:
|
| 160 |
return f"Connection Error: {e}", None
|
| 161 |
|
| 162 |
-
def
|
| 163 |
if not OPENAI_KEY:
|
| 164 |
return "Error: OPENAI_API_KEY not set.", None
|
| 165 |
|
|
@@ -169,10 +188,7 @@ def query_gpt4o(prompt, persona, max_tokens):
|
|
| 169 |
response = client.chat.completions.create(
|
| 170 |
model="gpt-4o",
|
| 171 |
max_tokens=max_tokens,
|
| 172 |
-
messages=
|
| 173 |
-
{"role": "system", "content": persona},
|
| 174 |
-
{"role": "user", "content": prompt}
|
| 175 |
-
],
|
| 176 |
temperature=0.3
|
| 177 |
)
|
| 178 |
usage_obj = response.usage
|
|
@@ -192,7 +208,7 @@ def clean_text(text):
|
|
| 192 |
|
| 193 |
def ask_ai(user_prompt, system_persona, max_tokens):
|
| 194 |
if "GPT-4o" in model_choice:
|
| 195 |
-
return
|
| 196 |
else:
|
| 197 |
technical_name = model_map[model_choice]
|
| 198 |
return query_local_model(user_prompt, system_persona, max_tokens, technical_name)
|
|
@@ -268,72 +284,125 @@ with tab1:
|
|
| 268 |
with tab2:
|
| 269 |
st.header("Choose Your Model and Start a Discussion")
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
| 277 |
with c1:
|
| 278 |
-
|
|
|
|
| 279 |
with c2:
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
f"### CONTEXT:\n{context_text}"
|
| 316 |
-
)
|
| 317 |
-
st.success(f"Found {len(retrieved_docs)} relevant documents.")
|
| 318 |
-
with st.expander("View Context Used"):
|
| 319 |
-
st.text(context_text)
|
| 320 |
-
else:
|
| 321 |
-
st.warning("No relevant documents found. Using general knowledge.")
|
| 322 |
-
|
| 323 |
-
# --- GENERATION ---
|
| 324 |
-
with st.spinner(f"Thinking with {model_choice}..."):
|
| 325 |
-
reply, usage = ask_ai(final_prompt, system_persona, max_len)
|
| 326 |
-
st.session_state.chat_response = reply
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
# --- TAB 3: PROMPT ARCHITECT ---
|
| 339 |
with tab3:
|
|
|
|
| 133 |
# Call metrics once on load
|
| 134 |
update_sidebar_metrics()
|
| 135 |
|
| 136 |
+
def query_local_model(messages, max_tokens, model_name):
|
| 137 |
if not API_URL_ROOT:
|
| 138 |
return "Error: API_URL not set.", None
|
| 139 |
|
| 140 |
url = API_URL_ROOT + "/generate"
|
| 141 |
+
|
| 142 |
+
# --- FLATTEN MESSAGE HISTORY ---
|
| 143 |
+
# Since the backend expects a single string ("text"), we format the history here.
|
| 144 |
+
# We extract the system persona separately to pass to the 'persona' field.
|
| 145 |
+
|
| 146 |
+
formatted_history = ""
|
| 147 |
+
system_persona = "You are a helpful assistant." # Default
|
| 148 |
+
|
| 149 |
+
for msg in messages:
|
| 150 |
+
if msg['role'] == 'system':
|
| 151 |
+
system_persona = msg['content']
|
| 152 |
+
elif msg['role'] == 'user':
|
| 153 |
+
formatted_history += f"User: {msg['content']}\n"
|
| 154 |
+
elif msg['role'] == 'assistant':
|
| 155 |
+
formatted_history += f"Assistant: {msg['content']}\n"
|
| 156 |
+
|
| 157 |
+
# Append the "Assistant:" prompt at the end to cue the model
|
| 158 |
+
formatted_history += "Assistant: "
|
| 159 |
+
|
| 160 |
payload = {
|
| 161 |
+
"text": formatted_history, # <--- History goes here
|
| 162 |
"persona": system_persona,
|
| 163 |
"max_tokens": max_tokens,
|
| 164 |
"model": model_name
|
|
|
|
| 178 |
except Exception as e:
|
| 179 |
return f"Connection Error: {e}", None
|
| 180 |
|
| 181 |
+
def query_openai_model(messages, max_tokens):
|
| 182 |
if not OPENAI_KEY:
|
| 183 |
return "Error: OPENAI_API_KEY not set.", None
|
| 184 |
|
|
|
|
| 188 |
response = client.chat.completions.create(
|
| 189 |
model="gpt-4o",
|
| 190 |
max_tokens=max_tokens,
|
| 191 |
+
messages=messages,
|
|
|
|
|
|
|
|
|
|
| 192 |
temperature=0.3
|
| 193 |
)
|
| 194 |
usage_obj = response.usage
|
|
|
|
| 208 |
|
| 209 |
def ask_ai(user_prompt, system_persona, max_tokens):
|
| 210 |
if "GPT-4o" in model_choice:
|
| 211 |
+
return query_local_model(user_prompt, system_persona, max_tokens)
|
| 212 |
else:
|
| 213 |
technical_name = model_map[model_choice]
|
| 214 |
return query_local_model(user_prompt, system_persona, max_tokens, technical_name)
|
|
|
|
| 284 |
with tab2:
|
| 285 |
st.header("Choose Your Model and Start a Discussion")
|
| 286 |
|
| 287 |
+
# --- INITIALIZE CHAT MEMORY (MUST BE DONE FIRST) ---
|
| 288 |
+
if "messages" not in st.session_state:
|
| 289 |
+
st.session_state.messages = []
|
| 290 |
+
|
| 291 |
+
# --- CONTROLS AND METRICS ---
|
| 292 |
+
# The controls are kept outside the chat loop.
|
| 293 |
+
c1, c2, c3 = st.columns([1, 1, 1])
|
| 294 |
with c1:
|
| 295 |
+
# Use the global model_choice from the sidebar/tab1 initialization
|
| 296 |
+
selected_model_name = st.session_state.get('model_choice', 'Granite 4 (IBM)')
|
| 297 |
with c2:
|
| 298 |
+
use_rag = st.toggle("🔌 Enable Knowledge Base", value=False)
|
| 299 |
+
# The token progress bar will be handled inside the prompt logic based on input length
|
| 300 |
+
|
| 301 |
+
# --- DISPLAY CONVERSATION HISTORY ---
|
| 302 |
+
for message in st.session_state.messages:
|
| 303 |
+
with st.chat_message(message["role"]):
|
| 304 |
+
st.markdown(message["content"])
|
| 305 |
+
|
| 306 |
+
# --- CHAT INPUT HANDLING (Replaces st.text_input and st.button) ---
|
| 307 |
+
if prompt := st.chat_input("Ask about Naval Systems..."):
|
| 308 |
+
# 1. Display User Message and save to history
|
| 309 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 310 |
+
with st.chat_message("user"):
|
| 311 |
+
st.markdown(prompt)
|
| 312 |
+
|
| 313 |
+
# 2. Initialize the Payload with System Persona
|
| 314 |
+
system_persona = "You are a Navy Document Analyst. Your task is to answer the user's question using ONLY the Context provided below. If the answer is not present in the Context, return ONLY this exact phrase: 'I cannot find that information in the provided documents.' If no context is provided, answer generally."
|
| 315 |
+
|
| 316 |
+
# Start the message payload with the system persona
|
| 317 |
+
messages_payload = [{"role": "system", "content": system_persona}]
|
| 318 |
+
|
| 319 |
+
# --- MEMORY LOGIC: SLIDING WINDOW ---
|
| 320 |
+
# Get the last N messages (e.g., 6 total: 3 user + 3 assistant) for memory.
|
| 321 |
+
# We start from -7 because we need to exclude the current prompt (already added)
|
| 322 |
+
# and we want pairs of messages (user/assistant).
|
| 323 |
+
history_depth = 8 # 4 full exchanges (8 messages) + current
|
| 324 |
+
recent_history = st.session_state.messages[-(history_depth+1):-1]
|
| 325 |
+
|
| 326 |
+
# Add history to payload
|
| 327 |
+
messages_payload.extend(recent_history)
|
| 328 |
+
|
| 329 |
+
# 3. Handle RAG & Current Prompt Augmentation
|
| 330 |
+
final_user_content = prompt
|
| 331 |
+
retrieved_docs = [] # Initialize for the context display later
|
| 332 |
+
|
| 333 |
+
if use_rag:
|
| 334 |
+
with st.spinner("🧠 Searching Knowledge Base..."):
|
| 335 |
+
# Retrieve Docs
|
| 336 |
+
retrieved_docs = rag_engine.search_knowledge_base(
|
| 337 |
+
prompt,
|
| 338 |
+
st.session_state.username
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Format Context
|
| 342 |
+
context_text = ""
|
| 343 |
+
if retrieved_docs:
|
| 344 |
+
for doc in retrieved_docs:
|
| 345 |
+
score = doc.metadata.get('relevance_score', 'N/A')
|
| 346 |
+
src = os.path.basename(doc.metadata.get('source', 'Unknown'))
|
| 347 |
+
context_text += f"---\nSOURCE: {src} (Rel: {score})\nTEXT: {doc.page_content}\n"
|
| 348 |
|
| 349 |
+
# Augment the FINAL prompt with RAG context
|
| 350 |
+
final_user_content = (
|
| 351 |
+
f"User Question: {prompt}\n\n"
|
| 352 |
+
f"Relevant Context:\n{context_text}\n\n"
|
| 353 |
+
"Answer the question using the context provided."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# 4. Add the final (potentially augmented) user message to payload
|
| 357 |
+
messages_payload.append({"role": "user", "content": final_user_content})
|
| 358 |
+
|
| 359 |
+
# 5. Generate Response and Display
|
| 360 |
+
with st.chat_message("assistant"):
|
| 361 |
+
with st.spinner(f"Thinking with {selected_model_name}..."):
|
| 362 |
+
# Determine model ID and max_len (assuming these are defined globally)
|
| 363 |
+
max_len = 2000 # Example max length
|
| 364 |
+
model_id = "" # To be mapped
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
+
# --- MODEL MAPPING LOGIC (Use your existing global logic) ---
|
| 367 |
+
ollama_map = {
|
| 368 |
+
"Granite 4 (IBM)": "granite4:latest",
|
| 369 |
+
"Llama 3.2 (Meta)": "llama3.2:latest",
|
| 370 |
+
"Gemma 3 (Google)": "gemma3:latest"
|
| 371 |
+
}
|
| 372 |
+
for key, val in ollama_map.items():
|
| 373 |
+
if key in selected_model_name:
|
| 374 |
+
model_id = val
|
| 375 |
+
break
|
| 376 |
+
|
| 377 |
+
if not model_id and "gpt" in selected_model_name.lower():
|
| 378 |
+
# If it's the GPT model choice
|
| 379 |
+
response, usage = query_openai_model(messages_payload, max_len)
|
| 380 |
+
elif model_id:
|
| 381 |
+
# If it's the local Ollama model
|
| 382 |
+
response, usage = query_local_model(messages_payload, max_len, model_id)
|
| 383 |
+
else:
|
| 384 |
+
response, usage = "Error: Could not determine model to use.", None
|
| 385 |
+
|
| 386 |
+
st.markdown(response)
|
| 387 |
+
|
| 388 |
+
# 6. Final Steps: Save Assistant Response and Update Metrics
|
| 389 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 390 |
+
|
| 391 |
+
if usage:
|
| 392 |
+
m_name = "Granite" if "Granite" in selected_model_name else "GPT-4o"
|
| 393 |
+
tracker.log_usage(m_name, usage["input"], usage["output"])
|
| 394 |
+
# Assuming update_sidebar_metrics() is defined globally
|
| 395 |
+
update_sidebar_metrics()
|
| 396 |
+
|
| 397 |
+
# 7. Display Context Used (if RAG was enabled)
|
| 398 |
+
if use_rag and retrieved_docs:
|
| 399 |
+
with st.expander("📚 View Context Used"):
|
| 400 |
+
for i, doc in enumerate(retrieved_docs):
|
| 401 |
+
score = doc.metadata.get('relevance_score', 'N/A')
|
| 402 |
+
src = os.path.basename(doc.metadata.get('source', 'Unknown'))
|
| 403 |
+
st.caption(f"Rank {i+1} (Source: {src}, Rel: {score})")
|
| 404 |
+
st.text(doc.page_content)
|
| 405 |
+
st.divider()
|
| 406 |
|
| 407 |
# --- TAB 3: PROMPT ARCHITECT ---
|
| 408 |
with tab3:
|