Update src/backends.py
Browse files- src/backends.py +27 -5
src/backends.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
from dataclasses import dataclass
|
| 3 |
-
from typing import Optional, Dict, Any,
|
|
|
|
| 4 |
from huggingface_hub import InferenceClient
|
| 5 |
|
| 6 |
try:
|
|
@@ -9,8 +10,17 @@ except Exception:
|
|
| 9 |
Image = None
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
@dataclass
|
| 13 |
class HFInferenceAPIBackend:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
model_id: str
|
| 15 |
token: Optional[str] = None
|
| 16 |
timeout_s: int = 180
|
|
@@ -23,8 +33,9 @@ class HFInferenceAPIBackend:
|
|
| 23 |
temperature = float(params.get("temperature", 0.2))
|
| 24 |
max_new_tokens = int(params.get("max_new_tokens", 600))
|
| 25 |
top_p = float(params.get("top_p", 0.95))
|
|
|
|
| 26 |
|
| 27 |
-
#
|
| 28 |
try:
|
| 29 |
messages = []
|
| 30 |
if system:
|
|
@@ -40,19 +51,30 @@ class HFInferenceAPIBackend:
|
|
| 40 |
)
|
| 41 |
return resp.choices[0].message.content
|
| 42 |
except Exception:
|
|
|
|
| 43 |
out = self.client.text_generation(
|
| 44 |
prompt=(f"{system}\n\n{prompt}" if system else prompt),
|
| 45 |
temperature=temperature,
|
| 46 |
max_new_tokens=max_new_tokens,
|
| 47 |
top_p=top_p,
|
|
|
|
| 48 |
do_sample=True,
|
| 49 |
return_full_text=False,
|
| 50 |
)
|
| 51 |
return out
|
| 52 |
|
| 53 |
-
# --- NEW: image -> text (OCR / caption) ---
|
| 54 |
def image_to_text(self, image: "Image.Image") -> str:
|
| 55 |
"""
|
| 56 |
-
|
| 57 |
"""
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional, Dict, Any, Protocol
|
| 4 |
+
|
| 5 |
from huggingface_hub import InferenceClient
|
| 6 |
|
| 7 |
try:
|
|
|
|
| 10 |
Image = None
|
| 11 |
|
| 12 |
|
| 13 |
+
class LLMBackend(Protocol):
|
| 14 |
+
def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
|
| 15 |
+
...
|
| 16 |
+
|
| 17 |
+
|
| 18 |
@dataclass
|
| 19 |
class HFInferenceAPIBackend:
|
| 20 |
+
"""
|
| 21 |
+
Uses HF Inference API via huggingface_hub.InferenceClient.
|
| 22 |
+
Works well on Spaces if you provide HF_TOKEN in Secrets.
|
| 23 |
+
"""
|
| 24 |
model_id: str
|
| 25 |
token: Optional[str] = None
|
| 26 |
timeout_s: int = 180
|
|
|
|
| 33 |
temperature = float(params.get("temperature", 0.2))
|
| 34 |
max_new_tokens = int(params.get("max_new_tokens", 600))
|
| 35 |
top_p = float(params.get("top_p", 0.95))
|
| 36 |
+
repetition_penalty = float(params.get("repetition_penalty", 1.05))
|
| 37 |
|
| 38 |
+
# Prefer chat when supported
|
| 39 |
try:
|
| 40 |
messages = []
|
| 41 |
if system:
|
|
|
|
| 51 |
)
|
| 52 |
return resp.choices[0].message.content
|
| 53 |
except Exception:
|
| 54 |
+
# Fallback: text generation
|
| 55 |
out = self.client.text_generation(
|
| 56 |
prompt=(f"{system}\n\n{prompt}" if system else prompt),
|
| 57 |
temperature=temperature,
|
| 58 |
max_new_tokens=max_new_tokens,
|
| 59 |
top_p=top_p,
|
| 60 |
+
repetition_penalty=repetition_penalty,
|
| 61 |
do_sample=True,
|
| 62 |
return_full_text=False,
|
| 63 |
)
|
| 64 |
return out
|
| 65 |
|
|
|
|
| 66 |
def image_to_text(self, image: "Image.Image") -> str:
|
| 67 |
"""
|
| 68 |
+
HF task 'image-to-text' (captioning / OCR-like depending on model).
|
| 69 |
"""
|
| 70 |
+
if Image is None:
|
| 71 |
+
raise RuntimeError("Pillow not installed")
|
| 72 |
+
res = self.client.image_to_text(image)
|
| 73 |
+
# huggingface_hub returns an object with generated_text
|
| 74 |
+
return getattr(res, "generated_text", str(res))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def make_backend(backend_type: str, model_id: str) -> LLMBackend:
|
| 78 |
+
if backend_type == "hf_inference_api":
|
| 79 |
+
return HFInferenceAPIBackend(model_id=model_id)
|
| 80 |
+
raise ValueError(f"Unknown backend: {backend_type}")
|