champ-chatbot / champ /service.py
qyle's picture
deployment for load testing
e82f783 verified
# app/champ/service.py
import logging
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple
from langchain_community.vectorstores import FAISS as LCFAISS
from langchain_core.messages import HumanMessage
from champ.qwen_agent import QwenAgent
from .agent import build_champ_agent
from .triage import safety_triage
logger = logging.getLogger("uvicorn")
class ChampService:
vector_store: Optional[LCFAISS] = None
agent = None
lang = None
context_store = None
def __init__(
self,
vector_store: LCFAISS,
lang: Literal["en", "fr"],
model_type: str = "champ",
prompt_template: str | None = None,
):
self.vector_store = vector_store
self.model_type = model_type
if model_type == "champ":
self.agent, self.context_store = build_champ_agent(self.vector_store, lang, prompt_template=prompt_template)
elif model_type == "qwen":
self.agent = QwenAgent(self.vector_store, lang)
def invoke(
self, lc_messages: Sequence
) -> Tuple[str, Dict[str, Any], List[str], int]:
"""Invokes the agent.
Args:
lc_messages (Sequence): Sequence of LangChain messages
Raises:
RuntimeError: Raised when the function is called before CHAMP is initialized
Returns:
Tuple[str, Dict[str, Any], List[str], int]: The replay, the triage_triggered object,
the retrieved passages, and the number of output tokens
"""
if self.agent is None:
logger.error("CHAMP invoked before initialization")
raise RuntimeError("CHAMP is not initialized yet.")
# --- Safety triage micro-layer (before LLM) ---
last_user_text = None
for m in reversed(lc_messages):
if isinstance(m, HumanMessage):
last_user_text = m.content
break
if last_user_text:
triggered, override_reply, reason = safety_triage(last_user_text)
if triggered and override_reply is not None:
return (
override_reply,
{
"triage_triggered": True,
"triage_reason": reason,
},
[], # No retrieved documents
0,
)
if self.model_type == "champ":
result = self.agent.invoke({"messages": list(lc_messages)}) # type: ignore
retrieved_passages = (
self.context_store["last_retrieved_docs"]
if self.context_store is not None
else []
)
output_message = result["messages"][-1] # pyright: ignore[reportCallIssue, reportArgumentType]
return (
output_message.text.strip(),
{
"triage_triggered": False,
},
retrieved_passages,
# output_message.usage_metadata["output_tokens"], This value is inaccurate because Champ is an agent. We use tiktoken instead to estimate the number of output tokens.
0,
)
elif self.model_type == "qwen":
chat_response, retrieved_passages, output_tokens = self.agent.invoke(
list(lc_messages) # type: ignore
)
return (
chat_response,
{
"triage_triggered": False,
},
retrieved_passages,
output_tokens,
) # pyright: ignore[reportReturnType]
raise ValueError(f"Invalid model type (should never happen): {self.model_type}")