CanerDedeoglu commited on
Commit
1fd6be7
·
verified ·
1 Parent(s): 579d7f4

Upload handler (5).py

Browse files
Files changed (1) hide show
  1. handler (5).py +477 -0
handler (5).py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ PULSE ECG Handler — Demo Parity + Style Hint
4
+ - Demo app.py ile aynı üretim ayarları:
5
+ do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
+ - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
+ - Görsel tensörü: .half() ve model cihazında
8
+ - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
+ - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
+ - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) yaklaşmak için
11
+ - Post-process: YALNIZCA whitespace/biçim normalizasyonu (yönetim/öneri cümleleri korunur)
12
+ """
13
+
14
+ import os
15
+ import re
16
+ import json
17
+ import base64
18
+ import hashlib
19
+ import datetime
20
+ from io import BytesIO
21
+ from threading import Thread
22
+ from typing import Optional, Union
23
+
24
+ import torch
25
+ from PIL import Image
26
+ import requests
27
+
28
+ # ====== LLaVA & Transformers ======
29
+ try:
30
+ from llava.constants import (
31
+ IMAGE_TOKEN_INDEX,
32
+ DEFAULT_IMAGE_TOKEN,
33
+ )
34
+ from llava.conversation import conv_templates, SeparatorStyle
35
+ from llava.model.builder import load_pretrained_model
36
+ from llava.mm_utils import (
37
+ tokenizer_image_token,
38
+ process_images,
39
+ get_model_name_from_path,
40
+ )
41
+ from llava.utils import disable_torch_init
42
+ LLAVA_AVAILABLE = True
43
+ except Exception as e:
44
+ LLAVA_AVAILABLE = False
45
+ print(f"[WARN] LLaVA not available: {e}")
46
+
47
+ try:
48
+ from transformers import TextIteratorStreamer, StoppingCriteria
49
+ TRANSFORMERS_AVAILABLE = True
50
+ except Exception as e:
51
+ TRANSFORMERS_AVAILABLE = False
52
+ print(f"[WARN] transformers not available: {e}")
53
+
54
+ # ====== HF Hub logging (opsiyonel) ======
55
+ try:
56
+ from huggingface_hub import HfApi, login
57
+ HF_HUB_AVAILABLE = True
58
+ except Exception:
59
+ HF_HUB_AVAILABLE = False
60
+
61
+ api = None
62
+ repo_name = ""
63
+ if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
64
+ try:
65
+ login(token=os.environ["HF_TOKEN"], write_permission=True)
66
+ api = HfApi()
67
+ repo_name = os.environ.get("LOG_REPO", "")
68
+ except Exception as e:
69
+ print(f"[HF Hub] init failed: {e}")
70
+ api = None
71
+ repo_name = ""
72
+
73
+ LOGDIR = "./logs"
74
+ os.makedirs(LOGDIR, exist_ok=True)
75
+
76
+ # ====== Global State ======
77
+ tokenizer = None
78
+ model = None
79
+ image_processor = None
80
+ context_len = None
81
+ args = None
82
+ model_initialized = False
83
+
84
+ # ====== Style Hint (demo benzeri üslup) ======
85
+ STYLE_HINT = (
86
+ "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
87
+ "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
88
+ "Use neutral, factual cardiology language. Avoid headings and bullet points. "
89
+ "Finish with a single final line starting exactly with 'Structured clinical impression:' "
90
+ "followed by a succinct, comma-separated summary of the key diagnoses."
91
+ )
92
+
93
+ # ===================== Utilities =====================
94
+
95
+ def _safe_upload(path: str):
96
+ if api and repo_name and path and os.path.isfile(path):
97
+ try:
98
+ api.upload_file(
99
+ path_or_fileobj=path,
100
+ path_in_repo=path.replace("./logs/", ""),
101
+ repo_id=repo_name,
102
+ repo_type="dataset",
103
+ )
104
+ except Exception as e:
105
+ print(f"[upload] failed for {path}: {e}")
106
+
107
+ def _conv_log_path() -> str:
108
+ t = datetime.datetime.now()
109
+ return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
110
+
111
+ def load_image_any(image_input: Union[str, dict]) -> Image.Image:
112
+ """
113
+ Desteklenen:
114
+ - URL (http/https)
115
+ - yerel dosya yolu
116
+ - base64 (opsiyonel data URL prefix ile)
117
+ - {"image": <base64|dataurl>}
118
+ """
119
+ if isinstance(image_input, str):
120
+ s = image_input.strip()
121
+ if s.startswith(("http://", "https://")):
122
+ r = requests.get(s, timeout=(5, 20))
123
+ r.raise_for_status()
124
+ return Image.open(BytesIO(r.content)).convert("RGB")
125
+ if os.path.exists(s):
126
+ return Image.open(s).convert("RGB")
127
+ # base64 (dataurl olabilir)
128
+ if s.startswith("data:image"):
129
+ s = s.split(",", 1)[1]
130
+ raw = base64.b64decode(s)
131
+ return Image.open(BytesIO(raw)).convert("RGB")
132
+
133
+ if isinstance(image_input, dict) and "image" in image_input:
134
+ return load_image_any(image_input["image"])
135
+
136
+ raise ValueError("Unsupported image input format")
137
+
138
+ def _normalize_whitespace(text: str) -> str:
139
+ """
140
+ Gereksiz boşluk/boş satırları toparlar:
141
+ - Satır başı/sonu boşluklarını siler
142
+ - Birden çok boşluğu tek boşluğa indirger
143
+ - 3+ boş satırı 1 boş satıra indirger
144
+ """
145
+ text = text.replace("\r\n", "\n").replace("\r", "\n")
146
+ lines = [re.sub(r"[ \t]+", " ", ln.strip()) for ln in text.split("\n")]
147
+ text = "\n".join(lines).strip()
148
+ text = re.sub(r"\n{3,}", "\n\n", text)
149
+ return text
150
+
151
+ def _postprocess_min(text: str) -> str:
152
+ # Yalnızca whitespace/biçim temizliği
153
+ return _normalize_whitespace(text)
154
+
155
+ # ====== Güvenli Stop Kriteri (conv separator) ======
156
+ class SafeKeywordsStoppingCriteria(StoppingCriteria):
157
+ """
158
+ conv.sep/sep2 bazlı token eşleşmesi; tensör → bool hatası yok.
159
+ """
160
+ def __init__(self, keyword: str, tokenizer):
161
+ self.tokenizer = tokenizer
162
+ tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
163
+ self.kw_ids = tok # shape: (n,)
164
+
165
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
166
+ if input_ids is None or input_ids.shape[0] == 0:
167
+ return False
168
+ out = input_ids[0] # assume bsz=1
169
+ n = self.kw_ids.shape[0]
170
+ if out.shape[0] < n:
171
+ return False
172
+ tail = out[-n:]
173
+ kw = self.kw_ids.to(tail.device)
174
+ return torch.equal(tail, kw)
175
+
176
+ # ===================== Core Generation =====================
177
+
178
+ class InferenceDemo:
179
+ def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
180
+ if not LLAVA_AVAILABLE:
181
+ raise ImportError("LLaVA not available")
182
+ disable_torch_init()
183
+ self.tokenizer, self.model, self.image_processor, self.context_len = (
184
+ tokenizer_, model_, image_processor_, context_len_
185
+ )
186
+ # Parite için sabit şablon
187
+ self.conv_mode = "llava_v1"
188
+ self.conversation = conv_templates[self.conv_mode].copy()
189
+ self.num_frames = getattr(args, "num_frames", 16)
190
+
191
+ class ChatSessionManager:
192
+ def __init__(self):
193
+ self.chatbot = None
194
+ self.args = None
195
+ self.model_path = None
196
+ def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len):
197
+ if self.chatbot is None:
198
+ self.args = args
199
+ self.model_path = model_path
200
+ self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
201
+ def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
202
+ self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
203
+ # Her çağrıda taze template (demo gibi yeni tur)
204
+ self.chatbot.conversation = conv_templates[self.chatbot.conv_mode].copy()
205
+ return self.chatbot
206
+
207
+ chat_manager = ChatSessionManager()
208
+
209
+ def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
210
+ # DEMO PARİTE: sarım yok, tek görüntü için tek image token
211
+ inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
212
+ chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
213
+ chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
214
+ prompt = chatbot.conversation.get_prompt()
215
+
216
+ input_ids = tokenizer_image_token(
217
+ prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
218
+ ).unsqueeze(0).to(device)
219
+ return prompt, input_ids
220
+
221
+ def generate_response(
222
+ message_text: str,
223
+ image_input,
224
+ *,
225
+ temperature: Optional[float] = None,
226
+ top_p: Optional[float] = None,
227
+ max_new_tokens: Optional[int] = None,
228
+ conv_mode_override: Optional[str] = None,
229
+ repetition_penalty: Optional[float] = None,
230
+ det_seed: Optional[int] = None, # None → stokastik (demo gibi)
231
+ ):
232
+ if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
233
+ return {"error": "Required libraries not available (llava/transformers)"}
234
+ if not message_text or image_input is None:
235
+ return {"error": "Both 'message' and 'image' are required"}
236
+
237
+ # Varsayılanlar → demo
238
+ if temperature is None: temperature = 0.05
239
+ if top_p is None: top_p = 1.0
240
+ if max_new_tokens is None: max_new_tokens = 4096
241
+ if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz
242
+
243
+ # Chat session
244
+ chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
245
+ if conv_mode_override and conv_mode_override in conv_templates:
246
+ chatbot.conversation = conv_templates[conv_mode_override].copy()
247
+
248
+ # Görüntü yükle
249
+ try:
250
+ pil_img = load_image_any(image_input)
251
+ except Exception as e:
252
+ return {"error": f"Failed to load image: {e}"}
253
+
254
+ # Log için hash+path
255
+ img_hash, img_path = "NA", None
256
+ try:
257
+ buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
258
+ img_hash = hashlib.md5(raw).hexdigest()
259
+ t = datetime.datetime.now()
260
+ img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
261
+ os.makedirs(os.path.dirname(img_path), exist_ok=True)
262
+ if not os.path.isfile(img_path):
263
+ pil_img.save(img_path)
264
+ except Exception as e:
265
+ print(f"[log] save image failed: {e}")
266
+
267
+ # Cihaz/dtype
268
+ device = next(chatbot.model.parameters()).device
269
+ dtype = torch.float16 # demo: half
270
+
271
+ # Görüntü ön-işleme → tensör
272
+ try:
273
+ processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
274
+ if isinstance(processed, (list, tuple)) and len(processed) > 0:
275
+ image_tensor = processed[0]
276
+ elif isinstance(processed, torch.Tensor):
277
+ image_tensor = processed[0] if processed.ndim == 4 else processed
278
+ else:
279
+ return {"error": "Image processing returned empty"}
280
+ if image_tensor.ndim == 3:
281
+ image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
282
+ image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device
283
+ except Exception as e:
284
+ return {"error": f"Image processing failed: {e}"}
285
+
286
+ # STYLE_HINT ekle ve prompt hazırla
287
+ msg = (message_text or "").strip()
288
+ msg = f"{msg}\n\n{STYLE_HINT}"
289
+ _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
290
+
291
+ # Stop string (conv separator) → güvenli kriter
292
+ stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
293
+ stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
294
+
295
+ # Seed (gönderilmediyse stokastik → demo gibi)
296
+ if det_seed is not None:
297
+ try:
298
+ s = int(det_seed)
299
+ torch.manual_seed(s)
300
+ if torch.cuda.is_available():
301
+ torch.cuda.manual_seed(s)
302
+ torch.cuda.manual_seed_all(s)
303
+ except Exception:
304
+ pass
305
+
306
+ # Streamer (demo gibi)
307
+ streamer = TextIteratorStreamer(
308
+ chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
309
+ )
310
+
311
+ # Generate kwargs — demo ayarları
312
+ gen_kwargs = dict(
313
+ inputs=input_ids,
314
+ images=image_tensor,
315
+ streamer=streamer,
316
+ do_sample=True, # DEMO
317
+ temperature=float(temperature), # DEMO default 0.05
318
+ top_p=float(top_p), # DEMO default 1.0
319
+ max_new_tokens=int(max_new_tokens), # DEMO slider
320
+ repetition_penalty=float(repetition_penalty), # default 1.0 → etkisiz
321
+ use_cache=False,
322
+ stopping_criteria=[stopping], # DEMO-benzeri durdurma
323
+ )
324
+
325
+ # Üretim (arka thread) + akışı topla
326
+ try:
327
+ t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
328
+ t.start()
329
+ chunks = []
330
+ for piece in streamer:
331
+ chunks.append(piece)
332
+ text = "".join(chunks)
333
+ text = _postprocess_min(text) # yalnızca whitespace/format temizliği
334
+ chatbot.conversation.messages[-1][-1] = text
335
+ except Exception as e:
336
+ return {"error": f"Generation failed: {e}"}
337
+
338
+ # Log
339
+ try:
340
+ row = {
341
+ "time": datetime.datetime.now().isoformat(),
342
+ "type": "chat",
343
+ "model": "PULSE-7B",
344
+ "state": [(message_text, text)],
345
+ "image_hash": img_hash,
346
+ "image_path": img_path or "",
347
+ }
348
+ with open(_conv_log_path(), "a", encoding="utf-8") as f:
349
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
350
+ _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
351
+ except Exception as e:
352
+ print(f"[log] failed: {e}")
353
+
354
+ return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
355
+
356
+ # ===================== Public API =====================
357
+
358
+ def query(payload: dict):
359
+ """HF Endpoint entry (demo-like)."""
360
+ global model_initialized, tokenizer, model, image_processor, context_len, args
361
+ if not model_initialized:
362
+ if not initialize_model():
363
+ return {"error": "Model initialization failed"}
364
+ model_initialized = True
365
+
366
+ try:
367
+ message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
368
+ image = payload.get("image") or payload.get("image_url") or payload.get("img") or None
369
+ if not message.strip(): return {"error": "Missing 'message' text"}
370
+ if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
371
+
372
+ # Demo varsayılanları — payload override edebilir
373
+ temperature = float(payload.get("temperature", 0.05))
374
+ top_p = float(payload.get("top_p", 1.0))
375
+ max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
376
+ repetition_penalty = float(payload.get("repetition_penalty", 1.0)) # etkisiz default
377
+
378
+ conv_mode_override = payload.get("conv_mode", None)
379
+ det_seed = payload.get("det_seed", None)
380
+ if det_seed is not None:
381
+ try: det_seed = int(det_seed)
382
+ except Exception: det_seed = None
383
+
384
+ return generate_response(
385
+ message_text=message,
386
+ image_input=image,
387
+ temperature=temperature,
388
+ top_p=top_p,
389
+ max_new_tokens=max_new_tokens,
390
+ conv_mode_override=conv_mode_override,
391
+ repetition_penalty=repetition_penalty,
392
+ det_seed=det_seed,
393
+ )
394
+ except Exception as e:
395
+ return {"error": f"Query failed: {e}"}
396
+
397
+ def health_check():
398
+ return {
399
+ "status": "healthy",
400
+ "model_initialized": model_initialized,
401
+ "cuda_available": torch.cuda.is_available(),
402
+ "llava_available": LLAVA_AVAILABLE,
403
+ "transformers_available": TRANSFORMERS_AVAILABLE,
404
+ }
405
+
406
+ def get_model_info():
407
+ if not model_initialized:
408
+ return {"error": "Model not initialized"}
409
+ return {
410
+ "model_path": args.model_path if args else "Unknown",
411
+ "context_len": context_len,
412
+ "device": str(next(model.parameters()).device) if model else "Unknown",
413
+ }
414
+
415
+ # ===================== Init & Session =====================
416
+
417
+ class _Args:
418
+ def __init__(self):
419
+ self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
420
+ self.model_base = None
421
+ self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
422
+ self.conv_mode = "llava_v1" # Parite için sabit
423
+ self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
424
+ self.num_frames = 16
425
+ self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
426
+ self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
427
+ self.debug = bool(int(os.getenv("DEBUG", "0")))
428
+
429
+ def initialize_model():
430
+ global tokenizer, model, image_processor, context_len, args
431
+ if not LLAVA_AVAILABLE:
432
+ print("[init] LLaVA not available; cannot init.")
433
+ return False
434
+ try:
435
+ args = _Args()
436
+ model_name = get_model_name_from_path(args.model_path)
437
+ tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
438
+ args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
439
+ )
440
+ # demo: model'ı genelde cuda’da çalıştırır
441
+ try:
442
+ _ = next(model_.parameters()).device
443
+ except Exception:
444
+ if torch.cuda.is_available():
445
+ model_ = model_.to(torch.device("cuda"))
446
+ model_.eval()
447
+
448
+ globals()["tokenizer"] = tokenizer_
449
+ globals()["model"] = model_
450
+ globals()["image_processor"] = image_processor_
451
+ globals()["context_len"] = context_len_
452
+
453
+ chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_)
454
+ print("[init] model/tokenizer/image_processor loaded.")
455
+ return True
456
+ except Exception as e:
457
+ print(f"[init] failed: {e}")
458
+ return False
459
+
460
+ # ===================== HF EndpointHandler =====================
461
+
462
+ class EndpointHandler:
463
+ """Hugging Face Endpoint uyumlu sınıf"""
464
+ def __init__(self, model_dir):
465
+ self.model_dir = model_dir
466
+ print(f"EndpointHandler initialized with model_dir: {model_dir}")
467
+ def __call__(self, payload):
468
+ if "inputs" in payload:
469
+ return query(payload["inputs"])
470
+ return query(payload)
471
+ def health_check(self):
472
+ return health_check()
473
+ def get_model_info(self):
474
+ return get_model_info()
475
+
476
+ if __name__ == "__main__":
477
+ print("Handler ready (Demo Parity + Style Hint + whitespace post-process). Use `EndpointHandler` or `query`.")