Spaces:
Paused
Paused
| # app/champ/service.py | |
| import logging | |
| from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple | |
| from langchain_community.vectorstores import FAISS as LCFAISS | |
| from langchain_core.messages import HumanMessage | |
| from champ.qwen_agent import QwenAgent | |
| from .agent import build_champ_agent | |
| from .triage import safety_triage | |
| logger = logging.getLogger("uvicorn") | |
| class ChampService: | |
| vector_store: Optional[LCFAISS] = None | |
| agent = None | |
| lang = None | |
| context_store = None | |
| def __init__( | |
| self, | |
| vector_store: LCFAISS, | |
| lang: Literal["en", "fr"], | |
| model_type: str = "champ", | |
| prompt_template: str | None = None, | |
| ): | |
| self.vector_store = vector_store | |
| self.model_type = model_type | |
| if model_type == "champ": | |
| self.agent, self.context_store = build_champ_agent(self.vector_store, lang, prompt_template=prompt_template) | |
| elif model_type == "qwen": | |
| self.agent = QwenAgent(self.vector_store, lang) | |
| def invoke( | |
| self, lc_messages: Sequence | |
| ) -> Tuple[str, Dict[str, Any], List[str], int]: | |
| """Invokes the agent. | |
| Args: | |
| lc_messages (Sequence): Sequence of LangChain messages | |
| Raises: | |
| RuntimeError: Raised when the function is called before CHAMP is initialized | |
| Returns: | |
| Tuple[str, Dict[str, Any], List[str], int]: The replay, the triage_triggered object, | |
| the retrieved passages, and the number of output tokens | |
| """ | |
| if self.agent is None: | |
| logger.error("CHAMP invoked before initialization") | |
| raise RuntimeError("CHAMP is not initialized yet.") | |
| # --- Safety triage micro-layer (before LLM) --- | |
| last_user_text = None | |
| for m in reversed(lc_messages): | |
| if isinstance(m, HumanMessage): | |
| last_user_text = m.content | |
| break | |
| if last_user_text: | |
| triggered, override_reply, reason = safety_triage(last_user_text) | |
| if triggered and override_reply is not None: | |
| return ( | |
| override_reply, | |
| { | |
| "triage_triggered": True, | |
| "triage_reason": reason, | |
| }, | |
| [], # No retrieved documents | |
| 0, | |
| ) | |
| if self.model_type == "champ": | |
| result = self.agent.invoke({"messages": list(lc_messages)}) # type: ignore | |
| retrieved_passages = ( | |
| self.context_store["last_retrieved_docs"] | |
| if self.context_store is not None | |
| else [] | |
| ) | |
| output_message = result["messages"][-1] # pyright: ignore[reportCallIssue, reportArgumentType] | |
| return ( | |
| output_message.text.strip(), | |
| { | |
| "triage_triggered": False, | |
| }, | |
| retrieved_passages, | |
| # output_message.usage_metadata["output_tokens"], This value is inaccurate because Champ is an agent. We use tiktoken instead to estimate the number of output tokens. | |
| 0, | |
| ) | |
| elif self.model_type == "qwen": | |
| chat_response, retrieved_passages, output_tokens = self.agent.invoke( | |
| list(lc_messages) # type: ignore | |
| ) | |
| return ( | |
| chat_response, | |
| { | |
| "triage_triggered": False, | |
| }, | |
| retrieved_passages, | |
| output_tokens, | |
| ) # pyright: ignore[reportReturnType] | |
| raise ValueError(f"Invalid model type (should never happen): {self.model_type}") | |