KheemDH's picture
Updated
62e9f37 verified
# src/agentic_multiwriter/models/llm_client.py
import os
import logging
from typing import Optional
from huggingface_hub import InferenceClient
from langchain_core.messages import SystemMessage, HumanMessage
try:
# Modern LangChain + OpenAI
from langchain_openai import ChatOpenAI
except ImportError: # Fallback for older setups
try:
from langchain.chat_models import ChatOpenAI # type: ignore
except ImportError:
ChatOpenAI = None # type: ignore
try:
from langchain_ollama import ChatOllama
except ImportError:
ChatOllama = None # type: ignore
logger = logging.getLogger(__name__)
class LLMClient:
"""
Unified LLM client.
Providers:
- openai -> ChatOpenAI (gpt-4o-mini, etc.)
- ollama -> Local Ollama server (not used on HF Spaces)
- hf_endpoint -> Hugging Face Inference API (backup / optional)
Defaults:
AMW_LLM_PROVIDER = "openai"
AMW_LLM_MODEL = "gpt-4o-mini"
AMW_TEMPERATURE = 0.3
"""
def __init__(
self,
provider: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
) -> None:
# ---------- Resolve configuration ----------
self.provider = (provider or os.getenv("AMW_LLM_PROVIDER", "openai")).lower()
self.temperature = float(temperature or os.getenv("AMW_TEMPERATURE", "0.3"))
if model is not None:
self.model = model
else:
if self.provider == "openai":
self.model = os.getenv("AMW_LLM_MODEL", "gpt-4o-mini")
elif self.provider == "ollama":
self.model = os.getenv("AMW_LLM_MODEL", "llama3")
elif self.provider == "hf_endpoint":
# Only used if you deliberately switch to HF Inference
self.model = os.getenv("AMW_LLM_MODEL", "gpt2")
else:
raise ValueError(f"Unknown LLM provider: {self.provider}")
logger.info(
"LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
self.provider,
self.model,
self.temperature,
)
# ---------- Initialize backend client ----------
if self.provider == "openai":
self._init_openai_client()
elif self.provider == "ollama":
self._init_ollama_client()
elif self.provider == "hf_endpoint":
self._init_hf_client()
else:
raise ValueError(f"Unsupported provider: {self.provider}")
# ------------------------------------------------------------------
# Provider initializers
# ------------------------------------------------------------------
def _init_openai_client(self) -> None:
if ChatOpenAI is None:
raise RuntimeError(
"ChatOpenAI could not be imported. Make sure 'langchain-openai' "
"is installed (e.g., `pip install langchain-openai`)."
)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.warning(
"OPENAI_API_KEY is not set; OpenAI calls will fail until it is configured."
)
# ChatOpenAI reads OPENAI_API_KEY from the environment by default.
self._client = ChatOpenAI(
model=self.model,
temperature=self.temperature,
# Do NOT pass the key explicitly – let it read from env
# api_key=api_key # (optional if you want to be explicit)
)
def _init_ollama_client(self) -> None:
if ChatOllama is None:
raise RuntimeError(
"langchain_ollama is not installed, but provider='ollama' was selected."
)
self._client = ChatOllama(
model=self.model,
temperature=self.temperature,
)
def _init_hf_client(self) -> None:
"""
Optional: Hugging Face Inference client (not used if you stay on OpenAI).
Uses HUGGINGFACEHUB_API_TOKEN from env, which is automatically set
inside your own Space if you define it as a secret.
"""
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not hf_token:
logger.warning(
"HUGGINGFACEHUB_API_TOKEN is not set. HF Inference calls will fail "
"unless the environment injects the token (e.g., in a HF Space)."
)
self._client = InferenceClient(
model=self.model,
token=hf_token,
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def generate(self, system_prompt: str, user_prompt: str) -> str:
"""
Generate text from the configured model.
"""
if self.provider in ("openai", "ollama"):
return self._generate_chat_model(system_prompt, user_prompt)
elif self.provider == "hf_endpoint":
return self._generate_hf_text(system_prompt, user_prompt)
else:
raise ValueError(f"Unsupported provider in generate(): {self.provider}")
# ------------------------------------------------------------------
# OpenAI / Ollama (chat-style models via LangChain)
# ------------------------------------------------------------------
def _generate_chat_model(self, system_prompt: str, user_prompt: str) -> str:
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt),
]
resp = self._client.invoke(messages) # type: ignore[attr-defined]
# LangChain ChatModels usually return a ChatMessage with `.content`
text = getattr(resp, "content", None)
if not isinstance(text, str):
text = str(resp)
return text
# ------------------------------------------------------------------
# Hugging Face Inference (text-generation; optional)
# ------------------------------------------------------------------
def _generate_hf_text(self, system_prompt: str, user_prompt: str) -> str:
"""
Use Hugging Face Inference `text_generation`.
Only used if AMW_LLM_PROVIDER=hf_endpoint.
"""
prompt = (
f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
f"<<USER>>\n{user_prompt}\n<</USER>>\n\n"
"Assistant:"
)
try:
text = self._client.text_generation(
prompt,
max_new_tokens=512,
temperature=self.temperature,
do_sample=True,
top_p=0.9,
return_full_text=False,
)
except Exception as e: # noqa: BLE001
logger.error(
"Error while calling Hugging Face Inference API for model '%s': %s",
self.model,
e,
exc_info=True,
)
raise RuntimeError(
f"Hugging Face Inference error for model '{self.model}'. "
f"Ensure the model supports 'text-generation' and that your token "
f"has Inference permissions."
) from e
if isinstance(text, str):
return text
try:
return text.get("generated_text", str(text)) # type: ignore[arg-type]
except Exception: # noqa: BLE001
return str(text)