Rapid_ECG / handler.py
CanerDedeoglu's picture
Update handler.py
05ae2ff verified
raw
history blame
18 kB
# -*- coding: utf-8 -*-
"""
PULSE ECG Handler (demo-like streaming)
- TextIteratorStreamer + skip_prompt=True (dilimleme yok; Step 1 korunur)
- do_sample=True (demo davranışı), temperature/top_p payload'dan
- Opsiyonel: no_stop, custom_stop, no_repeat_ngram_size, min_new_tokens
- IM_START/END otomatik; 3D/4D/5D görüntü tensörü uyumlu; device/dtype eşleştirme
"""
import os
import json
import base64
import hashlib
import datetime
from io import BytesIO
from threading import Thread
from typing import Optional, List
import torch
from PIL import Image
import requests
# --- LLaVA / Transformers ---
try:
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.mm_utils import (
tokenizer_image_token,
process_images,
get_model_name_from_path,
KeywordsStoppingCriteria,
)
from llava.utils import disable_torch_init
LLAVA_AVAILABLE = True
except Exception as e:
LLAVA_AVAILABLE = False
print(f"[WARN] LLaVA modules not available: {e}")
try:
from transformers import TextIteratorStreamer
TRANSFORMERS_AVAILABLE = True
except Exception as e:
TRANSFORMERS_AVAILABLE = False
print(f"[WARN] transformers not available: {e}")
# --- HF Hub (opsiyonel logging) ---
try:
from huggingface_hub import HfApi, login
HF_HUB_AVAILABLE = True
except Exception:
HF_HUB_AVAILABLE = False
api = None
repo_name = ""
if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
try:
login(token=os.environ["HF_TOKEN"], write_permission=True)
api = HfApi()
repo_name = os.environ.get("LOG_REPO", "")
except Exception as e:
print(f"[HF Hub] init failed: {e}")
api = None
repo_name = ""
LOGDIR = "./logs"
os.makedirs(LOGDIR, exist_ok=True)
# --- Global Model State ---
tokenizer = None
model = None
image_processor = None
context_len = None
args = None
model_initialized = False
# ----------------- Utilities -----------------
def _safe_upload(path: str):
if api and repo_name and path and os.path.isfile(path):
try:
api.upload_file(
path_or_fileobj=path,
path_in_repo=path.replace("./logs/", ""),
repo_id=repo_name,
repo_type="dataset",
)
except Exception as e:
print(f"[upload] failed for {path}: {e}")
def _conv_log_path():
t = datetime.datetime.now()
p = os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
os.makedirs(os.path.dirname(p), exist_ok=True)
return p
def load_image_any(image_input):
"""
Desteklenen:
- URL (http/https)
- Yerel dosya yolu
- base64 (opsiyonel data URL prefix ile)
"""
if isinstance(image_input, str):
s = image_input.strip()
if s.startswith(("http://", "https://")):
r = requests.get(s, timeout=(5, 15))
r.raise_for_status()
return Image.open(BytesIO(r.content)).convert("RGB")
if os.path.exists(s):
return Image.open(s).convert("RGB")
# base64
if s.startswith("data:image"):
s = s.split(",", 1)[1]
raw = base64.b64decode(s)
return Image.open(BytesIO(raw)).convert("RGB")
elif isinstance(image_input, dict) and "image" in image_input:
return load_image_any(image_input["image"])
else:
raise ValueError("Unsupported image input format")
def _guess_conv_mode(model_path: str) -> str:
name = get_model_name_from_path(model_path).lower()
if "llama-2" in name: return "llava_llama_2"
if "v1" in name or "pulse" in name: return "llava_v1"
if "mpt" in name: return "mpt"
if "qwen" in name: return "qwen_1_5"
return "llava_v0"
def _wrap_image_token_if_needed(model_cfg) -> bool:
try:
return bool(getattr(model_cfg, "mm_use_im_start_end", False))
except Exception:
return False
def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
if use_wrap:
# <im_start><image><im_end>\n + user text
inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
else:
inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
prompt = chatbot.conversation.get_prompt()
input_ids = tokenizer_image_token(
prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0).to(device)
return prompt, input_ids
def _stopping_keywords(chatbot, input_ids, extra: Optional[List[str]] = None):
conv = chatbot.conversation
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keys = [stop_str]
if extra:
keys.extend([k for k in extra if isinstance(k, str) and k.strip()])
return KeywordsStoppingCriteria(keys, chatbot.tokenizer, input_ids)
# ----------------- Core Generation -----------------
def generate_response(
message_text: str,
image_input,
*,
max_new_tokens: int = 1800,
min_new_tokens: Optional[int] = None,
temperature: float = 0.20,
top_p: float = 0.95,
repetition_penalty: float = 1.20,
no_repeat_ngram_size: Optional[int] = 6,
conv_mode_override: Optional[str] = None,
det_seed: Optional[int] = None,
no_stop: bool = False,
custom_stop: Optional[List[str]] = None,
):
if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
return {"error": "Required libraries not available (llava/transformers)"}
if not message_text or image_input is None:
return {"error": "Both 'message' and 'image' are required"}
# Chat session (fresh conv each call, demo-like)
chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
if conv_mode_override and conv_mode_override in conv_templates:
chatbot.conversation = conv_templates[conv_mode_override].copy()
else:
chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
# Load image
try:
pil_img = load_image_any(image_input)
except Exception as e:
return {"error": f"Failed to load image: {e}"}
# Save image to logs (optional)
img_hash, img_path = "NA", None
try:
buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
img_hash = hashlib.md5(raw).hexdigest()
t = datetime.datetime.now()
img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
os.makedirs(os.path.dirname(img_path), exist_ok=True)
if not os.path.isfile(img_path):
pil_img.save(img_path)
except Exception as e:
print(f"[log] saving image failed: {e}")
# To device/dtype
device = next(chatbot.model.parameters()).device
dtype = next(chatbot.model.parameters()).dtype
# Preprocess image -> tensor (support 3D/4D/5D)
try:
processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
if isinstance(processed, torch.Tensor):
if processed.ndim == 3: image_tensor = processed.unsqueeze(0)
elif processed.ndim == 4: image_tensor = processed
elif processed.ndim == 5: # (B,T,C,H,W) -> (B*T,C,H,W)
b,t,c,h,w = processed.shape
image_tensor = processed.reshape(b*t, c, h, w)
else:
return {"error": f"Unexpected image tensor shape: {tuple(processed.shape)}"}
elif isinstance(processed, (list, tuple)) and len(processed) > 0:
first = processed[0]
image_tensor = first.unsqueeze(0) if isinstance(first, torch.Tensor) and first.ndim == 3 else first
else:
return {"error": "Image processing returned empty"}
image_tensor = image_tensor.to(device=device, dtype=dtype)
except Exception as e:
return {"error": f"Image processing failed: {e}"}
# Prompt & ids
_, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
# Stopping criteria
stopping = None if no_stop else _stopping_keywords(chatbot, input_ids, custom_stop)
eos_id = chatbot.tokenizer.eos_token_id
pad_id = chatbot.tokenizer.pad_token_id if chatbot.tokenizer.pad_token_id is not None else (eos_id if eos_id is not None else 0)
eos_for_gen = None if no_stop else eos_id
# Deterministic sampling (optional)
if det_seed is not None:
try:
det_seed = int(det_seed)
torch.manual_seed(det_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(det_seed)
torch.cuda.manual_seed_all(det_seed)
except Exception:
pass
# Streamer (demo-like, avoids manual slicing)
streamer = TextIteratorStreamer(
chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
inputs=input_ids,
images=image_tensor,
streamer=streamer,
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
max_new_tokens=int(max_new_tokens),
use_cache=False,
pad_token_id=pad_id,
eos_token_id=eos_for_gen,
length_penalty=1.0,
early_stopping=False,
stopping_criteria=None if no_stop else ([stopping] if stopping else None),
)
if no_repeat_ngram_size:
try:
n = int(no_repeat_ngram_size)
if n > 0:
gen_kwargs["no_repeat_ngram_size"] = n
except Exception:
pass
if min_new_tokens is not None:
try:
mn = int(min_new_tokens)
if 1 <= mn <= int(max_new_tokens):
gen_kwargs["min_new_tokens"] = mn
except Exception:
pass
# Generate in a background thread; collect streamed tokens
try:
t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
t.start()
chunks = []
for piece in streamer:
chunks.append(piece)
text = "".join(chunks)
chatbot.conversation.messages[-1][-1] = text
except Exception as e:
return {"error": f"Generation failed: {e}"}
# Log
try:
row = {
"time": datetime.datetime.now().isoformat(),
"type": "chat",
"model": "PULSE-7B",
"state": [(message_text, text)],
"image_hash": img_hash,
"image_path": img_path or "",
}
with open(_conv_log_path(), "a", encoding="utf-8") as f:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
_safe_upload(_conv_log_path()); _safe_upload(img_path or "")
except Exception as e:
print(f"[log] failed: {e}")
return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
# ----------------- Public API -----------------
def query(payload: dict):
"""HF Endpoint entry (demo-like streaming)"""
global model_initialized, tokenizer, model, image_processor, context_len, args
if not model_initialized:
if not initialize_model():
return {"error": "Model initialization failed"}
model_initialized = True
try:
message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
image = payload.get("image") or payload.get("image_url") or payload.get("img") or None
if not message.strip(): return {"error": "Missing 'message' text"}
if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
# Demo-like knobs
max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1800))))
min_new_tokens = payload.get("min_new_tokens", None)
if min_new_tokens is not None:
try: min_new_tokens = int(min_new_tokens)
except Exception: min_new_tokens = None
temperature = float(payload.get("temperature", 0.20))
top_p = float(payload.get("top_p", 0.95))
repetition_penalty = float(payload.get("repetition_penalty", 1.20))
no_repeat_ngram = payload.get("no_repeat_ngram_size", 6)
try:
no_repeat_ngram = int(no_repeat_ngram) if no_repeat_ngram is not None else None
except Exception:
no_repeat_ngram = None
conv_mode_override = payload.get("conv_mode", None)
det_seed = payload.get("det_seed", None)
if det_seed is not None:
try: det_seed = int(det_seed)
except Exception: det_seed = None
no_stop = bool(payload.get("no_stop", False))
custom_stop = payload.get("custom_stop", None)
return generate_response(
message_text=message,
image_input=image,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram,
conv_mode_override=conv_mode_override,
det_seed=det_seed,
no_stop=no_stop,
custom_stop=custom_stop,
)
except Exception as e:
return {"error": f"Query failed: {e}"}
def health_check():
return {
"status": "healthy",
"model_initialized": model_initialized,
"cuda_available": torch.cuda.is_available(),
"llava_available": LLAVA_AVAILABLE,
"transformers_available": TRANSFORMERS_AVAILABLE,
}
def get_model_info():
if not model_initialized:
return {"error": "Model not initialized"}
return {
"model_path": args.model_path if args else "Unknown",
"context_len": context_len,
"device": str(next(model.parameters()).device) if model else "Unknown",
}
# ----------------- Init & Session -----------------
class _Args:
def __init__(self):
self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
self.model_base = None
self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
self.conv_mode = None
self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "1800"))
self.num_frames = 16
self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
self.debug = bool(int(os.getenv("DEBUG", "0")))
class InferenceDemo:
def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
if not LLAVA_AVAILABLE:
raise ImportError("LLaVA modules not available")
disable_torch_init()
self.tokenizer, self.model, self.image_processor, self.context_len = (
tokenizer_, model_, image_processor_, context_len_
)
conv_mode_auto = _guess_conv_mode(model_path)
self.conv_mode = args.conv_mode if args.conv_mode else conv_mode_auto
args.conv_mode = self.conv_mode
self.conversation = conv_templates[self.conv_mode].copy()
self.num_frames = args.num_frames
class ChatSessionManager:
def __init__(self):
self.chatbot = None
self.args = None
self.model_path = None
def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len):
if self.chatbot is None:
self.args = args
self.model_path = model_path
self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
return self.chatbot
chat_manager = ChatSessionManager()
def initialize_model():
global tokenizer, model, image_processor, context_len, args
if not LLAVA_AVAILABLE:
print("[init] LLaVA not available; cannot init.")
return False
try:
args = _Args()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
)
try:
_ = next(model.parameters()).device
except Exception:
if torch.cuda.is_available():
model = model.to(torch.device("cuda"))
model.eval()
chat_manager.init_if_needed(args, args.model_path, tokenizer, model, image_processor, context_len)
print("[init] model/tokenizer/image_processor loaded.")
return True
except Exception as e:
print(f"[init] failed: {e}")
return False
# ----------------- HF EndpointHandler -----------------
class EndpointHandler:
"""Hugging Face Endpoint uyumlu sınıf"""
def __init__(self, model_dir):
self.model_dir = model_dir
print(f"EndpointHandler initialized with model_dir: {model_dir}")
def __call__(self, payload):
if "inputs" in payload:
return query(payload["inputs"])
return query(payload)
def health_check(self):
return health_check()
def get_model_info(self):
return get_model_info()
if __name__ == "__main__":
print("Handler ready. Use `EndpointHandler` or `query` for HF Inference Endpoints.")