AlsuGibadullina's picture
Update src/backends.py
16ecc2a verified
import os
from dataclasses import dataclass
from typing import Optional, Dict, Any, Protocol
from huggingface_hub import InferenceClient
try:
from PIL import Image
except Exception:
Image = None
class LLMBackend(Protocol):
def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
...
@dataclass
class HFInferenceAPIBackend:
"""
Uses HF Inference API via huggingface_hub.InferenceClient.
Works well on Spaces if you provide HF_TOKEN in Secrets.
"""
model_id: str
token: Optional[str] = None
timeout_s: int = 180
def __post_init__(self):
self.token = self.token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
self.client = InferenceClient(model=self.model_id, token=self.token, timeout=self.timeout_s)
def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
temperature = float(params.get("temperature", 0.2))
max_new_tokens = int(params.get("max_new_tokens", 600))
top_p = float(params.get("top_p", 0.95))
repetition_penalty = float(params.get("repetition_penalty", 1.05))
# Prefer chat when supported
try:
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
resp = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
top_p=top_p,
)
return resp.choices[0].message.content
except Exception:
# Fallback: text generation
out = self.client.text_generation(
prompt=(f"{system}\n\n{prompt}" if system else prompt),
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
return_full_text=False,
)
return out
def image_to_text(self, image: "Image.Image") -> str:
"""
HF task 'image-to-text' (captioning / OCR-like depending on model).
"""
if Image is None:
raise RuntimeError("Pillow not installed")
res = self.client.image_to_text(image)
# huggingface_hub returns an object with generated_text
return getattr(res, "generated_text", str(res))
def make_backend(backend_type: str, model_id: str) -> LLMBackend:
if backend_type == "hf_inference_api":
return HFInferenceAPIBackend(model_id=model_id)
raise ValueError(f"Unknown backend: {backend_type}")