CanerDedeoglu commited on
Commit
2783652
·
verified ·
1 Parent(s): 0bdd613

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +71 -91
handler.py CHANGED
@@ -1,11 +1,32 @@
1
-
2
- import os, io, base64
3
  from typing import Any, Dict, List, Optional
4
 
5
  import torch
6
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # --- LLaVA (demo) parçaları ---
 
 
9
  from llava.model.builder import load_pretrained_model, get_model_name_from_path
10
  from llava.mm_utils import tokenizer_image_token, process_images
11
  from llava.constants import (
@@ -17,60 +38,39 @@ from llava.constants import (
17
  from llava.conversation import conv_templates
18
  from llava.utils import disable_torch_init
19
 
20
-
21
- # =========================
22
- # Ortam / Varsayılanlar
23
- # =========================
24
- # 1) Yerelden yüklemek için (bu repository içi): boş bırakın veya HF_MODEL_LOCAL_DIR=/repository
25
- HF_MODEL_LOCAL_DIR = os.getenv("HF_MODEL_LOCAL_DIR", "").strip()
26
- # 2) Hub'dan yüklemek isterseniz: HF_MODEL_ID=org/name
27
- HF_MODEL_ID = os.getenv("HF_MODEL_ID", "").strip()
28
-
29
- # Demo ile aynı conv_mode
30
- DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2")
31
- # Güvenli varsayılan (çok büyük tutmayalım)
32
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
33
-
34
- # Flash-Attention zorunluluğunu kaldır, SDPA kullan
35
- os.environ.setdefault("ATTN_IMPLEMENTATION", "sdpa")
36
-
37
 
38
  class EndpointHandler:
39
  """
40
- Hugging Face Inference Toolkit tarafından çağrılan handler.
41
- Girdi şeması (demo ile uyumlu):
42
  {
43
  "inputs": { "query": "...", "image": "<url|dataurl|path>" },
44
- "parameters": {
45
- "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0,
46
- "repetition_penalty": 1.0, "do_sample": false, "use_cache": true
47
- },
48
  "conv_mode": "llava_v2" # opsiyonel
49
  }
50
- Dönüş:
51
- [ { "generated_text": "..." } ]
52
  """
53
-
54
  def __init__(self, path: str = "") -> None:
55
- # path -> /repository
56
  disable_torch_init()
57
 
58
- # Modelin yüklenme yolu seçimi
59
- if HF_MODEL_LOCAL_DIR:
60
- model_path = HF_MODEL_LOCAL_DIR
61
- elif HF_MODEL_ID:
62
- model_path = HF_MODEL_ID
63
  else:
64
- # Ağırlıklar bu repoda ise
65
- model_path = path
66
 
67
- # Model adı (LLaVA yardımcı)
68
  self.model_name = get_model_name_from_path(model_path)
69
 
70
- # LLaVA yüklemesi (demo ile aynı giriş noktası)
71
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
72
  model_path=model_path,
73
- model_base=None, # LoRA yoksa None
74
  model_name=self.model_name,
75
  torch_dtype="auto",
76
  attn_implementation=os.getenv("ATTN_IMPLEMENTATION", "sdpa"),
@@ -78,89 +78,73 @@ class EndpointHandler:
78
  )
79
  self.model.eval()
80
 
81
- # Görsel token işaretleri (model config'ine bağlı)
82
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
83
  self.image_token = DEFAULT_IMAGE_TOKEN
84
  self.im_start = DEFAULT_IM_START_TOKEN
85
  self.im_end = DEFAULT_IM_END_TOKEN
86
 
87
- # ---------------------------
88
- # Yardımcılar
89
- # ---------------------------
90
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
91
- """URL / data URL / yerel path -> PIL.Image"""
92
  if not img_field:
93
  return None
94
  try:
95
  if img_field.startswith("data:image"):
96
- head, b64 = img_field.split(",", 1)
97
  return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
98
- elif img_field.startswith("http://") or img_field.startswith("https://"):
99
- import requests
100
  r = requests.get(img_field, timeout=20)
101
  r.raise_for_status()
102
  return Image.open(io.BytesIO(r.content)).convert("RGB")
103
- else:
104
- return Image.open(img_field).convert("RGB")
105
  except Exception as e:
106
- # Görsel okunamadıysa açıklayıcı hata bırak
107
- raise RuntimeError(f"Image load failed: {e}") from e
 
108
 
109
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
110
- """Demodaki gibi conv_templates ile diyalog şablonu kur."""
111
- # Yanlış conv_mode gelirse default'a düş
112
  if conv_mode not in conv_templates:
113
  conv_mode = DEFAULT_CONV_MODE
114
-
115
  conv = conv_templates[conv_mode].copy()
116
  if self.use_im_start_end:
117
  content = f"{self.im_start}{self.image_token}{self.im_end}\n{user_text}"
118
  else:
119
  content = f"{self.image_token}\n{user_text}"
120
-
121
- conv.append_message(conv.roles[0], content) # user
122
- conv.append_message(conv.roles[1], None) # assistant (boş)
123
  return conv.get_prompt()
124
 
125
- # ---------------------------
126
- # Inference giriş noktası
127
- # ---------------------------
128
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
129
  inputs = data.get("inputs") or {}
130
  params = data.get("parameters") or {}
131
  conv_mode_req = data.get("conv_mode")
132
- conv_mode = conv_mode_req if conv_mode_req in conv_templates else DEFAULT_CONV_MODE
133
 
134
- query_text = inputs.get("query", "")
135
- image_f = inputs.get("image", "")
136
- pil_img = self._load_image(image_f) if image_f else None
137
 
138
- # 1) Prompt hazırla
139
  prompt = self._build_prompt(query_text, conv_mode)
140
 
141
- # 2) Görsel tensörü
142
  image_tensors = None
143
- if pil_img is not None:
144
- image_tensors = process_images([pil_img], self.image_processor, self.model.config)
 
 
 
145
 
146
- # 3) Tokenize (görüntü tokenını göm)
147
  input_ids = tokenizer_image_token(
148
- prompt,
149
- self.tokenizer,
150
- IMAGE_TOKEN_INDEX,
151
- return_tensors="pt",
152
- )
153
- input_ids = input_ids.to(self.model.device, non_blocking=True)
154
 
155
- # 4) context_len'e göre güvenli max_new_tokens
156
- requested_max_new = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF))
157
- # ufak tampon ile aşımı engelle
158
  avail = max(16, int(self.context_len) - int(input_ids.shape[-1]) - 8)
159
- max_new_tokens = max(1, min(requested_max_new, avail))
160
-
161
- # Görseli cihaza taşı
162
- if image_tensors is not None:
163
- image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
164
 
165
  gen_kwargs = {
166
  "max_new_tokens": max_new_tokens,
@@ -172,11 +156,7 @@ class EndpointHandler:
172
  }
173
 
174
  with torch.inference_mode():
175
- output_ids = self.model.generate(
176
- input_ids,
177
- images=image_tensors,
178
- **gen_kwargs,
179
- )
180
-
181
- outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
182
- return [{"generated_text": outputs}]
 
1
+ # -*- coding: utf-8 -*-
2
+ import os, io, sys, subprocess, base64
3
  from typing import Any, Dict, List, Optional
4
 
5
  import torch
6
  from PIL import Image
7
+ import requests
8
+
9
+ # ===== Kullanılacak HF model id =====
10
+ MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
11
+
12
+ # ===== LLaVA kaynak kodunu runtime'da getir (pip yok) =====
13
+ LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
14
+ LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1") # kanıtlı, stabil
15
+ LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/LLaVA")
16
+
17
+ def _ensure_llava():
18
+ if not os.path.isdir(LLAVA_SRC_DIR):
19
+ os.makedirs(os.path.dirname(LLAVA_SRC_DIR), exist_ok=True)
20
+ subprocess.run(
21
+ ["git", "clone", "--depth", "1", "--branch", LLAVA_GIT_REF, LLAVA_GIT_URL, LLAVA_SRC_DIR],
22
+ check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
23
+ )
24
+ if LLAVA_SRC_DIR not in sys.path:
25
+ sys.path.insert(0, LLAVA_SRC_DIR)
26
 
27
+ _ensure_llava()
28
+
29
+ # ---- LLaVA parçaları (demo akışı) ----
30
  from llava.model.builder import load_pretrained_model, get_model_name_from_path
31
  from llava.mm_utils import tokenizer_image_token, process_images
32
  from llava.constants import (
 
38
  from llava.conversation import conv_templates
39
  from llava.utils import disable_torch_init
40
 
41
+ # Varsayılanlar
42
+ DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2")
 
 
 
 
 
 
 
 
 
 
43
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
44
+ os.environ.setdefault("ATTN_IMPLEMENTATION", os.getenv("ATTN_IMPLEMENTATION", "sdpa"))
 
 
 
45
 
46
  class EndpointHandler:
47
  """
48
+ Girdi:
 
49
  {
50
  "inputs": { "query": "...", "image": "<url|dataurl|path>" },
51
+ "parameters": { "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0,
52
+ "repetition_penalty": 1.0, "do_sample": false, "use_cache": true },
 
 
53
  "conv_mode": "llava_v2" # opsiyonel
54
  }
55
+ Çıktı: [ { "generated_text": "..." } ]
 
56
  """
 
57
  def __init__(self, path: str = "") -> None:
 
58
  disable_torch_init()
59
 
60
+ # PULSE-7B HF’den/yerelden nereden yükleniyorsa yolu belirle
61
+ if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
62
+ model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
63
+ elif os.getenv("HF_MODEL_ID", "").strip():
64
+ model_path = os.getenv("HF_MODEL_ID").strip()
65
  else:
66
+ model_path = MODEL_ID # default: HF Hub PULSE-7B
 
67
 
 
68
  self.model_name = get_model_name_from_path(model_path)
69
 
70
+ # PULSE, LLaVA tabanlı olduğundan LLaVA loader ile yüklenir
71
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
72
  model_path=model_path,
73
+ model_base=None,
74
  model_name=self.model_name,
75
  torch_dtype="auto",
76
  attn_implementation=os.getenv("ATTN_IMPLEMENTATION", "sdpa"),
 
78
  )
79
  self.model.eval()
80
 
81
+ # Görsel token işaretleri (LLaVA config)
82
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
83
  self.image_token = DEFAULT_IMAGE_TOKEN
84
  self.im_start = DEFAULT_IM_START_TOKEN
85
  self.im_end = DEFAULT_IM_END_TOKEN
86
 
87
+ # ---- yardımcılar ----
 
 
88
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
89
+ """URL / base64 / path -> PIL.Image"""
90
  if not img_field:
91
  return None
92
  try:
93
  if img_field.startswith("data:image"):
94
+ _, b64 = img_field.split(",", 1)
95
  return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
96
+ if img_field.startswith(("http://", "https://")):
 
97
  r = requests.get(img_field, timeout=20)
98
  r.raise_for_status()
99
  return Image.open(io.BytesIO(r.content)).convert("RGB")
100
+ return Image.open(img_field).convert("RGB")
 
101
  except Exception as e:
102
+ # Görsel opsiyoneldir; okunamazsa kullanıcıya hata dönmek yerine None bırakabiliriz.
103
+ print(f"[warn] image load failed: {e}")
104
+ return None
105
 
106
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
 
 
107
  if conv_mode not in conv_templates:
108
  conv_mode = DEFAULT_CONV_MODE
 
109
  conv = conv_templates[conv_mode].copy()
110
  if self.use_im_start_end:
111
  content = f"{self.im_start}{self.image_token}{self.im_end}\n{user_text}"
112
  else:
113
  content = f"{self.image_token}\n{user_text}"
114
+ conv.append_message(conv.roles[0], content)
115
+ conv.append_message(conv.roles[1], None)
 
116
  return conv.get_prompt()
117
 
118
+ # ---- inference ----
 
 
119
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
120
  inputs = data.get("inputs") or {}
121
  params = data.get("parameters") or {}
122
  conv_mode_req = data.get("conv_mode")
 
123
 
124
+ conv_mode = conv_mode_req if conv_mode_req in conv_templates else DEFAULT_CONV_MODE
125
+ query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
126
+ image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
127
 
128
+ # 1) prompt
129
  prompt = self._build_prompt(query_text, conv_mode)
130
 
131
+ # 2) image -> tensor (opsiyonel)
132
  image_tensors = None
133
+ if image_f:
134
+ pil = self._load_image(image_f)
135
+ if pil is not None:
136
+ image_tensors = process_images([pil], self.image_processor, self.model.config)
137
+ image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
138
 
139
+ # 3) tokenize (image token’ı gömülü)
140
  input_ids = tokenizer_image_token(
141
+ prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
142
+ ).to(self.model.device, non_blocking=True)
 
 
 
 
143
 
144
+ # 4) güvenli max_new_tokens
145
+ requested = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF))
 
146
  avail = max(16, int(self.context_len) - int(input_ids.shape[-1]) - 8)
147
+ max_new_tokens = max(1, min(requested, avail))
 
 
 
 
148
 
149
  gen_kwargs = {
150
  "max_new_tokens": max_new_tokens,
 
156
  }
157
 
158
  with torch.inference_mode():
159
+ output_ids = self.model.generate(input_ids, images=image_tensors, **gen_kwargs)
160
+
161
+ text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
162
+ return [{"generated_text": text}]