Huseyin commited on
Commit
4ffecfe
·
verified ·
1 Parent(s): 548a299

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +218 -56
handler.py CHANGED
@@ -1,68 +1,230 @@
1
- from typing import Dict, List, Any
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  """
8
- Qwen2 modelini yükle
9
  """
10
- # Tokenizer'ı yükle
11
- self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
12
-
13
- # Modeli yükle - GPU belleğini optimize et
14
- self.model = AutoModelForCausalLM.from_pretrained(
15
- path,
16
- torch_dtype=torch.bfloat16, # Bellek optimizasyonu için
17
- device_map="auto", # GPU'yu otomatik kullan
18
- trust_remote_code=True # Qwen2 için gerekli
19
- )
20
-
21
- # Eğer pad token yoksa ekle
22
- if self.tokenizer.pad_token is None:
23
- self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
26
  """
27
- Inference endpoint
28
  """
29
- # Girdileri al
30
- inputs = data.pop("inputs", "")
31
-
32
- # Parametreleri al (varsayılan değerlerle)
33
- parameters = data.pop("parameters", {})
34
- max_new_tokens = parameters.get("max_new_tokens", 256)
35
- temperature = parameters.get("temperature", 0.7)
36
- top_p = parameters.get("top_p", 0.95)
37
- do_sample = parameters.get("do_sample", True)
38
- repetition_penalty = parameters.get("repetition_penalty", 1.1)
39
-
40
- # Tokenize et
41
- input_ids = self.tokenizer(
42
- inputs,
43
- return_tensors="pt",
44
- padding=True,
45
- truncation=True,
46
- max_length=2048
47
- ).input_ids.to(self.model.device)
48
-
49
- # Generate et
50
- with torch.no_grad():
51
- outputs = self.model.generate(
52
- input_ids,
53
- max_new_tokens=max_new_tokens,
54
- temperature=temperature,
55
- top_p=top_p,
56
- do_sample=do_sample,
57
- repetition_penalty=repetition_penalty,
58
- pad_token_id=self.tokenizer.pad_token_id,
59
- eos_token_id=self.tokenizer.eos_token_id
 
 
 
 
 
 
 
60
  )
61
-
62
- # Sadece yeni oluşturulan tokenleri al
63
- generated_ids = outputs[0][input_ids.shape[-1]:]
64
-
65
- # Decode et
66
- result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
67
-
68
- return [{"generated_text": result}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import json
5
+ import logging
6
+
7
+ # Loglama ayarları
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
  """
14
+ Qwen2 7.6B modelini optimize edilmiş şekilde yükle
15
  """
16
+ try:
17
+ logger.info(f"Model yükleniyor: {path}")
18
+
19
+ # Tokenizer'ı yükle - Qwen2 için trust_remote_code gerekli
20
+ self.tokenizer = AutoTokenizer.from_pretrained(
21
+ path,
22
+ trust_remote_code=True,
23
+ use_fast=True # Fast tokenizer kullan
24
+ )
25
+
26
+ # Model konfigürasyonu
27
+ model_kwargs = {
28
+ "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
29
+ "device_map": "auto",
30
+ "trust_remote_code": True,
31
+ "low_cpu_mem_usage": True, # Bellek optimizasyonu
32
+ }
33
+
34
+ # Modeli yükle
35
+ self.model = AutoModelForCausalLM.from_pretrained(
36
+ path,
37
+ **model_kwargs
38
+ )
39
+
40
+ # Model'i eval moduna al
41
+ self.model.eval()
42
+
43
+ # Tokenizer ayarları
44
+ if self.tokenizer.pad_token is None:
45
+ self.tokenizer.pad_token = self.tokenizer.eos_token
46
+ if self.tokenizer.pad_token_id is None:
47
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
48
+
49
+ # Chat template kontrolü
50
+ self.has_chat_template = hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template is not None
51
 
52
+ logger.info(f"Model başarıyla yüklendi. Chat template: {self.has_chat_template}")
53
+ logger.info(f"Device: {next(self.model.parameters()).device}")
54
+ logger.info(f"Dtype: {next(self.model.parameters()).dtype}")
55
+
56
+ except Exception as e:
57
+ logger.error(f"Model yükleme hatası: {str(e)}")
58
+ raise RuntimeError(f"Model initialization failed: {str(e)}")
59
+
60
+ def format_chat_input(self, messages: List[Dict[str, str]]) -> str:
61
+ """
62
+ Chat formatında gelen mesajları işle
63
+ """
64
+ if self.has_chat_template:
65
+ return self.tokenizer.apply_chat_template(
66
+ messages,
67
+ tokenize=False,
68
+ add_generation_prompt=True
69
+ )
70
+ else:
71
+ # Fallback: Basit format
72
+ formatted = ""
73
+ for message in messages:
74
+ role = message.get("role", "user")
75
+ content = message.get("content", "")
76
+ if role == "system":
77
+ formatted += f"System: {content}\n"
78
+ elif role == "user":
79
+ formatted += f"User: {content}\n"
80
+ elif role == "assistant":
81
+ formatted += f"Assistant: {content}\n"
82
+ formatted += "Assistant: "
83
+ return formatted
84
+
85
+ @torch.inference_mode()
86
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
87
  """
88
+ Inference endpoint - Hem text hem de chat formatını destekler
89
  """
90
+ try:
91
+ # Input türünü belirle
92
+ inputs = data.pop("inputs", None)
93
+ messages = data.pop("messages", None)
94
+
95
+ # Input kontrolü
96
+ if not inputs and not messages:
97
+ return [{"error": "Either 'inputs' or 'messages' must be provided"}]
98
+
99
+ # Chat format'ı kontrol et
100
+ if messages:
101
+ text_input = self.format_chat_input(messages)
102
+ else:
103
+ text_input = inputs
104
+
105
+ # Parametreleri al
106
+ parameters = data.pop("parameters", {})
107
+
108
+ # Generation parametreleri
109
+ max_new_tokens = parameters.get("max_new_tokens", 256)
110
+ temperature = parameters.get("temperature", 0.7)
111
+ top_p = parameters.get("top_p", 0.9)
112
+ top_k = parameters.get("top_k", 50)
113
+ do_sample = parameters.get("do_sample", True)
114
+ repetition_penalty = parameters.get("repetition_penalty", 1.1)
115
+ num_return_sequences = parameters.get("num_return_sequences", 1)
116
+ stop_sequences = parameters.get("stop_sequences", None)
117
+
118
+ logger.info(f"Processing input (length: {len(text_input)})")
119
+
120
+ # Tokenize
121
+ inputs_encoded = self.tokenizer(
122
+ text_input,
123
+ return_tensors="pt",
124
+ padding=True,
125
+ truncation=True,
126
+ max_length=min(2048, self.model.config.max_position_embeddings),
127
+ return_attention_mask=True
128
  )
129
+
130
+ # Device'a taşı
131
+ input_ids = inputs_encoded["input_ids"].to(self.model.device)
132
+ attention_mask = inputs_encoded["attention_mask"].to(self.model.device)
133
+
134
+ # Stopping criteria ayarla
135
+ stop_token_ids = []
136
+ if stop_sequences:
137
+ for seq in stop_sequences:
138
+ tokens = self.tokenizer.encode(seq, add_special_tokens=False)
139
+ stop_token_ids.extend(tokens)
140
+
141
+ # Generate
142
+ generation_kwargs = {
143
+ "input_ids": input_ids,
144
+ "attention_mask": attention_mask,
145
+ "max_new_tokens": max_new_tokens,
146
+ "temperature": temperature if do_sample else 1.0,
147
+ "top_p": top_p if do_sample else 1.0,
148
+ "top_k": top_k if do_sample else None,
149
+ "do_sample": do_sample,
150
+ "repetition_penalty": repetition_penalty,
151
+ "num_return_sequences": num_return_sequences,
152
+ "pad_token_id": self.tokenizer.pad_token_id,
153
+ "eos_token_id": self.tokenizer.eos_token_id,
154
+ "use_cache": True,
155
+ }
156
+
157
+ # Stop tokens ekle
158
+ if stop_token_ids:
159
+ generation_kwargs["eos_token_id"] = stop_token_ids
160
+
161
+ # Generate
162
+ outputs = self.model.generate(**generation_kwargs)
163
+
164
+ # Decode
165
+ results = []
166
+ for output in outputs:
167
+ # Input kısmını çıkar
168
+ generated_ids = output[input_ids.shape[-1]:]
169
+ generated_text = self.tokenizer.decode(
170
+ generated_ids,
171
+ skip_special_tokens=True,
172
+ clean_up_tokenization_spaces=True
173
+ )
174
+
175
+ results.append({
176
+ "generated_text": generated_text,
177
+ "details": {
178
+ "finish_reason": "length" if len(generated_ids) >= max_new_tokens else "stop",
179
+ "generated_tokens": len(generated_ids),
180
+ "input_tokens": input_ids.shape[-1]
181
+ }
182
+ })
183
+
184
+ logger.info(f"Generation completed. Generated {len(results)} sequences")
185
+
186
+ # Tek sonuç istenmişse direkt döndür
187
+ if num_return_sequences == 1:
188
+ return results
189
+ else:
190
+ return [{"results": results}]
191
+
192
+ except torch.cuda.OutOfMemoryError:
193
+ logger.error("GPU bellek yetersiz!")
194
+ return [{
195
+ "error": "GPU out of memory. Try reducing max_new_tokens or input length",
196
+ "type": "memory_error"
197
+ }]
198
+ except Exception as e:
199
+ logger.error(f"Inference hatası: {str(e)}")
200
+ import traceback
201
+ logger.error(traceback.format_exc())
202
+ return [{
203
+ "error": str(e),
204
+ "type": "inference_error",
205
+ "traceback": traceback.format_exc()
206
+ }]
207
+
208
+ def health_check(self) -> Dict[str, Any]:
209
+ """
210
+ Endpoint sağlık kontrolü
211
+ """
212
+ try:
213
+ test_input = "Test"
214
+ inputs = self.tokenizer(test_input, return_tensors="pt")
215
+ with torch.no_grad():
216
+ _ = self.model.generate(
217
+ inputs.input_ids.to(self.model.device),
218
+ max_new_tokens=5
219
+ )
220
+ return {
221
+ "status": "healthy",
222
+ "model": "Qwen2-7.6B",
223
+ "device": str(next(self.model.parameters()).device),
224
+ "dtype": str(next(self.model.parameters()).dtype)
225
+ }
226
+ except Exception as e:
227
+ return {
228
+ "status": "unhealthy",
229
+ "error": str(e)
230
+ }