File size: 7,497 Bytes
62e9f37
 
cafd75d
73e8adc
 
32f259e
96c62f3
32f259e
62e9f37
32f259e
62e9f37
 
 
 
 
 
 
 
6948c1c
62e9f37
 
 
 
6948c1c
62e9f37
cafd75d
 
32f259e
 
62e9f37
cafd75d
62e9f37
 
 
 
 
 
 
 
 
32f259e
 
62e9f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32f259e
6948c1c
 
73e8adc
 
 
6948c1c
 
62e9f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cafd75d
 
62e9f37
 
 
 
32f259e
cafd75d
62e9f37
 
 
 
 
 
 
 
 
 
 
 
73e8adc
62e9f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32f259e
6948c1c
cafd75d
62e9f37
cafd75d
62e9f37
 
 
 
 
 
96c62f3
62e9f37
 
 
6948c1c
62e9f37
 
 
 
 
 
 
6948c1c
62e9f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# src/agentic_multiwriter/models/llm_client.py

import os
import logging
from typing import Optional

from huggingface_hub import InferenceClient

from langchain_core.messages import SystemMessage, HumanMessage

try:
    # Modern LangChain + OpenAI
    from langchain_openai import ChatOpenAI
except ImportError:  # Fallback for older setups
    try:
        from langchain.chat_models import ChatOpenAI  # type: ignore
    except ImportError:
        ChatOpenAI = None  # type: ignore

try:
    from langchain_ollama import ChatOllama
except ImportError:
    ChatOllama = None  # type: ignore

logger = logging.getLogger(__name__)


class LLMClient:
    """
    Unified LLM client.

    Providers:
      - openai      -> ChatOpenAI (gpt-4o-mini, etc.)
      - ollama      -> Local Ollama server (not used on HF Spaces)
      - hf_endpoint -> Hugging Face Inference API (backup / optional)

    Defaults:
      AMW_LLM_PROVIDER = "openai"
      AMW_LLM_MODEL    = "gpt-4o-mini"
      AMW_TEMPERATURE  = 0.3
    """

    def __init__(
        self,
        provider: Optional[str] = None,
        model: Optional[str] = None,
        temperature: Optional[float] = None,
    ) -> None:
        # ---------- Resolve configuration ----------
        self.provider = (provider or os.getenv("AMW_LLM_PROVIDER", "openai")).lower()
        self.temperature = float(temperature or os.getenv("AMW_TEMPERATURE", "0.3"))

        if model is not None:
            self.model = model
        else:
            if self.provider == "openai":
                self.model = os.getenv("AMW_LLM_MODEL", "gpt-4o-mini")
            elif self.provider == "ollama":
                self.model = os.getenv("AMW_LLM_MODEL", "llama3")
            elif self.provider == "hf_endpoint":
                # Only used if you deliberately switch to HF Inference
                self.model = os.getenv("AMW_LLM_MODEL", "gpt2")
            else:
                raise ValueError(f"Unknown LLM provider: {self.provider}")

        logger.info(
            "LLMClient initialized with provider='%s', model='%s', temperature=%.2f",
            self.provider,
            self.model,
            self.temperature,
        )

        # ---------- Initialize backend client ----------
        if self.provider == "openai":
            self._init_openai_client()
        elif self.provider == "ollama":
            self._init_ollama_client()
        elif self.provider == "hf_endpoint":
            self._init_hf_client()
        else:
            raise ValueError(f"Unsupported provider: {self.provider}")

    # ------------------------------------------------------------------
    # Provider initializers
    # ------------------------------------------------------------------

    def _init_openai_client(self) -> None:
        if ChatOpenAI is None:
            raise RuntimeError(
                "ChatOpenAI could not be imported. Make sure 'langchain-openai' "
                "is installed (e.g., `pip install langchain-openai`)."
            )

        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            logger.warning(
                "OPENAI_API_KEY is not set; OpenAI calls will fail until it is configured."
            )

        # ChatOpenAI reads OPENAI_API_KEY from the environment by default.
        self._client = ChatOpenAI(
            model=self.model,
            temperature=self.temperature,
            # Do NOT pass the key explicitly – let it read from env
            # api_key=api_key  # (optional if you want to be explicit)
        )

    def _init_ollama_client(self) -> None:
        if ChatOllama is None:
            raise RuntimeError(
                "langchain_ollama is not installed, but provider='ollama' was selected."
            )
        self._client = ChatOllama(
            model=self.model,
            temperature=self.temperature,
        )

    def _init_hf_client(self) -> None:
        """
        Optional: Hugging Face Inference client (not used if you stay on OpenAI).

        Uses HUGGINGFACEHUB_API_TOKEN from env, which is automatically set
        inside your own Space if you define it as a secret.
        """
        hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
        if not hf_token:
            logger.warning(
                "HUGGINGFACEHUB_API_TOKEN is not set. HF Inference calls will fail "
                "unless the environment injects the token (e.g., in a HF Space)."
            )

        self._client = InferenceClient(
            model=self.model,
            token=hf_token,
        )

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def generate(self, system_prompt: str, user_prompt: str) -> str:
        """
        Generate text from the configured model.
        """
        if self.provider in ("openai", "ollama"):
            return self._generate_chat_model(system_prompt, user_prompt)
        elif self.provider == "hf_endpoint":
            return self._generate_hf_text(system_prompt, user_prompt)
        else:
            raise ValueError(f"Unsupported provider in generate(): {self.provider}")

    # ------------------------------------------------------------------
    # OpenAI / Ollama (chat-style models via LangChain)
    # ------------------------------------------------------------------

    def _generate_chat_model(self, system_prompt: str, user_prompt: str) -> str:
        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(content=user_prompt),
        ]

        resp = self._client.invoke(messages)  # type: ignore[attr-defined]

        # LangChain ChatModels usually return a ChatMessage with `.content`
        text = getattr(resp, "content", None)
        if not isinstance(text, str):
            text = str(resp)
        return text

    # ------------------------------------------------------------------
    # Hugging Face Inference (text-generation; optional)
    # ------------------------------------------------------------------

    def _generate_hf_text(self, system_prompt: str, user_prompt: str) -> str:
        """
        Use Hugging Face Inference `text_generation`.

        Only used if AMW_LLM_PROVIDER=hf_endpoint.
        """
        prompt = (
            f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
            f"<<USER>>\n{user_prompt}\n<</USER>>\n\n"
            "Assistant:"
        )

        try:
            text = self._client.text_generation(
                prompt,
                max_new_tokens=512,
                temperature=self.temperature,
                do_sample=True,
                top_p=0.9,
                return_full_text=False,
            )
        except Exception as e:  # noqa: BLE001
            logger.error(
                "Error while calling Hugging Face Inference API for model '%s': %s",
                self.model,
                e,
                exc_info=True,
            )
            raise RuntimeError(
                f"Hugging Face Inference error for model '{self.model}'. "
                f"Ensure the model supports 'text-generation' and that your token "
                f"has Inference permissions."
            ) from e

        if isinstance(text, str):
            return text

        try:
            return text.get("generated_text", str(text))  # type: ignore[arg-type]
        except Exception:  # noqa: BLE001
            return str(text)