ZXAI-Backend / router.py
ZBro7's picture
Update router.py
35c109d verified
import asyncio
import time
import requests
from llm_clients import (
call_llama,
call_gemini,
classify_prompt,
judge_answers
)
from memory import save_message, load_memory
from search_tool import search_web
from rag_engine import rag_response
# =====================================
# CONFIG
# =====================================
IMAGE_SPACE_URL = "https://your-image-space.hf.space/generate"
CACHE_TTL_SECONDS = 300 # 5 minutes
response_cache = {}
# =====================================
# CACHE HELPERS
# =====================================
def get_cached_response(cache_key):
entry = response_cache.get(cache_key)
if not entry:
return None
if time.time() > entry["expires_at"]:
del response_cache[cache_key]
return None
return entry["response"]
def set_cache(cache_key, response):
response_cache[cache_key] = {
"response": response,
"expires_at": time.time() + CACHE_TTL_SECONDS
}
# =====================================
# MESSAGE BUILDER
# =====================================
def build_messages(system_prompt, memory, user_prompt):
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(memory)
messages.append({"role": "user", "content": user_prompt})
return messages
# =====================================
# IMAGE SERVICE (Async Safe)
# =====================================
async def call_image_microservice(prompt):
try:
return await asyncio.to_thread(
lambda: requests.post(
IMAGE_SPACE_URL,
json={"prompt": prompt},
timeout=60
).json()
)
except Exception:
return {"error": "Image service unavailable"}
# =====================================
# ASYNC LLM WRAPPERS
# =====================================
async def async_llama(messages):
return await asyncio.to_thread(call_llama, messages)
async def async_gemini(messages):
return await asyncio.to_thread(call_gemini, messages)
# =====================================
# MAIN ROUTER
# =====================================
async def route_request(prompt, user_id):
cache_key = f"{user_id}:{prompt}"
# ==========================
# CACHE CHECK
# ==========================
cached = get_cached_response(cache_key)
if cached:
return {"response": cached}
# ==========================
# IMAGE COMMAND
# ==========================
if prompt.startswith("/image"):
clean_prompt = prompt.replace("/image", "").strip()
return await call_image_microservice(clean_prompt)
# ==========================
# RAG QUICK RESPONSE
# ==========================
rag_answer = rag_response(prompt)
if rag_answer:
set_cache(cache_key, rag_answer)
return {"response": rag_answer}
# ==========================
# LOAD MEMORY
# ==========================
memory = load_memory(user_id)
# ==========================
# CLASSIFY
# ==========================
classification = classify_prompt(prompt)
intent = classification.get("intent", "chat")
needs_search = classification.get("needs_search", False)
system_prompt = "You are ZXAI, an advanced AI assistant."
# ==========================
# GREETING FAST PATH
# ==========================
if intent == "greeting":
response = "Hello πŸ‘‹ I am ZXAI. How can I help you today?"
save_message(user_id, "user", prompt)
save_message(user_id, "assistant", response)
set_cache(cache_key, response)
return {"response": response}
# ==========================
# REASONING β†’ GEMINI
# ==========================
if intent == "reasoning":
messages = build_messages(system_prompt, memory, prompt)
response = await async_gemini(messages)
save_message(user_id, "user", prompt)
save_message(user_id, "assistant", response)
set_cache(cache_key, response)
return {"response": response}
# ==========================
# LIVE DATA (Parallel LLM)
# ==========================
if intent == "live_data" or needs_search:
web_data = search_web(prompt)
enriched_prompt = f"""
User Question:
{prompt}
Web Data:
{web_data}
Use web data if helpful.
"""
messages = build_messages(system_prompt, memory, enriched_prompt)
llama_task = asyncio.create_task(async_llama(messages))
gemini_task = asyncio.create_task(async_gemini(messages))
llama_answer = await llama_task
gemini_answer = await gemini_task
winner = judge_answers(llama_answer, gemini_answer)
final_answer = gemini_answer if winner == 2 else llama_answer
save_message(user_id, "user", prompt)
save_message(user_id, "assistant", final_answer)
set_cache(cache_key, final_answer)
return {"response": final_answer}
# ==========================
# DEFAULT CHAT β†’ LLAMA
# ==========================
messages = build_messages(system_prompt, memory, prompt)
response = await async_llama(messages)
save_message(user_id, "user", prompt)
save_message(user_id, "assistant", response)
set_cache(cache_key, response)
return {"response": response}