KheemDH commited on
Commit
62e9f37
·
verified ·
1 Parent(s): 73e8adc
src/agentic_multiwriter/models/llm_client.py CHANGED
@@ -1,44 +1,67 @@
 
 
1
  import os
2
  import logging
3
- from dataclasses import dataclass
4
  from typing import Optional
5
 
6
  from huggingface_hub import InferenceClient
7
- from langchain_ollama import ChatOllama
8
- from langchain_openai import ChatOpenAI
9
- from langchain_core.messages import SystemMessage, HumanMessage
10
-
11
- logger = logging.getLogger(__name__)
12
 
 
13
 
14
- @dataclass
15
- class LLMSettings:
16
- provider: str = os.getenv("AMW_LLM_PROVIDER", "ollama") # 'ollama', 'openai', 'hf_endpoint'
17
- llm_model: str = os.getenv("AMW_LLM_MODEL", "llama3")
18
- temperature: float = float(os.getenv("AMW_TEMPERATURE", "0.4"))
19
- max_tokens: int = int(os.getenv("AMW_MAX_TOKENS", "768"))
 
 
20
 
21
- # HF token is optional; if not set, HF will still work for some public models
22
- hf_api_token: Optional[str] = os.getenv("HUGGINGFACEHUB_API_TOKEN", None)
 
 
23
 
24
- # OpenAI key is optional unless provider='openai'
25
- openai_api_key: Optional[str] = os.getenv("OPENAI_API_KEY", None)
26
 
27
 
28
  class LLMClient:
29
  """
30
- Thin wrapper over different backends:
31
 
32
- - provider='ollama' -> local Ollama (ChatOllama)
33
- - provider='openai' -> OpenAI ChatCompletion models
34
- - provider='hf_endpoint' -> Hugging Face Inference API (text_generation)
 
 
 
 
 
 
35
  """
36
 
37
- def __init__(self, settings: Optional[LLMSettings] = None) -> None:
38
- self.settings = settings or LLMSettings()
39
- self.provider = self.settings.provider.lower()
40
- self.model = self.settings.llm_model
41
- self.temperature = self.settings.temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  logger.info(
44
  "LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
@@ -47,99 +70,145 @@ class LLMClient:
47
  self.temperature,
48
  )
49
 
50
- if self.provider == "ollama":
51
- self._client = ChatOllama(
52
- model=self.model,
53
- temperature=self.temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
55
 
56
- elif self.provider == "openai":
57
- if not self.settings.openai_api_key:
58
- logger.warning(
59
- "OPENAI_API_KEY not set but provider='openai'. "
60
- "Requests will fail until the key is configured."
61
- )
62
- self._client = ChatOpenAI(
63
- model=self.model,
64
- temperature=self.temperature,
65
- api_key=self.settings.openai_api_key,
66
  )
67
 
68
- elif self.provider in {"hf_endpoint", "huggingface", "hf"}:
69
- # Bind the client directly to the model so we use the model's
70
- # Inference API endpoint (not the generic router).
71
- if self.settings.hf_api_token:
72
- logger.info("Using explicit HUGGINGFACEHUB_API_TOKEN for hf_endpoint.")
73
- else:
74
- logger.warning(
75
- "HUGGINGFACEHUB_API_TOKEN not set. For reliable HF Inference, "
76
- "set it as a secret in your Space or local environment."
77
- )
78
-
79
- self._client = InferenceClient(
80
- model=self.model,
81
- token=self.settings.hf_api_token,
82
  )
83
- else:
84
- raise ValueError(f"Unknown LLM provider: {self.provider}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # ---------------------------------------------------------------------
87
- # Unified generate() API
88
- # ---------------------------------------------------------------------
89
  def generate(self, system_prompt: str, user_prompt: str) -> str:
90
  """
91
- Generates a single string response from the configured backend.
92
  """
 
 
 
 
 
 
93
 
94
- if self.provider == "ollama":
95
- messages = [
96
- SystemMessage(content=system_prompt),
97
- HumanMessage(content=user_prompt),
98
- ]
99
- response = self._client.invoke(messages)
100
- return response.content # type: ignore[return-value]
101
 
102
- if self.provider == "openai":
103
- messages = [
104
- {"role": "system", "content": system_prompt},
105
- {"role": "user", "content": user_prompt},
106
- ]
107
- response = self._client.invoke(messages)
108
- # langchain-openai returns AIMessage
109
- return response.content # type: ignore[return-value]
110
-
111
- if self.provider in {"hf_endpoint", "huggingface", "hf"}:
112
- # For HF Inference we use plain text-generation.
113
- # We concatenate system + user into a single prompt.
114
- prompt = (
115
- system_prompt.strip()
116
- + "\n\nUser:\n"
117
- + user_prompt.strip()
118
- + "\n\nAssistant:"
119
- )
120
 
121
- try:
122
- text = self._client.text_generation(
123
- prompt,
124
- max_new_tokens=self.settings.max_tokens,
125
- temperature=self.temperature,
126
- do_sample=self.temperature > 0,
127
- repetition_penalty=1.05,
128
- return_full_text=False, # only new tokens
129
- )
130
- # text_generation returns a plain string when return_full_text=False
131
- return text.strip()
132
- except Exception as e:
133
- logger.exception(
134
- "Error while calling Hugging Face Inference API for model '%s': %s",
135
- self.model,
136
- e,
137
- )
138
- raise RuntimeError(
139
- f"Hugging Face Inference error for model '{self.model}'. "
140
- f"Check that the model supports text-generation and that "
141
- f"your token has Inference permissions."
142
- ) from e
143
-
144
- # Should never reach here
145
- raise RuntimeError(f"Unsupported provider: {self.provider}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/agentic_multiwriter/models/llm_client.py
2
+
3
  import os
4
  import logging
 
5
  from typing import Optional
6
 
7
  from huggingface_hub import InferenceClient
 
 
 
 
 
8
 
9
+ from langchain_core.messages import SystemMessage, HumanMessage
10
 
11
+ try:
12
+ # Modern LangChain + OpenAI
13
+ from langchain_openai import ChatOpenAI
14
+ except ImportError: # Fallback for older setups
15
+ try:
16
+ from langchain.chat_models import ChatOpenAI # type: ignore
17
+ except ImportError:
18
+ ChatOpenAI = None # type: ignore
19
 
20
+ try:
21
+ from langchain_ollama import ChatOllama
22
+ except ImportError:
23
+ ChatOllama = None # type: ignore
24
 
25
+ logger = logging.getLogger(__name__)
 
26
 
27
 
28
  class LLMClient:
29
  """
30
+ Unified LLM client.
31
 
32
+ Providers:
33
+ - openai -> ChatOpenAI (gpt-4o-mini, etc.)
34
+ - ollama -> Local Ollama server (not used on HF Spaces)
35
+ - hf_endpoint -> Hugging Face Inference API (backup / optional)
36
+
37
+ Defaults:
38
+ AMW_LLM_PROVIDER = "openai"
39
+ AMW_LLM_MODEL = "gpt-4o-mini"
40
+ AMW_TEMPERATURE = 0.3
41
  """
42
 
43
+ def __init__(
44
+ self,
45
+ provider: Optional[str] = None,
46
+ model: Optional[str] = None,
47
+ temperature: Optional[float] = None,
48
+ ) -> None:
49
+ # ---------- Resolve configuration ----------
50
+ self.provider = (provider or os.getenv("AMW_LLM_PROVIDER", "openai")).lower()
51
+ self.temperature = float(temperature or os.getenv("AMW_TEMPERATURE", "0.3"))
52
+
53
+ if model is not None:
54
+ self.model = model
55
+ else:
56
+ if self.provider == "openai":
57
+ self.model = os.getenv("AMW_LLM_MODEL", "gpt-4o-mini")
58
+ elif self.provider == "ollama":
59
+ self.model = os.getenv("AMW_LLM_MODEL", "llama3")
60
+ elif self.provider == "hf_endpoint":
61
+ # Only used if you deliberately switch to HF Inference
62
+ self.model = os.getenv("AMW_LLM_MODEL", "gpt2")
63
+ else:
64
+ raise ValueError(f"Unknown LLM provider: {self.provider}")
65
 
66
  logger.info(
67
  "LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
 
70
  self.temperature,
71
  )
72
 
73
+ # ---------- Initialize backend client ----------
74
+ if self.provider == "openai":
75
+ self._init_openai_client()
76
+ elif self.provider == "ollama":
77
+ self._init_ollama_client()
78
+ elif self.provider == "hf_endpoint":
79
+ self._init_hf_client()
80
+ else:
81
+ raise ValueError(f"Unsupported provider: {self.provider}")
82
+
83
+ # ------------------------------------------------------------------
84
+ # Provider initializers
85
+ # ------------------------------------------------------------------
86
+
87
+ def _init_openai_client(self) -> None:
88
+ if ChatOpenAI is None:
89
+ raise RuntimeError(
90
+ "ChatOpenAI could not be imported. Make sure 'langchain-openai' "
91
+ "is installed (e.g., `pip install langchain-openai`)."
92
  )
93
 
94
+ api_key = os.getenv("OPENAI_API_KEY")
95
+ if not api_key:
96
+ logger.warning(
97
+ "OPENAI_API_KEY is not set; OpenAI calls will fail until it is configured."
 
 
 
 
 
 
98
  )
99
 
100
+ # ChatOpenAI reads OPENAI_API_KEY from the environment by default.
101
+ self._client = ChatOpenAI(
102
+ model=self.model,
103
+ temperature=self.temperature,
104
+ # Do NOT pass the key explicitly – let it read from env
105
+ # api_key=api_key # (optional if you want to be explicit)
106
+ )
107
+
108
+ def _init_ollama_client(self) -> None:
109
+ if ChatOllama is None:
110
+ raise RuntimeError(
111
+ "langchain_ollama is not installed, but provider='ollama' was selected."
 
 
112
  )
113
+ self._client = ChatOllama(
114
+ model=self.model,
115
+ temperature=self.temperature,
116
+ )
117
+
118
+ def _init_hf_client(self) -> None:
119
+ """
120
+ Optional: Hugging Face Inference client (not used if you stay on OpenAI).
121
+
122
+ Uses HUGGINGFACEHUB_API_TOKEN from env, which is automatically set
123
+ inside your own Space if you define it as a secret.
124
+ """
125
+ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
126
+ if not hf_token:
127
+ logger.warning(
128
+ "HUGGINGFACEHUB_API_TOKEN is not set. HF Inference calls will fail "
129
+ "unless the environment injects the token (e.g., in a HF Space)."
130
+ )
131
+
132
+ self._client = InferenceClient(
133
+ model=self.model,
134
+ token=hf_token,
135
+ )
136
+
137
+ # ------------------------------------------------------------------
138
+ # Public API
139
+ # ------------------------------------------------------------------
140
 
 
 
 
141
  def generate(self, system_prompt: str, user_prompt: str) -> str:
142
  """
143
+ Generate text from the configured model.
144
  """
145
+ if self.provider in ("openai", "ollama"):
146
+ return self._generate_chat_model(system_prompt, user_prompt)
147
+ elif self.provider == "hf_endpoint":
148
+ return self._generate_hf_text(system_prompt, user_prompt)
149
+ else:
150
+ raise ValueError(f"Unsupported provider in generate(): {self.provider}")
151
 
152
+ # ------------------------------------------------------------------
153
+ # OpenAI / Ollama (chat-style models via LangChain)
154
+ # ------------------------------------------------------------------
 
 
 
 
155
 
156
+ def _generate_chat_model(self, system_prompt: str, user_prompt: str) -> str:
157
+ messages = [
158
+ SystemMessage(content=system_prompt),
159
+ HumanMessage(content=user_prompt),
160
+ ]
161
+
162
+ resp = self._client.invoke(messages) # type: ignore[attr-defined]
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ # LangChain ChatModels usually return a ChatMessage with `.content`
165
+ text = getattr(resp, "content", None)
166
+ if not isinstance(text, str):
167
+ text = str(resp)
168
+ return text
169
+
170
+ # ------------------------------------------------------------------
171
+ # Hugging Face Inference (text-generation; optional)
172
+ # ------------------------------------------------------------------
173
+
174
+ def _generate_hf_text(self, system_prompt: str, user_prompt: str) -> str:
175
+ """
176
+ Use Hugging Face Inference `text_generation`.
177
+
178
+ Only used if AMW_LLM_PROVIDER=hf_endpoint.
179
+ """
180
+ prompt = (
181
+ f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
182
+ f"<<USER>>\n{user_prompt}\n<</USER>>\n\n"
183
+ "Assistant:"
184
+ )
185
+
186
+ try:
187
+ text = self._client.text_generation(
188
+ prompt,
189
+ max_new_tokens=512,
190
+ temperature=self.temperature,
191
+ do_sample=True,
192
+ top_p=0.9,
193
+ return_full_text=False,
194
+ )
195
+ except Exception as e: # noqa: BLE001
196
+ logger.error(
197
+ "Error while calling Hugging Face Inference API for model '%s': %s",
198
+ self.model,
199
+ e,
200
+ exc_info=True,
201
+ )
202
+ raise RuntimeError(
203
+ f"Hugging Face Inference error for model '{self.model}'. "
204
+ f"Ensure the model supports 'text-generation' and that your token "
205
+ f"has Inference permissions."
206
+ ) from e
207
+
208
+ if isinstance(text, str):
209
+ return text
210
+
211
+ try:
212
+ return text.get("generated_text", str(text)) # type: ignore[arg-type]
213
+ except Exception: # noqa: BLE001
214
+ return str(text)