llm-annotation-platform / llm_client.py
GitHub Actions
Sync from GitHub
0be7900
"""
LLM client using HuggingFace Inference API.
Uses InferenceClient which supports streaming chat completions.
"""
from typing import Optional, Generator
from utils.dataset_io import get_hf_token
from utils.constants import DEFAULT_CHAT_MODEL
def chat_completion(
messages: list[dict],
system_prompt: Optional[str] = None,
model: str = DEFAULT_CHAT_MODEL,
max_tokens: int = 512,
temperature: float = 0.7,
stream: bool = False,
) -> str | Generator:
"""
Send a chat completion request to the HF Inference API.
Args:
messages: list of {role, content} dicts (user/assistant turns)
system_prompt: prepended as a system message if provided
model: HF model ID
max_tokens: maximum tokens to generate
temperature: sampling temperature
stream: if True, returns a generator for streaming output
Returns:
Full response string (stream=False) or a generator (stream=True).
"""
try:
from huggingface_hub import InferenceClient # noqa: PLC0415
except ImportError:
err = "ERROR: huggingface_hub not installed. Run: pip install huggingface_hub"
return (x for x in [err]) if stream else err
token = get_hf_token()
# Build final message list
all_messages: list[dict] = []
if system_prompt:
all_messages.append({"role": "system", "content": system_prompt})
all_messages.extend(messages)
try:
client = InferenceClient(model=model, token=token)
if stream:
return client.chat_completion(
messages=all_messages,
max_tokens=max_tokens,
temperature=temperature,
stream=True,
)
else:
response = client.chat_completion(
messages=all_messages,
max_tokens=max_tokens,
temperature=temperature,
stream=False,
)
return response.choices[0].message.content or ""
except Exception as e:
err = f"LLM API error: {e}"
return (x for x in [err]) if stream else err
def stream_to_string(stream_gen) -> Generator[str, None, None]:
"""
Yield incremental text from a streaming chat_completion response.
Each yielded value is the full accumulated string so far (for Streamlit's write).
"""
full_text = ""
try:
for chunk in stream_gen:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, "content") and delta.content:
full_text += delta.content
yield full_text
except Exception as e:
full_text += f"\n\n[Stream error: {e}]"
yield full_text