# src/agentic_multiwriter/models/llm_client.py import os import logging from typing import Optional from huggingface_hub import InferenceClient from langchain_core.messages import SystemMessage, HumanMessage try: # Modern LangChain + OpenAI from langchain_openai import ChatOpenAI except ImportError: # Fallback for older setups try: from langchain.chat_models import ChatOpenAI # type: ignore except ImportError: ChatOpenAI = None # type: ignore try: from langchain_ollama import ChatOllama except ImportError: ChatOllama = None # type: ignore logger = logging.getLogger(__name__) class LLMClient: """ Unified LLM client. Providers: - openai -> ChatOpenAI (gpt-4o-mini, etc.) - ollama -> Local Ollama server (not used on HF Spaces) - hf_endpoint -> Hugging Face Inference API (backup / optional) Defaults: AMW_LLM_PROVIDER = "openai" AMW_LLM_MODEL = "gpt-4o-mini" AMW_TEMPERATURE = 0.3 """ def __init__( self, provider: Optional[str] = None, model: Optional[str] = None, temperature: Optional[float] = None, ) -> None: # ---------- Resolve configuration ---------- self.provider = (provider or os.getenv("AMW_LLM_PROVIDER", "openai")).lower() self.temperature = float(temperature or os.getenv("AMW_TEMPERATURE", "0.3")) if model is not None: self.model = model else: if self.provider == "openai": self.model = os.getenv("AMW_LLM_MODEL", "gpt-4o-mini") elif self.provider == "ollama": self.model = os.getenv("AMW_LLM_MODEL", "llama3") elif self.provider == "hf_endpoint": # Only used if you deliberately switch to HF Inference self.model = os.getenv("AMW_LLM_MODEL", "gpt2") else: raise ValueError(f"Unknown LLM provider: {self.provider}") logger.info( "LLMClient initialized with provider='%s', model='%s', temperature=%.2f", self.provider, self.model, self.temperature, ) # ---------- Initialize backend client ---------- if self.provider == "openai": self._init_openai_client() elif self.provider == "ollama": self._init_ollama_client() elif self.provider == "hf_endpoint": self._init_hf_client() else: raise ValueError(f"Unsupported provider: {self.provider}") # ------------------------------------------------------------------ # Provider initializers # ------------------------------------------------------------------ def _init_openai_client(self) -> None: if ChatOpenAI is None: raise RuntimeError( "ChatOpenAI could not be imported. Make sure 'langchain-openai' " "is installed (e.g., `pip install langchain-openai`)." ) api_key = os.getenv("OPENAI_API_KEY") if not api_key: logger.warning( "OPENAI_API_KEY is not set; OpenAI calls will fail until it is configured." ) # ChatOpenAI reads OPENAI_API_KEY from the environment by default. self._client = ChatOpenAI( model=self.model, temperature=self.temperature, # Do NOT pass the key explicitly – let it read from env # api_key=api_key # (optional if you want to be explicit) ) def _init_ollama_client(self) -> None: if ChatOllama is None: raise RuntimeError( "langchain_ollama is not installed, but provider='ollama' was selected." ) self._client = ChatOllama( model=self.model, temperature=self.temperature, ) def _init_hf_client(self) -> None: """ Optional: Hugging Face Inference client (not used if you stay on OpenAI). Uses HUGGINGFACEHUB_API_TOKEN from env, which is automatically set inside your own Space if you define it as a secret. """ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") if not hf_token: logger.warning( "HUGGINGFACEHUB_API_TOKEN is not set. HF Inference calls will fail " "unless the environment injects the token (e.g., in a HF Space)." ) self._client = InferenceClient( model=self.model, token=hf_token, ) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def generate(self, system_prompt: str, user_prompt: str) -> str: """ Generate text from the configured model. """ if self.provider in ("openai", "ollama"): return self._generate_chat_model(system_prompt, user_prompt) elif self.provider == "hf_endpoint": return self._generate_hf_text(system_prompt, user_prompt) else: raise ValueError(f"Unsupported provider in generate(): {self.provider}") # ------------------------------------------------------------------ # OpenAI / Ollama (chat-style models via LangChain) # ------------------------------------------------------------------ def _generate_chat_model(self, system_prompt: str, user_prompt: str) -> str: messages = [ SystemMessage(content=system_prompt), HumanMessage(content=user_prompt), ] resp = self._client.invoke(messages) # type: ignore[attr-defined] # LangChain ChatModels usually return a ChatMessage with `.content` text = getattr(resp, "content", None) if not isinstance(text, str): text = str(resp) return text # ------------------------------------------------------------------ # Hugging Face Inference (text-generation; optional) # ------------------------------------------------------------------ def _generate_hf_text(self, system_prompt: str, user_prompt: str) -> str: """ Use Hugging Face Inference `text_generation`. Only used if AMW_LLM_PROVIDER=hf_endpoint. """ prompt = ( f"<>\n{system_prompt}\n<>\n\n" f"<>\n{user_prompt}\n<>\n\n" "Assistant:" ) try: text = self._client.text_generation( prompt, max_new_tokens=512, temperature=self.temperature, do_sample=True, top_p=0.9, return_full_text=False, ) except Exception as e: # noqa: BLE001 logger.error( "Error while calling Hugging Face Inference API for model '%s': %s", self.model, e, exc_info=True, ) raise RuntimeError( f"Hugging Face Inference error for model '{self.model}'. " f"Ensure the model supports 'text-generation' and that your token " f"has Inference permissions." ) from e if isinstance(text, str): return text try: return text.get("generated_text", str(text)) # type: ignore[arg-type] except Exception: # noqa: BLE001 return str(text)