Spaces:
Sleeping
Sleeping
| # src/agentic_multiwriter/models/llm_client.py | |
| import os | |
| import logging | |
| from typing import Optional | |
| from huggingface_hub import InferenceClient | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| try: | |
| # Modern LangChain + OpenAI | |
| from langchain_openai import ChatOpenAI | |
| except ImportError: # Fallback for older setups | |
| try: | |
| from langchain.chat_models import ChatOpenAI # type: ignore | |
| except ImportError: | |
| ChatOpenAI = None # type: ignore | |
| try: | |
| from langchain_ollama import ChatOllama | |
| except ImportError: | |
| ChatOllama = None # type: ignore | |
| logger = logging.getLogger(__name__) | |
| class LLMClient: | |
| """ | |
| Unified LLM client. | |
| Providers: | |
| - openai -> ChatOpenAI (gpt-4o-mini, etc.) | |
| - ollama -> Local Ollama server (not used on HF Spaces) | |
| - hf_endpoint -> Hugging Face Inference API (backup / optional) | |
| Defaults: | |
| AMW_LLM_PROVIDER = "openai" | |
| AMW_LLM_MODEL = "gpt-4o-mini" | |
| AMW_TEMPERATURE = 0.3 | |
| """ | |
| def __init__( | |
| self, | |
| provider: Optional[str] = None, | |
| model: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| ) -> None: | |
| # ---------- Resolve configuration ---------- | |
| self.provider = (provider or os.getenv("AMW_LLM_PROVIDER", "openai")).lower() | |
| self.temperature = float(temperature or os.getenv("AMW_TEMPERATURE", "0.3")) | |
| if model is not None: | |
| self.model = model | |
| else: | |
| if self.provider == "openai": | |
| self.model = os.getenv("AMW_LLM_MODEL", "gpt-4o-mini") | |
| elif self.provider == "ollama": | |
| self.model = os.getenv("AMW_LLM_MODEL", "llama3") | |
| elif self.provider == "hf_endpoint": | |
| # Only used if you deliberately switch to HF Inference | |
| self.model = os.getenv("AMW_LLM_MODEL", "gpt2") | |
| else: | |
| raise ValueError(f"Unknown LLM provider: {self.provider}") | |
| logger.info( | |
| "LLMClient initialized with provider='%s', model='%s', temperature=%.2f", | |
| self.provider, | |
| self.model, | |
| self.temperature, | |
| ) | |
| # ---------- Initialize backend client ---------- | |
| if self.provider == "openai": | |
| self._init_openai_client() | |
| elif self.provider == "ollama": | |
| self._init_ollama_client() | |
| elif self.provider == "hf_endpoint": | |
| self._init_hf_client() | |
| else: | |
| raise ValueError(f"Unsupported provider: {self.provider}") | |
| # ------------------------------------------------------------------ | |
| # Provider initializers | |
| # ------------------------------------------------------------------ | |
| def _init_openai_client(self) -> None: | |
| if ChatOpenAI is None: | |
| raise RuntimeError( | |
| "ChatOpenAI could not be imported. Make sure 'langchain-openai' " | |
| "is installed (e.g., `pip install langchain-openai`)." | |
| ) | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| logger.warning( | |
| "OPENAI_API_KEY is not set; OpenAI calls will fail until it is configured." | |
| ) | |
| # ChatOpenAI reads OPENAI_API_KEY from the environment by default. | |
| self._client = ChatOpenAI( | |
| model=self.model, | |
| temperature=self.temperature, | |
| # Do NOT pass the key explicitly – let it read from env | |
| # api_key=api_key # (optional if you want to be explicit) | |
| ) | |
| def _init_ollama_client(self) -> None: | |
| if ChatOllama is None: | |
| raise RuntimeError( | |
| "langchain_ollama is not installed, but provider='ollama' was selected." | |
| ) | |
| self._client = ChatOllama( | |
| model=self.model, | |
| temperature=self.temperature, | |
| ) | |
| def _init_hf_client(self) -> None: | |
| """ | |
| Optional: Hugging Face Inference client (not used if you stay on OpenAI). | |
| Uses HUGGINGFACEHUB_API_TOKEN from env, which is automatically set | |
| inside your own Space if you define it as a secret. | |
| """ | |
| hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not hf_token: | |
| logger.warning( | |
| "HUGGINGFACEHUB_API_TOKEN is not set. HF Inference calls will fail " | |
| "unless the environment injects the token (e.g., in a HF Space)." | |
| ) | |
| self._client = InferenceClient( | |
| model=self.model, | |
| token=hf_token, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def generate(self, system_prompt: str, user_prompt: str) -> str: | |
| """ | |
| Generate text from the configured model. | |
| """ | |
| if self.provider in ("openai", "ollama"): | |
| return self._generate_chat_model(system_prompt, user_prompt) | |
| elif self.provider == "hf_endpoint": | |
| return self._generate_hf_text(system_prompt, user_prompt) | |
| else: | |
| raise ValueError(f"Unsupported provider in generate(): {self.provider}") | |
| # ------------------------------------------------------------------ | |
| # OpenAI / Ollama (chat-style models via LangChain) | |
| # ------------------------------------------------------------------ | |
| def _generate_chat_model(self, system_prompt: str, user_prompt: str) -> str: | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=user_prompt), | |
| ] | |
| resp = self._client.invoke(messages) # type: ignore[attr-defined] | |
| # LangChain ChatModels usually return a ChatMessage with `.content` | |
| text = getattr(resp, "content", None) | |
| if not isinstance(text, str): | |
| text = str(resp) | |
| return text | |
| # ------------------------------------------------------------------ | |
| # Hugging Face Inference (text-generation; optional) | |
| # ------------------------------------------------------------------ | |
| def _generate_hf_text(self, system_prompt: str, user_prompt: str) -> str: | |
| """ | |
| Use Hugging Face Inference `text_generation`. | |
| Only used if AMW_LLM_PROVIDER=hf_endpoint. | |
| """ | |
| prompt = ( | |
| f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n" | |
| f"<<USER>>\n{user_prompt}\n<</USER>>\n\n" | |
| "Assistant:" | |
| ) | |
| try: | |
| text = self._client.text_generation( | |
| prompt, | |
| max_new_tokens=512, | |
| temperature=self.temperature, | |
| do_sample=True, | |
| top_p=0.9, | |
| return_full_text=False, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| logger.error( | |
| "Error while calling Hugging Face Inference API for model '%s': %s", | |
| self.model, | |
| e, | |
| exc_info=True, | |
| ) | |
| raise RuntimeError( | |
| f"Hugging Face Inference error for model '{self.model}'. " | |
| f"Ensure the model supports 'text-generation' and that your token " | |
| f"has Inference permissions." | |
| ) from e | |
| if isinstance(text, str): | |
| return text | |
| try: | |
| return text.get("generated_text", str(text)) # type: ignore[arg-type] | |
| except Exception: # noqa: BLE001 | |
| return str(text) | |