ollive-api / llm /streaming_invoke.py
Karthik Namboori
Deploy ollive FastAPI Docker Space
7b4b748
from __future__ import annotations
import time
from typing import Any
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage
from api.observability.inference_metrics import InferenceMetrics, compute_tokens_per_sec
from llm.tokenizer_utils import count_prompt_tokens, count_text_tokens, get_llm_tokenizer
def _chunk_text(chunk: Any) -> str:
if isinstance(chunk, AIMessageChunk):
content = chunk.content
return content if isinstance(content, str) else str(content or "")
if hasattr(chunk, "content"):
content = chunk.content
return content if isinstance(content, str) else str(content or "")
return str(chunk)
def stream_llm_response(
llm: BaseChatModel,
messages: list[BaseMessage],
*,
input_tokens: int,
) -> tuple[str, InferenceMetrics]:
"""Stream model output and compute TTFT/TBT from real chunk arrival times."""
inference_start = time.perf_counter()
chunk_times: list[float] = []
parts: list[str] = []
try:
stream = llm.stream(messages)
except (NotImplementedError, TypeError, AttributeError):
stream = None
if stream is not None:
for chunk in stream:
text = _chunk_text(chunk)
if not text:
continue
chunk_times.append(time.perf_counter())
parts.append(text)
if not parts:
batch_start = time.perf_counter()
batch = llm.invoke(messages)
response_text = _chunk_text(batch)
inference_end = time.perf_counter()
latency_ms = (inference_end - batch_start) * 1000
tokenizer = get_llm_tokenizer(llm)
output_tokens = count_text_tokens(tokenizer, response_text)
latency_rounded = round(latency_ms, 2)
ttft_rounded = round(latency_ms, 2)
return response_text, InferenceMetrics(
latency_ms=latency_rounded,
ttft_ms=ttft_rounded,
tbt_ms=None,
input_tokens=input_tokens,
output_tokens=output_tokens,
stream_chunks=1 if response_text else 0,
tokens_per_sec=compute_tokens_per_sec(output_tokens, latency_rounded, ttft_rounded),
)
inference_end = time.perf_counter()
response_text = "".join(parts)
latency_ms = (inference_end - inference_start) * 1000
ttft_ms: float | None = None
tbt_ms: float | None = None
if chunk_times:
ttft_ms = (chunk_times[0] - inference_start) * 1000
if len(chunk_times) >= 2:
gaps_ms = [(chunk_times[i] - chunk_times[i - 1]) * 1000 for i in range(1, len(chunk_times))]
tbt_ms = sum(gaps_ms) / len(gaps_ms)
tokenizer = get_llm_tokenizer(llm)
output_tokens = count_text_tokens(tokenizer, response_text)
latency_rounded = round(latency_ms, 2)
ttft_rounded = round(ttft_ms, 2) if ttft_ms is not None else None
tbt_rounded = round(tbt_ms, 2) if tbt_ms is not None else None
return response_text, InferenceMetrics(
latency_ms=latency_rounded,
ttft_ms=ttft_rounded,
tbt_ms=tbt_rounded,
input_tokens=input_tokens,
output_tokens=output_tokens,
stream_chunks=len(chunk_times),
tokens_per_sec=compute_tokens_per_sec(output_tokens, latency_rounded, ttft_rounded),
)
def count_input_tokens_for_prompt(
llm: BaseChatModel,
prompt,
*,
model_input: str,
history: list[BaseMessage],
) -> int:
return count_prompt_tokens(llm, prompt, model_input=model_input, history=history)