AlsuGibadullina commited on
Commit
e1fbc11
·
verified ·
1 Parent(s): f435cf6

Create backends.py

Browse files
Files changed (1) hide show
  1. src/backends.py +130 -0
src/backends.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import requests
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Dict, Any, Protocol
7
+
8
+ from huggingface_hub import InferenceClient
9
+
10
+ # Local backend (optional)
11
+ try:
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import torch
14
+ except Exception:
15
+ AutoTokenizer = None
16
+ AutoModelForCausalLM = None
17
+ torch = None
18
+
19
+
20
+ class LLMBackend(Protocol):
21
+ def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
22
+ ...
23
+
24
+
25
+ @dataclass
26
+ class HFInferenceAPIBackend:
27
+ """
28
+ Uses HF Inference API via huggingface_hub.InferenceClient.
29
+ Works well on Spaces for large models if you provide HF_TOKEN in Secrets.
30
+ """
31
+ model_id: str
32
+ token: Optional[str] = None
33
+ timeout_s: int = 120
34
+
35
+ def __post_init__(self):
36
+ self.token = self.token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
37
+ self.client = InferenceClient(model=self.model_id, token=self.token, timeout=self.timeout_s)
38
+
39
+ def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
40
+ # We use chat.completions when available (for chat-tuned models),
41
+ # otherwise fall back to text_generation.
42
+ # InferenceClient adapts per model capabilities.
43
+ temperature = float(params.get("temperature", 0.2))
44
+ max_new_tokens = int(params.get("max_new_tokens", 600))
45
+ top_p = float(params.get("top_p", 0.95))
46
+ repetition_penalty = float(params.get("repetition_penalty", 1.05))
47
+
48
+ # Try chat first
49
+ try:
50
+ messages = []
51
+ if system:
52
+ messages.append({"role": "system", "content": system})
53
+ messages.append({"role": "user", "content": prompt})
54
+
55
+ resp = self.client.chat.completions.create(
56
+ model=self.model_id,
57
+ messages=messages,
58
+ temperature=temperature,
59
+ max_tokens=max_new_tokens,
60
+ top_p=top_p,
61
+ )
62
+ return resp.choices[0].message.content
63
+ except Exception:
64
+ # Fallback: text generation
65
+ out = self.client.text_generation(
66
+ prompt=(f"{system}\n\n{prompt}" if system else prompt),
67
+ temperature=temperature,
68
+ max_new_tokens=max_new_tokens,
69
+ top_p=top_p,
70
+ repetition_penalty=repetition_penalty,
71
+ do_sample=True,
72
+ return_full_text=False,
73
+ )
74
+ return out
75
+
76
+
77
+ @dataclass
78
+ class LocalTransformersBackend:
79
+ """
80
+ Loads model locally in the Space container.
81
+ Use only small models unless you have GPU Space and enough memory.
82
+ """
83
+ model_id: str
84
+ device: str = "cpu"
85
+
86
+ def __post_init__(self):
87
+ if AutoTokenizer is None or AutoModelForCausalLM is None:
88
+ raise RuntimeError("transformers/torch not available in this environment.")
89
+
90
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True)
91
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
92
+ if torch is not None:
93
+ self.model.to(self.device)
94
+
95
+ def generate(self, prompt: str, *, system: Optional[str], params: Dict[str, Any]) -> str:
96
+ temperature = float(params.get("temperature", 0.2))
97
+ max_new_tokens = int(params.get("max_new_tokens", 600))
98
+ top_p = float(params.get("top_p", 0.95))
99
+ repetition_penalty = float(params.get("repetition_penalty", 1.05))
100
+
101
+ full_prompt = (f"{system}\n\n{prompt}" if system else prompt)
102
+
103
+ inputs = self.tokenizer(full_prompt, return_tensors="pt")
104
+ if torch is not None:
105
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
106
+
107
+ with torch.no_grad():
108
+ output_ids = self.model.generate(
109
+ **inputs,
110
+ do_sample=True,
111
+ temperature=temperature,
112
+ top_p=top_p,
113
+ repetition_penalty=repetition_penalty,
114
+ max_new_tokens=max_new_tokens,
115
+ )
116
+
117
+ text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
118
+ # Heuristic: remove the prompt prefix if present
119
+ if text.startswith(full_prompt):
120
+ text = text[len(full_prompt):].lstrip()
121
+ return text
122
+
123
+
124
+ def make_backend(backend_type: str, model_id: str) -> LLMBackend:
125
+ if backend_type == "hf_inference_api":
126
+ return HFInferenceAPIBackend(model_id=model_id)
127
+ if backend_type == "local_transformers":
128
+ # auto-device for local; keep cpu by default
129
+ return LocalTransformersBackend(model_id=model_id, device="cpu")
130
+ raise ValueError(f"Unknown backend: {backend_type}")