CanerDedeoglu commited on
Commit
77b64f3
·
verified ·
1 Parent(s): 7250ba8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +96 -63
handler.py CHANGED
@@ -1,135 +1,169 @@
1
  # /repository/handler.py
2
  import os, io, base64
3
  from typing import Any, Dict, List, Optional
 
4
  import torch
5
  from PIL import Image
6
 
7
- # ---- LLaVA demodaki modüller ----
8
- from llava.model.builder import load_pretrained_model
9
  from llava.mm_utils import tokenizer_image_token, process_images
10
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
11
- from llava.conversation import conv_templates, SeparatorStyle
 
 
 
 
 
12
  from llava.utils import disable_torch_init
13
- from llava.model.builder import get_model_name_from_path
14
 
15
- # Ortam değişkenleri (modeli nereden alacağımız)
16
- # 1) Yerel klasörden yüklemek istersen HF_MODEL_LOCAL_DIR kullan
17
- # 2) HF Hub repo id ile yüklemek istersen HF_MODEL_ID kullan
 
 
18
  HF_MODEL_LOCAL_DIR = os.getenv("HF_MODEL_LOCAL_DIR", "").strip()
19
- HF_MODEL_ID = os.getenv("HF_MODEL_ID", "").strip() # ör: "your-org/your-llava-model"
 
20
 
21
- DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2") # demo: llava_v2
22
- MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "4096"))
 
 
23
 
24
- # Flash-Attention yoksa SDPA güvenli yoldur
25
  os.environ.setdefault("ATTN_IMPLEMENTATION", "sdpa")
26
 
 
27
  class EndpointHandler:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def __init__(self, path: str = "") -> None:
29
- """
30
- path: /repository (endpoint bu klasörü model_dir olarak geçer)
31
- """
32
  disable_torch_init()
33
 
34
- # Model yolunu belirle
35
  if HF_MODEL_LOCAL_DIR:
36
  model_path = HF_MODEL_LOCAL_DIR
37
  elif HF_MODEL_ID:
38
  model_path = HF_MODEL_ID
39
  else:
40
- # Eğer ağırlık/konfig bu repository içindeyse path= "/repository"
41
  model_path = path
42
 
43
- # model adı (LLaVA utils)
44
  self.model_name = get_model_name_from_path(model_path)
45
 
46
- # LLaVA yükleme demo ile aynı giriş:
47
- # Dönüş: tokenizer, model, image_processor, context_len
48
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
49
  model_path=model_path,
50
- model_base=None, # LoRA vb. yoksa None
51
  model_name=self.model_name,
52
- torch_dtype="auto", # ortam GPU'ya göre seçsin
53
  attn_implementation=os.getenv("ATTN_IMPLEMENTATION", "sdpa"),
54
- device_map="auto"
55
  )
 
56
 
57
- # Görüntü başlangıç/bitiş tokenları (model sürümüne göre aktif)
58
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
59
  self.image_token = DEFAULT_IMAGE_TOKEN
60
  self.im_start = DEFAULT_IM_START_TOKEN
61
  self.im_end = DEFAULT_IM_END_TOKEN
62
 
63
- # ---- Yardımcılar ----
 
 
64
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
 
65
  if not img_field:
66
  return None
67
- if img_field.startswith("data:image"):
68
- head, b64 = img_field.split(",", 1)
69
- return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
70
- elif img_field.startswith("http://") or img_field.startswith("https://"):
71
- import requests
72
- r = requests.get(img_field, timeout=20)
73
- r.raise_for_status()
74
- return Image.open(io.BytesIO(r.content)).convert("RGB")
75
- else:
76
- # container içinden dosya okunacaksa
77
- return Image.open(img_field).convert("RGB")
 
 
 
78
 
79
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
80
- # Demo: conv_templates ile diyalog kur
 
 
 
 
81
  conv = conv_templates[conv_mode].copy()
82
  if self.use_im_start_end:
83
- # <im_start> <image> <im_end> + kullanıcı metni
84
  content = f"{self.im_start}{self.image_token}{self.im_end}\n{user_text}"
85
  else:
86
  content = f"{self.image_token}\n{user_text}"
87
- conv.append_message(conv.roles[0], content)
88
- conv.append_message(conv.roles[1], None) # assistant boş, model dolduracak
 
89
  return conv.get_prompt()
90
 
91
- # ---- Inference giriş noktası ----
 
 
92
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
93
- """
94
- Beklenen giriş (demo ile uyumlu):
95
- {
96
- "inputs": { "query": "...", "image": "<url|dataurl|path>" },
97
- "parameters": { "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0, ... },
98
- "conv_mode": "llava_v2" # opsiyonel; yoksa varsayılanı kullanırız
99
- }
100
- """
101
  inputs = data.get("inputs") or {}
102
  params = data.get("parameters") or {}
103
- conv_mode = data.get("conv_mode") or DEFAULT_CONV_MODE
 
104
 
105
  query_text = inputs.get("query", "")
106
- image_f = inputs.get("image", "")
107
- pil_img = self._load_image(image_f)
108
 
109
- # 1) Prompt (conversation şablonu)
110
  prompt = self._build_prompt(query_text, conv_mode)
111
 
112
- # 2) Görsel tensörü (demo: process_images)
113
  image_tensors = None
114
  if pil_img is not None:
115
  image_tensors = process_images([pil_img], self.image_processor, self.model.config)
116
 
117
- # 3) Tokenize (görüntü tokenını metne göm)
118
  input_ids = tokenizer_image_token(
119
  prompt,
120
  self.tokenizer,
121
  IMAGE_TOKEN_INDEX,
122
- return_tensors="pt"
123
  )
124
-
125
- # 4) Cihaza taşı
126
  input_ids = input_ids.to(self.model.device, non_blocking=True)
 
 
 
 
 
 
 
 
127
  if image_tensors is not None:
128
  image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
129
 
130
- # 5) Generate (demo parametreleri)
131
  gen_kwargs = {
132
- "max_new_tokens": int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF)),
133
  "temperature": float(params.get("temperature", 0.0)),
134
  "top_p": float(params.get("top_p", 1.0)),
135
  "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
@@ -137,13 +171,12 @@ class EndpointHandler:
137
  "use_cache": bool(params.get("use_cache", True)),
138
  }
139
 
140
- with torch.no_grad():
141
  output_ids = self.model.generate(
142
  input_ids,
143
  images=image_tensors,
144
- **gen_kwargs
145
  )
146
 
147
- # 6) Decode (assistant’ın cevabı)
148
  outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
149
  return [{"generated_text": outputs}]
 
1
  # /repository/handler.py
2
  import os, io, base64
3
  from typing import Any, Dict, List, Optional
4
+
5
  import torch
6
  from PIL import Image
7
 
8
+ # --- LLaVA (demo) parçaları ---
9
+ from llava.model.builder import load_pretrained_model, get_model_name_from_path
10
  from llava.mm_utils import tokenizer_image_token, process_images
11
+ from llava.constants import (
12
+ IMAGE_TOKEN_INDEX,
13
+ DEFAULT_IMAGE_TOKEN,
14
+ DEFAULT_IM_START_TOKEN,
15
+ DEFAULT_IM_END_TOKEN,
16
+ )
17
+ from llava.conversation import conv_templates
18
  from llava.utils import disable_torch_init
 
19
 
20
+
21
+ # =========================
22
+ # Ortam / Varsayılanlar
23
+ # =========================
24
+ # 1) Yerelden yüklemek için (bu repository içi): boş bırakın veya HF_MODEL_LOCAL_DIR=/repository
25
  HF_MODEL_LOCAL_DIR = os.getenv("HF_MODEL_LOCAL_DIR", "").strip()
26
+ # 2) Hub'dan yüklemek isterseniz: HF_MODEL_ID=org/name
27
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "").strip()
28
 
29
+ # Demo ile aynı conv_mode
30
+ DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2")
31
+ # Güvenli varsayılan (çok büyük tutmayalım)
32
+ MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
33
 
34
+ # Flash-Attention zorunluluğunu kaldır, SDPA kullan
35
  os.environ.setdefault("ATTN_IMPLEMENTATION", "sdpa")
36
 
37
+
38
  class EndpointHandler:
39
+ """
40
+ Hugging Face Inference Toolkit tarafından çağrılan handler.
41
+ Girdi şeması (demo ile uyumlu):
42
+ {
43
+ "inputs": { "query": "...", "image": "<url|dataurl|path>" },
44
+ "parameters": {
45
+ "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0,
46
+ "repetition_penalty": 1.0, "do_sample": false, "use_cache": true
47
+ },
48
+ "conv_mode": "llava_v2" # opsiyonel
49
+ }
50
+ Dönüş:
51
+ [ { "generated_text": "..." } ]
52
+ """
53
+
54
  def __init__(self, path: str = "") -> None:
55
+ # path -> /repository
 
 
56
  disable_torch_init()
57
 
58
+ # Modelin yüklenme yolu seçimi
59
  if HF_MODEL_LOCAL_DIR:
60
  model_path = HF_MODEL_LOCAL_DIR
61
  elif HF_MODEL_ID:
62
  model_path = HF_MODEL_ID
63
  else:
64
+ # Ağırlıklar bu repoda ise
65
  model_path = path
66
 
67
+ # Model adı (LLaVA yardımcı)
68
  self.model_name = get_model_name_from_path(model_path)
69
 
70
+ # LLaVA yüklemesi (demo ile aynı giriş noktası)
 
71
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
72
  model_path=model_path,
73
+ model_base=None, # LoRA yoksa None
74
  model_name=self.model_name,
75
+ torch_dtype="auto",
76
  attn_implementation=os.getenv("ATTN_IMPLEMENTATION", "sdpa"),
77
+ device_map="auto",
78
  )
79
+ self.model.eval()
80
 
81
+ # Görsel token işaretleri (model config'ine bağlı)
82
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
83
  self.image_token = DEFAULT_IMAGE_TOKEN
84
  self.im_start = DEFAULT_IM_START_TOKEN
85
  self.im_end = DEFAULT_IM_END_TOKEN
86
 
87
+ # ---------------------------
88
+ # Yardımcılar
89
+ # ---------------------------
90
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
91
+ """URL / data URL / yerel path -> PIL.Image"""
92
  if not img_field:
93
  return None
94
+ try:
95
+ if img_field.startswith("data:image"):
96
+ head, b64 = img_field.split(",", 1)
97
+ return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
98
+ elif img_field.startswith("http://") or img_field.startswith("https://"):
99
+ import requests
100
+ r = requests.get(img_field, timeout=20)
101
+ r.raise_for_status()
102
+ return Image.open(io.BytesIO(r.content)).convert("RGB")
103
+ else:
104
+ return Image.open(img_field).convert("RGB")
105
+ except Exception as e:
106
+ # Görsel okunamadıysa açıklayıcı hata bırak
107
+ raise RuntimeError(f"Image load failed: {e}") from e
108
 
109
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
110
+ """Demodaki gibi conv_templates ile diyalog şablonu kur."""
111
+ # Yanlış conv_mode gelirse default'a düş
112
+ if conv_mode not in conv_templates:
113
+ conv_mode = DEFAULT_CONV_MODE
114
+
115
  conv = conv_templates[conv_mode].copy()
116
  if self.use_im_start_end:
 
117
  content = f"{self.im_start}{self.image_token}{self.im_end}\n{user_text}"
118
  else:
119
  content = f"{self.image_token}\n{user_text}"
120
+
121
+ conv.append_message(conv.roles[0], content) # user
122
+ conv.append_message(conv.roles[1], None) # assistant (boş)
123
  return conv.get_prompt()
124
 
125
+ # ---------------------------
126
+ # Inference giriş noktası
127
+ # ---------------------------
128
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
129
  inputs = data.get("inputs") or {}
130
  params = data.get("parameters") or {}
131
+ conv_mode_req = data.get("conv_mode")
132
+ conv_mode = conv_mode_req if conv_mode_req in conv_templates else DEFAULT_CONV_MODE
133
 
134
  query_text = inputs.get("query", "")
135
+ image_f = inputs.get("image", "")
136
+ pil_img = self._load_image(image_f) if image_f else None
137
 
138
+ # 1) Prompt hazırla
139
  prompt = self._build_prompt(query_text, conv_mode)
140
 
141
+ # 2) Görsel tensörü
142
  image_tensors = None
143
  if pil_img is not None:
144
  image_tensors = process_images([pil_img], self.image_processor, self.model.config)
145
 
146
+ # 3) Tokenize (görüntü tokenını göm)
147
  input_ids = tokenizer_image_token(
148
  prompt,
149
  self.tokenizer,
150
  IMAGE_TOKEN_INDEX,
151
+ return_tensors="pt",
152
  )
 
 
153
  input_ids = input_ids.to(self.model.device, non_blocking=True)
154
+
155
+ # 4) context_len'e göre güvenli max_new_tokens
156
+ requested_max_new = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF))
157
+ # ufak tampon ile aşımı engelle
158
+ avail = max(16, int(self.context_len) - int(input_ids.shape[-1]) - 8)
159
+ max_new_tokens = max(1, min(requested_max_new, avail))
160
+
161
+ # Görseli cihaza taşı
162
  if image_tensors is not None:
163
  image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
164
 
 
165
  gen_kwargs = {
166
+ "max_new_tokens": max_new_tokens,
167
  "temperature": float(params.get("temperature", 0.0)),
168
  "top_p": float(params.get("top_p", 1.0)),
169
  "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
 
171
  "use_cache": bool(params.get("use_cache", True)),
172
  }
173
 
174
+ with torch.inference_mode():
175
  output_ids = self.model.generate(
176
  input_ids,
177
  images=image_tensors,
178
+ **gen_kwargs,
179
  )
180
 
 
181
  outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
182
  return [{"generated_text": outputs}]