from __future__ import annotations import logging import sys import time from dataclasses import replace from pathlib import Path from langchain_core.messages import AIMessage, HumanMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from api.settings import ApiConfig from api.guardrails.pipeline import GuardrailPipeline from api.session_store import SessionMemoryStore from api.observability.inference_metrics import compute_tokens_per_sec from api.observability.metrics import METRICS_STORE, RequestMetric, estimate_cost_usd, now_iso from api.observability.telemetry import get_tracer from api.schemas import ChatResponse from api.agent_tools.router import route_tools from assistants.response_utils import clean_model_response, fix_wrong_identity from config import AppConfig from llm.factory import build_oss_llm from llm.streaming_invoke import count_input_tokens_for_prompt, stream_llm_response from llm.tokenizer_utils import count_text_tokens, get_llm_tokenizer logger = logging.getLogger(__name__) class ChatService: def __init__(self, api_config: ApiConfig, app_config: AppConfig) -> None: self.api_config = api_config self.app_config = app_config self.guardrails = GuardrailPipeline(api_config.guardrails) self.memory = SessionMemoryStore(max_turns=api_config.max_history_turns) self._llm = None self._prompt = None self._parser = StrOutputParser() def startup(self) -> None: if self._llm is not None: return self._llm = build_oss_llm(self.app_config.oss) system_prompt = ( f"{self.app_config.oss_system_prompt}\n\n" "You may receive tool results in the user message. Use them when helpful. " "Keep answers concise." ) self._prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder("history"), ("human", "{input}"), ] ) logger.warning("OSS chat engine loaded model=%s", self.app_config.oss.model_id) @property def model_id(self) -> str: return self.app_config.oss.model_id def reset_session(self, session_id: str) -> None: self.memory.clear(session_id) def _record_inference_span_attrs(self, span, inference) -> None: span.set_attribute("inference.latency_ms", inference.latency_ms) span.set_attribute("inference.input_tokens", inference.input_tokens) span.set_attribute("inference.output_tokens", inference.output_tokens) if inference.ttft_ms is not None: span.set_attribute("inference.ttft_ms", inference.ttft_ms) if inference.tbt_ms is not None: span.set_attribute("inference.tbt_ms", inference.tbt_ms) if inference.tokens_per_sec is not None: span.set_attribute("inference.tokens_per_sec", inference.tokens_per_sec) span.set_attribute("inference.stream_chunks", inference.stream_chunks) def chat(self, message: str, session_id: str) -> ChatResponse: self.startup() assert self._llm is not None assert self._prompt is not None tracer = get_tracer("ollive.api.chat") start = time.perf_counter() guardrail_blocks: list[str] = [] with tracer.start_as_current_span("chat.request") as span: span.set_attribute("session.id", session_id) span.set_attribute("input.length", len(message)) input_decision = self.guardrails.check_input(message) if not input_decision.allowed: guardrail_blocks.extend(input_decision.blocked_layers) span.set_attribute("guardrail.blocks", ",".join(guardrail_blocks)) span.set_attribute("guardrail.blocked", True) latency_ms = (time.perf_counter() - start) * 1000 tokens, cost = estimate_cost_usd( latency_ms, len(message), self.api_config.cost.cpu_hour_usd, self.api_config.cost.estimated_tokens_per_char, ) METRICS_STORE.record( RequestMetric( timestamp=now_iso(), session_id=session_id, latency_ms=latency_ms, input_chars=len(message), output_chars=len(input_decision.response_text), estimated_tokens=tokens, estimated_cost_usd=cost, guardrail_blocks=guardrail_blocks, ) ) return ChatResponse( session_id=session_id, response=input_decision.response_text, model=self.model_id, latency_ms=latency_ms, estimated_tokens=tokens, estimated_cost_usd=cost, tools_used=[], guardrail_blocks=guardrail_blocks, ) tool_route = route_tools(message) model_input = f"{tool_route.context_prefix}{message}" if tool_route.context_prefix else message history = self.memory.get_history(session_id) input_tokens = count_input_tokens_for_prompt( self._llm, self._prompt, model_input=model_input, history=history.messages, ) prompt_messages = self._prompt.format_messages( input=model_input, history=history.messages, ) raw, inference = stream_llm_response( self._llm, prompt_messages, input_tokens=input_tokens, ) text = clean_model_response(self._parser.invoke(raw)) if not text and raw: text = str(raw).strip() text = fix_wrong_identity(text) tokenizer = get_llm_tokenizer(self._llm) final_output_tokens = count_text_tokens(tokenizer, text) if final_output_tokens != inference.output_tokens: inference = replace( inference, output_tokens=final_output_tokens, tokens_per_sec=compute_tokens_per_sec( final_output_tokens, inference.latency_ms, inference.ttft_ms, ), ) output_decision = self.guardrails.check_output(text) if not output_decision.allowed: guardrail_blocks.extend(output_decision.blocked_layers) text = output_decision.response_text history.add_message(HumanMessage(content=message)) history.add_message(AIMessage(content=text)) latency_ms = (time.perf_counter() - start) * 1000 total_chars = len(message) + len(text) tokens, cost = estimate_cost_usd( latency_ms, total_chars, self.api_config.cost.cpu_hour_usd, self.api_config.cost.estimated_tokens_per_char, ) METRICS_STORE.record( RequestMetric( timestamp=now_iso(), session_id=session_id, latency_ms=latency_ms, input_chars=len(message), output_chars=len(text), estimated_tokens=tokens, estimated_cost_usd=cost, guardrail_blocks=guardrail_blocks, tools_used=tool_route.tools_used, inference_latency_ms=inference.latency_ms, ttft_ms=inference.ttft_ms, tbt_ms=inference.tbt_ms, input_tokens=inference.input_tokens, output_tokens=inference.output_tokens, tokens_per_sec=inference.tokens_per_sec, ) ) span.set_attribute("latency_ms", latency_ms) span.set_attribute("tools.count", len(tool_route.tools_used)) span.set_attribute("output.length", len(text)) self._record_inference_span_attrs(span, inference) if guardrail_blocks: span.set_attribute("guardrail.blocks", ",".join(guardrail_blocks)) span.set_attribute("guardrail.blocked", True) if tool_route.tools_used: span.set_attribute("tools.used", ",".join(tool_route.tools_used)) return ChatResponse( session_id=session_id, response=text, model=self.model_id, latency_ms=round(latency_ms, 1), estimated_tokens=tokens, estimated_cost_usd=cost, tools_used=tool_route.tools_used, guardrail_blocks=guardrail_blocks, )