AlsuGibadullina commited on
Commit
16ecc2a
·
verified ·
1 Parent(s): 70d4442

Update src/backends.py

Browse files
Files changed (1) hide show
  1. src/backends.py +27 -5
src/backends.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from dataclasses import dataclass
3
- from typing import Optional, Dict, Any, Union
 
4
  from huggingface_hub import InferenceClient
5
 
6
  try:
@@ -9,8 +10,17 @@ except Exception:
9
  Image = None
10
 
11
 
 
 
 
 
 
12
  @dataclass
13
  class HFInferenceAPIBackend:
 
 
 
 
14
  model_id: str
15
  token: Optional[str] = None
16
  timeout_s: int = 180
@@ -23,8 +33,9 @@ class HFInferenceAPIBackend:
23
  temperature = float(params.get("temperature", 0.2))
24
  max_new_tokens = int(params.get("max_new_tokens", 600))
25
  top_p = float(params.get("top_p", 0.95))
 
26
 
27
- # Chat when possible
28
  try:
29
  messages = []
30
  if system:
@@ -40,19 +51,30 @@ class HFInferenceAPIBackend:
40
  )
41
  return resp.choices[0].message.content
42
  except Exception:
 
43
  out = self.client.text_generation(
44
  prompt=(f"{system}\n\n{prompt}" if system else prompt),
45
  temperature=temperature,
46
  max_new_tokens=max_new_tokens,
47
  top_p=top_p,
 
48
  do_sample=True,
49
  return_full_text=False,
50
  )
51
  return out
52
 
53
- # --- NEW: image -> text (OCR / caption) ---
54
  def image_to_text(self, image: "Image.Image") -> str:
55
  """
56
- Uses HF task 'image-to-text' for models like TrOCR or BLIP-caption.
57
  """
58
- return self.client.image_to_text(image).generated_text
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from dataclasses import dataclass
3
+ from typing import Optional, Dict, Any, Protocol
4
+
5
  from huggingface_hub import InferenceClient
6
 
7
  try:
 
10
  Image = None
11
 
12
 
13
+ class LLMBackend(Protocol):
14
+ def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
15
+ ...
16
+
17
+
18
  @dataclass
19
  class HFInferenceAPIBackend:
20
+ """
21
+ Uses HF Inference API via huggingface_hub.InferenceClient.
22
+ Works well on Spaces if you provide HF_TOKEN in Secrets.
23
+ """
24
  model_id: str
25
  token: Optional[str] = None
26
  timeout_s: int = 180
 
33
  temperature = float(params.get("temperature", 0.2))
34
  max_new_tokens = int(params.get("max_new_tokens", 600))
35
  top_p = float(params.get("top_p", 0.95))
36
+ repetition_penalty = float(params.get("repetition_penalty", 1.05))
37
 
38
+ # Prefer chat when supported
39
  try:
40
  messages = []
41
  if system:
 
51
  )
52
  return resp.choices[0].message.content
53
  except Exception:
54
+ # Fallback: text generation
55
  out = self.client.text_generation(
56
  prompt=(f"{system}\n\n{prompt}" if system else prompt),
57
  temperature=temperature,
58
  max_new_tokens=max_new_tokens,
59
  top_p=top_p,
60
+ repetition_penalty=repetition_penalty,
61
  do_sample=True,
62
  return_full_text=False,
63
  )
64
  return out
65
 
 
66
  def image_to_text(self, image: "Image.Image") -> str:
67
  """
68
+ HF task 'image-to-text' (captioning / OCR-like depending on model).
69
  """
70
+ if Image is None:
71
+ raise RuntimeError("Pillow not installed")
72
+ res = self.client.image_to_text(image)
73
+ # huggingface_hub returns an object with generated_text
74
+ return getattr(res, "generated_text", str(res))
75
+
76
+
77
+ def make_backend(backend_type: str, model_id: str) -> LLMBackend:
78
+ if backend_type == "hf_inference_api":
79
+ return HFInferenceAPIBackend(model_id=model_id)
80
+ raise ValueError(f"Unknown backend: {backend_type}")