Spaces:
Running
Running
File size: 13,274 Bytes
0805c5b 281611f 0805c5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
"""
Hugging Face Inference API Wrapper
This module provides a robust wrapper around the Hugging Face Inference API
with rate limiting, error handling, and support for various model types.
"""
import asyncio
import base64
import io
import logging
import time
from typing import Any, BinaryIO, Dict, List, Optional, Union
import aiohttp
from huggingface_hub import AsyncInferenceClient, InferenceClient
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class RateLimiter:
"""Simple rate limiter for API calls."""
def __init__(self, max_calls: int = 60, time_window: int = 60):
self.max_calls = max_calls
self.time_window = time_window
self.calls = []
async def acquire(self):
"""Wait if rate limit would be exceeded."""
now = time.time()
# Remove calls outside the time window
self.calls = [
call_time for call_time in self.calls if now - call_time < self.time_window
]
if len(self.calls) >= self.max_calls:
# Calculate wait time
oldest_call = min(self.calls)
wait_time = self.time_window - (now - oldest_call)
if wait_time > 0:
logger.info(f"Rate limit reached, waiting {wait_time:.2f} seconds")
await asyncio.sleep(wait_time)
self.calls.append(now)
class HFInferenceWrapper:
"""
Wrapper for Hugging Face Inference API with rate limiting and error handling.
"""
def __init__(self, api_key: Optional[str] = None, max_calls_per_minute: int = 60):
self.client = AsyncInferenceClient(token=api_key)
self.rate_limiter = RateLimiter(max_calls=max_calls_per_minute, time_window=60)
async def text_generation(
self,
model: str,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
**kwargs,
) -> str:
"""Generate text using a language model.
Notes:
- Uses AsyncInferenceClient by default.
- Works around a known issue where `AsyncInferenceClient.text_generation`
may raise `StopIteration` ("coroutine raised StopIteration") by
falling back to the synchronous `InferenceClient` inside a thread.
- Automatically detects if a model supports conversational tasks and
uses chat_completion instead of text_generation.
- Always normalizes the result to a plain string, extracting
`generated_text` when the client returns a `TextGenerationOutput`
object.
"""
await self.rate_limiter.acquire()
try:
# Check if this is a conversational model that doesn't support text_generation
if self._is_conversational_model(model):
logger.info(f"Using chat_completion for conversational model: {model}")
return await self._chat_completion_fallback(
model, prompt, max_new_tokens, temperature, **kwargs
)
# Primary path: async client with text_generation
response = await self.client.text_generation(
prompt=prompt,
model=model,
max_new_tokens=max_new_tokens,
temperature=temperature,
**kwargs,
)
except Exception as e:
# Check if this is a model capability issue
if "not supported for task text-generation" in str(e):
logger.info(f"Falling back to chat_completion for model: {model}")
return await self._chat_completion_fallback(
model, prompt, max_new_tokens, temperature, **kwargs
)
# Newer versions of `huggingface_hub` sometimes surface a
# `RuntimeError` with message "coroutine raised StopIteration" from
# the async client. Detect that pattern (or a raw StopIteration)
# and fall back to the sync client in a background thread.
is_stop_iteration_like = isinstance(
e, StopIteration
) or "StopIteration" in str(e)
if is_stop_iteration_like: # pragma: no cover - defensive against HF bug
logger.warning(
"Async text_generation raised/contained StopIteration for "
"model %s; falling back to sync InferenceClient: %s",
model,
e,
)
def _call_sync() -> str:
"""Synchronous text-generation call for asyncio.to_thread."""
sync_client = InferenceClient(token=self.client.token)
# Check if this is a conversational model
if self._is_conversational_model(model):
messages = [{"role": "user", "content": prompt}]
chat_response = sync_client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
**kwargs,
)
return chat_response.choices[0].message.content
else:
return sync_client.text_generation(
prompt=prompt,
model=model,
max_new_tokens=max_new_tokens,
temperature=temperature,
**kwargs,
)
response = await asyncio.to_thread(_call_sync)
else:
logger.error(f"Text generation failed with model {model}: {e}")
raise
# Normalize various possible return types to a plain string
try:
from huggingface_hub.inference._generated.types.text_generation import (
TextGenerationOutput,
)
except Exception: # pragma: no cover - type import fallback
TextGenerationOutput = None # type: ignore
if TextGenerationOutput is not None and isinstance(
response, TextGenerationOutput
):
return response.generated_text
if isinstance(response, str):
return response
# Fallback: best-effort stringification
return str(response)
def _is_conversational_model(self, model: str) -> bool:
"""Check if a model is primarily conversational (doesn't support text_generation)."""
conversational_models = [
"zai-org/GLM-4.6",
# Add other known conversational-only models here
]
return model in conversational_models
async def _chat_completion_fallback(
self,
model: str,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
**kwargs,
) -> str:
"""Fallback method using chat.completions for conversational models."""
messages = [{"role": "user", "content": prompt}]
try:
# Try async first
response = await self.client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
**kwargs,
)
return response.choices[0].message.content
except Exception as e:
logger.warning(f"Async chat_completion failed, falling back to sync: {e}")
# Fall back to sync if async fails
def _sync_chat_completion():
sync_client = InferenceClient(token=self.client.token)
response = sync_client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
**kwargs,
)
return response.choices[0].message.content
return await asyncio.to_thread(_sync_chat_completion)
async def conversation(
self,
model: str,
messages: List[Dict[str, str]],
max_tokens: int = 512,
temperature: float = 0.7,
**kwargs,
) -> str:
"""Generate response in a conversation format."""
await self.rate_limiter.acquire()
try:
response = await self.client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
**kwargs,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Conversation failed with model {model}: {e}")
raise
async def image_generation(
self,
model: str,
prompt: str,
negative_prompt: Optional[str] = None,
width: int = 1024,
height: int = 1024,
**kwargs,
) -> bytes:
"""Generate an image and return as bytes."""
await self.rate_limiter.acquire()
try:
image_bytes = await self.client.text_to_image(
model=model,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
**kwargs,
)
return image_bytes
except Exception as e:
logger.error(f"Image generation failed with model {model}: {e}")
raise
async def text_to_speech(
self, model: str, text: str, voice: Optional[str] = None, **kwargs
) -> bytes:
"""Convert text to speech and return audio bytes.
Note: The voice parameter is kept for backwards compatibility but is not used
as the HuggingFace API doesn't support it.
"""
await self.rate_limiter.acquire()
try:
# HuggingFace text_to_speech API: text as first arg, model as kwarg
audio_bytes = await self.client.text_to_speech(text, model=model)
return audio_bytes
except Exception as e:
logger.error(f"TTS failed with model {model}: {e}")
raise
async def vision_analysis(
self, model: str, image: Union[bytes, BinaryIO], text: str, **kwargs
) -> str:
"""Analyze an image with a vision model."""
await self.rate_limiter.acquire()
try:
response = await self.client.image_to_text(
model=model, image=image, text=text, **kwargs
)
return response
except Exception as e:
logger.error(f"Vision analysis failed with model {model}: {e}")
raise
async def save_audio_to_file(self, audio_bytes: bytes, output_path: str) -> bool:
"""Save audio bytes to a file."""
try:
with open(output_path, "wb") as f:
f.write(audio_bytes)
logger.info(f"Audio saved to {output_path}")
return True
except Exception as e:
logger.error(f"Failed to save audio to {output_path}: {e}")
return False
def audio_bytes_to_base64(self, audio_bytes: bytes) -> str:
"""Convert audio bytes to base64 string for transmission."""
return base64.b64encode(audio_bytes).decode("utf-8")
def base64_to_audio_bytes(self, base64_str: str) -> bytes:
"""Convert base64 string back to audio bytes."""
return base64.b64decode(base64_str.encode("utf-8"))
class ModelConfig(BaseModel):
"""Configuration for different model types."""
text_models: List[str] = Field(
default_factory=lambda: [
# Primary general/text models
"zai-org/GLM-4.6",
"mistralai/Mistral-Nemo-Instruct-2407",
"Qwen/Qwen2.5-7B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
]
)
code_models: List[str] = Field(
default_factory=lambda: [
# Primary code-capable models
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
"zai-org/GLM-4.6",
"meta-llama/CodeLlama-70b-Instruct-hf",
# Kept last because it has caused auth issues in practice
"ZhipuAI/glm-4-9b-chat",
]
)
vision_models: List[str] = Field(
default_factory=lambda: [
"llava-hf/llava-v1.6-mistral-7b-hf",
"Salesforce/blip2-flan-t5-xxl",
"google/paligemma-3b-mix-448",
]
)
tts_models: List[str] = Field(
default_factory=lambda: [
"ResembleAI/chatterbox",
"suno/bark",
"facebook/mms-tts-all",
]
)
image_models: List[str] = Field(
default_factory=lambda: [
"stabilityai/stable-diffusion-3-medium",
"black-forest-labs/FLUX.1-dev",
"prompthero/openjourney",
]
)
# Global instance factory
def get_hf_wrapper(api_key: Optional[str] = None) -> HFInferenceWrapper:
"""Get a configured HFInferenceWrapper instance."""
return HFInferenceWrapper(api_key=api_key)
|