CanerDedeoglu commited on
Commit
bd86c44
·
verified ·
1 Parent(s): d7ecbe4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -51
handler.py CHANGED
@@ -1,10 +1,10 @@
1
  # -*- coding: utf-8 -*-
2
- # handler.py — Rapid_ECG / PULSE-7B — Lokal model yükleme + DEBUG
3
- # - YEREL model dizininden yükleme (HF Hub yok)
4
- # - HF Endpoint sözleşmesi (EndpointHandler(model_dir).load().__call__(inputs))
5
  # - {"inputs": {...}} sarmalaması destekli
6
- # - Sadece preprocess kullanır (process_images yok)
7
- # - Her kritik adımda [DEBUG] çıktı
8
 
9
  import os
10
  import io
@@ -17,7 +17,7 @@ import torch
17
  from PIL import Image
18
  import requests
19
 
20
- # ===== LLaVA kurulumu (kütüphane) =====
21
  def _ensure_llava(tag: str = "v1.2.0"):
22
  try:
23
  import llava # noqa
@@ -39,6 +39,7 @@ from llava.constants import (
39
  DEFAULT_IMAGE_TOKEN,
40
  DEFAULT_IM_START_TOKEN,
41
  DEFAULT_IM_END_TOKEN,
 
42
  )
43
  from llava.model.builder import load_pretrained_model
44
  from llava.mm_utils import tokenizer_image_token
@@ -119,11 +120,6 @@ def _build_prompt_with_image(prompt: str, model_cfg) -> str:
119
  return f"{token}\n{prompt}"
120
  return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
121
 
122
- def _require_files(dir_path: str, fnames: list):
123
- missing = [f for f in fnames if not os.path.exists(os.path.join(dir_path, f))]
124
- if missing:
125
- raise FileNotFoundError(f"[ERROR] Missing files in {dir_path}: {missing}")
126
-
127
  # ---------- Endpoint Handler ----------
128
  class EndpointHandler:
129
  def __init__(self, model_dir: Optional[str] = None):
@@ -137,53 +133,38 @@ class EndpointHandler:
137
  self.dtype = _pick_dtype(self.device)
138
  self.model_name = None
139
 
140
- def _resolve_local_model_dir(self) -> str:
141
- # Öncelik: HF_MODEL_DIR env → self.model_dir → /repository
142
- local = _get_env("HF_MODEL_DIR", None) or self.model_dir or "/repository"
143
- local = os.path.abspath(local)
144
- print(f"[DEBUG] resolved local model dir: {local}")
145
- if not os.path.isdir(local):
146
- raise FileNotFoundError(f"[ERROR] Local model directory not found: {local}")
147
- return local
148
-
149
  def load(self):
150
- # Yerel dizini çöz
151
- model_path = self._resolve_local_model_dir()
152
- # Bu dosyalar repo kökünde olmalı (sende var):
153
- # - model.safetensors.index.json + shard'lar
154
- # - tokenizer.model / tokenizer_config.json / config.json
155
- _require_files(model_path, [
156
- "config.json",
157
- "tokenizer_config.json",
158
- "tokenizer.model",
159
- "model.safetensors.index.json",
160
- ])
161
 
162
  os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
163
  os.environ.setdefault("FLASH_ATTENTION", "1")
164
 
165
- print(f"[DEBUG] calling load_pretrained_model from local path: {model_path}")
166
- try:
167
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
168
- model_path=model_path,
169
- model_base=None, # tam birleştirilmiş ağırlıklar için base yok
170
- load_8bit=False,
171
- load_4bit=False,
172
- device_map="auto",
173
- device=self.device,
174
- )
175
- except Exception as e:
176
- raise RuntimeError(f"[ERROR] load_pretrained_model failed at {model_path}: {e}")
177
-
178
  self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
179
  print(f"[DEBUG] model loaded: name={self.model_name}")
180
 
181
- # Vision tower kontrolü (yanlış model ise burada yakalar)
182
- vt = getattr(self.model.config, "vision_tower", None)
 
 
 
 
183
  if self.image_processor is None or vt is None:
184
  raise RuntimeError(
185
- f"[ERROR] Vision tower not loaded (vision_tower={vt}). "
186
- f"Bu dizin multimodal (LLaVA/PULSE) bir model içermiyor gibi görünüyor: {model_path}"
 
187
  )
188
 
189
  # tokenizer güvenliği
@@ -242,10 +223,10 @@ class EndpointHandler:
242
  conv.append_message(conv.roles[1], None)
243
  full_prompt = conv.get_prompt()
244
 
245
- # ---- tokenization
246
  try:
247
  input_ids = tokenizer_image_token(
248
- full_prompt, self.tokenizer, image_token_index=-200, return_tensors="pt"
249
  ).unsqueeze(0).to(self.device)
250
  except Exception:
251
  toks = self.tokenizer([full_prompt], return_tensors="pt", padding=True, truncation=True)
@@ -270,7 +251,7 @@ class EndpointHandler:
270
  except Exception as e:
271
  return {"error": f"Generation failed: {e}"}
272
 
273
- # ---- decode
274
  new_tokens = gen_ids[0, input_ids.shape[1]:]
275
  text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
276
 
 
1
  # -*- coding: utf-8 -*-
2
+ # handler.py — Rapid_ECG / PULSE-7B — Stabil ve DEBUG'li sürüm (vision tower fix)
3
+ # - HuggingFace Endpoint uyumlu
4
+ # - Görsel sadece .preprocess() ile işlenir (process_images yok)
5
  # - {"inputs": {...}} sarmalaması destekli
6
+ # - Vision tower kontrolü: mm_vision_tower veya vision_tower
7
+ # - IMAGE_TOKEN_INDEX kullanımı
8
 
9
  import os
10
  import io
 
17
  from PIL import Image
18
  import requests
19
 
20
+ # ===== LLaVA kurulumu =====
21
  def _ensure_llava(tag: str = "v1.2.0"):
22
  try:
23
  import llava # noqa
 
39
  DEFAULT_IMAGE_TOKEN,
40
  DEFAULT_IM_START_TOKEN,
41
  DEFAULT_IM_END_TOKEN,
42
+ IMAGE_TOKEN_INDEX, # <-- sabit index
43
  )
44
  from llava.model.builder import load_pretrained_model
45
  from llava.mm_utils import tokenizer_image_token
 
120
  return f"{token}\n{prompt}"
121
  return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
122
 
 
 
 
 
 
123
  # ---------- Endpoint Handler ----------
124
  class EndpointHandler:
125
  def __init__(self, model_dir: Optional[str] = None):
 
133
  self.dtype = _pick_dtype(self.device)
134
  self.model_name = None
135
 
 
 
 
 
 
 
 
 
 
136
  def load(self):
137
+ # Uzaktan yüklemek için HF_MODEL_ID; yerel için HF_MODEL_DIR kullanabilirsiniz (ayrı mantık eklemek isterseniz)
138
+ model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
139
+ model_base = _get_env("HF_MODEL_BASE", None)
140
+ print(f"[DEBUG] load(): HF_MODEL_ID={model_path}, HF_MODEL_BASE={model_base}")
 
 
 
 
 
 
 
141
 
142
  os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
143
  os.environ.setdefault("FLASH_ATTENTION", "1")
144
 
145
+ print("[DEBUG] calling load_pretrained_model ...")
146
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
147
+ model_path=model_path,
148
+ model_base=model_base,
149
+ load_8bit=False,
150
+ load_4bit=False,
151
+ device_map="auto",
152
+ device=self.device,
153
+ )
 
 
 
 
154
  self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
155
  print(f"[DEBUG] model loaded: name={self.model_name}")
156
 
157
+ # ---- Vision tower kontrolü: mm_vision_tower veya vision_tower
158
+ vt = (
159
+ getattr(self.model.config, "mm_vision_tower", None)
160
+ or getattr(self.model.config, "vision_tower", None)
161
+ )
162
+ print(f"[DEBUG] vision tower: {vt}")
163
  if self.image_processor is None or vt is None:
164
  raise RuntimeError(
165
+ "[ERROR] Vision tower not loaded (mm_vision_tower/vision_tower None). "
166
+ "Bu model multimodal değil veya yanlış checkpoint yüklendi. "
167
+ "HF_MODEL_ID olarak PULSE/LLaVA tabanlı bir model verin (örn: 'PULSE-ECG/PULSE-7B')."
168
  )
169
 
170
  # tokenizer güvenliği
 
223
  conv.append_message(conv.roles[1], None)
224
  full_prompt = conv.get_prompt()
225
 
226
+ # ---- tokenization (IMAGE_TOKEN_INDEX ile)
227
  try:
228
  input_ids = tokenizer_image_token(
229
+ full_prompt, self.tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt"
230
  ).unsqueeze(0).to(self.device)
231
  except Exception:
232
  toks = self.tokenizer([full_prompt], return_tensors="pt", padding=True, truncation=True)
 
251
  except Exception as e:
252
  return {"error": f"Generation failed: {e}"}
253
 
254
+ # ---- decode (sadece yeni tokenlar)
255
  new_tokens = gen_ids[0, input_ids.shape[1]:]
256
  text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
257