AlsuGibadullina commited on
Commit
8f36f3e
·
verified ·
1 Parent(s): c763c96

Update src/backends.py

Browse files
Files changed (1) hide show
  1. src/backends.py +40 -158
src/backends.py CHANGED
@@ -1,176 +1,58 @@
1
  import os
2
- import base64
3
- import mimetypes
4
  from dataclasses import dataclass
5
- from typing import Optional, Dict, Any, Protocol, Tuple
 
6
 
7
- import requests
8
-
9
- # OpenAI
10
- from openai import OpenAI
11
-
12
- # Gemini
13
- from google import genai
14
- from google.genai import types
15
-
16
-
17
- class LLMBackend(Protocol):
18
- def generate(
19
- self,
20
- prompt: str,
21
- *,
22
- system: Optional[str],
23
- params: Dict[str, Any],
24
- image_path: Optional[str] = None,
25
- ) -> str:
26
- ...
27
-
28
-
29
- def _file_to_data_url(path: str) -> Tuple[str, str]:
30
- mime, _ = mimetypes.guess_type(path)
31
- mime = mime or "image/png"
32
- with open(path, "rb") as f:
33
- b64 = base64.b64encode(f.read()).decode("utf-8")
34
- return f"data:{mime};base64,{b64}", mime
35
-
36
-
37
- @dataclass
38
- class OpenAIBackend:
39
- model_id: str
40
- api_key: Optional[str] = None
41
-
42
- def __post_init__(self):
43
- self.api_key = self.api_key or os.getenv("OPENAI_API_KEY")
44
- if not self.api_key:
45
- raise RuntimeError("OPENAI_API_KEY is not set (Spaces → Settings → Secrets).")
46
- self.client = OpenAI(api_key=self.api_key)
47
-
48
- def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any], image_path: Optional[str] = None) -> str:
49
- temperature = float(params.get("temperature", 0.2))
50
- max_new_tokens = int(params.get("max_new_tokens", 800))
51
- top_p = float(params.get("top_p", 0.95))
52
-
53
- user_content = [{"type": "input_text", "text": prompt}]
54
- if image_path:
55
- data_url, _ = _file_to_data_url(image_path)
56
- user_content.append({"type": "input_image", "image_url": data_url})
57
-
58
- # Responses API: supports image inputs via input_image items. :contentReference[oaicite:4]{index=4}
59
- input_items = []
60
- if system:
61
- input_items.append({
62
- "role": "developer",
63
- "content": [{"type": "input_text", "text": system}]
64
- })
65
- input_items.append({"role": "user", "content": user_content})
66
-
67
- resp = self.client.responses.create(
68
- model=self.model_id,
69
- input=input_items,
70
- temperature=temperature,
71
- top_p=top_p,
72
- max_output_tokens=max_new_tokens,
73
- )
74
- return resp.output_text
75
 
76
 
77
  @dataclass
78
- class GeminiBackend:
79
  model_id: str
80
- api_key: Optional[str] = None
 
81
 
82
  def __post_init__(self):
83
- self.api_key = self.api_key or os.getenv("GEMINI_API_KEY")
84
- if not self.api_key:
85
- raise RuntimeError("GEMINI_API_KEY is not set (Spaces → Settings → Secrets).")
86
- self.client = genai.Client(api_key=self.api_key)
87
 
88
- def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any], image_path: Optional[str] = None) -> str:
89
  temperature = float(params.get("temperature", 0.2))
90
- max_new_tokens = int(params.get("max_new_tokens", 800))
91
  top_p = float(params.get("top_p", 0.95))
92
 
93
- parts = []
 
 
 
 
 
94
 
95
- # Gemini: inline bytes via Part.from_bytes (официальный пример). :contentReference[oaicite:5]{index=5}
96
- if image_path:
97
- mime, _ = mimetypes.guess_type(image_path)
98
- mime = mime or "image/png"
99
- with open(image_path, "rb") as f:
100
- img_bytes = f.read()
101
- parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime))
102
-
103
- text = prompt if not system else f"{system}\n\n{prompt}"
104
- parts.append(text)
105
-
106
- resp = self.client.models.generate_content(
107
- model=self.model_id,
108
- contents=parts,
109
- config=types.GenerateContentConfig(
110
  temperature=temperature,
 
111
  top_p=top_p,
112
- max_output_tokens=max_new_tokens,
113
  )
114
- )
115
- return resp.text or ""
116
-
117
-
118
- @dataclass
119
- class DeepSeekBackend:
120
- model_id: str
121
- api_key: Optional[str] = None
122
- base_url: str = "https://api.deepseek.com"
123
-
124
- def __post_init__(self):
125
- self.api_key = self.api_key or os.getenv("DEEPSEEK_API_KEY")
126
- if not self.api_key:
127
- raise RuntimeError("DEEPSEEK_API_KEY is not set (Spaces → Settings → Secrets).")
128
-
129
- def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any], image_path: Optional[str] = None) -> str:
130
- # DeepSeek official docs show text chat completions. :contentReference[oaicite:6]{index=6}
131
- temperature = float(params.get("temperature", 0.2))
132
- max_tokens = int(params.get("max_new_tokens", 800))
133
- top_p = float(params.get("top_p", 0.95))
134
-
135
- if image_path:
136
- prompt = (
137
- "ВАЖНО: Пользователь приложил изображение (диаграмму), "
138
- "но этот провайдер в текущей конфигурации работает только с текстом. "
139
- "Попроси пользователя описать диаграмму текстом, либо продолжи только по тексту.\n\n"
140
- + prompt
141
  )
142
-
143
- messages = []
144
- if system:
145
- messages.append({"role": "system", "content": system})
146
- messages.append({"role": "user", "content": prompt})
147
-
148
- r = requests.post(
149
- f"{self.base_url}/chat/completions",
150
- headers={
151
- "Authorization": f"Bearer {self.api_key}",
152
- "Content-Type": "application/json",
153
- },
154
- json={
155
- "model": self.model_id,
156
- "messages": messages,
157
- "temperature": temperature,
158
- "top_p": top_p,
159
- "max_tokens": max_tokens,
160
- "stream": False,
161
- },
162
- timeout=120,
163
- )
164
- r.raise_for_status()
165
- data = r.json()
166
- return data["choices"][0]["message"]["content"]
167
-
168
-
169
- def make_backend(provider: str, model_id: str) -> LLMBackend:
170
- if provider == "openai":
171
- return OpenAIBackend(model_id=model_id)
172
- if provider == "gemini":
173
- return GeminiBackend(model_id=model_id)
174
- if provider == "deepseek":
175
- return DeepSeekBackend(model_id=model_id)
176
- raise ValueError(f"Unknown provider: {provider}")
 
1
  import os
 
 
2
  from dataclasses import dataclass
3
+ from typing import Optional, Dict, Any, Union
4
+ from huggingface_hub import InferenceClient
5
 
6
+ try:
7
+ from PIL import Image
8
+ except Exception:
9
+ Image = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  @dataclass
13
+ class HFInferenceAPIBackend:
14
  model_id: str
15
+ token: Optional[str] = None
16
+ timeout_s: int = 180
17
 
18
  def __post_init__(self):
19
+ self.token = self.token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
20
+ self.client = InferenceClient(model=self.model_id, token=self.token, timeout=self.timeout_s)
 
 
21
 
22
+ def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
23
  temperature = float(params.get("temperature", 0.2))
24
+ max_new_tokens = int(params.get("max_new_tokens", 600))
25
  top_p = float(params.get("top_p", 0.95))
26
 
27
+ # Chat when possible
28
+ try:
29
+ messages = []
30
+ if system:
31
+ messages.append({"role": "system", "content": system})
32
+ messages.append({"role": "user", "content": prompt})
33
 
34
+ resp = self.client.chat.completions.create(
35
+ model=self.model_id,
36
+ messages=messages,
 
 
 
 
 
 
 
 
 
 
 
 
37
  temperature=temperature,
38
+ max_tokens=max_new_tokens,
39
  top_p=top_p,
 
40
  )
41
+ return resp.choices[0].message.content
42
+ except Exception:
43
+ out = self.client.text_generation(
44
+ prompt=(f"{system}\n\n{prompt}" if system else prompt),
45
+ temperature=temperature,
46
+ max_new_tokens=max_new_tokens,
47
+ top_p=top_p,
48
+ do_sample=True,
49
+ return_full_text=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
+ return out
52
+
53
+ # --- NEW: image -> text (OCR / caption) ---
54
+ def image_to_text(self, image: "Image.Image") -> str:
55
+ """
56
+ Uses HF task 'image-to-text' for models like TrOCR or BLIP-caption.
57
+ """
58
+ return self.client.image_to_text(image).generated_text