Spaces:
Running
Running
| from __future__ import annotations | |
| import time | |
| from typing import Any | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.messages import AIMessageChunk, BaseMessage | |
| from api.observability.inference_metrics import InferenceMetrics, compute_tokens_per_sec | |
| from llm.tokenizer_utils import count_prompt_tokens, count_text_tokens, get_llm_tokenizer | |
| def _chunk_text(chunk: Any) -> str: | |
| if isinstance(chunk, AIMessageChunk): | |
| content = chunk.content | |
| return content if isinstance(content, str) else str(content or "") | |
| if hasattr(chunk, "content"): | |
| content = chunk.content | |
| return content if isinstance(content, str) else str(content or "") | |
| return str(chunk) | |
| def stream_llm_response( | |
| llm: BaseChatModel, | |
| messages: list[BaseMessage], | |
| *, | |
| input_tokens: int, | |
| ) -> tuple[str, InferenceMetrics]: | |
| """Stream model output and compute TTFT/TBT from real chunk arrival times.""" | |
| inference_start = time.perf_counter() | |
| chunk_times: list[float] = [] | |
| parts: list[str] = [] | |
| try: | |
| stream = llm.stream(messages) | |
| except (NotImplementedError, TypeError, AttributeError): | |
| stream = None | |
| if stream is not None: | |
| for chunk in stream: | |
| text = _chunk_text(chunk) | |
| if not text: | |
| continue | |
| chunk_times.append(time.perf_counter()) | |
| parts.append(text) | |
| if not parts: | |
| batch_start = time.perf_counter() | |
| batch = llm.invoke(messages) | |
| response_text = _chunk_text(batch) | |
| inference_end = time.perf_counter() | |
| latency_ms = (inference_end - batch_start) * 1000 | |
| tokenizer = get_llm_tokenizer(llm) | |
| output_tokens = count_text_tokens(tokenizer, response_text) | |
| latency_rounded = round(latency_ms, 2) | |
| ttft_rounded = round(latency_ms, 2) | |
| return response_text, InferenceMetrics( | |
| latency_ms=latency_rounded, | |
| ttft_ms=ttft_rounded, | |
| tbt_ms=None, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| stream_chunks=1 if response_text else 0, | |
| tokens_per_sec=compute_tokens_per_sec(output_tokens, latency_rounded, ttft_rounded), | |
| ) | |
| inference_end = time.perf_counter() | |
| response_text = "".join(parts) | |
| latency_ms = (inference_end - inference_start) * 1000 | |
| ttft_ms: float | None = None | |
| tbt_ms: float | None = None | |
| if chunk_times: | |
| ttft_ms = (chunk_times[0] - inference_start) * 1000 | |
| if len(chunk_times) >= 2: | |
| gaps_ms = [(chunk_times[i] - chunk_times[i - 1]) * 1000 for i in range(1, len(chunk_times))] | |
| tbt_ms = sum(gaps_ms) / len(gaps_ms) | |
| tokenizer = get_llm_tokenizer(llm) | |
| output_tokens = count_text_tokens(tokenizer, response_text) | |
| latency_rounded = round(latency_ms, 2) | |
| ttft_rounded = round(ttft_ms, 2) if ttft_ms is not None else None | |
| tbt_rounded = round(tbt_ms, 2) if tbt_ms is not None else None | |
| return response_text, InferenceMetrics( | |
| latency_ms=latency_rounded, | |
| ttft_ms=ttft_rounded, | |
| tbt_ms=tbt_rounded, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| stream_chunks=len(chunk_times), | |
| tokens_per_sec=compute_tokens_per_sec(output_tokens, latency_rounded, ttft_rounded), | |
| ) | |
| def count_input_tokens_for_prompt( | |
| llm: BaseChatModel, | |
| prompt, | |
| *, | |
| model_input: str, | |
| history: list[BaseMessage], | |
| ) -> int: | |
| return count_prompt_tokens(llm, prompt, model_input=model_input, history=history) | |