Spaces:
Sleeping
Sleeping
Update src/app.py
Browse files- src/app.py +20 -3
src/app.py
CHANGED
|
@@ -207,11 +207,22 @@ def clean_text(text):
|
|
| 207 |
return text.strip()
|
| 208 |
|
| 209 |
def ask_ai(user_prompt, system_persona, max_tokens):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
if "GPT-4o" in model_choice:
|
| 211 |
-
|
|
|
|
| 212 |
else:
|
|
|
|
| 213 |
technical_name = model_map[model_choice]
|
| 214 |
-
|
|
|
|
| 215 |
|
| 216 |
# --- MAIN UI ---
|
| 217 |
st.title("AI Toolkit")
|
|
@@ -272,7 +283,13 @@ with tab1:
|
|
| 272 |
st.session_state.email_draft = reply
|
| 273 |
|
| 274 |
if usage:
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
tracker.log_usage(m_name, usage["input"], usage["output"])
|
| 277 |
update_sidebar_metrics() # Force update
|
| 278 |
|
|
|
|
| 207 |
return text.strip()
|
| 208 |
|
| 209 |
def ask_ai(user_prompt, system_persona, max_tokens):
|
| 210 |
+
# 1. Standardize Input: Convert the strings into the Message List format
|
| 211 |
+
# This ensures compatibility with our new memory-aware backend functions
|
| 212 |
+
messages_payload = [
|
| 213 |
+
{"role": "system", "content": system_persona},
|
| 214 |
+
{"role": "user", "content": user_prompt}
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
# 2. Routing Logic
|
| 218 |
if "GPT-4o" in model_choice:
|
| 219 |
+
# CORRECTED: Now calls the OpenAI function
|
| 220 |
+
return query_openai_model(messages_payload, max_tokens)
|
| 221 |
else:
|
| 222 |
+
# Lookup the technical name for Ollama
|
| 223 |
technical_name = model_map[model_choice]
|
| 224 |
+
# Calls the Local function
|
| 225 |
+
return query_local_model(messages_payload, max_tokens, technical_name)
|
| 226 |
|
| 227 |
# --- MAIN UI ---
|
| 228 |
st.title("AI Toolkit")
|
|
|
|
| 283 |
st.session_state.email_draft = reply
|
| 284 |
|
| 285 |
if usage:
|
| 286 |
+
# 1. Determine a clean name for the log
|
| 287 |
+
if "GPT-4o" in model_choice:
|
| 288 |
+
m_name = "GPT-4o"
|
| 289 |
+
else:
|
| 290 |
+
# Use the first word of the model choice (e.g., "Llama", "Gemma", "Granite")
|
| 291 |
+
m_name = model_choice.split(" ")[0]
|
| 292 |
+
# 2. Log it
|
| 293 |
tracker.log_usage(m_name, usage["input"], usage["output"])
|
| 294 |
update_sidebar_metrics() # Force update
|
| 295 |
|