Update handler.py
Browse files- handler.py +19 -20
handler.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
-
PULSE ECG Handler — Demo Parity
|
| 4 |
- Demo app.py ile aynı üretim ayarları:
|
| 5 |
do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
|
| 6 |
- Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
|
| 7 |
- Görsel tensörü: .half() ve model cihazında
|
| 8 |
- Streamer: TextIteratorStreamer (demo gibi), thread ile generate
|
| 9 |
- Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
import os
|
|
@@ -80,6 +81,13 @@ context_len = None
|
|
| 80 |
args = None
|
| 81 |
model_initialized = False
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# ===================== Utilities =====================
|
| 85 |
|
|
@@ -140,7 +148,6 @@ def _wrap_image_token_if_needed(model_cfg) -> bool:
|
|
| 140 |
except Exception:
|
| 141 |
return False
|
| 142 |
|
| 143 |
-
|
| 144 |
# ====== Güvenli Stop Kriteri (demo eşleniği) ======
|
| 145 |
class SafeKeywordsStoppingCriteria(StoppingCriteria):
|
| 146 |
"""
|
|
@@ -153,7 +160,6 @@ class SafeKeywordsStoppingCriteria(StoppingCriteria):
|
|
| 153 |
self.kw_ids = tok # shape: (n,)
|
| 154 |
|
| 155 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 156 |
-
# input_ids: (bsz, seq_len)
|
| 157 |
if input_ids is None or input_ids.shape[0] == 0:
|
| 158 |
return False
|
| 159 |
out = input_ids[0] # assume bsz=1
|
|
@@ -161,18 +167,15 @@ class SafeKeywordsStoppingCriteria(StoppingCriteria):
|
|
| 161 |
if out.shape[0] < n:
|
| 162 |
return False
|
| 163 |
tail = out[-n:]
|
| 164 |
-
# cihaz hizası
|
| 165 |
kw = self.kw_ids.to(tail.device)
|
| 166 |
return torch.equal(tail, kw)
|
| 167 |
|
| 168 |
-
|
| 169 |
# ===================== Core Generation =====================
|
| 170 |
|
| 171 |
def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
|
| 172 |
# demo gibi: <image> + text (IM_START/END gerekiyorsa sar)
|
| 173 |
use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
|
| 174 |
if use_wrap:
|
| 175 |
-
# <im_start><image><im_end>\n + user text
|
| 176 |
inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
|
| 177 |
else:
|
| 178 |
inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
|
|
@@ -195,7 +198,6 @@ def generate_response(
|
|
| 195 |
max_new_tokens: Optional[int] = None,
|
| 196 |
conv_mode_override: Optional[str] = None,
|
| 197 |
repetition_penalty: Optional[float] = None, # demo'da yok; verilirse 1.0 yaparız
|
| 198 |
-
# NOT: no_repeat_ngram_size / min_new_tokens / custom_stop KULLANMIYORUZ → demo-parite
|
| 199 |
det_seed: Optional[int] = None, # seed gönderilmezse stokastik (demo gibi)
|
| 200 |
):
|
| 201 |
if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
|
|
@@ -237,30 +239,31 @@ def generate_response(
|
|
| 237 |
|
| 238 |
# Cihaz/dtype
|
| 239 |
device = next(chatbot.model.parameters()).device
|
| 240 |
-
# demo
|
| 241 |
-
dtype = torch.float16
|
| 242 |
|
| 243 |
# Görüntü ön-işleme → tensör
|
| 244 |
try:
|
| 245 |
processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
|
| 246 |
-
# LLaVA genelde list döndürür
|
| 247 |
if isinstance(processed, (list, tuple)) and len(processed) > 0:
|
| 248 |
image_tensor = processed[0]
|
| 249 |
elif isinstance(processed, torch.Tensor):
|
| 250 |
-
image_tensor = processed[0] if processed.ndim == 4 else processed
|
| 251 |
else:
|
| 252 |
return {"error": "Image processing returned empty"}
|
| 253 |
if image_tensor.ndim == 3:
|
| 254 |
image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
|
| 255 |
-
# demo: half + device
|
| 256 |
-
image_tensor = image_tensor.to(device=device, dtype=dtype)
|
| 257 |
except Exception as e:
|
| 258 |
return {"error": f"Image processing failed: {e}"}
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
# Prompt & input ids
|
| 261 |
_, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
|
| 262 |
|
| 263 |
-
# Stop string
|
| 264 |
stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
|
| 265 |
stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
|
| 266 |
|
|
@@ -324,11 +327,10 @@ def generate_response(
|
|
| 324 |
|
| 325 |
return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
|
| 326 |
|
| 327 |
-
|
| 328 |
# ===================== Public API =====================
|
| 329 |
|
| 330 |
def query(payload: dict):
|
| 331 |
-
"""HF Endpoint entry (demo parity)."""
|
| 332 |
global model_initialized, tokenizer, model, image_processor, context_len, args
|
| 333 |
if not model_initialized:
|
| 334 |
if not initialize_model():
|
|
@@ -384,7 +386,6 @@ def get_model_info():
|
|
| 384 |
"device": str(next(model.parameters()).device) if model else "Unknown",
|
| 385 |
}
|
| 386 |
|
| 387 |
-
|
| 388 |
# ===================== Init & Session =====================
|
| 389 |
|
| 390 |
class _Args:
|
|
@@ -448,7 +449,6 @@ def initialize_model():
|
|
| 448 |
model_ = model_.to(torch.device("cuda"))
|
| 449 |
model_.eval()
|
| 450 |
|
| 451 |
-
# assign globals
|
| 452 |
globals()["tokenizer"] = tokenizer_
|
| 453 |
globals()["model"] = model_
|
| 454 |
globals()["image_processor"] = image_processor_
|
|
@@ -461,7 +461,6 @@ def initialize_model():
|
|
| 461 |
print(f"[init] failed: {e}")
|
| 462 |
return False
|
| 463 |
|
| 464 |
-
|
| 465 |
# ===================== HF EndpointHandler =====================
|
| 466 |
|
| 467 |
class EndpointHandler:
|
|
@@ -479,4 +478,4 @@ class EndpointHandler:
|
|
| 479 |
return get_model_info()
|
| 480 |
|
| 481 |
if __name__ == "__main__":
|
| 482 |
-
print("Handler ready (Demo Parity
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
+
PULSE ECG Handler — Demo Parity + Style Hint
|
| 4 |
- Demo app.py ile aynı üretim ayarları:
|
| 5 |
do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
|
| 6 |
- Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
|
| 7 |
- Görsel tensörü: .half() ve model cihazında
|
| 8 |
- Streamer: TextIteratorStreamer (demo gibi), thread ile generate
|
| 9 |
- Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
|
| 10 |
+
- STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) yaklaşmak için
|
| 11 |
"""
|
| 12 |
|
| 13 |
import os
|
|
|
|
| 81 |
args = None
|
| 82 |
model_initialized = False
|
| 83 |
|
| 84 |
+
# ====== Demo üslubuna yönlendiren stil ipucu ======
|
| 85 |
+
STYLE_HINT = (
|
| 86 |
+
"Write a concise diagnostic narrative as in a cardiology read: "
|
| 87 |
+
"use 2–3 short paragraphs describing rhythm, rate, axis, chamber enlargement, conduction, QRS, ST–T, QT; "
|
| 88 |
+
"then finish with a single final line starting exactly with 'Structured clinical impression:'. "
|
| 89 |
+
"Do not include recommendations, prognosis, follow-up, or risk counseling. No emojis or bullet points."
|
| 90 |
+
)
|
| 91 |
|
| 92 |
# ===================== Utilities =====================
|
| 93 |
|
|
|
|
| 148 |
except Exception:
|
| 149 |
return False
|
| 150 |
|
|
|
|
| 151 |
# ====== Güvenli Stop Kriteri (demo eşleniği) ======
|
| 152 |
class SafeKeywordsStoppingCriteria(StoppingCriteria):
|
| 153 |
"""
|
|
|
|
| 160 |
self.kw_ids = tok # shape: (n,)
|
| 161 |
|
| 162 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
| 163 |
if input_ids is None or input_ids.shape[0] == 0:
|
| 164 |
return False
|
| 165 |
out = input_ids[0] # assume bsz=1
|
|
|
|
| 167 |
if out.shape[0] < n:
|
| 168 |
return False
|
| 169 |
tail = out[-n:]
|
|
|
|
| 170 |
kw = self.kw_ids.to(tail.device)
|
| 171 |
return torch.equal(tail, kw)
|
| 172 |
|
|
|
|
| 173 |
# ===================== Core Generation =====================
|
| 174 |
|
| 175 |
def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
|
| 176 |
# demo gibi: <image> + text (IM_START/END gerekiyorsa sar)
|
| 177 |
use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
|
| 178 |
if use_wrap:
|
|
|
|
| 179 |
inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
|
| 180 |
else:
|
| 181 |
inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
|
|
|
|
| 198 |
max_new_tokens: Optional[int] = None,
|
| 199 |
conv_mode_override: Optional[str] = None,
|
| 200 |
repetition_penalty: Optional[float] = None, # demo'da yok; verilirse 1.0 yaparız
|
|
|
|
| 201 |
det_seed: Optional[int] = None, # seed gönderilmezse stokastik (demo gibi)
|
| 202 |
):
|
| 203 |
if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
|
|
|
|
| 239 |
|
| 240 |
# Cihaz/dtype
|
| 241 |
device = next(chatbot.model.parameters()).device
|
| 242 |
+
dtype = torch.float16 # demo: half
|
|
|
|
| 243 |
|
| 244 |
# Görüntü ön-işleme → tensör
|
| 245 |
try:
|
| 246 |
processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
|
|
|
|
| 247 |
if isinstance(processed, (list, tuple)) and len(processed) > 0:
|
| 248 |
image_tensor = processed[0]
|
| 249 |
elif isinstance(processed, torch.Tensor):
|
| 250 |
+
image_tensor = processed[0] if processed.ndim == 4 else processed
|
| 251 |
else:
|
| 252 |
return {"error": "Image processing returned empty"}
|
| 253 |
if image_tensor.ndim == 3:
|
| 254 |
image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
|
| 255 |
+
image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device
|
|
|
|
| 256 |
except Exception as e:
|
| 257 |
return {"error": f"Image processing failed: {e}"}
|
| 258 |
|
| 259 |
+
# --------- STIL İPUCU EKLEME ---------
|
| 260 |
+
message_text = (message_text or "").strip() + "\n\n" + STYLE_HINT
|
| 261 |
+
# -------------------------------------
|
| 262 |
+
|
| 263 |
# Prompt & input ids
|
| 264 |
_, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
|
| 265 |
|
| 266 |
+
# Stop string (conv separator) → güvenli kriter
|
| 267 |
stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
|
| 268 |
stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
|
| 269 |
|
|
|
|
| 327 |
|
| 328 |
return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
|
| 329 |
|
|
|
|
| 330 |
# ===================== Public API =====================
|
| 331 |
|
| 332 |
def query(payload: dict):
|
| 333 |
+
"""HF Endpoint entry (demo parity + style hint)."""
|
| 334 |
global model_initialized, tokenizer, model, image_processor, context_len, args
|
| 335 |
if not model_initialized:
|
| 336 |
if not initialize_model():
|
|
|
|
| 386 |
"device": str(next(model.parameters()).device) if model else "Unknown",
|
| 387 |
}
|
| 388 |
|
|
|
|
| 389 |
# ===================== Init & Session =====================
|
| 390 |
|
| 391 |
class _Args:
|
|
|
|
| 449 |
model_ = model_.to(torch.device("cuda"))
|
| 450 |
model_.eval()
|
| 451 |
|
|
|
|
| 452 |
globals()["tokenizer"] = tokenizer_
|
| 453 |
globals()["model"] = model_
|
| 454 |
globals()["image_processor"] = image_processor_
|
|
|
|
| 461 |
print(f"[init] failed: {e}")
|
| 462 |
return False
|
| 463 |
|
|
|
|
| 464 |
# ===================== HF EndpointHandler =====================
|
| 465 |
|
| 466 |
class EndpointHandler:
|
|
|
|
| 478 |
return get_model_info()
|
| 479 |
|
| 480 |
if __name__ == "__main__":
|
| 481 |
+
print("Handler ready (Demo Parity + Style Hint). Use `EndpointHandler` or `query`.")
|