Spaces:
Running on T4
Running on T4
File size: 4,201 Bytes
e6b0e2f 10418d0 e6b0e2f 10418d0 e6b0e2f 21da591 e6b0e2f 3b78637 e6b0e2f 21da591 e6b0e2f 10418d0 e6b0e2f 4ac72af 10418d0 e6b0e2f 6506d63 e6b0e2f 10418d0 e6b0e2f 3dc48b7 10418d0 e6b0e2f 10418d0 e6b0e2f 4ac72af 10418d0 4ac72af e6b0e2f 3b78637 e6b0e2f 10418d0 3b78637 4ae001d 3b78637 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | """
Voice agent for Layer 2 conversations.
Uses Llama 3.1 8B Instruct to act as the customer support agent during
evaluation. Supports both local model and HF Inference API backends.
"""
from __future__ import annotations
import logging
import os
import time
from typing import Any
try:
from huggingface_hub import InferenceClient
except ImportError:
InferenceClient = None # type: ignore
logger = logging.getLogger(__name__)
class HFAgent:
"""
Voice agent powered by Llama 3.1 8B.
Takes a system prompt from Layer 1 and generates responses
in the customer support conversation.
Supports two backends:
- local: loads model in-process via transformers (pass local_model=...)
- api: uses HF Inference API (pass hf_token=...)
"""
DEFAULT_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct"
def __init__(
self,
model_id: str | None = None,
hf_token: str | None = None,
max_tokens: int = 300,
temperature: float = 0.3,
local_model: Any = None,
):
self.model_id = model_id or self.DEFAULT_MODEL
self.max_tokens = max_tokens
self.temperature = temperature
self._local_model = local_model
self._client: Any = None
if local_model is None:
self.hf_token = hf_token or os.environ.get("HF_TOKEN")
if self.hf_token and InferenceClient is not None:
self._client = InferenceClient(token=self.hf_token)
@property
def is_llm_available(self) -> bool:
return self._local_model is not None or self._client is not None
def __call__(
self,
system_prompt: str,
conversation_history: list[dict[str, str]],
observation: dict[str, Any],
max_retries: int = 4,
) -> str:
"""
Generate an agent response.
Compatible with ConversationEnvironment.run_episode(agent_fn=...).
"""
messages = [{"role": "system", "content": system_prompt}]
for msg in conversation_history:
if msg["role"] == "customer":
messages.append({"role": "user", "content": msg["content"]})
elif msg["role"] == "agent":
messages.append({"role": "assistant", "content": msg["content"]})
# Add the latest customer message from observation
customer_msg = observation.get("customer_message", "")
if customer_msg:
messages.append({"role": "user", "content": customer_msg})
if self._local_model is not None:
return self._local_model.generate(
messages, max_tokens=self.max_tokens, temperature=self.temperature,
)
if self._client is None:
raise RuntimeError(
"No inference backend available. "
"Pass local_model=... or set HF_TOKEN for API access."
)
last_err = None
for attempt in range(max_retries + 1):
try:
response = self._client.chat_completion(
model=self.model_id,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
return response.choices[0].message.content.strip()
except Exception as e:
err_str = str(e)
if "402" in err_str or "Payment Required" in err_str:
raise RuntimeError(
"HF API credits depleted. "
"Get more credits at https://huggingface.co/settings/billing"
) from e
if any(code in err_str for code in ("500", "502", "503", "504", "429", "timeout", "Timeout", "Time-out")):
last_err = e
wait = 2 ** (attempt + 1) # 2, 4, 8, 16s
logger.warning(
"HF API error (attempt %d/%d), retrying in %ds: %s",
attempt + 1, max_retries + 1, wait, e,
)
time.sleep(wait)
continue
raise
raise last_err # type: ignore[misc]
|