NavyDevilDoc commited on
Commit
ac7456f
·
verified ·
1 Parent(s): 172135e

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +20 -3
src/app.py CHANGED
@@ -207,11 +207,22 @@ def clean_text(text):
207
  return text.strip()
208
 
209
  def ask_ai(user_prompt, system_persona, max_tokens):
 
 
 
 
 
 
 
 
210
  if "GPT-4o" in model_choice:
211
- return query_local_model(user_prompt, system_persona, max_tokens)
 
212
  else:
 
213
  technical_name = model_map[model_choice]
214
- return query_local_model(user_prompt, system_persona, max_tokens, technical_name)
 
215
 
216
  # --- MAIN UI ---
217
  st.title("AI Toolkit")
@@ -272,7 +283,13 @@ with tab1:
272
  st.session_state.email_draft = reply
273
 
274
  if usage:
275
- m_name = "Granite" if "Granite" in model_choice else "GPT-4o"
 
 
 
 
 
 
276
  tracker.log_usage(m_name, usage["input"], usage["output"])
277
  update_sidebar_metrics() # Force update
278
 
 
207
  return text.strip()
208
 
209
  def ask_ai(user_prompt, system_persona, max_tokens):
210
+ # 1. Standardize Input: Convert the strings into the Message List format
211
+ # This ensures compatibility with our new memory-aware backend functions
212
+ messages_payload = [
213
+ {"role": "system", "content": system_persona},
214
+ {"role": "user", "content": user_prompt}
215
+ ]
216
+
217
+ # 2. Routing Logic
218
  if "GPT-4o" in model_choice:
219
+ # CORRECTED: Now calls the OpenAI function
220
+ return query_openai_model(messages_payload, max_tokens)
221
  else:
222
+ # Lookup the technical name for Ollama
223
  technical_name = model_map[model_choice]
224
+ # Calls the Local function
225
+ return query_local_model(messages_payload, max_tokens, technical_name)
226
 
227
  # --- MAIN UI ---
228
  st.title("AI Toolkit")
 
283
  st.session_state.email_draft = reply
284
 
285
  if usage:
286
+ # 1. Determine a clean name for the log
287
+ if "GPT-4o" in model_choice:
288
+ m_name = "GPT-4o"
289
+ else:
290
+ # Use the first word of the model choice (e.g., "Llama", "Gemma", "Granite")
291
+ m_name = model_choice.split(" ")[0]
292
+ # 2. Log it
293
  tracker.log_usage(m_name, usage["input"], usage["output"])
294
  update_sidebar_metrics() # Force update
295