Spaces:
Sleeping
Sleeping
Updated LLMs Client
Browse files
src/agentic_multiwriter/models/llm_client.py
CHANGED
|
@@ -5,8 +5,8 @@ import os
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Literal
|
| 7 |
|
|
|
|
| 8 |
from langchain_community.chat_models import ChatOllama
|
| 9 |
-
from langchain_community.llms import HuggingFaceEndpoint
|
| 10 |
from langchain_openai import ChatOpenAI
|
| 11 |
|
| 12 |
from agentic_multiwriter.tools import get_logger
|
|
@@ -27,7 +27,7 @@ class LLMClient:
|
|
| 27 |
Small wrapper around different LLM backends:
|
| 28 |
|
| 29 |
- provider="ollama" -> local ChatOllama
|
| 30 |
-
- provider="hf_endpoint" -> Hugging Face Inference
|
| 31 |
- provider="openai" -> OpenAI Chat model
|
| 32 |
"""
|
| 33 |
|
|
@@ -45,7 +45,6 @@ class LLMClient:
|
|
| 45 |
)
|
| 46 |
|
| 47 |
elif provider == "hf_endpoint":
|
| 48 |
-
# Uses HuggingFaceHosted Inference API (text-generation)
|
| 49 |
token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 50 |
if not token:
|
| 51 |
raise RuntimeError(
|
|
@@ -53,13 +52,10 @@ class LLMClient:
|
|
| 53 |
)
|
| 54 |
|
| 55 |
self._mode = "hf_endpoint"
|
| 56 |
-
#
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
huggingfacehub_api_token=token,
|
| 61 |
-
temperature=temperature,
|
| 62 |
-
max_new_tokens=800,
|
| 63 |
)
|
| 64 |
|
| 65 |
elif provider == "openai":
|
|
@@ -100,18 +96,24 @@ class LLMClient:
|
|
| 100 |
("user", user_prompt),
|
| 101 |
]
|
| 102 |
response = self._client.invoke(messages)
|
| 103 |
-
# Both ChatOllama and ChatOpenAI return an object with `.content`
|
| 104 |
return getattr(response, "content", str(response))
|
| 105 |
|
| 106 |
elif self._mode == "hf_endpoint":
|
| 107 |
-
#
|
| 108 |
prompt = (
|
| 109 |
f"{system_prompt}\n\n"
|
| 110 |
f"User:\n{user_prompt}\n\n"
|
| 111 |
f"Assistant:"
|
| 112 |
)
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
return text.strip() if isinstance(text, str) else str(text)
|
| 116 |
|
| 117 |
else:
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Literal
|
| 7 |
|
| 8 |
+
from huggingface_hub import InferenceClient
|
| 9 |
from langchain_community.chat_models import ChatOllama
|
|
|
|
| 10 |
from langchain_openai import ChatOpenAI
|
| 11 |
|
| 12 |
from agentic_multiwriter.tools import get_logger
|
|
|
|
| 27 |
Small wrapper around different LLM backends:
|
| 28 |
|
| 29 |
- provider="ollama" -> local ChatOllama
|
| 30 |
+
- provider="hf_endpoint" -> Hugging Face Inference API (InferenceClient)
|
| 31 |
- provider="openai" -> OpenAI Chat model
|
| 32 |
"""
|
| 33 |
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
elif provider == "hf_endpoint":
|
|
|
|
| 48 |
token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 49 |
if not token:
|
| 50 |
raise RuntimeError(
|
|
|
|
| 52 |
)
|
| 53 |
|
| 54 |
self._mode = "hf_endpoint"
|
| 55 |
+
# Direct low-level client; no LangChain wrapper here.
|
| 56 |
+
self._client = InferenceClient(
|
| 57 |
+
model=model,
|
| 58 |
+
token=token,
|
|
|
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
elif provider == "openai":
|
|
|
|
| 96 |
("user", user_prompt),
|
| 97 |
]
|
| 98 |
response = self._client.invoke(messages)
|
|
|
|
| 99 |
return getattr(response, "content", str(response))
|
| 100 |
|
| 101 |
elif self._mode == "hf_endpoint":
|
| 102 |
+
# Use HF InferenceClient directly (text generation)
|
| 103 |
prompt = (
|
| 104 |
f"{system_prompt}\n\n"
|
| 105 |
f"User:\n{user_prompt}\n\n"
|
| 106 |
f"Assistant:"
|
| 107 |
)
|
| 108 |
+
|
| 109 |
+
text = self._client.text_generation(
|
| 110 |
+
prompt,
|
| 111 |
+
max_new_tokens=800,
|
| 112 |
+
temperature=self.settings.temperature,
|
| 113 |
+
top_p=0.95,
|
| 114 |
+
do_sample=True,
|
| 115 |
+
)
|
| 116 |
+
# InferenceClient.text_generation returns a plain string
|
| 117 |
return text.strip() if isinstance(text, str) else str(text)
|
| 118 |
|
| 119 |
else:
|