Spaces:
Running
Running
Update LLM Client
Browse files
src/agentic_multiwriter/models/llm_client.py
CHANGED
|
@@ -1,77 +1,118 @@
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
from
|
| 6 |
-
from langchain_openai import ChatOpenAI
|
| 7 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
| 8 |
from langchain_community.llms import HuggingFaceEndpoint
|
|
|
|
| 9 |
|
| 10 |
-
from
|
| 11 |
-
from ..tools import get_logger
|
| 12 |
|
| 13 |
logger = get_logger()
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class LLMClient:
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
-
|
| 21 |
-
- hf_endpoint
|
|
|
|
| 22 |
"""
|
| 23 |
|
| 24 |
-
def __init__(
|
| 25 |
-
self
|
| 26 |
-
provider
|
| 27 |
-
model
|
| 28 |
-
temperature
|
| 29 |
-
) -> None:
|
| 30 |
-
self.provider = (provider or settings.llm_provider).lower()
|
| 31 |
-
self.model = model or settings.llm_model
|
| 32 |
-
self.temperature = temperature if temperature is not None else settings.temperature
|
| 33 |
-
|
| 34 |
-
settings.validate()
|
| 35 |
-
self._init_client()
|
| 36 |
-
logger.info(
|
| 37 |
-
"LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
|
| 38 |
-
self.provider,
|
| 39 |
-
self.model,
|
| 40 |
-
self.temperature,
|
| 41 |
-
)
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
self._client = ChatOllama(
|
| 46 |
-
model=
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
elif
|
| 51 |
-
# Uses
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
self._client = HuggingFaceEndpoint(
|
| 54 |
-
repo_id=
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
else:
|
| 58 |
-
raise ValueError(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
messages = [
|
| 63 |
-
|
| 64 |
-
|
| 65 |
]
|
| 66 |
response = self._client.invoke(messages)
|
| 67 |
-
return
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
# HuggingFaceEndpoint
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 1 |
+
# src/agentic_multiwriter/models/llm_client.py
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Literal
|
| 7 |
|
| 8 |
+
from langchain_community.chat_models import ChatOllama
|
|
|
|
|
|
|
| 9 |
from langchain_community.llms import HuggingFaceEndpoint
|
| 10 |
+
from langchain_openai import ChatOpenAI
|
| 11 |
|
| 12 |
+
from agentic_multiwriter.tools import get_logger
|
|
|
|
| 13 |
|
| 14 |
logger = get_logger()
|
| 15 |
|
| 16 |
|
| 17 |
+
@dataclass
|
| 18 |
+
class LLMSettings:
|
| 19 |
+
"""Configuration for the LLM backend."""
|
| 20 |
+
llm_provider: str = os.getenv("AMW_LLM_PROVIDER", "ollama").lower()
|
| 21 |
+
llm_model: str = os.getenv("AMW_LLM_MODEL", "llama3")
|
| 22 |
+
temperature: float = float(os.getenv("AMW_TEMPERATURE", "0.4"))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
class LLMClient:
|
| 26 |
"""
|
| 27 |
+
Small wrapper around different LLM backends:
|
| 28 |
+
|
| 29 |
+
- provider="ollama" -> local ChatOllama
|
| 30 |
+
- provider="hf_endpoint" -> Hugging Face Inference endpoint
|
| 31 |
+
- provider="openai" -> OpenAI Chat model
|
| 32 |
"""
|
| 33 |
|
| 34 |
+
def __init__(self, settings: LLMSettings | None = None) -> None:
|
| 35 |
+
self.settings = settings or LLMSettings()
|
| 36 |
+
provider = self.settings.llm_provider
|
| 37 |
+
model = self.settings.llm_model
|
| 38 |
+
temperature = self.settings.temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
if provider == "ollama":
|
| 41 |
+
self._mode: Literal["ollama", "hf_endpoint", "openai"] = "ollama"
|
| 42 |
self._client = ChatOllama(
|
| 43 |
+
model=model,
|
| 44 |
+
temperature=temperature,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
elif provider == "hf_endpoint":
|
| 48 |
+
# Uses HuggingFaceHosted Inference API (text-generation)
|
| 49 |
+
token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 50 |
+
if not token:
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
"HUGGINGFACEHUB_API_TOKEN is required when AMW_LLM_PROVIDER=hf_endpoint"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self._mode = "hf_endpoint"
|
| 56 |
+
# IMPORTANT: do NOT pass our own client object here; let
|
| 57 |
+
# HuggingFaceEndpoint build the correct internal client.
|
| 58 |
self._client = HuggingFaceEndpoint(
|
| 59 |
+
repo_id=model,
|
| 60 |
+
huggingfacehub_api_token=token,
|
| 61 |
+
temperature=temperature,
|
| 62 |
+
max_new_tokens=800,
|
| 63 |
)
|
| 64 |
+
|
| 65 |
+
elif provider == "openai":
|
| 66 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 67 |
+
if not api_key:
|
| 68 |
+
raise RuntimeError(
|
| 69 |
+
"OPENAI_API_KEY is required when AMW_LLM_PROVIDER=openai"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
self._mode = "openai"
|
| 73 |
+
self._client = ChatOpenAI(
|
| 74 |
+
model=model,
|
| 75 |
+
temperature=temperature,
|
| 76 |
+
api_key=api_key,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
else:
|
| 80 |
+
raise ValueError(f"Unknown AMW_LLM_PROVIDER='{provider}'")
|
| 81 |
+
|
| 82 |
+
logger.info(
|
| 83 |
+
"LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
|
| 84 |
+
provider,
|
| 85 |
+
model,
|
| 86 |
+
temperature,
|
| 87 |
+
)
|
| 88 |
|
| 89 |
+
def generate(self, *, system_prompt: str, user_prompt: str) -> str:
|
| 90 |
+
"""
|
| 91 |
+
Unified generate() interface for all providers.
|
| 92 |
+
"""
|
| 93 |
+
system_prompt = system_prompt.strip()
|
| 94 |
+
user_prompt = user_prompt.strip()
|
| 95 |
+
|
| 96 |
+
if self._mode in ("ollama", "openai"):
|
| 97 |
+
# Chat-style models (ChatOllama / ChatOpenAI)
|
| 98 |
messages = [
|
| 99 |
+
("system", system_prompt),
|
| 100 |
+
("user", user_prompt),
|
| 101 |
]
|
| 102 |
response = self._client.invoke(messages)
|
| 103 |
+
# Both ChatOllama and ChatOpenAI return an object with `.content`
|
| 104 |
+
return getattr(response, "content", str(response))
|
| 105 |
|
| 106 |
+
elif self._mode == "hf_endpoint":
|
| 107 |
+
# HuggingFaceEndpoint expects a single text prompt
|
| 108 |
+
prompt = (
|
| 109 |
+
f"{system_prompt}\n\n"
|
| 110 |
+
f"User:\n{user_prompt}\n\n"
|
| 111 |
+
f"Assistant:"
|
| 112 |
+
)
|
| 113 |
+
text = self._client.invoke(prompt)
|
| 114 |
+
# HuggingFaceEndpoint typically returns raw text
|
| 115 |
+
return text.strip() if isinstance(text, str) else str(text)
|
| 116 |
|
| 117 |
+
else:
|
| 118 |
+
raise RuntimeError("Unsupported LLM provider mode")
|