File size: 2,810 Bytes
e1fbc11 16ecc2a 8f36f3e 71139d5 8f36f3e e1fbc11 16ecc2a e1fbc11 8f36f3e 16ecc2a e1fbc11 8f36f3e e1fbc11 8f36f3e e1fbc11 8f36f3e e1fbc11 8f36f3e e1fbc11 16ecc2a e1fbc11 16ecc2a 8f36f3e 71139d5 8f36f3e e1fbc11 8f36f3e e1fbc11 8f36f3e 16ecc2a 8f36f3e 16ecc2a 8f36f3e 71139d5 8f36f3e 16ecc2a 8f36f3e 16ecc2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import os
from dataclasses import dataclass
from typing import Optional, Dict, Any, Protocol
from huggingface_hub import InferenceClient
try:
from PIL import Image
except Exception:
Image = None
class LLMBackend(Protocol):
def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
...
@dataclass
class HFInferenceAPIBackend:
"""
Uses HF Inference API via huggingface_hub.InferenceClient.
Works well on Spaces if you provide HF_TOKEN in Secrets.
"""
model_id: str
token: Optional[str] = None
timeout_s: int = 180
def __post_init__(self):
self.token = self.token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
self.client = InferenceClient(model=self.model_id, token=self.token, timeout=self.timeout_s)
def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
temperature = float(params.get("temperature", 0.2))
max_new_tokens = int(params.get("max_new_tokens", 600))
top_p = float(params.get("top_p", 0.95))
repetition_penalty = float(params.get("repetition_penalty", 1.05))
# Prefer chat when supported
try:
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
resp = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
top_p=top_p,
)
return resp.choices[0].message.content
except Exception:
# Fallback: text generation
out = self.client.text_generation(
prompt=(f"{system}\n\n{prompt}" if system else prompt),
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
return_full_text=False,
)
return out
def image_to_text(self, image: "Image.Image") -> str:
"""
HF task 'image-to-text' (captioning / OCR-like depending on model).
"""
if Image is None:
raise RuntimeError("Pillow not installed")
res = self.client.image_to_text(image)
# huggingface_hub returns an object with generated_text
return getattr(res, "generated_text", str(res))
def make_backend(backend_type: str, model_id: str) -> LLMBackend:
if backend_type == "hf_inference_api":
return HFInferenceAPIBackend(model_id=model_id)
raise ValueError(f"Unknown backend: {backend_type}")
|