# -*- coding: utf-8 -*- """ PULSE ECG Handler — Deterministik Versiyon - Üretim ayarları: do_sample=False (Tutarlı çıktı), temperature/top_p etkisiz - Stopping: Konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter - Görsel tensörü: .half() ve model cihazında - Streamer: TextIteratorStreamer (demo gibi), thread ile generate - Seed/deterministic KAPALI (do_sample=False ile determinizm sağlanır) - STYLE_HINT: demo üslubuna yaklaşmak için - Post-process: YALNIZCA whitespace/biçim normalizasyonu """ import os import re import json import base64 import hashlib import datetime from io import BytesIO from threading import Thread from typing import Optional, Union import torch from PIL import Image import requests # ====== LLaVA & Transformers ====== try: from llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, ) from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.mm_utils import ( tokenizer_image_token, process_images, get_model_name_from_path, ) from llava.utils import disable_torch_init LLAVA_AVAILABLE = True except Exception as e: LLAVA_AVAILABLE = False print(f"[WARN] LLaVA not available: {e}") try: from transformers import TextIteratorStreamer, StoppingCriteria TRANSFORMERS_AVAILABLE = True except Exception as e: TRANSFORMERS_AVAILABLE = False print(f"[WARN] transformers not available: {e}") # ====== HF Hub logging (opsiyonel) ====== try: from huggingface_hub import HfApi, login HF_HUB_AVAILABLE = True except Exception: HF_HUB_AVAILABLE = False api = None repo_name = "" if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: try: login(token=os.environ["HF_TOKEN"], write_permission=True) api = HfApi() repo_name = os.environ.get("LOG_REPO", "") except Exception as e: print(f"[HF Hub] init failed: {e}") api = None repo_name = "" LOGDIR = "./logs" os.makedirs(LOGDIR, exist_ok=True) # ====== Global State ====== tokenizer = None model = None image_processor = None context_len = None args = None model_initialized = False # ====== Style Hint (demo benzeri üslup) ====== STYLE_HINT = ( "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, " "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. " "Use neutral, factual cardiology language. Avoid headings and bullet points. " "Finish with a single final line starting exactly with 'Structured clinical impression:' " "followed by a succinct, comma-separated summary of the key diagnoses." ) # ===================== Utilities ===================== def _safe_upload(path: str): if api and repo_name and path and os.path.isfile(path): try: api.upload_file( path_or_fileobj=path, path_in_repo=path.replace("./logs/", ""), repo_id=repo_name, repo_type="dataset", ) except Exception as e: print(f"[upload] failed for {path}: {e}") def _conv_log_path() -> str: t = datetime.datetime.now() return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json") def load_image_any(image_input: Union[str, dict]) -> Image.Image: """ Desteklenen: - URL (http/https) - yerel dosya yolu - base64 (opsiyonel data URL prefix ile) - {"image": } """ if isinstance(image_input, str): s = image_input.strip() if s.startswith(("http://", "https://")): r = requests.get(s, timeout=(5, 20)) r.raise_for_status() return Image.open(BytesIO(r.content)).convert("RGB") if os.path.exists(s): return Image.open(s).convert("RGB") # base64 (dataurl olabilir) if s.startswith("data:image"): s = s.split(",", 1)[1] raw = base64.b64decode(s) return Image.open(BytesIO(raw)).convert("RGB") if isinstance(image_input, dict) and "image" in image_input: return load_image_any(image_input["image"]) raise ValueError("Unsupported image input format") def _normalize_whitespace(text: str) -> str: """ Gereksiz boşluk/boş satırları toparlar: - Satır başı/sonu boşluklarını siler - Birden çok boşluğu tek boşluğa indirger - 3+ boş satırı 1 boş satıra indirger """ text = text.replace("\r\n", "\n").replace("\r", "\n") lines = [re.sub(r"[ \t]+", " ", ln.strip()) for ln in text.split("\n")] text = "\n".join(lines).strip() text = re.sub(r"\n{3,}", "\n\n", text) return text def _postprocess_min(text: str) -> str: # Yalnızca whitespace/biçim temizliği return _normalize_whitespace(text) # ====== Güvenli Stop Kriteri (conv separator) ====== class SafeKeywordsStoppingCriteria(StoppingCriteria): """ conv.sep/sep2 bazlı token eşleşmesi; tensör → bool hatası yok. """ def __init__(self, keyword: str, tokenizer): self.tokenizer = tokenizer tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0] self.kw_ids = tok # shape: (n,) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if input_ids is None or input_ids.shape[0] == 0: return False out = input_ids[0] # assume bsz=1 n = self.kw_ids.shape[0] if out.shape[0] < n: return False tail = out[-n:] kw = self.kw_ids.to(tail.device) return torch.equal(tail, kw) # ===================== Core Generation ===================== class InferenceDemo: def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_): if not LLAVA_AVAILABLE: raise ImportError("LLaVA not available") disable_torch_init() self.tokenizer, self.model, self.image_processor, self.context_len = ( tokenizer_, model_, image_processor_, context_len_ ) # Parite için sabit şablon self.conv_mode = "llava_v1" self.conversation = conv_templates[self.conv_mode].copy() self.num_frames = getattr(args, "num_frames", 16) class ChatSessionManager: def __init__(self): self.chatbot = None self.args = None self.model_path = None def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len): if self.chatbot is None: self.args = args self.model_path = model_path self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len) # Her çağrıda taze template (demo gibi yeni tur) self.chatbot.conversation = conv_templates[self.chatbot.conv_mode].copy() return self.chatbot chat_manager = ChatSessionManager() def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device): # DEMO PARİTE: sarım yok, tek görüntü için tek image token inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}" chatbot.conversation.append_message(chatbot.conversation.roles[0], inp) chatbot.conversation.append_message(chatbot.conversation.roles[1], None) prompt = chatbot.conversation.get_prompt() input_ids = tokenizer_image_token( prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0).to(device) return prompt, input_ids def generate_response( message_text: str, image_input, *, temperature: Optional[float] = None, # Deterministik modda yoksayılır top_p: Optional[float] = None, # Deterministik modda yoksayılır max_new_tokens: Optional[int] = None, conv_mode_override: Optional[str] = None, repetition_penalty: Optional[float] = None, det_seed: Optional[int] = None, # Deterministik modda yoksayılır ): if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE): return {"error": "Required libraries not available (llava/transformers)"} if not message_text or image_input is None: return {"error": "Both 'message' and 'image' are required"} # Varsayılanlar if max_new_tokens is None: max_new_tokens = 4096 if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz # Chat session chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len) if conv_mode_override and conv_mode_override in conv_templates: chatbot.conversation = conv_templates[conv_mode_override].copy() # Görüntü yükle try: pil_img = load_image_any(image_input) except Exception as e: return {"error": f"Failed to load image: {e}"} # Log için hash+path img_hash, img_path = "NA", None try: buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue() img_hash = hashlib.md5(raw).hexdigest() t = datetime.datetime.now() img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg") os.makedirs(os.path.dirname(img_path), exist_ok=True) if not os.path.isfile(img_path): pil_img.save(img_path) except Exception as e: print(f"[log] save image failed: {e}") # Cihaz/dtype device = next(chatbot.model.parameters()).device dtype = torch.float16 # demo: half # Görüntü ön-işleme → tensör try: processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config) if isinstance(processed, (list, tuple)) and len(processed) > 0: image_tensor = processed[0] elif isinstance(processed, torch.Tensor): image_tensor = processed[0] if processed.ndim == 4 else processed else: return {"error": "Image processing returned empty"} if image_tensor.ndim == 3: image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W) image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device except Exception as e: return {"error": f"Image processing failed: {e}"} # STYLE_HINT ekle ve prompt hazırla msg = (message_text or "").strip() msg = f"{msg}\n\n{STYLE_HINT}" _, input_ids = _build_prompt_and_ids(chatbot, msg, device) # Stop string (conv separator) → güvenli kriter stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2 stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer) # Seed (do_sample=False olduğu için önemsiz, ancak kodda bırakılabilir) if det_seed is not None: try: s = int(det_seed) torch.manual_seed(s) if torch.cuda.is_available(): torch.cuda.manual_seed(s) torch.cuda.manual_seed_all(s) except Exception: pass # Streamer (demo gibi) streamer = TextIteratorStreamer( chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True ) # Generate kwargs — Deterministik Ayarlar gen_kwargs = dict( inputs=input_ids, images=image_tensor, streamer=streamer, # 🟢 ÖNEMLİ DEĞİŞİKLİK: Deterministiği (Tutarlılığı) Aç do_sample=False, # temperature ve top_p ayarları artık yoksayılır # temperature=float(temperature), # top_p=float(top_p), max_new_tokens=int(max_new_tokens), repetition_penalty=float(repetition_penalty), use_cache=False, stopping_criteria=[stopping], ) # Üretim (arka thread) + akışı topla try: t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs) t.start() chunks = [] for piece in streamer: chunks.append(piece) text = "".join(chunks) text = _postprocess_min(text) # yalnızca whitespace/format temizliği chatbot.conversation.messages[-1][-1] = text except Exception as e: return {"error": f"Generation failed: {e}"} # Log try: row = { "time": datetime.datetime.now().isoformat(), "type": "chat", "model": "PULSE-7B", "state": [(message_text, text)], "image_hash": img_hash, "image_path": img_path or "", } with open(_conv_log_path(), "a", encoding="utf-8") as f: f.write(json.dumps(row, ensure_ascii=False) + "\n") _safe_upload(_conv_log_path()); _safe_upload(img_path or "") except Exception as e: print(f"[log] failed: {e}") return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)} # ===================== Public API ===================== def query(payload: dict): """HF Endpoint entry (demo-like).""" global model_initialized, tokenizer, model, image_processor, context_len, args # 🟢 Health check kısayolu: hem {"health_check": true} hem de {"message": "health_check"} desteklenir if payload.get("health_check") or payload.get("message") == "health_check": return health_check() if not model_initialized: if not initialize_model(): return {"error": "Model initialization failed"} model_initialized = True try: message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "" image = payload.get("image") or payload.get("image_url") or payload.get("img") or None if not message.strip(): return {"error": "Missing 'message' text"} if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."} # Deterministik modda temperature/top_p yoksayılır, ancak API uyumluluğu için tutulur temperature = float(payload.get("temperature", 0.0)) # Default 0.0 top_p = float(payload.get("top_p", 1.0)) max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096)))) repetition_penalty = float(payload.get("repetition_penalty", 1.0)) conv_mode_override = payload.get("conv_mode", None) det_seed = payload.get("det_seed", None) if det_seed is not None: try: det_seed = int(det_seed) except Exception: det_seed = None return generate_response( message_text=message, image_input=image, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, conv_mode_override=conv_mode_override, repetition_penalty=repetition_penalty, det_seed=det_seed, ) except Exception as e: return {"error": f"Query failed: {e}"} def health_check(): info = { "status": "healthy", "model_initialized": model_initialized, "llava_available": LLAVA_AVAILABLE, "transformers_available": TRANSFORMERS_AVAILABLE, "cuda_available": torch.cuda.is_available(), } if torch.cuda.is_available(): try: device_index = torch.cuda.current_device() props = torch.cuda.get_device_properties(device_index) total_vram_gb = round(props.total_memory / (1024 ** 3), 2) used_vram_gb = round(torch.cuda.memory_allocated(device_index) / (1024 ** 3), 2) reserved_vram_gb = round(torch.cuda.memory_reserved(device_index) / (1024 ** 3), 2) info.update({ "cuda_device_index": device_index, "cuda_name": props.name, "cuda_compute_capability": f"{props.major}.{props.minor}", "cuda_total_vram_gb": total_vram_gb, "cuda_used_vram_gb": used_vram_gb, "cuda_reserved_vram_gb": reserved_vram_gb, "torch_version": torch.__version__, "cuda_runtime_version": torch.version.cuda, }) except Exception as e: info["cuda_error"] = str(e) return info def get_model_info(): if not model_initialized: return {"error": "Model not initialized"} return { "model_path": args.model_path if args else "Unknown", "context_len": context_len, "device": str(next(model.parameters()).device) if model else "Unknown", } # ===================== Init & Session ===================== class _Args: def __init__(self): self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B") self.model_base = None self.num_gpus = int(os.getenv("NUM_GPUS", "1")) self.conv_mode = "llava_v1" # Parite için sabit self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096")) self.num_frames = 16 self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0"))) self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0"))) self.debug = bool(int(os.getenv("DEBUG", "0"))) def initialize_model(): global tokenizer, model, image_processor, context_len, args if not LLAVA_AVAILABLE: print("[init] LLaVA not available; cannot init.") return False try: args = _Args() model_name = get_model_name_from_path(args.model_path) tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model( args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit ) # model'ı cuda’ya taşı try: _ = next(model_.parameters()).device except Exception: if torch.cuda.is_available(): model_ = model_.to(torch.device("cuda")) model_.eval() globals()["tokenizer"] = tokenizer_ globals()["model"] = model_ globals()["image_processor"] = image_processor_ globals()["context_len"] = context_len_ chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_) print("[init] model/tokenizer/image_processor loaded.") return True except Exception as e: print(f"[init] failed: {e}") return False # ===================== HF EndpointHandler ===================== class EndpointHandler: """Hugging Face Endpoint uyumlu sınıf""" def __init__(self, model_dir): self.model_dir = model_dir print(f"EndpointHandler initialized with model_dir: {model_dir}") def __call__(self, payload): if "inputs" in payload: return query(payload["inputs"]) return query(payload) def health_check(self): return health_check() def get_model_info(self): return get_model_info() if __name__ == "__main__": print("Handler ready (Deterministik Mode: do_sample=False). Use `EndpointHandler` or `query`.")