| """ |
| LLM client using HuggingFace Inference API. |
| Uses InferenceClient which supports streaming chat completions. |
| """ |
|
|
| from typing import Optional, Generator |
| from utils.dataset_io import get_hf_token |
| from utils.constants import DEFAULT_CHAT_MODEL |
|
|
|
|
| def chat_completion( |
| messages: list[dict], |
| system_prompt: Optional[str] = None, |
| model: str = DEFAULT_CHAT_MODEL, |
| max_tokens: int = 512, |
| temperature: float = 0.7, |
| stream: bool = False, |
| ) -> str | Generator: |
| """ |
| Send a chat completion request to the HF Inference API. |
| |
| Args: |
| messages: list of {role, content} dicts (user/assistant turns) |
| system_prompt: prepended as a system message if provided |
| model: HF model ID |
| max_tokens: maximum tokens to generate |
| temperature: sampling temperature |
| stream: if True, returns a generator for streaming output |
| |
| Returns: |
| Full response string (stream=False) or a generator (stream=True). |
| """ |
| try: |
| from huggingface_hub import InferenceClient |
| except ImportError: |
| err = "ERROR: huggingface_hub not installed. Run: pip install huggingface_hub" |
| return (x for x in [err]) if stream else err |
|
|
| token = get_hf_token() |
|
|
| |
| all_messages: list[dict] = [] |
| if system_prompt: |
| all_messages.append({"role": "system", "content": system_prompt}) |
| all_messages.extend(messages) |
|
|
| try: |
| client = InferenceClient(model=model, token=token) |
|
|
| if stream: |
| return client.chat_completion( |
| messages=all_messages, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stream=True, |
| ) |
| else: |
| response = client.chat_completion( |
| messages=all_messages, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stream=False, |
| ) |
| return response.choices[0].message.content or "" |
|
|
| except Exception as e: |
| err = f"LLM API error: {e}" |
| return (x for x in [err]) if stream else err |
|
|
|
|
| def stream_to_string(stream_gen) -> Generator[str, None, None]: |
| """ |
| Yield incremental text from a streaming chat_completion response. |
| Each yielded value is the full accumulated string so far (for Streamlit's write). |
| """ |
| full_text = "" |
| try: |
| for chunk in stream_gen: |
| if hasattr(chunk, "choices") and chunk.choices: |
| delta = chunk.choices[0].delta |
| if hasattr(delta, "content") and delta.content: |
| full_text += delta.content |
| yield full_text |
| except Exception as e: |
| full_text += f"\n\n[Stream error: {e}]" |
| yield full_text |
|
|