# -*- coding: utf-8 -*- """ PULSE ECG Handler (demo-like streaming) - TextIteratorStreamer + skip_prompt=True (dilimleme yok; Step 1 korunur) - do_sample=True (demo davranışı), temperature/top_p payload'dan - Opsiyonel: no_stop, custom_stop, no_repeat_ngram_size, min_new_tokens - IM_START/END otomatik; 3D/4D/5D görüntü tensörü uyumlu; device/dtype eşleştirme """ import os import json import base64 import hashlib import datetime from io import BytesIO from threading import Thread from typing import Optional, List import torch from PIL import Image import requests # --- LLaVA / Transformers --- try: from llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.mm_utils import ( tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria, ) from llava.utils import disable_torch_init LLAVA_AVAILABLE = True except Exception as e: LLAVA_AVAILABLE = False print(f"[WARN] LLaVA modules not available: {e}") try: from transformers import TextIteratorStreamer TRANSFORMERS_AVAILABLE = True except Exception as e: TRANSFORMERS_AVAILABLE = False print(f"[WARN] transformers not available: {e}") # --- HF Hub (opsiyonel logging) --- try: from huggingface_hub import HfApi, login HF_HUB_AVAILABLE = True except Exception: HF_HUB_AVAILABLE = False api = None repo_name = "" if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: try: login(token=os.environ["HF_TOKEN"], write_permission=True) api = HfApi() repo_name = os.environ.get("LOG_REPO", "") except Exception as e: print(f"[HF Hub] init failed: {e}") api = None repo_name = "" LOGDIR = "./logs" os.makedirs(LOGDIR, exist_ok=True) # --- Global Model State --- tokenizer = None model = None image_processor = None context_len = None args = None model_initialized = False # ----------------- Utilities ----------------- def _safe_upload(path: str): if api and repo_name and path and os.path.isfile(path): try: api.upload_file( path_or_fileobj=path, path_in_repo=path.replace("./logs/", ""), repo_id=repo_name, repo_type="dataset", ) except Exception as e: print(f"[upload] failed for {path}: {e}") def _conv_log_path(): t = datetime.datetime.now() p = os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json") os.makedirs(os.path.dirname(p), exist_ok=True) return p def load_image_any(image_input): """ Desteklenen: - URL (http/https) - Yerel dosya yolu - base64 (opsiyonel data URL prefix ile) """ if isinstance(image_input, str): s = image_input.strip() if s.startswith(("http://", "https://")): r = requests.get(s, timeout=(5, 15)) r.raise_for_status() return Image.open(BytesIO(r.content)).convert("RGB") if os.path.exists(s): return Image.open(s).convert("RGB") # base64 if s.startswith("data:image"): s = s.split(",", 1)[1] raw = base64.b64decode(s) return Image.open(BytesIO(raw)).convert("RGB") elif isinstance(image_input, dict) and "image" in image_input: return load_image_any(image_input["image"]) else: raise ValueError("Unsupported image input format") def _guess_conv_mode(model_path: str) -> str: name = get_model_name_from_path(model_path).lower() if "llama-2" in name: return "llava_llama_2" if "v1" in name or "pulse" in name: return "llava_v1" if "mpt" in name: return "mpt" if "qwen" in name: return "qwen_1_5" return "llava_v0" def _wrap_image_token_if_needed(model_cfg) -> bool: try: return bool(getattr(model_cfg, "mm_use_im_start_end", False)) except Exception: return False def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device): use_wrap = _wrap_image_token_if_needed(chatbot.model.config) if use_wrap: # \n + user text inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}" else: inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}" chatbot.conversation.append_message(chatbot.conversation.roles[0], inp) chatbot.conversation.append_message(chatbot.conversation.roles[1], None) prompt = chatbot.conversation.get_prompt() input_ids = tokenizer_image_token( prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0).to(device) return prompt, input_ids def _stopping_keywords(chatbot, input_ids, extra: Optional[List[str]] = None): conv = chatbot.conversation stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keys = [stop_str] if extra: keys.extend([k for k in extra if isinstance(k, str) and k.strip()]) return KeywordsStoppingCriteria(keys, chatbot.tokenizer, input_ids) # ----------------- Core Generation ----------------- def generate_response( message_text: str, image_input, *, max_new_tokens: int = 1800, min_new_tokens: Optional[int] = None, temperature: float = 0.20, top_p: float = 0.95, repetition_penalty: float = 1.20, no_repeat_ngram_size: Optional[int] = 6, conv_mode_override: Optional[str] = None, det_seed: Optional[int] = None, no_stop: bool = False, custom_stop: Optional[List[str]] = None, ): if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE): return {"error": "Required libraries not available (llava/transformers)"} if not message_text or image_input is None: return {"error": "Both 'message' and 'image' are required"} # Chat session (fresh conv each call, demo-like) chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len) if conv_mode_override and conv_mode_override in conv_templates: chatbot.conversation = conv_templates[conv_mode_override].copy() else: chatbot.conversation = conv_templates[chatbot.conv_mode].copy() # Load image try: pil_img = load_image_any(image_input) except Exception as e: return {"error": f"Failed to load image: {e}"} # Save image to logs (optional) img_hash, img_path = "NA", None try: buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue() img_hash = hashlib.md5(raw).hexdigest() t = datetime.datetime.now() img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg") os.makedirs(os.path.dirname(img_path), exist_ok=True) if not os.path.isfile(img_path): pil_img.save(img_path) except Exception as e: print(f"[log] saving image failed: {e}") # To device/dtype device = next(chatbot.model.parameters()).device dtype = next(chatbot.model.parameters()).dtype # Preprocess image -> tensor (support 3D/4D/5D) try: processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config) if isinstance(processed, torch.Tensor): if processed.ndim == 3: image_tensor = processed.unsqueeze(0) elif processed.ndim == 4: image_tensor = processed elif processed.ndim == 5: # (B,T,C,H,W) -> (B*T,C,H,W) b,t,c,h,w = processed.shape image_tensor = processed.reshape(b*t, c, h, w) else: return {"error": f"Unexpected image tensor shape: {tuple(processed.shape)}"} elif isinstance(processed, (list, tuple)) and len(processed) > 0: first = processed[0] image_tensor = first.unsqueeze(0) if isinstance(first, torch.Tensor) and first.ndim == 3 else first else: return {"error": "Image processing returned empty"} image_tensor = image_tensor.to(device=device, dtype=dtype) except Exception as e: return {"error": f"Image processing failed: {e}"} # Prompt & ids _, input_ids = _build_prompt_and_ids(chatbot, message_text, device) # Stopping criteria stopping = None if no_stop else _stopping_keywords(chatbot, input_ids, custom_stop) eos_id = chatbot.tokenizer.eos_token_id pad_id = chatbot.tokenizer.pad_token_id if chatbot.tokenizer.pad_token_id is not None else (eos_id if eos_id is not None else 0) eos_for_gen = None if no_stop else eos_id # Deterministic sampling (optional) if det_seed is not None: try: det_seed = int(det_seed) torch.manual_seed(det_seed) if torch.cuda.is_available(): torch.cuda.manual_seed(det_seed) torch.cuda.manual_seed_all(det_seed) except Exception: pass # Streamer (demo-like, avoids manual slicing) streamer = TextIteratorStreamer( chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True ) gen_kwargs = dict( inputs=input_ids, images=image_tensor, streamer=streamer, do_sample=True, temperature=float(temperature), top_p=float(top_p), repetition_penalty=float(repetition_penalty), max_new_tokens=int(max_new_tokens), use_cache=False, pad_token_id=pad_id, eos_token_id=eos_for_gen, length_penalty=1.0, early_stopping=False, stopping_criteria=None if no_stop else ([stopping] if stopping else None), ) if no_repeat_ngram_size: try: n = int(no_repeat_ngram_size) if n > 0: gen_kwargs["no_repeat_ngram_size"] = n except Exception: pass if min_new_tokens is not None: try: mn = int(min_new_tokens) if 1 <= mn <= int(max_new_tokens): gen_kwargs["min_new_tokens"] = mn except Exception: pass # Generate in a background thread; collect streamed tokens try: t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs) t.start() chunks = [] for piece in streamer: chunks.append(piece) text = "".join(chunks) chatbot.conversation.messages[-1][-1] = text except Exception as e: return {"error": f"Generation failed: {e}"} # Log try: row = { "time": datetime.datetime.now().isoformat(), "type": "chat", "model": "PULSE-7B", "state": [(message_text, text)], "image_hash": img_hash, "image_path": img_path or "", } with open(_conv_log_path(), "a", encoding="utf-8") as f: f.write(json.dumps(row, ensure_ascii=False) + "\n") _safe_upload(_conv_log_path()); _safe_upload(img_path or "") except Exception as e: print(f"[log] failed: {e}") return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)} # ----------------- Public API ----------------- def query(payload: dict): """HF Endpoint entry (demo-like streaming)""" global model_initialized, tokenizer, model, image_processor, context_len, args if not model_initialized: if not initialize_model(): return {"error": "Model initialization failed"} model_initialized = True try: message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "" image = payload.get("image") or payload.get("image_url") or payload.get("img") or None if not message.strip(): return {"error": "Missing 'message' text"} if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."} # Demo-like knobs max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1800)))) min_new_tokens = payload.get("min_new_tokens", None) if min_new_tokens is not None: try: min_new_tokens = int(min_new_tokens) except Exception: min_new_tokens = None temperature = float(payload.get("temperature", 0.20)) top_p = float(payload.get("top_p", 0.95)) repetition_penalty = float(payload.get("repetition_penalty", 1.20)) no_repeat_ngram = payload.get("no_repeat_ngram_size", 6) try: no_repeat_ngram = int(no_repeat_ngram) if no_repeat_ngram is not None else None except Exception: no_repeat_ngram = None conv_mode_override = payload.get("conv_mode", None) det_seed = payload.get("det_seed", None) if det_seed is not None: try: det_seed = int(det_seed) except Exception: det_seed = None no_stop = bool(payload.get("no_stop", False)) custom_stop = payload.get("custom_stop", None) return generate_response( message_text=message, image_input=image, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram, conv_mode_override=conv_mode_override, det_seed=det_seed, no_stop=no_stop, custom_stop=custom_stop, ) except Exception as e: return {"error": f"Query failed: {e}"} def health_check(): return { "status": "healthy", "model_initialized": model_initialized, "cuda_available": torch.cuda.is_available(), "llava_available": LLAVA_AVAILABLE, "transformers_available": TRANSFORMERS_AVAILABLE, } def get_model_info(): if not model_initialized: return {"error": "Model not initialized"} return { "model_path": args.model_path if args else "Unknown", "context_len": context_len, "device": str(next(model.parameters()).device) if model else "Unknown", } # ----------------- Init & Session ----------------- class _Args: def __init__(self): self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B") self.model_base = None self.num_gpus = int(os.getenv("NUM_GPUS", "1")) self.conv_mode = None self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "1800")) self.num_frames = 16 self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0"))) self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0"))) self.debug = bool(int(os.getenv("DEBUG", "0"))) class InferenceDemo: def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_): if not LLAVA_AVAILABLE: raise ImportError("LLaVA modules not available") disable_torch_init() self.tokenizer, self.model, self.image_processor, self.context_len = ( tokenizer_, model_, image_processor_, context_len_ ) conv_mode_auto = _guess_conv_mode(model_path) self.conv_mode = args.conv_mode if args.conv_mode else conv_mode_auto args.conv_mode = self.conv_mode self.conversation = conv_templates[self.conv_mode].copy() self.num_frames = args.num_frames class ChatSessionManager: def __init__(self): self.chatbot = None self.args = None self.model_path = None def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len): if self.chatbot is None: self.args = args self.model_path = model_path self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len) return self.chatbot chat_manager = ChatSessionManager() def initialize_model(): global tokenizer, model, image_processor, context_len, args if not LLAVA_AVAILABLE: print("[init] LLaVA not available; cannot init.") return False try: args = _Args() model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit ) try: _ = next(model.parameters()).device except Exception: if torch.cuda.is_available(): model = model.to(torch.device("cuda")) model.eval() chat_manager.init_if_needed(args, args.model_path, tokenizer, model, image_processor, context_len) print("[init] model/tokenizer/image_processor loaded.") return True except Exception as e: print(f"[init] failed: {e}") return False # ----------------- HF EndpointHandler ----------------- class EndpointHandler: """Hugging Face Endpoint uyumlu sınıf""" def __init__(self, model_dir): self.model_dir = model_dir print(f"EndpointHandler initialized with model_dir: {model_dir}") def __call__(self, payload): if "inputs" in payload: return query(payload["inputs"]) return query(payload) def health_check(self): return health_check() def get_model_info(self): return get_model_info() if __name__ == "__main__": print("Handler ready. Use `EndpointHandler` or `query` for HF Inference Endpoints.")