Spaces:
Sleeping
Sleeping
| """HuggingFace Inference Providers client with LangChain integration. | |
| This module provides a LangChain-compatible wrapper around HuggingFace's | |
| Inference Providers API (router.huggingface.co), which provides unified | |
| access to 22+ LLM providers including Groq, Together AI, Replicate, and more. | |
| Features: | |
| - Routing policies (:fastest, :cheapest, explicit provider selection) | |
| - Token usage tracking from response metadata | |
| - Cost tracking integration via callbacks | |
| - Streaming and non-streaming inference | |
| - Response caching for identical requests (LRU cache) | |
| """ | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| from functools import lru_cache | |
| from typing import Any, Dict, Iterator, List, Optional | |
| from huggingface_hub import InferenceClient | |
| from langchain_core.callbacks import CallbackManagerForLLMRun | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage | |
| from langchain_core.outputs import ChatGeneration, ChatResult | |
| logger = logging.getLogger(__name__) | |
| # Response cache configuration | |
| CACHE_ENABLED = os.getenv("HF_CACHE_ENABLED", "true").lower() == "true" | |
| CACHE_MAX_SIZE = int(os.getenv("HF_CACHE_MAX_SIZE", "128")) | |
| # Global response cache (shared across all instances) | |
| _response_cache: Dict[str, Any] = {} | |
| _cache_stats = {"hits": 0, "misses": 0, "size": 0} | |
| class HFProvidersLLM(BaseChatModel): | |
| """LangChain-compatible wrapper for HuggingFace Inference Providers. | |
| Provides unified access to 22+ LLM providers through HuggingFace's | |
| router.huggingface.co endpoint with intelligent routing policies. | |
| Attributes: | |
| model: Model identifier or routing policy (:fastest, :cheapest) | |
| client: HuggingFace InferenceClient instance | |
| temperature: Sampling temperature (0.0-1.0) | |
| max_tokens: Maximum tokens to generate | |
| token: HuggingFace API token (reads from HF_TOKEN env var if not provided) | |
| Example: | |
| ```python | |
| # Use :cheapest routing policy | |
| llm = HFProvidersLLM(model=":cheapest", temperature=0.7) | |
| # Use specific provider | |
| llm = HFProvidersLLM(model="meta-llama/Llama-3.3-70B-Instruct", temperature=0.7) | |
| # Invoke with messages | |
| response = llm.invoke([HumanMessage(content="What is 2+2?")]) | |
| ``` | |
| """ | |
| model: str = "meta-llama/Llama-3.3-70B-Instruct" | |
| temperature: float = 0.7 | |
| max_tokens: Optional[int] = None | |
| token: Optional[str] = None | |
| client: Optional[InferenceClient] = None | |
| enable_cache: bool = True | |
| def __init__( | |
| self, | |
| model: str = "meta-llama/Llama-3.3-70B-Instruct", | |
| temperature: float = 0.7, | |
| max_tokens: Optional[int] = None, | |
| token: Optional[str] = None, | |
| enable_cache: bool = True, | |
| **kwargs: Any, | |
| ): | |
| """Initialize HuggingFace Inference Providers client. | |
| Args: | |
| model: Model ID or routing policy (:fastest, :cheapest, or specific provider) | |
| temperature: Sampling temperature (0.0-1.0) | |
| max_tokens: Maximum tokens to generate (None = model default) | |
| token: HuggingFace API token (defaults to HF_TOKEN env var) | |
| enable_cache: Enable response caching for identical requests (default: True) | |
| **kwargs: Additional parameters passed to BaseChatModel | |
| Raises: | |
| ValueError: If token is missing or temperature is out of range | |
| """ | |
| super().__init__(**kwargs) | |
| self.model = model | |
| self.temperature = self._validate_temperature(temperature) | |
| self.max_tokens = max_tokens | |
| self.token = token or os.getenv("HF_TOKEN") | |
| self.enable_cache = enable_cache and CACHE_ENABLED | |
| if not self.token: | |
| raise ValueError(self._get_auth_error_message()) | |
| # Initialize InferenceClient with router endpoint | |
| self.client = InferenceClient(token=self.token) | |
| cache_status = "enabled" if self.enable_cache else "disabled" | |
| logger.info( | |
| f"Initialized HFProvidersLLM with model={model}, " | |
| f"temperature={temperature}, max_tokens={max_tokens}, cache={cache_status}" | |
| ) | |
| def _validate_temperature(temperature: float) -> float: | |
| """Validate temperature is in valid range. | |
| Args: | |
| temperature: Sampling temperature | |
| Returns: | |
| Validated temperature | |
| Raises: | |
| ValueError: If temperature is out of range | |
| """ | |
| if not 0.0 <= temperature <= 1.0: | |
| raise ValueError( | |
| f"Temperature must be between 0.0 and 1.0, got {temperature}" | |
| ) | |
| return temperature | |
| def _get_auth_error_message() -> str: | |
| """Get authentication error message. | |
| Returns: | |
| Formatted error message with setup instructions | |
| """ | |
| return ( | |
| "HuggingFace token required for Inference Providers.\n\n" | |
| "To fix this:\n" | |
| "1. Get your token from: https://huggingface.co/settings/tokens\n" | |
| "2. Set environment variable: export HF_TOKEN=your_token_here\n" | |
| "3. Or set HUGGINGFACE_API_KEY as alternative\n\n" | |
| "Free tier available - no credit card required!" | |
| ) | |
| def _llm_type(self) -> str: | |
| """Return identifier for this LLM type.""" | |
| return "huggingface_inference_providers" | |
| def _convert_messages_to_prompt( | |
| self, messages: List[BaseMessage] | |
| ) -> List[Dict[str, str]]: | |
| """Convert LangChain messages to HuggingFace chat format. | |
| Args: | |
| messages: List of LangChain messages | |
| Returns: | |
| List of message dicts with 'role' and 'content' keys | |
| """ | |
| hf_messages = [] | |
| for message in messages: | |
| if isinstance(message, SystemMessage): | |
| role = "system" | |
| elif isinstance(message, HumanMessage): | |
| role = "user" | |
| elif isinstance(message, AIMessage): | |
| role = "assistant" | |
| else: | |
| # Default to user for unknown message types | |
| role = "user" | |
| hf_messages.append({"role": role, "content": message.content}) | |
| return hf_messages | |
| def _generate_cache_key( | |
| self, | |
| messages: List[BaseMessage], | |
| temperature: float, | |
| max_tokens: Optional[int], | |
| ) -> str: | |
| """Generate a cache key for the request. | |
| Args: | |
| messages: List of chat messages | |
| temperature: Sampling temperature | |
| max_tokens: Maximum tokens to generate | |
| Returns: | |
| Hash string as cache key | |
| """ | |
| # Create a deterministic representation of the request | |
| cache_data = { | |
| "model": self.model, | |
| "messages": [ | |
| {"role": msg.type, "content": msg.content} for msg in messages | |
| ], | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| } | |
| # Generate SHA256 hash | |
| cache_str = json.dumps(cache_data, sort_keys=True) | |
| cache_key = hashlib.sha256(cache_str.encode()).hexdigest() | |
| return cache_key | |
| def _get_cached_response(self, cache_key: str) -> Optional[Dict[str, Any]]: | |
| """Get cached response if available. | |
| Args: | |
| cache_key: Cache key for the request | |
| Returns: | |
| Cached response dictionary or None if not found | |
| """ | |
| global _response_cache, _cache_stats | |
| if cache_key in _response_cache: | |
| _cache_stats["hits"] += 1 | |
| logger.info( | |
| f"✓ Cache HIT: {cache_key[:12]}... " | |
| f"(hits: {_cache_stats['hits']}, misses: {_cache_stats['misses']}, " | |
| f"hit_rate: {_cache_stats['hits'] / (_cache_stats['hits'] + _cache_stats['misses']) * 100:.1f}%)" | |
| ) | |
| return _response_cache[cache_key] | |
| _cache_stats["misses"] += 1 | |
| logger.debug( | |
| f"Cache MISS: {cache_key[:12]}... " | |
| f"(hits: {_cache_stats['hits']}, misses: {_cache_stats['misses']})" | |
| ) | |
| return None | |
| def _store_cached_response( | |
| self, cache_key: str, response_data: Dict[str, Any] | |
| ) -> None: | |
| """Store response in cache. | |
| Args: | |
| cache_key: Cache key for the request | |
| response_data: Response data to cache | |
| """ | |
| global _response_cache, _cache_stats | |
| # Implement simple LRU: if cache is full, remove oldest entry | |
| if len(_response_cache) >= CACHE_MAX_SIZE and cache_key not in _response_cache: | |
| # Remove first (oldest) key | |
| oldest_key = next(iter(_response_cache)) | |
| del _response_cache[oldest_key] | |
| logger.debug(f"Cache eviction: removed {oldest_key[:12]}...") | |
| _response_cache[cache_key] = response_data | |
| _cache_stats["size"] = len(_response_cache) | |
| logger.debug( | |
| f"Cached response: {cache_key[:12]}... (cache size: {_cache_stats['size']})" | |
| ) | |
| def _generate( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| """Generate chat completion using HuggingFace Inference Providers. | |
| Args: | |
| messages: List of chat messages | |
| stop: Stop sequences (not supported by all providers) | |
| run_manager: Callback manager for tracking | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| ChatResult with generated message and metadata | |
| """ | |
| # Prepare generation parameters | |
| temperature = kwargs.get("temperature", self.temperature) | |
| max_tokens = kwargs.get("max_tokens", self.max_tokens) | |
| # Check cache if enabled (only for non-streaming, no stop sequences) | |
| cache_key = None | |
| if self.enable_cache and not stop: | |
| cache_key = self._generate_cache_key(messages, temperature, max_tokens) | |
| cached_response = self._get_cached_response(cache_key) | |
| if cached_response: | |
| # Return cached response | |
| message = AIMessage(content=cached_response["content"]) | |
| generation = ChatGeneration( | |
| message=message, | |
| generation_info=cached_response["generation_info"], | |
| ) | |
| return ChatResult( | |
| generations=[generation], | |
| llm_output=cached_response["llm_output"], | |
| ) | |
| # Convert messages to HuggingFace format | |
| hf_messages = self._convert_messages_to_prompt(messages) | |
| # Prepare generation parameters | |
| gen_kwargs = { | |
| "model": self.model, | |
| "messages": hf_messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| } | |
| # Add stop sequences if provided | |
| if stop: | |
| gen_kwargs["stop"] = stop | |
| try: | |
| # Log the generation request | |
| logger.info( | |
| f"🚀 HF Inference Providers request: model={self.model}, " | |
| f"temp={gen_kwargs.get('temperature')}, " | |
| f"messages={len(hf_messages)}" | |
| ) | |
| # Call HuggingFace Inference Providers | |
| response = self.client.chat_completion(**gen_kwargs) | |
| # Extract generated text | |
| if hasattr(response, "choices") and len(response.choices) > 0: | |
| content = response.choices[0].message.content | |
| else: | |
| raise ValueError(f"Unexpected response format: {response}") | |
| # Extract token usage from response metadata | |
| token_usage = {} | |
| if hasattr(response, "usage"): | |
| token_usage = { | |
| "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), | |
| "completion_tokens": getattr( | |
| response.usage, "completion_tokens", 0 | |
| ), | |
| "total_tokens": getattr(response.usage, "total_tokens", 0), | |
| } | |
| # Extract model used (may differ from requested if using routing policy) | |
| model_used = self.model | |
| if hasattr(response, "model"): | |
| model_used = response.model | |
| # Log routing decision if model differs (routing policy was used) | |
| if model_used != self.model and self.model.startswith(":"): | |
| logger.info( | |
| f"🎯 Routing policy '{self.model}' selected model: {model_used} " | |
| f"(tokens: {token_usage.get('total_tokens', 0)})" | |
| ) | |
| elif model_used != self.model: | |
| logger.warning( | |
| f"⚠️ Model mismatch: requested={self.model}, used={model_used}" | |
| ) | |
| # Create LangChain AIMessage | |
| message = AIMessage(content=content) | |
| # Create ChatGeneration with metadata | |
| generation = ChatGeneration( | |
| message=message, | |
| generation_info={ | |
| "model": model_used, | |
| "token_usage": token_usage, | |
| "finish_reason": getattr( | |
| response.choices[0], "finish_reason", "stop" | |
| ) | |
| if hasattr(response, "choices") | |
| else "stop", | |
| }, | |
| ) | |
| # Create ChatResult with usage metadata | |
| llm_output = { | |
| "model": model_used, | |
| "token_usage": token_usage, | |
| } | |
| result = ChatResult(generations=[generation], llm_output=llm_output) | |
| # Store in cache if enabled and cache_key was generated | |
| if self.enable_cache and cache_key: | |
| cached_data = { | |
| "content": content, | |
| "generation_info": generation.generation_info, | |
| "llm_output": llm_output, | |
| } | |
| self._store_cached_response(cache_key, cached_data) | |
| return result | |
| except PermissionError as e: | |
| error_msg = self._get_permission_error_message(e) | |
| logger.error(error_msg) | |
| raise PermissionError(error_msg) from e | |
| except ValueError as e: | |
| if "does not exist" in str(e).lower(): | |
| error_msg = self._get_model_not_found_message() | |
| logger.error(error_msg) | |
| raise ValueError(error_msg) from e | |
| raise | |
| except Exception as e: | |
| error_msg = self._get_generic_error_message(e) | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |
| def _get_permission_error_message(self, error: Exception) -> str: | |
| """Get formatted permission error message. | |
| Args: | |
| error: Original permission error | |
| Returns: | |
| Formatted error message with troubleshooting steps | |
| """ | |
| return ( | |
| f"HuggingFace authentication failed: {error}\n\n" | |
| "Please check your HF_TOKEN:\n" | |
| "1. Verify token at: https://huggingface.co/settings/tokens\n" | |
| "2. Ensure token has 'read' permissions\n" | |
| "3. Set environment variable: export HF_TOKEN=your_token_here" | |
| ) | |
| def _get_model_not_found_message(self) -> str: | |
| """Get formatted model not found error message. | |
| Returns: | |
| Formatted error message with alternatives | |
| """ | |
| return ( | |
| f"Model '{self.model}' not available via HuggingFace Inference Providers.\n\n" | |
| "Try these alternatives:\n" | |
| "- Use routing policy: ':fastest' or ':cheapest'\n" | |
| "- Use specific model: 'meta-llama/Llama-3.3-70B-Instruct'\n" | |
| "- Check available models at: https://huggingface.co/docs/api-inference/supported-models" | |
| ) | |
| def _get_generic_error_message(self, error: Exception) -> str: | |
| """Get formatted generic error message. | |
| Args: | |
| error: Original error | |
| Returns: | |
| Formatted error message with troubleshooting steps | |
| """ | |
| return ( | |
| f"HuggingFace Inference Providers error (model={self.model}): {error}\n\n" | |
| "Troubleshooting:\n" | |
| "1. Check HF_TOKEN is set correctly\n" | |
| "2. Verify internet connectivity\n" | |
| "3. Try a different routing policy (:fastest, :cheapest)\n" | |
| "4. Check HuggingFace status: https://status.huggingface.co/" | |
| ) | |
| def _stream( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> Iterator[ChatGeneration]: | |
| """Stream chat completion using HuggingFace Inference Providers. | |
| Args: | |
| messages: List of chat messages | |
| stop: Stop sequences (not supported by all providers) | |
| run_manager: Callback manager for tracking | |
| **kwargs: Additional generation parameters | |
| Yields: | |
| ChatGeneration chunks as they are generated | |
| """ | |
| # Convert messages to HuggingFace format | |
| hf_messages = self._convert_messages_to_prompt(messages) | |
| # Prepare generation parameters | |
| gen_kwargs = { | |
| "model": self.model, | |
| "messages": hf_messages, | |
| "temperature": kwargs.get("temperature", self.temperature), | |
| "max_tokens": kwargs.get("max_tokens", self.max_tokens), | |
| "stream": True, | |
| } | |
| # Add stop sequences if provided | |
| if stop: | |
| gen_kwargs["stop"] = stop | |
| try: | |
| # Log the streaming request | |
| logger.info( | |
| f"🚀 HF Inference Providers streaming request: model={self.model}, " | |
| f"temp={gen_kwargs.get('temperature')}, " | |
| f"messages={len(hf_messages)}" | |
| ) | |
| # Call HuggingFace Inference Providers with streaming | |
| stream = self.client.chat_completion(**gen_kwargs) | |
| for chunk in stream: | |
| if hasattr(chunk, "choices") and len(chunk.choices) > 0: | |
| delta = chunk.choices[0].delta | |
| if hasattr(delta, "content") and delta.content: | |
| message = AIMessage(content=delta.content) | |
| yield ChatGeneration( | |
| message=message, | |
| generation_info={ | |
| "finish_reason": getattr( | |
| chunk.choices[0], "finish_reason", None | |
| ) | |
| }, | |
| ) | |
| # Trigger callback for streaming | |
| if run_manager: | |
| run_manager.on_llm_new_token(delta.content) | |
| except Exception as e: | |
| logger.error( | |
| f"HuggingFace Inference Providers streaming error (model={self.model}): {e}" | |
| ) | |
| raise | |
| def get_cache_stats() -> Dict[str, Any]: | |
| """Get current cache statistics. | |
| Returns: | |
| Dictionary with cache statistics (hits, misses, hit_rate, size) | |
| """ | |
| global _cache_stats | |
| total_requests = _cache_stats["hits"] + _cache_stats["misses"] | |
| hit_rate = ( | |
| (_cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0.0 | |
| ) | |
| return { | |
| "hits": _cache_stats["hits"], | |
| "misses": _cache_stats["misses"], | |
| "total_requests": total_requests, | |
| "hit_rate_percent": round(hit_rate, 2), | |
| "cache_size": _cache_stats["size"], | |
| "cache_max_size": CACHE_MAX_SIZE, | |
| "cache_enabled": CACHE_ENABLED, | |
| } | |
| def clear_cache() -> None: | |
| """Clear the response cache and reset statistics.""" | |
| global _response_cache, _cache_stats | |
| _response_cache.clear() | |
| _cache_stats = {"hits": 0, "misses": 0, "size": 0} | |
| logger.info("Response cache cleared") | |