KheemDH commited on
Commit
cafd75d
·
verified ·
1 Parent(s): ffe08d3

Update LLM Client

Browse files
src/agentic_multiwriter/models/llm_client.py CHANGED
@@ -1,77 +1,118 @@
 
1
  from __future__ import annotations
2
 
3
- from typing import Optional
 
 
4
 
5
- from langchain_ollama import ChatOllama
6
- from langchain_openai import ChatOpenAI
7
- from langchain_core.messages import SystemMessage, HumanMessage
8
  from langchain_community.llms import HuggingFaceEndpoint
 
9
 
10
- from ..config import settings
11
- from ..tools import get_logger
12
 
13
  logger = get_logger()
14
 
15
 
 
 
 
 
 
 
 
 
16
  class LLMClient:
17
  """
18
- Wrapper to abstract over multiple LLM providers:
19
- - ollama: local Ollama server
20
- - openai: OpenAI Chat API
21
- - hf_endpoint: Hugging Face Inference Endpoint / hosted model
 
22
  """
23
 
24
- def __init__(
25
- self,
26
- provider: Optional[str] = None,
27
- model: Optional[str] = None,
28
- temperature: Optional[float] = None,
29
- ) -> None:
30
- self.provider = (provider or settings.llm_provider).lower()
31
- self.model = model or settings.llm_model
32
- self.temperature = temperature if temperature is not None else settings.temperature
33
-
34
- settings.validate()
35
- self._init_client()
36
- logger.info(
37
- "LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
38
- self.provider,
39
- self.model,
40
- self.temperature,
41
- )
42
 
43
- def _init_client(self) -> None:
44
- if self.provider == "ollama":
45
  self._client = ChatOllama(
46
- model=self.model, temperature=self.temperature)
47
- elif self.provider == "openai":
48
- self._client = ChatOpenAI(
49
- model=self.model, temperature=self.temperature)
50
- elif self.provider == "hf_endpoint":
51
- # Uses Hugging Face hosted model via Inference API
52
- # Model id should be a HF repo id, e.g. "HuggingFaceH4/zephyr-7b-beta"
 
 
 
 
 
 
 
 
53
  self._client = HuggingFaceEndpoint(
54
- repo_id=self.model,
55
- temperature=self.temperature,
 
 
56
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  else:
58
- raise ValueError(f"Unsupported provider: {self.provider}")
 
 
 
 
 
 
 
59
 
60
- def generate(self, system_prompt: str, user_prompt: str) -> str:
61
- if self.provider in {"ollama", "openai"}:
 
 
 
 
 
 
 
62
  messages = [
63
- SystemMessage(content=system_prompt),
64
- HumanMessage(content=user_prompt),
65
  ]
66
  response = self._client.invoke(messages)
67
- return response.content or ""
 
68
 
69
- if self.provider == "hf_endpoint":
70
- # HuggingFaceEndpoint is a text-completion-style model,
71
- # so we concatenate system + user into a single prompt.
72
- prompt = f"{system_prompt.strip()}\n\n{user_prompt.strip()}"
73
- response = self._client.invoke(prompt)
74
- return response or ""
 
 
 
 
75
 
76
- raise ValueError(
77
- f"Unsupported provider at generate(): {self.provider}")
 
1
+ # src/agentic_multiwriter/models/llm_client.py
2
  from __future__ import annotations
3
 
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Literal
7
 
8
+ from langchain_community.chat_models import ChatOllama
 
 
9
  from langchain_community.llms import HuggingFaceEndpoint
10
+ from langchain_openai import ChatOpenAI
11
 
12
+ from agentic_multiwriter.tools import get_logger
 
13
 
14
  logger = get_logger()
15
 
16
 
17
+ @dataclass
18
+ class LLMSettings:
19
+ """Configuration for the LLM backend."""
20
+ llm_provider: str = os.getenv("AMW_LLM_PROVIDER", "ollama").lower()
21
+ llm_model: str = os.getenv("AMW_LLM_MODEL", "llama3")
22
+ temperature: float = float(os.getenv("AMW_TEMPERATURE", "0.4"))
23
+
24
+
25
  class LLMClient:
26
  """
27
+ Small wrapper around different LLM backends:
28
+
29
+ - provider="ollama" -> local ChatOllama
30
+ - provider="hf_endpoint" -> Hugging Face Inference endpoint
31
+ - provider="openai" -> OpenAI Chat model
32
  """
33
 
34
+ def __init__(self, settings: LLMSettings | None = None) -> None:
35
+ self.settings = settings or LLMSettings()
36
+ provider = self.settings.llm_provider
37
+ model = self.settings.llm_model
38
+ temperature = self.settings.temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ if provider == "ollama":
41
+ self._mode: Literal["ollama", "hf_endpoint", "openai"] = "ollama"
42
  self._client = ChatOllama(
43
+ model=model,
44
+ temperature=temperature,
45
+ )
46
+
47
+ elif provider == "hf_endpoint":
48
+ # Uses HuggingFaceHosted Inference API (text-generation)
49
+ token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
50
+ if not token:
51
+ raise RuntimeError(
52
+ "HUGGINGFACEHUB_API_TOKEN is required when AMW_LLM_PROVIDER=hf_endpoint"
53
+ )
54
+
55
+ self._mode = "hf_endpoint"
56
+ # IMPORTANT: do NOT pass our own client object here; let
57
+ # HuggingFaceEndpoint build the correct internal client.
58
  self._client = HuggingFaceEndpoint(
59
+ repo_id=model,
60
+ huggingfacehub_api_token=token,
61
+ temperature=temperature,
62
+ max_new_tokens=800,
63
  )
64
+
65
+ elif provider == "openai":
66
+ api_key = os.getenv("OPENAI_API_KEY")
67
+ if not api_key:
68
+ raise RuntimeError(
69
+ "OPENAI_API_KEY is required when AMW_LLM_PROVIDER=openai"
70
+ )
71
+
72
+ self._mode = "openai"
73
+ self._client = ChatOpenAI(
74
+ model=model,
75
+ temperature=temperature,
76
+ api_key=api_key,
77
+ )
78
+
79
  else:
80
+ raise ValueError(f"Unknown AMW_LLM_PROVIDER='{provider}'")
81
+
82
+ logger.info(
83
+ "LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
84
+ provider,
85
+ model,
86
+ temperature,
87
+ )
88
 
89
+ def generate(self, *, system_prompt: str, user_prompt: str) -> str:
90
+ """
91
+ Unified generate() interface for all providers.
92
+ """
93
+ system_prompt = system_prompt.strip()
94
+ user_prompt = user_prompt.strip()
95
+
96
+ if self._mode in ("ollama", "openai"):
97
+ # Chat-style models (ChatOllama / ChatOpenAI)
98
  messages = [
99
+ ("system", system_prompt),
100
+ ("user", user_prompt),
101
  ]
102
  response = self._client.invoke(messages)
103
+ # Both ChatOllama and ChatOpenAI return an object with `.content`
104
+ return getattr(response, "content", str(response))
105
 
106
+ elif self._mode == "hf_endpoint":
107
+ # HuggingFaceEndpoint expects a single text prompt
108
+ prompt = (
109
+ f"{system_prompt}\n\n"
110
+ f"User:\n{user_prompt}\n\n"
111
+ f"Assistant:"
112
+ )
113
+ text = self._client.invoke(prompt)
114
+ # HuggingFaceEndpoint typically returns raw text
115
+ return text.strip() if isinstance(text, str) else str(text)
116
 
117
+ else:
118
+ raise RuntimeError("Unsupported LLM provider mode")