trading-tools / utils /llm /hf_providers_client.py
Deploy Bot
Deploy Trading Analysis Platform to HuggingFace Spaces
a1bf219
"""HuggingFace Inference Providers client with LangChain integration.
This module provides a LangChain-compatible wrapper around HuggingFace's
Inference Providers API (router.huggingface.co), which provides unified
access to 22+ LLM providers including Groq, Together AI, Replicate, and more.
Features:
- Routing policies (:fastest, :cheapest, explicit provider selection)
- Token usage tracking from response metadata
- Cost tracking integration via callbacks
- Streaming and non-streaming inference
- Response caching for identical requests (LRU cache)
"""
import hashlib
import json
import logging
import os
from functools import lru_cache
from typing import Any, Dict, Iterator, List, Optional
from huggingface_hub import InferenceClient
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
logger = logging.getLogger(__name__)
# Response cache configuration
CACHE_ENABLED = os.getenv("HF_CACHE_ENABLED", "true").lower() == "true"
CACHE_MAX_SIZE = int(os.getenv("HF_CACHE_MAX_SIZE", "128"))
# Global response cache (shared across all instances)
_response_cache: Dict[str, Any] = {}
_cache_stats = {"hits": 0, "misses": 0, "size": 0}
class HFProvidersLLM(BaseChatModel):
"""LangChain-compatible wrapper for HuggingFace Inference Providers.
Provides unified access to 22+ LLM providers through HuggingFace's
router.huggingface.co endpoint with intelligent routing policies.
Attributes:
model: Model identifier or routing policy (:fastest, :cheapest)
client: HuggingFace InferenceClient instance
temperature: Sampling temperature (0.0-1.0)
max_tokens: Maximum tokens to generate
token: HuggingFace API token (reads from HF_TOKEN env var if not provided)
Example:
```python
# Use :cheapest routing policy
llm = HFProvidersLLM(model=":cheapest", temperature=0.7)
# Use specific provider
llm = HFProvidersLLM(model="meta-llama/Llama-3.3-70B-Instruct", temperature=0.7)
# Invoke with messages
response = llm.invoke([HumanMessage(content="What is 2+2?")])
```
"""
model: str = "meta-llama/Llama-3.3-70B-Instruct"
temperature: float = 0.7
max_tokens: Optional[int] = None
token: Optional[str] = None
client: Optional[InferenceClient] = None
enable_cache: bool = True
def __init__(
self,
model: str = "meta-llama/Llama-3.3-70B-Instruct",
temperature: float = 0.7,
max_tokens: Optional[int] = None,
token: Optional[str] = None,
enable_cache: bool = True,
**kwargs: Any,
):
"""Initialize HuggingFace Inference Providers client.
Args:
model: Model ID or routing policy (:fastest, :cheapest, or specific provider)
temperature: Sampling temperature (0.0-1.0)
max_tokens: Maximum tokens to generate (None = model default)
token: HuggingFace API token (defaults to HF_TOKEN env var)
enable_cache: Enable response caching for identical requests (default: True)
**kwargs: Additional parameters passed to BaseChatModel
Raises:
ValueError: If token is missing or temperature is out of range
"""
super().__init__(**kwargs)
self.model = model
self.temperature = self._validate_temperature(temperature)
self.max_tokens = max_tokens
self.token = token or os.getenv("HF_TOKEN")
self.enable_cache = enable_cache and CACHE_ENABLED
if not self.token:
raise ValueError(self._get_auth_error_message())
# Initialize InferenceClient with router endpoint
self.client = InferenceClient(token=self.token)
cache_status = "enabled" if self.enable_cache else "disabled"
logger.info(
f"Initialized HFProvidersLLM with model={model}, "
f"temperature={temperature}, max_tokens={max_tokens}, cache={cache_status}"
)
@staticmethod
def _validate_temperature(temperature: float) -> float:
"""Validate temperature is in valid range.
Args:
temperature: Sampling temperature
Returns:
Validated temperature
Raises:
ValueError: If temperature is out of range
"""
if not 0.0 <= temperature <= 1.0:
raise ValueError(
f"Temperature must be between 0.0 and 1.0, got {temperature}"
)
return temperature
@staticmethod
def _get_auth_error_message() -> str:
"""Get authentication error message.
Returns:
Formatted error message with setup instructions
"""
return (
"HuggingFace token required for Inference Providers.\n\n"
"To fix this:\n"
"1. Get your token from: https://huggingface.co/settings/tokens\n"
"2. Set environment variable: export HF_TOKEN=your_token_here\n"
"3. Or set HUGGINGFACE_API_KEY as alternative\n\n"
"Free tier available - no credit card required!"
)
@property
def _llm_type(self) -> str:
"""Return identifier for this LLM type."""
return "huggingface_inference_providers"
def _convert_messages_to_prompt(
self, messages: List[BaseMessage]
) -> List[Dict[str, str]]:
"""Convert LangChain messages to HuggingFace chat format.
Args:
messages: List of LangChain messages
Returns:
List of message dicts with 'role' and 'content' keys
"""
hf_messages = []
for message in messages:
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
else:
# Default to user for unknown message types
role = "user"
hf_messages.append({"role": role, "content": message.content})
return hf_messages
def _generate_cache_key(
self,
messages: List[BaseMessage],
temperature: float,
max_tokens: Optional[int],
) -> str:
"""Generate a cache key for the request.
Args:
messages: List of chat messages
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Hash string as cache key
"""
# Create a deterministic representation of the request
cache_data = {
"model": self.model,
"messages": [
{"role": msg.type, "content": msg.content} for msg in messages
],
"temperature": temperature,
"max_tokens": max_tokens,
}
# Generate SHA256 hash
cache_str = json.dumps(cache_data, sort_keys=True)
cache_key = hashlib.sha256(cache_str.encode()).hexdigest()
return cache_key
def _get_cached_response(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Get cached response if available.
Args:
cache_key: Cache key for the request
Returns:
Cached response dictionary or None if not found
"""
global _response_cache, _cache_stats
if cache_key in _response_cache:
_cache_stats["hits"] += 1
logger.info(
f"✓ Cache HIT: {cache_key[:12]}... "
f"(hits: {_cache_stats['hits']}, misses: {_cache_stats['misses']}, "
f"hit_rate: {_cache_stats['hits'] / (_cache_stats['hits'] + _cache_stats['misses']) * 100:.1f}%)"
)
return _response_cache[cache_key]
_cache_stats["misses"] += 1
logger.debug(
f"Cache MISS: {cache_key[:12]}... "
f"(hits: {_cache_stats['hits']}, misses: {_cache_stats['misses']})"
)
return None
def _store_cached_response(
self, cache_key: str, response_data: Dict[str, Any]
) -> None:
"""Store response in cache.
Args:
cache_key: Cache key for the request
response_data: Response data to cache
"""
global _response_cache, _cache_stats
# Implement simple LRU: if cache is full, remove oldest entry
if len(_response_cache) >= CACHE_MAX_SIZE and cache_key not in _response_cache:
# Remove first (oldest) key
oldest_key = next(iter(_response_cache))
del _response_cache[oldest_key]
logger.debug(f"Cache eviction: removed {oldest_key[:12]}...")
_response_cache[cache_key] = response_data
_cache_stats["size"] = len(_response_cache)
logger.debug(
f"Cached response: {cache_key[:12]}... (cache size: {_cache_stats['size']})"
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate chat completion using HuggingFace Inference Providers.
Args:
messages: List of chat messages
stop: Stop sequences (not supported by all providers)
run_manager: Callback manager for tracking
**kwargs: Additional generation parameters
Returns:
ChatResult with generated message and metadata
"""
# Prepare generation parameters
temperature = kwargs.get("temperature", self.temperature)
max_tokens = kwargs.get("max_tokens", self.max_tokens)
# Check cache if enabled (only for non-streaming, no stop sequences)
cache_key = None
if self.enable_cache and not stop:
cache_key = self._generate_cache_key(messages, temperature, max_tokens)
cached_response = self._get_cached_response(cache_key)
if cached_response:
# Return cached response
message = AIMessage(content=cached_response["content"])
generation = ChatGeneration(
message=message,
generation_info=cached_response["generation_info"],
)
return ChatResult(
generations=[generation],
llm_output=cached_response["llm_output"],
)
# Convert messages to HuggingFace format
hf_messages = self._convert_messages_to_prompt(messages)
# Prepare generation parameters
gen_kwargs = {
"model": self.model,
"messages": hf_messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
# Add stop sequences if provided
if stop:
gen_kwargs["stop"] = stop
try:
# Log the generation request
logger.info(
f"🚀 HF Inference Providers request: model={self.model}, "
f"temp={gen_kwargs.get('temperature')}, "
f"messages={len(hf_messages)}"
)
# Call HuggingFace Inference Providers
response = self.client.chat_completion(**gen_kwargs)
# Extract generated text
if hasattr(response, "choices") and len(response.choices) > 0:
content = response.choices[0].message.content
else:
raise ValueError(f"Unexpected response format: {response}")
# Extract token usage from response metadata
token_usage = {}
if hasattr(response, "usage"):
token_usage = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"completion_tokens": getattr(
response.usage, "completion_tokens", 0
),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
# Extract model used (may differ from requested if using routing policy)
model_used = self.model
if hasattr(response, "model"):
model_used = response.model
# Log routing decision if model differs (routing policy was used)
if model_used != self.model and self.model.startswith(":"):
logger.info(
f"🎯 Routing policy '{self.model}' selected model: {model_used} "
f"(tokens: {token_usage.get('total_tokens', 0)})"
)
elif model_used != self.model:
logger.warning(
f"⚠️ Model mismatch: requested={self.model}, used={model_used}"
)
# Create LangChain AIMessage
message = AIMessage(content=content)
# Create ChatGeneration with metadata
generation = ChatGeneration(
message=message,
generation_info={
"model": model_used,
"token_usage": token_usage,
"finish_reason": getattr(
response.choices[0], "finish_reason", "stop"
)
if hasattr(response, "choices")
else "stop",
},
)
# Create ChatResult with usage metadata
llm_output = {
"model": model_used,
"token_usage": token_usage,
}
result = ChatResult(generations=[generation], llm_output=llm_output)
# Store in cache if enabled and cache_key was generated
if self.enable_cache and cache_key:
cached_data = {
"content": content,
"generation_info": generation.generation_info,
"llm_output": llm_output,
}
self._store_cached_response(cache_key, cached_data)
return result
except PermissionError as e:
error_msg = self._get_permission_error_message(e)
logger.error(error_msg)
raise PermissionError(error_msg) from e
except ValueError as e:
if "does not exist" in str(e).lower():
error_msg = self._get_model_not_found_message()
logger.error(error_msg)
raise ValueError(error_msg) from e
raise
except Exception as e:
error_msg = self._get_generic_error_message(e)
logger.error(error_msg)
raise RuntimeError(error_msg) from e
def _get_permission_error_message(self, error: Exception) -> str:
"""Get formatted permission error message.
Args:
error: Original permission error
Returns:
Formatted error message with troubleshooting steps
"""
return (
f"HuggingFace authentication failed: {error}\n\n"
"Please check your HF_TOKEN:\n"
"1. Verify token at: https://huggingface.co/settings/tokens\n"
"2. Ensure token has 'read' permissions\n"
"3. Set environment variable: export HF_TOKEN=your_token_here"
)
def _get_model_not_found_message(self) -> str:
"""Get formatted model not found error message.
Returns:
Formatted error message with alternatives
"""
return (
f"Model '{self.model}' not available via HuggingFace Inference Providers.\n\n"
"Try these alternatives:\n"
"- Use routing policy: ':fastest' or ':cheapest'\n"
"- Use specific model: 'meta-llama/Llama-3.3-70B-Instruct'\n"
"- Check available models at: https://huggingface.co/docs/api-inference/supported-models"
)
def _get_generic_error_message(self, error: Exception) -> str:
"""Get formatted generic error message.
Args:
error: Original error
Returns:
Formatted error message with troubleshooting steps
"""
return (
f"HuggingFace Inference Providers error (model={self.model}): {error}\n\n"
"Troubleshooting:\n"
"1. Check HF_TOKEN is set correctly\n"
"2. Verify internet connectivity\n"
"3. Try a different routing policy (:fastest, :cheapest)\n"
"4. Check HuggingFace status: https://status.huggingface.co/"
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGeneration]:
"""Stream chat completion using HuggingFace Inference Providers.
Args:
messages: List of chat messages
stop: Stop sequences (not supported by all providers)
run_manager: Callback manager for tracking
**kwargs: Additional generation parameters
Yields:
ChatGeneration chunks as they are generated
"""
# Convert messages to HuggingFace format
hf_messages = self._convert_messages_to_prompt(messages)
# Prepare generation parameters
gen_kwargs = {
"model": self.model,
"messages": hf_messages,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"stream": True,
}
# Add stop sequences if provided
if stop:
gen_kwargs["stop"] = stop
try:
# Log the streaming request
logger.info(
f"🚀 HF Inference Providers streaming request: model={self.model}, "
f"temp={gen_kwargs.get('temperature')}, "
f"messages={len(hf_messages)}"
)
# Call HuggingFace Inference Providers with streaming
stream = self.client.chat_completion(**gen_kwargs)
for chunk in stream:
if hasattr(chunk, "choices") and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if hasattr(delta, "content") and delta.content:
message = AIMessage(content=delta.content)
yield ChatGeneration(
message=message,
generation_info={
"finish_reason": getattr(
chunk.choices[0], "finish_reason", None
)
},
)
# Trigger callback for streaming
if run_manager:
run_manager.on_llm_new_token(delta.content)
except Exception as e:
logger.error(
f"HuggingFace Inference Providers streaming error (model={self.model}): {e}"
)
raise
def get_cache_stats() -> Dict[str, Any]:
"""Get current cache statistics.
Returns:
Dictionary with cache statistics (hits, misses, hit_rate, size)
"""
global _cache_stats
total_requests = _cache_stats["hits"] + _cache_stats["misses"]
hit_rate = (
(_cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0.0
)
return {
"hits": _cache_stats["hits"],
"misses": _cache_stats["misses"],
"total_requests": total_requests,
"hit_rate_percent": round(hit_rate, 2),
"cache_size": _cache_stats["size"],
"cache_max_size": CACHE_MAX_SIZE,
"cache_enabled": CACHE_ENABLED,
}
def clear_cache() -> None:
"""Clear the response cache and reset statistics."""
global _response_cache, _cache_stats
_response_cache.clear()
_cache_stats = {"hits": 0, "misses": 0, "size": 0}
logger.info("Response cache cleared")