# -*- coding: utf-8 -*- """ PULSE ECG Handler — Demo Parity + Style Hint - Demo app.py ile aynı üretim ayarları: do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096 - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter - Görsel tensörü: .half() ve model cihazında - Streamer: TextIteratorStreamer (demo gibi), thread ile generate - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) yaklaşmak için """ import os import json import base64 import hashlib import datetime from io import BytesIO from threading import Thread from typing import Optional, Union import torch from PIL import Image import requests # ====== LLaVA & Transformers ====== try: from llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.mm_utils import ( tokenizer_image_token, process_images, get_model_name_from_path, ) from llava.utils import disable_torch_init LLAVA_AVAILABLE = True except Exception as e: LLAVA_AVAILABLE = False print(f"[WARN] LLaVA not available: {e}") try: from transformers import TextIteratorStreamer, StoppingCriteria TRANSFORMERS_AVAILABLE = True except Exception as e: TRANSFORMERS_AVAILABLE = False print(f"[WARN] transformers not available: {e}") # ====== HF Hub logging (opsiyonel) ====== try: from huggingface_hub import HfApi, login HF_HUB_AVAILABLE = True except Exception: HF_HUB_AVAILABLE = False api = None repo_name = "" if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: try: login(token=os.environ["HF_TOKEN"], write_permission=True) api = HfApi() repo_name = os.environ.get("LOG_REPO", "") except Exception as e: print(f"[HF Hub] init failed: {e}") api = None repo_name = "" LOGDIR = "./logs" os.makedirs(LOGDIR, exist_ok=True) # ====== Global State ====== tokenizer = None model = None image_processor = None context_len = None args = None model_initialized = False # ====== Demo üslubuna yönlendiren stil ipucu ====== STYLE_HINT = ( "Write a concise diagnostic narrative as in a cardiology read: " "use 2–3 short paragraphs describing rhythm, rate, axis, chamber enlargement, conduction, QRS, ST–T, QT; " "then finish with a single final line starting exactly with 'Structured clinical impression:'. " "Do not include recommendations, prognosis, follow-up, or risk counseling. No emojis or bullet points." ) # ===================== Utilities ===================== def _safe_upload(path: str): if api and repo_name and path and os.path.isfile(path): try: api.upload_file( path_or_fileobj=path, path_in_repo=path.replace("./logs/", ""), repo_id=repo_name, repo_type="dataset", ) except Exception as e: print(f"[upload] failed for {path}: {e}") def _conv_log_path() -> str: t = datetime.datetime.now() return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json") def load_image_any(image_input: Union[str, dict]) -> Image.Image: """ Desteklenen: - URL (http/https) - yerel dosya yolu - base64 (opsiyonel data URL prefix ile) - {"image": } """ if isinstance(image_input, str): s = image_input.strip() if s.startswith(("http://", "https://")): r = requests.get(s, timeout=(5, 20)) r.raise_for_status() return Image.open(BytesIO(r.content)).convert("RGB") if os.path.exists(s): return Image.open(s).convert("RGB") # base64 (dataurl olabilir) if s.startswith("data:image"): s = s.split(",", 1)[1] raw = base64.b64decode(s) return Image.open(BytesIO(raw)).convert("RGB") if isinstance(image_input, dict) and "image" in image_input: return load_image_any(image_input["image"]) raise ValueError("Unsupported image input format") def _guess_conv_mode(model_path: str) -> str: name = get_model_name_from_path(model_path).lower() if "llama-2" in name: return "llava_llama_2" if "v1" in name or "pulse" in name: return "llava_v1" if "mpt" in name: return "mpt" if "qwen" in name: return "qwen_1_5" return "llava_v0" def _wrap_image_token_if_needed(model_cfg) -> bool: try: return bool(getattr(model_cfg, "mm_use_im_start_end", False)) except Exception: return False # ====== Güvenli Stop Kriteri (demo eşleniği) ====== class SafeKeywordsStoppingCriteria(StoppingCriteria): """ LLaVA'nın KeywordsStoppingCriteria'sına karşılık, token bazlı anahtar dizi (separator) eşleşmesi; tensör → bool hatası yok. """ def __init__(self, keyword: str, tokenizer): self.tokenizer = tokenizer tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0] self.kw_ids = tok # shape: (n,) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if input_ids is None or input_ids.shape[0] == 0: return False out = input_ids[0] # assume bsz=1 n = self.kw_ids.shape[0] if out.shape[0] < n: return False tail = out[-n:] kw = self.kw_ids.to(tail.device) return torch.equal(tail, kw) # ===================== Core Generation ===================== def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device): # demo gibi: + text (IM_START/END gerekiyorsa sar) use_wrap = _wrap_image_token_if_needed(chatbot.model.config) if use_wrap: inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}" else: inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}" chatbot.conversation.append_message(chatbot.conversation.roles[0], inp) chatbot.conversation.append_message(chatbot.conversation.roles[1], None) prompt = chatbot.conversation.get_prompt() input_ids = tokenizer_image_token( prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0).to(device) return prompt, input_ids def generate_response( message_text: str, image_input, *, temperature: Optional[float] = None, top_p: Optional[float] = None, max_new_tokens: Optional[int] = None, conv_mode_override: Optional[str] = None, repetition_penalty: Optional[float] = None, # demo'da yok; verilirse 1.0 yaparız det_seed: Optional[int] = None, # seed gönderilmezse stokastik (demo gibi) ): if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE): return {"error": "Required libraries not available (llava/transformers)"} if not message_text or image_input is None: return {"error": "Both 'message' and 'image' are required"} # Varsayılanlar → demo if temperature is None: temperature = 0.05 if top_p is None: top_p = 1.0 if max_new_tokens is None: max_new_tokens = 4096 if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz # Chat session: her çağrıda taze template chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len) if conv_mode_override and conv_mode_override in conv_templates: chatbot.conversation = conv_templates[conv_mode_override].copy() else: chatbot.conversation = conv_templates[chatbot.conv_mode].copy() # Görüntü yükle try: pil_img = load_image_any(image_input) except Exception as e: return {"error": f"Failed to load image: {e}"} # Log için hash+path img_hash, img_path = "NA", None try: buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue() img_hash = hashlib.md5(raw).hexdigest() t = datetime.datetime.now() img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg") os.makedirs(os.path.dirname(img_path), exist_ok=True) if not os.path.isfile(img_path): pil_img.save(img_path) except Exception as e: print(f"[log] save image failed: {e}") # Cihaz/dtype device = next(chatbot.model.parameters()).device dtype = torch.float16 # demo: half # Görüntü ön-işleme → tensör try: processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config) if isinstance(processed, (list, tuple)) and len(processed) > 0: image_tensor = processed[0] elif isinstance(processed, torch.Tensor): image_tensor = processed[0] if processed.ndim == 4 else processed else: return {"error": "Image processing returned empty"} if image_tensor.ndim == 3: image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W) image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device except Exception as e: return {"error": f"Image processing failed: {e}"} # --------- STIL İPUCU EKLEME --------- message_text = (message_text or "").strip() + "\n\n" + STYLE_HINT # ------------------------------------- # Prompt & input ids _, input_ids = _build_prompt_and_ids(chatbot, message_text, device) # Stop string (conv separator) → güvenli kriter stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2 stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer) # Seed (gönderilmediyse stokastik → demo gibi) if det_seed is not None: try: s = int(det_seed) torch.manual_seed(s) if torch.cuda.is_available(): torch.cuda.manual_seed(s) torch.cuda.manual_seed_all(s) except Exception: pass # Streamer (demo gibi) streamer = TextIteratorStreamer( chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True ) # Generate kwargs — demo ayarları gen_kwargs = dict( inputs=input_ids, images=image_tensor, streamer=streamer, do_sample=True, # DEMO temperature=float(temperature), # DEMO default 0.05 top_p=float(top_p), # DEMO default 1.0 max_new_tokens=int(max_new_tokens), # DEMO slider repetition_penalty=float(repetition_penalty), # default 1.0 → etkisiz use_cache=False, stopping_criteria=[stopping], # DEMO-benzeri durdurma ) # Üretim (arka thread) + akışı topla try: t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs) t.start() chunks = [] for piece in streamer: chunks.append(piece) text = "".join(chunks) chatbot.conversation.messages[-1][-1] = text except Exception as e: return {"error": f"Generation failed: {e}"} # Log try: row = { "time": datetime.datetime.now().isoformat(), "type": "chat", "model": "PULSE-7B", "state": [(message_text, text)], "image_hash": img_hash, "image_path": img_path or "", } with open(_conv_log_path(), "a", encoding="utf-8") as f: f.write(json.dumps(row, ensure_ascii=False) + "\n") _safe_upload(_conv_log_path()); _safe_upload(img_path or "") except Exception as e: print(f"[log] failed: {e}") return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)} # ===================== Public API ===================== def query(payload: dict): """HF Endpoint entry (demo parity + style hint).""" global model_initialized, tokenizer, model, image_processor, context_len, args if not model_initialized: if not initialize_model(): return {"error": "Model initialization failed"} model_initialized = True try: message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "" image = payload.get("image") or payload.get("image_url") or payload.get("img") or None if not message.strip(): return {"error": "Missing 'message' text"} if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."} # Demo varsayılanları — payload override edebilir temperature = float(payload.get("temperature", 0.05)) top_p = float(payload.get("top_p", 1.0)) max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096)))) repetition_penalty = float(payload.get("repetition_penalty", 1.0)) # etkisiz default conv_mode_override = payload.get("conv_mode", None) det_seed = payload.get("det_seed", None) if det_seed is not None: try: det_seed = int(det_seed) except Exception: det_seed = None return generate_response( message_text=message, image_input=image, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, conv_mode_override=conv_mode_override, repetition_penalty=repetition_penalty, det_seed=det_seed, ) except Exception as e: return {"error": f"Query failed: {e}"} def health_check(): return { "status": "healthy", "model_initialized": model_initialized, "cuda_available": torch.cuda.is_available(), "llava_available": LLAVA_AVAILABLE, "transformers_available": TRANSFORMERS_AVAILABLE, } def get_model_info(): if not model_initialized: return {"error": "Model not initialized"} return { "model_path": args.model_path if args else "Unknown", "context_len": context_len, "device": str(next(model.parameters()).device) if model else "Unknown", } # ===================== Init & Session ===================== class _Args: def __init__(self): self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B") self.model_base = None self.num_gpus = int(os.getenv("NUM_GPUS", "1")) self.conv_mode = None self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096")) self.num_frames = 16 self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0"))) self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0"))) self.debug = bool(int(os.getenv("DEBUG", "0"))) class InferenceDemo: def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_): if not LLAVA_AVAILABLE: raise ImportError("LLaVA not available") disable_torch_init() self.tokenizer, self.model, self.image_processor, self.context_len = ( tokenizer_, model_, image_processor_, context_len_ ) auto = _guess_conv_mode(model_path) self.conv_mode = args.conv_mode if args.conv_mode else auto args.conv_mode = self.conv_mode self.conversation = conv_templates[self.conv_mode].copy() self.num_frames = args.num_frames class ChatSessionManager: def __init__(self): self.chatbot = None self.args = None self.model_path = None def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len): if self.chatbot is None: self.args = args self.model_path = model_path self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len) return self.chatbot chat_manager = ChatSessionManager() def initialize_model(): global tokenizer, model, image_processor, context_len, args if not LLAVA_AVAILABLE: print("[init] LLaVA not available; cannot init.") return False try: args = _Args() model_name = get_model_name_from_path(args.model_path) tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model( args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit ) # demo: model'ı genelde cuda’da çalıştırır try: _ = next(model_.parameters()).device except Exception: if torch.cuda.is_available(): model_ = model_.to(torch.device("cuda")) model_.eval() globals()["tokenizer"] = tokenizer_ globals()["model"] = model_ globals()["image_processor"] = image_processor_ globals()["context_len"] = context_len_ chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_) print("[init] model/tokenizer/image_processor loaded.") return True except Exception as e: print(f"[init] failed: {e}") return False # ===================== HF EndpointHandler ===================== class EndpointHandler: """Hugging Face Endpoint uyumlu sınıf""" def __init__(self, model_dir): self.model_dir = model_dir print(f"EndpointHandler initialized with model_dir: {model_dir}") def __call__(self, payload): if "inputs" in payload: return query(payload["inputs"]) return query(payload) def health_check(self): return health_check() def get_model_info(self): return get_model_info() if __name__ == "__main__": print("Handler ready (Demo Parity + Style Hint). Use `EndpointHandler` or `query`.")