import os from dataclasses import dataclass from typing import Optional, Dict, Any, Protocol from huggingface_hub import InferenceClient try: from PIL import Image except Exception: Image = None class LLMBackend(Protocol): def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str: ... @dataclass class HFInferenceAPIBackend: """ Uses HF Inference API via huggingface_hub.InferenceClient. Works well on Spaces if you provide HF_TOKEN in Secrets. """ model_id: str token: Optional[str] = None timeout_s: int = 180 def __post_init__(self): self.token = self.token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") self.client = InferenceClient(model=self.model_id, token=self.token, timeout=self.timeout_s) def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str: temperature = float(params.get("temperature", 0.2)) max_new_tokens = int(params.get("max_new_tokens", 600)) top_p = float(params.get("top_p", 0.95)) repetition_penalty = float(params.get("repetition_penalty", 1.05)) # Prefer chat when supported try: messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) resp = self.client.chat.completions.create( model=self.model_id, messages=messages, temperature=temperature, max_tokens=max_new_tokens, top_p=top_p, ) return resp.choices[0].message.content except Exception: # Fallback: text generation out = self.client.text_generation( prompt=(f"{system}\n\n{prompt}" if system else prompt), temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, return_full_text=False, ) return out def image_to_text(self, image: "Image.Image") -> str: """ HF task 'image-to-text' (captioning / OCR-like depending on model). """ if Image is None: raise RuntimeError("Pillow not installed") res = self.client.image_to_text(image) # huggingface_hub returns an object with generated_text return getattr(res, "generated_text", str(res)) def make_backend(backend_type: str, model_id: str) -> LLMBackend: if backend_type == "hf_inference_api": return HFInferenceAPIBackend(model_id=model_id) raise ValueError(f"Unknown backend: {backend_type}")