KheemDH commited on
Commit
73e8adc
Β·
verified Β·
1 Parent(s): 6948c1c
src/agentic_multiwriter/models/llm_client.py CHANGED
@@ -1,156 +1,145 @@
1
- # src/agentic_multiwriter/models/llm_client.py
2
-
3
- from __future__ import annotations
4
-
5
- import logging
6
  import os
 
7
  from dataclasses import dataclass
 
8
 
9
  from huggingface_hub import InferenceClient
10
  from langchain_ollama import ChatOllama
11
  from langchain_openai import ChatOpenAI
12
- from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
- # ──────────────────────────────────────────────────────────────────────────────
18
- # Settings
19
- # ──────────────────────────────────────────────────────────────────────────────
20
-
21
-
22
  @dataclass
23
  class LLMSettings:
24
- """Configuration for the LLM client."""
25
-
26
- llm_provider: str = os.getenv("AMW_LLM_PROVIDER", "ollama").strip() or "ollama"
27
- llm_model: str = os.getenv("AMW_LLM_MODEL", "llama3").strip() or "llama3"
28
  temperature: float = float(os.getenv("AMW_TEMPERATURE", "0.4"))
29
- max_tokens: int = int(os.getenv("AMW_MAX_TOKENS", "1024"))
30
 
 
 
31
 
32
- # ──────────────────────────────────────────────────────────────────────────────
33
- # Client wrapper
34
- # ──────────────────────────────────────────────────────────────────────────────
35
 
36
 
37
  class LLMClient:
38
  """
39
- Thin wrapper that normalizes access to different LLM providers:
40
 
41
- - openai β†’ ChatOpenAI (OpenAI API)
42
- - ollama β†’ ChatOllama (local server)
43
- - hf_endpoint β†’ huggingface_hub.InferenceClient (Hugging Face endpoints / Spaces)
44
  """
45
 
46
- def __init__(self, settings: LLMSettings | None = None) -> None:
47
  self.settings = settings or LLMSettings()
48
- self._client = self._build_client()
49
-
50
- # --------------------------------------------------------------------- #
51
- # Client factory
52
- # --------------------------------------------------------------------- #
53
-
54
- def _build_client(self):
55
- provider = self.settings.llm_provider.lower()
56
- model = self.settings.llm_model
57
 
58
  logger.info(
59
  "LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
60
- provider,
61
- model,
62
- self.settings.temperature,
63
  )
64
 
65
- if provider == "openai":
66
- # Uses OPENAI_API_KEY from env
67
- return ChatOpenAI(
68
- model=model,
69
- temperature=self.settings.temperature,
70
- max_tokens=self.settings.max_tokens,
71
  )
72
 
73
- if provider == "ollama":
74
- # Expects a local Ollama server (not available on HF Spaces)
75
- return ChatOllama(
76
- model=model,
77
- temperature=self.settings.temperature,
 
 
 
 
 
78
  )
79
 
80
- if provider == "hf_endpoint":
81
- # Hugging Face Inference API / Spaces
82
- # HUGGINGFACEHUB_API_TOKEN is recommended, but on Spaces the
83
- # internal token usually works without specifying it.
84
- hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
85
- if hf_token:
86
- client = InferenceClient(token=hf_token)
87
  logger.info("Using explicit HUGGINGFACEHUB_API_TOKEN for hf_endpoint.")
88
  else:
89
- client = InferenceClient()
90
- logger.info(
91
- "No HUGGINGFACEHUB_API_TOKEN set; relying on default HF Space token."
92
  )
93
- return client
94
-
95
- raise ValueError(f"Unsupported LLM provider: {provider}")
96
 
97
- # --------------------------------------------------------------------- #
98
- # Public API
99
- # --------------------------------------------------------------------- #
 
 
 
100
 
 
 
 
101
  def generate(self, system_prompt: str, user_prompt: str) -> str:
102
  """
103
- Generate text using the configured provider, given a system and user prompt.
104
- Returns only the text content (stripped).
105
  """
106
- provider = self.settings.llm_provider.lower()
107
-
108
- if provider == "openai":
109
- messages: list[BaseMessage] = [
110
- SystemMessage(content=system_prompt),
111
- HumanMessage(content=user_prompt),
112
- ]
113
- response = self._client.invoke(messages)
114
- text = response.content
115
 
116
- elif provider == "ollama":
117
- messages: list[BaseMessage] = [
118
  SystemMessage(content=system_prompt),
119
  HumanMessage(content=user_prompt),
120
  ]
121
  response = self._client.invoke(messages)
122
- text = response.content
123
 
124
- elif provider == "hf_endpoint":
125
- # IMPORTANT: Zephyr and many HF models in Spaces are exposed as
126
- # conversational / chat models. We therefore use `chat_completion`
127
- # instead of `text_generation`, which fixes the "task text-generation
128
- # not supported, use conversational" error you saw.
129
  messages = [
130
  {"role": "system", "content": system_prompt},
131
  {"role": "user", "content": user_prompt},
132
  ]
133
-
134
- response = self._client.chat_completion(
135
- model=self.settings.llm_model,
136
- messages=messages,
137
- max_tokens=self.settings.max_tokens,
138
- temperature=self.settings.temperature,
139
- top_p=0.95,
 
 
 
 
 
140
  )
141
 
142
- # huggingface_hub ChatCompletionOutput is a small dataclass-like
143
- # object. Its first choice contains the assistant message.
144
- choice = response.choices[0]
145
- message = choice.message
146
-
147
- # Handle both dict-style and attribute-style message objects
148
- if isinstance(message, dict):
149
- text = message.get("content", "")
150
- else:
151
- text = getattr(message, "content", "")
152
-
153
- else:
154
- raise ValueError(f"Unsupported LLM provider: {provider}")
155
-
156
- return (text or "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import logging
3
  from dataclasses import dataclass
4
+ from typing import Optional
5
 
6
  from huggingface_hub import InferenceClient
7
  from langchain_ollama import ChatOllama
8
  from langchain_openai import ChatOpenAI
9
+ from langchain_core.messages import SystemMessage, HumanMessage
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
 
 
 
 
 
14
  @dataclass
15
  class LLMSettings:
16
+ provider: str = os.getenv("AMW_LLM_PROVIDER", "ollama") # 'ollama', 'openai', 'hf_endpoint'
17
+ llm_model: str = os.getenv("AMW_LLM_MODEL", "llama3")
 
 
18
  temperature: float = float(os.getenv("AMW_TEMPERATURE", "0.4"))
19
+ max_tokens: int = int(os.getenv("AMW_MAX_TOKENS", "768"))
20
 
21
+ # HF token is optional; if not set, HF will still work for some public models
22
+ hf_api_token: Optional[str] = os.getenv("HUGGINGFACEHUB_API_TOKEN", None)
23
 
24
+ # OpenAI key is optional unless provider='openai'
25
+ openai_api_key: Optional[str] = os.getenv("OPENAI_API_KEY", None)
 
26
 
27
 
28
  class LLMClient:
29
  """
30
+ Thin wrapper over different backends:
31
 
32
+ - provider='ollama' -> local Ollama (ChatOllama)
33
+ - provider='openai' -> OpenAI ChatCompletion models
34
+ - provider='hf_endpoint' -> Hugging Face Inference API (text_generation)
35
  """
36
 
37
+ def __init__(self, settings: Optional[LLMSettings] = None) -> None:
38
  self.settings = settings or LLMSettings()
39
+ self.provider = self.settings.provider.lower()
40
+ self.model = self.settings.llm_model
41
+ self.temperature = self.settings.temperature
 
 
 
 
 
 
42
 
43
  logger.info(
44
  "LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
45
+ self.provider,
46
+ self.model,
47
+ self.temperature,
48
  )
49
 
50
+ if self.provider == "ollama":
51
+ self._client = ChatOllama(
52
+ model=self.model,
53
+ temperature=self.temperature,
 
 
54
  )
55
 
56
+ elif self.provider == "openai":
57
+ if not self.settings.openai_api_key:
58
+ logger.warning(
59
+ "OPENAI_API_KEY not set but provider='openai'. "
60
+ "Requests will fail until the key is configured."
61
+ )
62
+ self._client = ChatOpenAI(
63
+ model=self.model,
64
+ temperature=self.temperature,
65
+ api_key=self.settings.openai_api_key,
66
  )
67
 
68
+ elif self.provider in {"hf_endpoint", "huggingface", "hf"}:
69
+ # Bind the client directly to the model so we use the model's
70
+ # Inference API endpoint (not the generic router).
71
+ if self.settings.hf_api_token:
 
 
 
72
  logger.info("Using explicit HUGGINGFACEHUB_API_TOKEN for hf_endpoint.")
73
  else:
74
+ logger.warning(
75
+ "HUGGINGFACEHUB_API_TOKEN not set. For reliable HF Inference, "
76
+ "set it as a secret in your Space or local environment."
77
  )
 
 
 
78
 
79
+ self._client = InferenceClient(
80
+ model=self.model,
81
+ token=self.settings.hf_api_token,
82
+ )
83
+ else:
84
+ raise ValueError(f"Unknown LLM provider: {self.provider}")
85
 
86
+ # ---------------------------------------------------------------------
87
+ # Unified generate() API
88
+ # ---------------------------------------------------------------------
89
  def generate(self, system_prompt: str, user_prompt: str) -> str:
90
  """
91
+ Generates a single string response from the configured backend.
 
92
  """
 
 
 
 
 
 
 
 
 
93
 
94
+ if self.provider == "ollama":
95
+ messages = [
96
  SystemMessage(content=system_prompt),
97
  HumanMessage(content=user_prompt),
98
  ]
99
  response = self._client.invoke(messages)
100
+ return response.content # type: ignore[return-value]
101
 
102
+ if self.provider == "openai":
 
 
 
 
103
  messages = [
104
  {"role": "system", "content": system_prompt},
105
  {"role": "user", "content": user_prompt},
106
  ]
107
+ response = self._client.invoke(messages)
108
+ # langchain-openai returns AIMessage
109
+ return response.content # type: ignore[return-value]
110
+
111
+ if self.provider in {"hf_endpoint", "huggingface", "hf"}:
112
+ # For HF Inference we use plain text-generation.
113
+ # We concatenate system + user into a single prompt.
114
+ prompt = (
115
+ system_prompt.strip()
116
+ + "\n\nUser:\n"
117
+ + user_prompt.strip()
118
+ + "\n\nAssistant:"
119
  )
120
 
121
+ try:
122
+ text = self._client.text_generation(
123
+ prompt,
124
+ max_new_tokens=self.settings.max_tokens,
125
+ temperature=self.temperature,
126
+ do_sample=self.temperature > 0,
127
+ repetition_penalty=1.05,
128
+ return_full_text=False, # only new tokens
129
+ )
130
+ # text_generation returns a plain string when return_full_text=False
131
+ return text.strip()
132
+ except Exception as e:
133
+ logger.exception(
134
+ "Error while calling Hugging Face Inference API for model '%s': %s",
135
+ self.model,
136
+ e,
137
+ )
138
+ raise RuntimeError(
139
+ f"Hugging Face Inference error for model '{self.model}'. "
140
+ f"Check that the model supports text-generation and that "
141
+ f"your token has Inference permissions."
142
+ ) from e
143
+
144
+ # Should never reach here
145
+ raise RuntimeError(f"Unsupported provider: {self.provider}")