File size: 4,201 Bytes
e6b0e2f
10418d0
e6b0e2f
10418d0
 
e6b0e2f
 
 
 
21da591
e6b0e2f
3b78637
e6b0e2f
 
 
 
 
 
 
21da591
 
e6b0e2f
 
 
10418d0
e6b0e2f
4ac72af
10418d0
 
 
 
 
e6b0e2f
 
6506d63
e6b0e2f
10418d0
 
 
 
 
 
 
 
e6b0e2f
3dc48b7
 
10418d0
e6b0e2f
10418d0
 
 
 
 
e6b0e2f
4ac72af
 
10418d0
4ac72af
e6b0e2f
 
 
 
 
3b78637
e6b0e2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10418d0
 
 
 
 
 
 
 
 
 
 
3b78637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae001d
3b78637
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Voice agent for Layer 2 conversations.

Uses Llama 3.1 8B Instruct to act as the customer support agent during
evaluation. Supports both local model and HF Inference API backends.
"""

from __future__ import annotations

import logging
import os
import time
from typing import Any

try:
    from huggingface_hub import InferenceClient
except ImportError:
    InferenceClient = None  # type: ignore

logger = logging.getLogger(__name__)


class HFAgent:
    """
    Voice agent powered by Llama 3.1 8B.

    Takes a system prompt from Layer 1 and generates responses
    in the customer support conversation.

    Supports two backends:
    - local: loads model in-process via transformers (pass local_model=...)
    - api: uses HF Inference API (pass hf_token=...)
    """

    DEFAULT_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct"

    def __init__(
        self,
        model_id: str | None = None,
        hf_token: str | None = None,
        max_tokens: int = 300,
        temperature: float = 0.3,
        local_model: Any = None,
    ):
        self.model_id = model_id or self.DEFAULT_MODEL
        self.max_tokens = max_tokens
        self.temperature = temperature
        self._local_model = local_model
        self._client: Any = None

        if local_model is None:
            self.hf_token = hf_token or os.environ.get("HF_TOKEN")
            if self.hf_token and InferenceClient is not None:
                self._client = InferenceClient(token=self.hf_token)

    @property
    def is_llm_available(self) -> bool:
        return self._local_model is not None or self._client is not None

    def __call__(
        self,
        system_prompt: str,
        conversation_history: list[dict[str, str]],
        observation: dict[str, Any],
        max_retries: int = 4,
    ) -> str:
        """
        Generate an agent response.

        Compatible with ConversationEnvironment.run_episode(agent_fn=...).
        """
        messages = [{"role": "system", "content": system_prompt}]

        for msg in conversation_history:
            if msg["role"] == "customer":
                messages.append({"role": "user", "content": msg["content"]})
            elif msg["role"] == "agent":
                messages.append({"role": "assistant", "content": msg["content"]})

        # Add the latest customer message from observation
        customer_msg = observation.get("customer_message", "")
        if customer_msg:
            messages.append({"role": "user", "content": customer_msg})

        if self._local_model is not None:
            return self._local_model.generate(
                messages, max_tokens=self.max_tokens, temperature=self.temperature,
            )

        if self._client is None:
            raise RuntimeError(
                "No inference backend available. "
                "Pass local_model=... or set HF_TOKEN for API access."
            )

        last_err = None
        for attempt in range(max_retries + 1):
            try:
                response = self._client.chat_completion(
                    model=self.model_id,
                    messages=messages,
                    max_tokens=self.max_tokens,
                    temperature=self.temperature,
                )
                return response.choices[0].message.content.strip()
            except Exception as e:
                err_str = str(e)
                if "402" in err_str or "Payment Required" in err_str:
                    raise RuntimeError(
                        "HF API credits depleted. "
                        "Get more credits at https://huggingface.co/settings/billing"
                    ) from e
                if any(code in err_str for code in ("500", "502", "503", "504", "429", "timeout", "Timeout", "Time-out")):
                    last_err = e
                    wait = 2 ** (attempt + 1)  # 2, 4, 8, 16s
                    logger.warning(
                        "HF API error (attempt %d/%d), retrying in %ds: %s",
                        attempt + 1, max_retries + 1, wait, e,
                    )
                    time.sleep(wait)
                    continue
                raise
        raise last_err  # type: ignore[misc]