AlsuGibadullina commited on
Commit
71139d5
·
verified ·
1 Parent(s): 3e503c9

Update src/backends.py

Browse files
Files changed (1) hide show
  1. src/backends.py +141 -95
src/backends.py CHANGED
@@ -1,130 +1,176 @@
1
  import os
2
- import time
3
- import json
4
- import requests
5
  from dataclasses import dataclass
6
- from typing import Optional, Dict, Any, Protocol
 
 
7
 
8
- from huggingface_hub import InferenceClient
 
9
 
10
- # Local backend (optional)
11
- try:
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
13
- import torch
14
- except Exception:
15
- AutoTokenizer = None
16
- AutoModelForCausalLM = None
17
- torch = None
18
 
19
 
20
  class LLMBackend(Protocol):
21
- def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
 
 
 
 
 
 
 
22
  ...
23
 
24
 
 
 
 
 
 
 
 
 
25
  @dataclass
26
- class HFInferenceAPIBackend:
27
- """
28
- Uses HF Inference API via huggingface_hub.InferenceClient.
29
- Works well on Spaces for large models if you provide HF_TOKEN in Secrets.
30
- """
31
  model_id: str
32
- token: Optional[str] = None
33
- timeout_s: int = 120
34
 
35
  def __post_init__(self):
36
- self.token = self.token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
37
- self.client = InferenceClient(model=self.model_id, token=self.token, timeout=self.timeout_s)
 
 
38
 
39
- def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
40
- # We use chat.completions when available (for chat-tuned models),
41
- # otherwise fall back to text_generation.
42
- # InferenceClient adapts per model capabilities.
43
  temperature = float(params.get("temperature", 0.2))
44
- max_new_tokens = int(params.get("max_new_tokens", 600))
45
  top_p = float(params.get("top_p", 0.95))
46
- repetition_penalty = float(params.get("repetition_penalty", 1.05))
47
-
48
- # Try chat first
49
- try:
50
- messages = []
51
- if system:
52
- messages.append({"role": "system", "content": system})
53
- messages.append({"role": "user", "content": prompt})
54
-
55
- resp = self.client.chat.completions.create(
56
- model=self.model_id,
57
- messages=messages,
58
- temperature=temperature,
59
- max_tokens=max_new_tokens,
60
- top_p=top_p,
61
- )
62
- return resp.choices[0].message.content
63
- except Exception:
64
- # Fallback: text generation
65
- out = self.client.text_generation(
66
- prompt=(f"{system}\n\n{prompt}" if system else prompt),
67
- temperature=temperature,
68
- max_new_tokens=max_new_tokens,
69
- top_p=top_p,
70
- repetition_penalty=repetition_penalty,
71
- do_sample=True,
72
- return_full_text=False,
73
- )
74
- return out
75
 
76
 
77
  @dataclass
78
- class LocalTransformersBackend:
79
- """
80
- Loads model locally in the Space container.
81
- Use only small models unless you have GPU Space and enough memory.
82
- """
83
  model_id: str
84
- device: str = "cpu"
85
 
86
  def __post_init__(self):
87
- if AutoTokenizer is None or AutoModelForCausalLM is None:
88
- raise RuntimeError("transformers/torch not available in this environment.")
 
 
89
 
90
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True)
91
- self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
92
- if torch is not None:
93
- self.model.to(self.device)
94
-
95
- def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
96
  temperature = float(params.get("temperature", 0.2))
97
- max_new_tokens = int(params.get("max_new_tokens", 600))
98
  top_p = float(params.get("top_p", 0.95))
99
- repetition_penalty = float(params.get("repetition_penalty", 1.05))
100
 
101
- full_prompt = (f"{system}\n\n{prompt}" if system else prompt)
 
 
 
 
 
 
 
 
102
 
103
- inputs = self.tokenizer(full_prompt, return_tensors="pt")
104
- if torch is not None:
105
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
106
 
107
- with torch.no_grad():
108
- output_ids = self.model.generate(
109
- **inputs,
110
- do_sample=True,
111
  temperature=temperature,
112
  top_p=top_p,
113
- repetition_penalty=repetition_penalty,
114
- max_new_tokens=max_new_tokens,
115
  )
 
 
 
116
 
117
- text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
118
- # Heuristic: remove the prompt prefix if present
119
- if text.startswith(full_prompt):
120
- text = text[len(full_prompt):].lstrip()
121
- return text
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- def make_backend(backend_type: str, model_id: str) -> LLMBackend:
125
- if backend_type == "hf_inference_api":
126
- return HFInferenceAPIBackend(model_id=model_id)
127
- if backend_type == "local_transformers":
128
- # auto-device for local; keep cpu by default
129
- return LocalTransformersBackend(model_id=model_id, device="cpu")
130
- raise ValueError(f"Unknown backend: {backend_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import base64
3
+ import mimetypes
 
4
  from dataclasses import dataclass
5
+ from typing import Optional, Dict, Any, Protocol, Tuple
6
+
7
+ import requests
8
 
9
+ # OpenAI
10
+ from openai import OpenAI
11
 
12
+ # Gemini
13
+ from google import genai
14
+ from google.genai import types
 
 
 
 
 
15
 
16
 
17
  class LLMBackend(Protocol):
18
+ def generate(
19
+ self,
20
+ prompt: str,
21
+ *,
22
+ system: Optional[str],
23
+ params: Dict[str, Any],
24
+ image_path: Optional[str] = None,
25
+ ) -> str:
26
  ...
27
 
28
 
29
+ def _file_to_data_url(path: str) -> Tuple[str, str]:
30
+ mime, _ = mimetypes.guess_type(path)
31
+ mime = mime or "image/png"
32
+ with open(path, "rb") as f:
33
+ b64 = base64.b64encode(f.read()).decode("utf-8")
34
+ return f"data:{mime};base64,{b64}", mime
35
+
36
+
37
  @dataclass
38
+ class OpenAIBackend:
 
 
 
 
39
  model_id: str
40
+ api_key: Optional[str] = None
 
41
 
42
  def __post_init__(self):
43
+ self.api_key = self.api_key or os.getenv("OPENAI_API_KEY")
44
+ if not self.api_key:
45
+ raise RuntimeError("OPENAI_API_KEY is not set (Spaces → Settings → Secrets).")
46
+ self.client = OpenAI(api_key=self.api_key)
47
 
48
+ def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any], image_path: Optional[str] = None) -> str:
 
 
 
49
  temperature = float(params.get("temperature", 0.2))
50
+ max_new_tokens = int(params.get("max_new_tokens", 800))
51
  top_p = float(params.get("top_p", 0.95))
52
+
53
+ user_content = [{"type": "input_text", "text": prompt}]
54
+ if image_path:
55
+ data_url, _ = _file_to_data_url(image_path)
56
+ user_content.append({"type": "input_image", "image_url": data_url})
57
+
58
+ # Responses API: supports image inputs via input_image items. :contentReference[oaicite:4]{index=4}
59
+ input_items = []
60
+ if system:
61
+ input_items.append({
62
+ "role": "developer",
63
+ "content": [{"type": "input_text", "text": system}]
64
+ })
65
+ input_items.append({"role": "user", "content": user_content})
66
+
67
+ resp = self.client.responses.create(
68
+ model=self.model_id,
69
+ input=input_items,
70
+ temperature=temperature,
71
+ top_p=top_p,
72
+ max_output_tokens=max_new_tokens,
73
+ )
74
+ return resp.output_text
 
 
 
 
 
 
75
 
76
 
77
  @dataclass
78
+ class GeminiBackend:
 
 
 
 
79
  model_id: str
80
+ api_key: Optional[str] = None
81
 
82
  def __post_init__(self):
83
+ self.api_key = self.api_key or os.getenv("GEMINI_API_KEY")
84
+ if not self.api_key:
85
+ raise RuntimeError("GEMINI_API_KEY is not set (Spaces → Settings → Secrets).")
86
+ self.client = genai.Client(api_key=self.api_key)
87
 
88
+ def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any], image_path: Optional[str] = None) -> str:
 
 
 
 
 
89
  temperature = float(params.get("temperature", 0.2))
90
+ max_new_tokens = int(params.get("max_new_tokens", 800))
91
  top_p = float(params.get("top_p", 0.95))
 
92
 
93
+ parts = []
94
+
95
+ # Gemini: inline bytes via Part.from_bytes (официальный пример). :contentReference[oaicite:5]{index=5}
96
+ if image_path:
97
+ mime, _ = mimetypes.guess_type(image_path)
98
+ mime = mime or "image/png"
99
+ with open(image_path, "rb") as f:
100
+ img_bytes = f.read()
101
+ parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime))
102
 
103
+ text = prompt if not system else f"{system}\n\n{prompt}"
104
+ parts.append(text)
 
105
 
106
+ resp = self.client.models.generate_content(
107
+ model=self.model_id,
108
+ contents=parts,
109
+ config=types.GenerateContentConfig(
110
  temperature=temperature,
111
  top_p=top_p,
112
+ max_output_tokens=max_new_tokens,
 
113
  )
114
+ )
115
+ return resp.text or ""
116
+
117
 
118
+ @dataclass
119
+ class DeepSeekBackend:
120
+ model_id: str
121
+ api_key: Optional[str] = None
122
+ base_url: str = "https://api.deepseek.com"
123
 
124
+ def __post_init__(self):
125
+ self.api_key = self.api_key or os.getenv("DEEPSEEK_API_KEY")
126
+ if not self.api_key:
127
+ raise RuntimeError("DEEPSEEK_API_KEY is not set (Spaces → Settings → Secrets).")
128
+
129
+ def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any], image_path: Optional[str] = None) -> str:
130
+ # DeepSeek official docs show text chat completions. :contentReference[oaicite:6]{index=6}
131
+ temperature = float(params.get("temperature", 0.2))
132
+ max_tokens = int(params.get("max_new_tokens", 800))
133
+ top_p = float(params.get("top_p", 0.95))
134
+
135
+ if image_path:
136
+ prompt = (
137
+ "ВАЖНО: Пользователь приложил изображение (диаграмму), "
138
+ "но этот провайдер в текущей конфигурации работает только с текстом. "
139
+ "Попроси пользователя описать диаграмму текстом, либо продолжи только по тексту.\n\n"
140
+ + prompt
141
+ )
142
 
143
+ messages = []
144
+ if system:
145
+ messages.append({"role": "system", "content": system})
146
+ messages.append({"role": "user", "content": prompt})
147
+
148
+ r = requests.post(
149
+ f"{self.base_url}/chat/completions",
150
+ headers={
151
+ "Authorization": f"Bearer {self.api_key}",
152
+ "Content-Type": "application/json",
153
+ },
154
+ json={
155
+ "model": self.model_id,
156
+ "messages": messages,
157
+ "temperature": temperature,
158
+ "top_p": top_p,
159
+ "max_tokens": max_tokens,
160
+ "stream": False,
161
+ },
162
+ timeout=120,
163
+ )
164
+ r.raise_for_status()
165
+ data = r.json()
166
+ return data["choices"][0]["message"]["content"]
167
+
168
+
169
+ def make_backend(provider: str, model_id: str) -> LLMBackend:
170
+ if provider == "openai":
171
+ return OpenAIBackend(model_id=model_id)
172
+ if provider == "gemini":
173
+ return GeminiBackend(model_id=model_id)
174
+ if provider == "deepseek":
175
+ return DeepSeekBackend(model_id=model_id)
176
+ raise ValueError(f"Unknown provider: {provider}")