File size: 2,810 Bytes
e1fbc11
 
16ecc2a
 
8f36f3e
71139d5
8f36f3e
 
 
 
e1fbc11
 
16ecc2a
 
 
 
 
e1fbc11
8f36f3e
16ecc2a
 
 
 
e1fbc11
8f36f3e
 
e1fbc11
 
8f36f3e
 
e1fbc11
8f36f3e
e1fbc11
8f36f3e
e1fbc11
16ecc2a
e1fbc11
16ecc2a
8f36f3e
 
 
 
 
71139d5
8f36f3e
 
 
e1fbc11
8f36f3e
e1fbc11
 
8f36f3e
 
16ecc2a
8f36f3e
 
 
 
 
16ecc2a
8f36f3e
 
71139d5
8f36f3e
 
 
 
16ecc2a
8f36f3e
16ecc2a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
from dataclasses import dataclass
from typing import Optional, Dict, Any, Protocol

from huggingface_hub import InferenceClient

try:
    from PIL import Image
except Exception:
    Image = None


class LLMBackend(Protocol):
    def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
        ...


@dataclass
class HFInferenceAPIBackend:
    """
    Uses HF Inference API via huggingface_hub.InferenceClient.
    Works well on Spaces if you provide HF_TOKEN in Secrets.
    """
    model_id: str
    token: Optional[str] = None
    timeout_s: int = 180

    def __post_init__(self):
        self.token = self.token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
        self.client = InferenceClient(model=self.model_id, token=self.token, timeout=self.timeout_s)

    def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
        temperature = float(params.get("temperature", 0.2))
        max_new_tokens = int(params.get("max_new_tokens", 600))
        top_p = float(params.get("top_p", 0.95))
        repetition_penalty = float(params.get("repetition_penalty", 1.05))

        # Prefer chat when supported
        try:
            messages = []
            if system:
                messages.append({"role": "system", "content": system})
            messages.append({"role": "user", "content": prompt})

            resp = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,
                temperature=temperature,
                max_tokens=max_new_tokens,
                top_p=top_p,
            )
            return resp.choices[0].message.content
        except Exception:
            # Fallback: text generation
            out = self.client.text_generation(
                prompt=(f"{system}\n\n{prompt}" if system else prompt),
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                do_sample=True,
                return_full_text=False,
            )
            return out

    def image_to_text(self, image: "Image.Image") -> str:
        """
        HF task 'image-to-text' (captioning / OCR-like depending on model).
        """
        if Image is None:
            raise RuntimeError("Pillow not installed")
        res = self.client.image_to_text(image)
        # huggingface_hub returns an object with generated_text
        return getattr(res, "generated_text", str(res))


def make_backend(backend_type: str, model_id: str) -> LLMBackend:
    if backend_type == "hf_inference_api":
        return HFInferenceAPIBackend(model_id=model_id)
    raise ValueError(f"Unknown backend: {backend_type}")