Files changed (1) hide show
  1. handler.py +487 -313
handler.py CHANGED
@@ -1,341 +1,515 @@
1
- # -*- coding: utf-8 -*-
2
- # handler.py — Rapid_ECG / PULSE-7B — Startup-load, Stabil ve DEBUG'li sürüm
3
- # - Sunucu açılır açılmaz model yüklenir (cold start only once)
4
- # - HF Endpoint sözleşmesi (EndpointHandler.load().__call__)
5
- # - Yerel (HF_MODEL_DIR) → Hub (HF_MODEL_ID) yükleme sırası
6
- # - Görsel sadece .preprocess() ile işlenir (process_images yok)
7
- # - Vision tower kontrolü: mm_vision_tower veya vision_tower
8
- # - IMAGE_TOKEN_INDEX kullanımı ve kapsamlı [DEBUG] logları
9
-
10
  import os
11
- import io
12
- import sys
13
- import base64
14
- import subprocess
15
- from typing import Any, Dict, Optional
16
-
17
  import torch
18
- from PIL import Image
 
 
 
 
19
  import requests
 
 
 
 
20
 
21
-
22
-
23
- import os
24
- os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
25
-
26
-
27
- # ===== LLaVA kütüphanesini garantiye al =====
28
- def _ensure_llava(tag: str = "v1.2.0"):
29
- try:
30
- import llava # noqa
31
- print("[DEBUG] LLaVA already available.")
32
- return
33
- except ImportError:
34
- print(f"[DEBUG] LLaVA not found; installing (tag={tag}) ...")
35
- subprocess.check_call([
36
- sys.executable, "-m", "pip", "install",
37
- f"git+https://github.com/haotian-liu/LLaVA@{tag}#egg=llava"
38
- ])
39
- print("[DEBUG] LLaVA installed.")
40
-
41
- _ensure_llava("v1.2.0")
42
-
43
- # ===== LLaVA importları =====
44
- from llava.conversation import conv_templates
45
  from llava.constants import (
 
46
  DEFAULT_IMAGE_TOKEN,
47
  DEFAULT_IM_START_TOKEN,
48
  DEFAULT_IM_END_TOKEN,
49
- IMAGE_TOKEN_INDEX,
50
  )
 
51
  from llava.model.builder import load_pretrained_model
52
- from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
53
-
54
-
55
- # ---------- yardımcılar ----------
56
- def _get_env(name: str, default: Optional[str] = None) -> Optional[str]:
57
- v = os.getenv(name)
58
- return v if v not in (None, "") else default
59
-
60
- def _pick_device() -> torch.device:
61
- if torch.cuda.is_available():
62
- dev = torch.device("cuda")
63
- elif torch.backends.mps.is_available():
64
- dev = torch.device("mps")
65
- else:
66
- dev = torch.device("cpu")
67
- print(f"[DEBUG] pick_device -> {dev}")
68
- return dev
69
-
70
- def _pick_dtype(device: torch.device):
71
- if device.type == "cuda":
72
- dt = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
73
- else:
74
- dt = torch.float32
75
- print(f"[DEBUG] pick_dtype({device}) -> {dt}")
76
- return dt
77
 
78
- def _is_probably_base64(s: str) -> bool:
79
- s = s.strip()
80
- if s.startswith("data:image"):
81
- return True
82
- allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=\n\r")
83
- return len(s) % 4 == 0 and all(c in allowed for c in s)
84
-
85
- def _load_image_from_any(image_input: Any) -> Image.Image:
86
- print(f"[DEBUG] _load_image_from_any type={type(image_input)}")
87
- if isinstance(image_input, Image.Image):
88
- return image_input.convert("RGB")
89
- if isinstance(image_input, (bytes, bytearray)):
90
- return Image.open(io.BytesIO(image_input)).convert("RGB")
91
- if hasattr(image_input, "read"):
92
- return Image.open(image_input).convert("RGB")
93
- if isinstance(image_input, str):
94
- s = image_input.strip()
95
- if s.startswith("data:image"):
96
- try:
97
- _, b64 = s.split(",", 1)
98
- data = base64.b64decode(b64)
99
- return Image.open(io.BytesIO(data)).convert("RGB")
100
- except Exception as e:
101
- raise ValueError(f"Bad data URL: {e}")
102
- if _is_probably_base64(s) and not s.startswith(("http://", "https://")):
103
- try:
104
- data = base64.b64decode(s)
105
- return Image.open(io.BytesIO(data)).convert("RGB")
106
- except Exception as e:
107
- raise ValueError(f"Bad base64 image: {e}")
108
- if s.startswith(("http://", "https://")):
109
- resp = requests.get(s, timeout=20)
110
- resp.raise_for_status()
111
- return Image.open(io.BytesIO(resp.content)).convert("RGB")
112
- # local path
113
- return Image.open(s).convert("RGB")
114
- raise ValueError(f"Unsupported image input type: {type(image_input)}")
115
-
116
- def _get_conv_mode(model_name: str) -> str:
117
- name = (model_name or "").lower()
118
- if "llama-2" in name:
119
- return "llava_llama_2"
120
- if "mistral" in name:
121
- return "mistral_instruct"
122
- if "v1.6-34b" in name:
123
- return "chatml_direct"
124
- if "v1" in name or "pulse" in name:
125
- return "llava_v1"
126
- if "mpt" in name:
127
- return "mpt"
128
- return "llava_v0"
129
-
130
- def _build_prompt_with_image(prompt: str, model_cfg) -> str:
131
- # Kullanıcı image token eklediyse yeniden eklemeyelim
132
- if DEFAULT_IMAGE_TOKEN in prompt or DEFAULT_IM_START_TOKEN in prompt:
133
- return prompt
134
- if getattr(model_cfg, "mm_use_im_start_end", False):
135
- token = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
136
- return f"{token}\n{prompt}"
137
- return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
138
-
139
- def _resolve_model_path(model_dir_hint: Optional[str], default_dir: str = "/repository") -> str:
140
- # Öncelik: HF_MODEL_DIR (yerel) -> ctor'dan gelen model_dir_hint -> default_dir
141
- p = _get_env("HF_MODEL_DIR") or model_dir_hint or default_dir
142
- p = os.path.abspath(p)
143
- print(f"[DEBUG] resolved model path: {p}")
144
- return p
145
-
146
-
147
- # ---------- Endpoint Handler ----------
148
- class EndpointHandler:
149
- def __init__(self, model_dir: Optional[str] = None):
150
- # DEBUG banner
151
- print("🚀 Starting up PULSE-7B handler (startup load)...")
152
- print("📝 Enhanced by Ubden® Team")
153
- print(f"🔧 Python: {sys.version}")
154
- print(f"🔧 PyTorch: {torch.__version__}")
155
  try:
156
- import transformers
157
- print(f"🔧 Transformers: {transformers.__version__}")
 
 
 
158
  except Exception as e:
159
- print(f"[DEBUG] transformers import failed: {e}")
160
-
161
- self.model_dir = model_dir
162
- self.device = _pick_device()
163
- self.dtype = _pick_dtype(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- # Ortam ayarları (flash attn ipucu, zarar vermez)
166
- os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
167
- os.environ.setdefault("FLASH_ATTENTION", "1")
168
- print(f"[DEBUG] ATTN_IMPLEMENTATION={os.getenv('ATTN_IMPLEMENTATION')} FLASH_ATTENTION={os.getenv('FLASH_ATTENTION')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- # Model/Tokenizer/ImageProcessor konteynerleri
171
- self.model = None
172
- self.tokenizer = None
173
- self.image_processor = None
174
- self.context_len = None
175
- self.model_name = None
176
 
177
- # ---- Modeli burada (startup’ta) yükle ----
178
- try:
179
- self._startup_load_model()
180
- print("✅ Model loaded & ready in __init__")
181
- except Exception as e:
182
- print(f"💥 CRITICAL: model startup load failed: {e}")
183
- raise
184
-
185
- def _startup_load_model(self):
186
- # Yerel dizin varsa onu kullan, yoksa hub
187
- local_path = _resolve_model_path(self.model_dir)
188
- use_local = os.path.isdir(local_path) and any(
189
- os.path.exists(os.path.join(local_path, f))
190
- for f in ("config.json", "tokenizer_config.json")
191
  )
192
- model_base = _get_env("HF_MODEL_BASE", None)
193
-
194
- if use_local:
195
- model_path = local_path
196
- print(f"[DEBUG] loading model LOCALLY from: {model_path}")
197
- else:
198
- model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
199
- print(f"[DEBUG] loading model from HUB: {model_path} (HF_MODEL_BASE={model_base})")
200
 
201
- # ⬇️ FIX: LLaVA v1.2.0 imzası model_name parametresi istiyor
202
  model_name = get_model_name_from_path(model_path)
203
- print(f"[DEBUG] resolved model_name: {model_name}")
204
-
205
- print("[DEBUG] calling load_pretrained_model ...")
206
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
207
- model_path=model_path,
208
- model_base=model_base,
209
- model_name=model_name, # <-- gerekli parametre
210
- load_8bit=False,
211
- load_4bit=False,
212
- device_map="auto",
213
- device=self.device,
214
- )
215
- self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
216
- print(f"[DEBUG] model loaded: name={self.model_name}")
217
 
218
- # Vision tower kontrolü (yeni/eskı alan adları)
219
- vt = (
220
- getattr(self.model.config, "mm_vision_tower", None)
221
- or getattr(self.model.config, "vision_tower", None)
222
- )
223
- print(f"[DEBUG] vision tower: {vt}")
224
- if self.image_processor is None or vt is None:
225
- raise RuntimeError(
226
- "[ERROR] Vision tower not loaded (mm_vision_tower/vision_tower None). "
227
- "Yerel yükleme için HF_MODEL_DIR doğru klasörü göstermeli; "
228
- "Hub için HF_MODEL_ID PULSE/LLaVA tabanlı olmalı (örn: 'PULSE-ECG/PULSE-7B')."
229
  )
230
-
231
- # Tokenizer güvenliği
232
- try:
233
- self.tokenizer.padding_side = "left"
234
- if getattr(self.tokenizer, "pad_token_id", None) is None:
235
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
236
- except Exception as e:
237
- print(f"[DEBUG] tokenizer safety patch failed: {e}")
238
-
239
- self.model.eval()
240
-
241
- # HF inference toolkit load() yine çağıracağı için no-op
242
- def load(self):
243
- print("[DEBUG] load(): model is already initialized in __init__")
244
- return True
245
-
246
- @torch.inference_mode()
247
- def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
248
- print(f"[DEBUG] __call__ inputs keys={list(inputs.keys()) if hasattr(inputs,'keys') else 'N/A'}")
249
- # HF {"inputs": {...}} sarmasını
250
- if "inputs" in inputs and isinstance(inputs["inputs"], dict):
251
- inputs = inputs["inputs"]
252
-
253
- prompt = inputs.get("query") or inputs.get("prompt") or inputs.get("istem") or ""
254
- image_in = inputs.get("image") or inputs.get("image_url") or inputs.get("img")
255
- if not image_in:
256
- return {"error": "Missing 'image' in payload"}
257
- if not isinstance(prompt, str) or not prompt.strip():
258
- return {"error": "Missing 'query'/'prompt' text"}
259
-
260
- # Üretim parametreleri
261
- temperature = float(inputs.get("temperature", 0.0))
262
- top_p = float(inputs.get("top_p", 0.9))
263
- max_new = int(inputs.get("max_new_tokens", inputs.get("max_tokens", 512)))
264
- repetition_penalty = float(inputs.get("repetition_penalty", 1.0))
265
- conv_mode_override = inputs.get("conv_mode") or _get_env("CONV_MODE", None)
266
-
267
- # ---- Görsel yükle + preprocess
268
- try:
269
- image = _load_image_from_any(image_in)
270
- print(f"[DEBUG] loaded image size={image.size}")
271
- except Exception as e:
272
- return {"error": f"Failed to load image: {e}"}
273
-
274
- if self.image_processor is None:
275
- return {"error": "image_processor is None; model not initialized properly (no vision tower)"}
276
-
277
- try:
278
- out = self.image_processor.preprocess(image, return_tensors="pt")
279
- images_tensor = out["pixel_values"].to(self.device, dtype=self.dtype)
280
- image_sizes = [image.size]
281
- print(f"[DEBUG] preprocess OK; images_tensor.shape={images_tensor.shape}")
282
- except Exception as e:
283
- return {"error": f"Image preprocessing failed: {e}"}
284
-
285
- # ---- Konuşma + prompt
286
- mode = conv_mode_override or _get_conv_mode(self.model_name)
287
- conv = (conv_templates.get(mode) or conv_templates[list(conv_templates.keys())[0]]).copy()
288
- conv.append_message(conv.roles[0], _build_prompt_with_image(prompt.strip(), self.model.config))
289
- conv.append_message(conv.roles[1], None)
290
- full_prompt = conv.get_prompt()
291
- print(f"[DEBUG] conv_mode={mode}; full_prompt_len={len(full_prompt)}")
292
-
293
- # ---- Tokenization (IMAGE_TOKEN_INDEX ile)
294
  try:
295
- input_ids = tokenizer_image_token(
296
- full_prompt, self.tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt"
297
- ).unsqueeze(0).to(self.device)
298
- print(f"[DEBUG] tokenizer_image_token OK; input_ids.shape={input_ids.shape}")
299
  except Exception as e:
300
- print(f"[DEBUG] tokenizer_image_token failed: {e}; fallback to plain tokenizer")
301
- try:
302
- toks = self.tokenizer([full_prompt], return_tensors="pt", padding=True, truncation=True)
303
- input_ids = toks["input_ids"].to(self.device)
304
- print(f"[DEBUG] plain tokenizer OK; input_ids.shape={input_ids.shape}")
305
- except Exception as e2:
306
- return {"error": f"Tokenization failed: {e} / {e2}"}
307
-
308
- attention_mask = torch.ones_like(input_ids, device=self.device)
309
-
310
- # ---- Generate
311
- try:
312
- print(f"[DEBUG] generate(max_new_tokens={max_new}, temp={temperature}, top_p={top_p}, rep={repetition_penalty})")
313
- gen_ids = self.model.generate(
314
- input_ids=input_ids,
315
- attention_mask=attention_mask,
316
- images=images_tensor,
317
- image_sizes=image_sizes,
318
- do_sample=(temperature > 0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  temperature=temperature,
320
  top_p=top_p,
321
- max_new_tokens=max_new,
322
- repetition_penalty=repetition_penalty,
323
- use_cache=True,
324
  )
325
- print(f"[DEBUG] generate OK; gen_ids.shape={gen_ids.shape}")
326
- except Exception as e:
327
- return {"error": f"Generation failed: {e}"}
328
-
329
- # ---- Decode (sadece yeni tokenlar)
330
- try:
331
- new_tokens = gen_ids[0, input_ids.shape[1]:]
332
- text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
333
- print(f"[DEBUG] decoded_text_len={len(text)}")
334
- except Exception as e:
335
- return {"error": f"Decode failed: {e}"}
336
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  return {
338
- "generated_text": text,
339
- "model": self.model_name,
340
- "conv_mode": mode,
341
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import cv2
3
+ import datetime
 
 
 
 
4
  import torch
5
+ import numpy as np
6
+ import hashlib
7
+ import PIL
8
+ import base64
9
+ import json
10
  import requests
11
+ from PIL import Image
12
+ from io import BytesIO
13
+ from transformers import TextStreamer, TextIteratorStreamer
14
+ from threading import Thread
15
 
16
+ from llava import conversation as conversation_lib
17
+ from llava.constants import DEFAULT_IMAGE_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from llava.constants import (
19
+ IMAGE_TOKEN_INDEX,
20
  DEFAULT_IMAGE_TOKEN,
21
  DEFAULT_IM_START_TOKEN,
22
  DEFAULT_IM_END_TOKEN,
 
23
  )
24
+ from llava.conversation import conv_templates, SeparatorStyle
25
  from llava.model.builder import load_pretrained_model
26
+ from llava.utils import disable_torch_init
27
+ from llava.mm_utils import (
28
+ tokenizer_image_token,
29
+ process_images,
30
+ get_model_name_from_path,
31
+ KeywordsStoppingCriteria,
32
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ import spaces
35
+ from huggingface_hub import HfApi
36
+ from huggingface_hub import login
37
+ from huggingface_hub import revision_exists
38
+
39
+ # Initialize Hugging Face API
40
+ if "HF_TOKEN" in os.environ:
41
+ login(token=os.environ["HF_TOKEN"], write_permission=True)
42
+ api = HfApi()
43
+ repo_name = os.environ.get("LOG_REPO", "")
44
+ else:
45
+ api = None
46
+ repo_name = ""
47
+
48
+ external_log_dir = "./logs"
49
+ LOGDIR = external_log_dir
50
+ VOTEDIR = "./votes"
51
+
52
+ # Global variables for model and tokenizer
53
+ tokenizer = None
54
+ model = None
55
+ image_processor = None
56
+ context_len = None
57
+ args = None
58
+
59
+ # Gradio artık kullanılmıyor - Hugging Face endpoint için gerekli değil
60
+
61
+ def get_conv_log_filename():
62
+ t = datetime.datetime.now()
63
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
64
+ return name
65
+
66
+ def get_conv_vote_filename():
67
+ t = datetime.datetime.now()
68
+ name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
69
+ if not os.path.isfile(name):
70
+ os.makedirs(os.path.dirname(name), exist_ok=True)
71
+ return name
72
+
73
+ def vote_last_response(state, vote_type, model_selector):
74
+ if api and repo_name:
75
+ with open(get_conv_vote_filename(), "a") as fout:
76
+ data = {
77
+ "type": vote_type,
78
+ "model": model_selector,
79
+ "state": state,
80
+ }
81
+ fout.write(json.dumps(data) + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  try:
83
+ api.upload_file(
84
+ path_or_fileobj=get_conv_vote_filename(),
85
+ path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
86
+ repo_id=repo_name,
87
+ repo_type="dataset")
88
  except Exception as e:
89
+ print(f"Failed to upload vote file: {e}")
90
+
91
+ def is_valid_video_filename(name):
92
+ video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
93
+ ext = name.split(".")[-1].lower()
94
+ return ext in video_extensions
95
+
96
+ def is_valid_image_filename(name):
97
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
98
+ ext = name.split(".")[-1].lower()
99
+ return ext in image_extensions
100
+
101
+ def sample_frames(video_file, num_frames):
102
+ video = cv2.VideoCapture(video_file)
103
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
104
+ interval = total_frames // num_frames
105
+ frames = []
106
+ for i in range(total_frames):
107
+ ret, frame = video.read()
108
+ if not ret:
109
+ continue
110
+ if i % interval == 0:
111
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
112
+ frames.append(pil_img)
113
+ video.release()
114
+ return frames
115
+
116
+ def load_image(image_file):
117
+ if image_file.startswith("http") or image_file.startswith("https"):
118
+ response = requests.get(image_file)
119
+ if response.status_code == 200:
120
+ image = Image.open(BytesIO(response.content)).convert("RGB")
121
+ else:
122
+ raise ValueError("Failed to load image from URL")
123
+ else:
124
+ print("Load image from local file")
125
+ print(image_file)
126
+ image = Image.open(image_file).convert("RGB")
127
+ return image
128
 
129
+ def process_base64_image(base64_string):
130
+ """Process base64 encoded image string"""
131
+ try:
132
+ # Remove data URL prefix if present
133
+ if base64_string.startswith('data:image'):
134
+ base64_string = base64_string.split(',')[1]
135
+
136
+ # Decode base64 to bytes
137
+ image_data = base64.b64decode(base64_string)
138
+
139
+ # Convert to PIL Image
140
+ image = Image.open(BytesIO(image_data)).convert("RGB")
141
+ return image
142
+ except Exception as e:
143
+ raise ValueError(f"Failed to process base64 image: {e}")
144
+
145
+ def process_image_input(image_input):
146
+ """Process different types of image input (file path, URL, or base64)"""
147
+ if isinstance(image_input, str):
148
+ if image_input.startswith("http"):
149
+ return load_image(image_input)
150
+ elif os.path.exists(image_input):
151
+ return load_image(image_input)
152
+ else:
153
+ # Try to process as base64
154
+ return process_base64_image(image_input)
155
+ elif isinstance(image_input, dict) and "image" in image_input:
156
+ # Handle base64 image from dict
157
+ return process_base64_image(image_input["image"])
158
+ else:
159
+ raise ValueError("Unsupported image input format")
160
 
161
+ class InferenceDemo(object):
162
+ def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
163
+ disable_torch_init()
 
 
 
164
 
165
+ self.tokenizer, self.model, self.image_processor, self.context_len = (
166
+ tokenizer,
167
+ model,
168
+ image_processor,
169
+ context_len,
 
 
 
 
 
 
 
 
 
170
  )
 
 
 
 
 
 
 
 
171
 
 
172
  model_name = get_model_name_from_path(model_path)
173
+ if "llama-2" in model_name.lower():
174
+ conv_mode = "llava_llama_2"
175
+ elif "v1" in model_name.lower() or "pulse" in model_name.lower():
176
+ conv_mode = "llava_v1"
177
+ elif "mpt" in model_name.lower():
178
+ conv_mode = "mpt"
179
+ elif "qwen" in model_name.lower():
180
+ conv_mode = "qwen_1_5"
181
+ else:
182
+ conv_mode = "llava_v0"
 
 
 
 
183
 
184
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
185
+ print(
186
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
187
+ conv_mode, args.conv_mode, args.conv_mode
188
+ )
 
 
 
 
 
 
189
  )
190
+ else:
191
+ args.conv_mode = conv_mode
192
+ self.conv_mode = conv_mode
193
+ self.conversation = conv_templates[args.conv_mode].copy()
194
+ self.num_frames = args.num_frames
195
+
196
+ class ChatSessionManager:
197
+ def __init__(self):
198
+ self.chatbot_instance = None
199
+
200
+ def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
201
+ self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
202
+ print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
203
+
204
+ def reset_chatbot(self):
205
+ self.chatbot_instance = None
206
+
207
+ def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
208
+ if self.chatbot_instance is None:
209
+ self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
210
+ return self.chatbot_instance
211
+
212
+ chat_manager = ChatSessionManager()
213
+
214
+ def clear_history():
215
+ """Clear conversation history"""
216
+ chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
217
+ chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
218
+ return {"status": "success", "message": "Conversation history cleared"}
219
+
220
+ def add_message(message_text, image_input=None):
221
+ """Add a message to the conversation"""
222
+ global chat_image_num
223
+
224
+ if not hasattr(add_message, 'chat_image_num'):
225
+ add_message.chat_image_num = 0
226
+
227
+ if image_input:
228
+ add_message.chat_image_num += 1
229
+ if add_message.chat_image_num > 1:
230
+ chat_manager.reset_chatbot()
231
+ add_message.chat_image_num = 1
232
+
233
+ return {"status": "success", "message": "Message added"}
234
+
235
+ @spaces.GPU
236
+ def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=4096):
237
+ """Generate response for the given message and image"""
238
+ try:
239
+ if not message_text or not image_input:
240
+ return {"error": "Both message text and image are required"}
241
+
242
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
243
+
244
+ # Process image input
 
 
 
 
 
 
 
 
 
245
  try:
246
+ image = process_image_input(image_input)
 
 
 
247
  except Exception as e:
248
+ return {"error": f"Failed to process image: {str(e)}"}
249
+
250
+ # Save image for logging
251
+ all_image_hash = []
252
+ all_image_path = []
253
+
254
+ # Generate hash for the image
255
+ img_byte_arr = BytesIO()
256
+ image.save(img_byte_arr, format='JPEG')
257
+ img_byte_arr = img_byte_arr.getvalue()
258
+ image_hash = hashlib.md5(img_byte_arr).hexdigest()
259
+ all_image_hash.append(image_hash)
260
+
261
+ # Save image to logs
262
+ t = datetime.datetime.now()
263
+ filename = os.path.join(
264
+ LOGDIR,
265
+ "serve_images",
266
+ f"{t.year}-{t.month:02d}-{t.day:02d}",
267
+ f"{image_hash}.jpg",
268
+ )
269
+ all_image_path.append(filename)
270
+ if not os.path.isfile(filename):
271
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
272
+ print("image save to", filename)
273
+ image.save(filename)
274
+
275
+ # Process image for model
276
+ image_tensor = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)[0]
277
+ image_tensor = image_tensor.half().to(our_chatbot.model.device)
278
+ image_tensor = image_tensor.unsqueeze(0)
279
+
280
+ # Prepare conversation
281
+ inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
282
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
283
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
284
+ prompt = our_chatbot.conversation.get_prompt()
285
+
286
+ # Tokenize input
287
+ input_ids = tokenizer_image_token(
288
+ prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
289
+ ).unsqueeze(0).to(our_chatbot.model.device)
290
+
291
+ # Set up stopping criteria
292
+ stop_str = (
293
+ our_chatbot.conversation.sep
294
+ if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
295
+ else our_chatbot.conversation.sep2
296
+ )
297
+ keywords = [stop_str]
298
+ stopping_criteria = KeywordsStoppingCriteria(
299
+ keywords, our_chatbot.tokenizer, input_ids
300
+ )
301
+
302
+ # Generate response
303
+ with torch.no_grad():
304
+ outputs = our_chatbot.model.generate(
305
+ inputs=input_ids,
306
+ images=image_tensor,
307
+ do_sample=True,
308
  temperature=temperature,
309
  top_p=top_p,
310
+ max_new_tokens=max_output_tokens,
311
+ use_cache=False,
312
+ stopping_criteria=[stopping_criteria],
313
  )
314
+
315
+ # Decode response
316
+ response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
317
+ our_chatbot.conversation.messages[-1][-1] = response
318
+
319
+ # Log conversation
320
+ history = [(message_text, response)]
321
+ with open(get_conv_log_filename(), "a") as fout:
322
+ data = {
323
+ "type": "chat",
324
+ "model": "PULSE-7b",
325
+ "state": history,
326
+ "images": all_image_hash,
327
+ "images_path": all_image_path
328
+ }
329
+ print("#### conv log", data)
330
+ fout.write(json.dumps(data) + "\n")
331
+
332
+ # Upload files to Hugging Face if configured
333
+ if api and repo_name:
334
+ try:
335
+ for upload_img in all_image_path:
336
+ api.upload_file(
337
+ path_or_fileobj=upload_img,
338
+ path_in_repo=upload_img.replace("./logs/", ""),
339
+ repo_id=repo_name,
340
+ repo_type="dataset",
341
+ )
342
+
343
+ # Upload conversation log
344
+ api.upload_file(
345
+ path_or_fileobj=get_conv_log_filename(),
346
+ path_in_repo=get_conv_log_filename().replace("./logs/", ""),
347
+ repo_id=repo_name,
348
+ repo_type="dataset")
349
+ except Exception as e:
350
+ print(f"Failed to upload files: {e}")
351
+
352
  return {
353
+ "status": "success",
354
+ "response": response,
355
+ "conversation_id": id(our_chatbot.conversation)
356
  }
357
+
358
+ except Exception as e:
359
+ return {"error": f"Generation failed: {str(e)}"}
360
+
361
+ def upvote_last_response(conversation_id):
362
+ """Upvote the last response"""
363
+ try:
364
+ vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
365
+ return {"status": "success", "message": "Thank you for your voting!"}
366
+ except Exception as e:
367
+ return {"error": f"Failed to upvote: {str(e)}"}
368
+
369
+ def downvote_last_response(conversation_id):
370
+ """Downvote the last response"""
371
+ try:
372
+ vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
373
+ return {"status": "success", "message": "Thank you for your voting!"}
374
+ except Exception as e:
375
+ return {"error": f"Failed to downvote: {str(e)}"}
376
+
377
+ def flag_response(conversation_id):
378
+ """Flag the last response"""
379
+ try:
380
+ vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
381
+ return {"status": "success", "message": "Response flagged successfully"}
382
+ except Exception as e:
383
+ return {"error": f"Failed to flag response: {str(e)}"}
384
+
385
+ # Initialize model when module is imported
386
+ def initialize_model():
387
+ """Initialize the model and tokenizer"""
388
+ global tokenizer, model, image_processor, context_len, args
389
+
390
+ try:
391
+ # Set default arguments
392
+ class Args:
393
+ def __init__(self):
394
+ self.model_path = "PULSE-ECG/PULSE-7B"
395
+ self.model_base = None
396
+ self.num_gpus = 1
397
+ self.conv_mode = None
398
+ self.temperature = 0.05
399
+ self.max_new_tokens = 1024
400
+ self.num_frames = 16
401
+ self.load_8bit = False
402
+ self.load_4bit = False
403
+ self.debug = False
404
+
405
+ args = Args()
406
+
407
+ # Load model
408
+ model_path = args.model_path
409
+ model_name = get_model_name_from_path(args.model_path)
410
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
411
+ args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
412
+ )
413
+
414
+ print("### image_processor", image_processor)
415
+ print("### tokenizer", tokenizer)
416
+
417
+ # Move model to GPU if available
418
+ if torch.cuda.is_available():
419
+ model = model.to(torch.device('cuda'))
420
+ print("Model moved to CUDA")
421
+ else:
422
+ print("CUDA not available, using CPU")
423
+
424
+ return True
425
+
426
+ except Exception as e:
427
+ print(f"Failed to initialize model: {e}")
428
+ return False
429
+
430
+ # Initialize model on import
431
+ model_initialized = initialize_model()
432
+
433
+ # Main endpoint function for Hugging Face
434
+ def query(payload):
435
+ """Main endpoint function for Hugging Face inference API"""
436
+ if not model_initialized:
437
+ return {"error": "Model not initialized"}
438
+
439
+ try:
440
+ # Extract parameters from payload
441
+ message_text = payload.get("message", "")
442
+ image_input = payload.get("image", None)
443
+ temperature = payload.get("temperature", 0.05)
444
+ top_p = payload.get("top_p", 1.0)
445
+ max_output_tokens = payload.get("max_output_tokens", 4096)
446
+
447
+ if not message_text or not image_input:
448
+ return {"error": "Both 'message' and 'image' are required in the payload"}
449
+
450
+ # Generate response
451
+ result = generate_response(
452
+ message_text=message_text,
453
+ image_input=image_input,
454
+ temperature=temperature,
455
+ top_p=top_p,
456
+ max_output_tokens=max_output_tokens
457
+ )
458
+
459
+ return result
460
+
461
+ except Exception as e:
462
+ return {"error": f"Query failed: {str(e)}"}
463
+
464
+ # Additional utility endpoints
465
+ def health_check():
466
+ """Health check endpoint"""
467
+ return {
468
+ "status": "healthy",
469
+ "model_initialized": model_initialized,
470
+ "cuda_available": torch.cuda.is_available()
471
+ }
472
+
473
+ def get_model_info():
474
+ """Get model information"""
475
+ if not model_initialized:
476
+ return {"error": "Model not initialized"}
477
+
478
+ return {
479
+ "model_path": args.model_path if args else "Unknown",
480
+ "model_type": "PULSE-7B",
481
+ "cuda_available": torch.cuda.is_available(),
482
+ "device": str(model.device) if model else "Unknown"
483
+ }
484
+
485
+ # For backward compatibility and testing
486
+ if __name__ == "__main__":
487
+ import argparse
488
+
489
+ argparser = argparse.ArgumentParser()
490
+ argparser.add_argument("--server_name", default="0.0.0.0", type=str)
491
+ argparser.add_argument("--port", default="6123", type=str)
492
+ argparser.add_argument("--model_path", default="PULSE-ECG/PULSE-7B", type=str)
493
+ argparser.add_argument("--model-base", type=str, default=None)
494
+ argparser.add_argument("--num-gpus", type=int, default=1)
495
+ argparser.add_argument("--conv-mode", type=str, default=None)
496
+ argparser.add_argument("--temperature", type=float, default=0.05)
497
+ argparser.add_argument("--max-new-tokens", type=int, default=1024)
498
+ argparser.add_argument("--num_frames", type=int, default=16)
499
+ argparser.add_argument("--load-8bit", action="store_true")
500
+ argparser.add_argument("--load-4bit", action="store_true")
501
+ argparser.add_argument("--debug", action="store_true")
502
+
503
+ args = argparser.parse_args()
504
+
505
+ model_path = args.model_path
506
+ filt_invalid = "cut"
507
+ model_name = get_model_name_from_path(args.model_path)
508
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
509
+ print("### image_processor",image_processor)
510
+ print("### tokenzier",tokenizer)
511
+ model=model.to(torch.device('cuda'))
512
+
513
+ print("Model initialized successfully!")
514
+ print("This handler is now ready for Hugging Face endpoints.")
515
+ print("Use the 'query' function as the main endpoint.")