KheemDH commited on
Commit
96c62f3
·
verified ·
1 Parent(s): cafd75d

Updated LLMs Client

Browse files
src/agentic_multiwriter/models/llm_client.py CHANGED
@@ -5,8 +5,8 @@ import os
5
  from dataclasses import dataclass
6
  from typing import Literal
7
 
 
8
  from langchain_community.chat_models import ChatOllama
9
- from langchain_community.llms import HuggingFaceEndpoint
10
  from langchain_openai import ChatOpenAI
11
 
12
  from agentic_multiwriter.tools import get_logger
@@ -27,7 +27,7 @@ class LLMClient:
27
  Small wrapper around different LLM backends:
28
 
29
  - provider="ollama" -> local ChatOllama
30
- - provider="hf_endpoint" -> Hugging Face Inference endpoint
31
  - provider="openai" -> OpenAI Chat model
32
  """
33
 
@@ -45,7 +45,6 @@ class LLMClient:
45
  )
46
 
47
  elif provider == "hf_endpoint":
48
- # Uses HuggingFaceHosted Inference API (text-generation)
49
  token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
50
  if not token:
51
  raise RuntimeError(
@@ -53,13 +52,10 @@ class LLMClient:
53
  )
54
 
55
  self._mode = "hf_endpoint"
56
- # IMPORTANT: do NOT pass our own client object here; let
57
- # HuggingFaceEndpoint build the correct internal client.
58
- self._client = HuggingFaceEndpoint(
59
- repo_id=model,
60
- huggingfacehub_api_token=token,
61
- temperature=temperature,
62
- max_new_tokens=800,
63
  )
64
 
65
  elif provider == "openai":
@@ -100,18 +96,24 @@ class LLMClient:
100
  ("user", user_prompt),
101
  ]
102
  response = self._client.invoke(messages)
103
- # Both ChatOllama and ChatOpenAI return an object with `.content`
104
  return getattr(response, "content", str(response))
105
 
106
  elif self._mode == "hf_endpoint":
107
- # HuggingFaceEndpoint expects a single text prompt
108
  prompt = (
109
  f"{system_prompt}\n\n"
110
  f"User:\n{user_prompt}\n\n"
111
  f"Assistant:"
112
  )
113
- text = self._client.invoke(prompt)
114
- # HuggingFaceEndpoint typically returns raw text
 
 
 
 
 
 
 
115
  return text.strip() if isinstance(text, str) else str(text)
116
 
117
  else:
 
5
  from dataclasses import dataclass
6
  from typing import Literal
7
 
8
+ from huggingface_hub import InferenceClient
9
  from langchain_community.chat_models import ChatOllama
 
10
  from langchain_openai import ChatOpenAI
11
 
12
  from agentic_multiwriter.tools import get_logger
 
27
  Small wrapper around different LLM backends:
28
 
29
  - provider="ollama" -> local ChatOllama
30
+ - provider="hf_endpoint" -> Hugging Face Inference API (InferenceClient)
31
  - provider="openai" -> OpenAI Chat model
32
  """
33
 
 
45
  )
46
 
47
  elif provider == "hf_endpoint":
 
48
  token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
49
  if not token:
50
  raise RuntimeError(
 
52
  )
53
 
54
  self._mode = "hf_endpoint"
55
+ # Direct low-level client; no LangChain wrapper here.
56
+ self._client = InferenceClient(
57
+ model=model,
58
+ token=token,
 
 
 
59
  )
60
 
61
  elif provider == "openai":
 
96
  ("user", user_prompt),
97
  ]
98
  response = self._client.invoke(messages)
 
99
  return getattr(response, "content", str(response))
100
 
101
  elif self._mode == "hf_endpoint":
102
+ # Use HF InferenceClient directly (text generation)
103
  prompt = (
104
  f"{system_prompt}\n\n"
105
  f"User:\n{user_prompt}\n\n"
106
  f"Assistant:"
107
  )
108
+
109
+ text = self._client.text_generation(
110
+ prompt,
111
+ max_new_tokens=800,
112
+ temperature=self.settings.temperature,
113
+ top_p=0.95,
114
+ do_sample=True,
115
+ )
116
+ # InferenceClient.text_generation returns a plain string
117
  return text.strip() if isinstance(text, str) else str(text)
118
 
119
  else: