CanerDedeoglu commited on
Commit
74861f0
·
verified ·
1 Parent(s): 514d76c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +119 -29
handler.py CHANGED
@@ -1,30 +1,69 @@
1
- # handler.py (örnek iskelet)
2
- import base64, io, os
3
- from typing import Any, Dict, List
4
  import torch
5
  from PIL import Image
6
- from transformers import AutoTokenizer, AutoProcessor, AutoModelForVision2Seq # model tipinize göre
7
 
8
- HF_MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B") # ağırlıkların olduğu repo id
9
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
- DT = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # bfloat16 GPU varsa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class EndpointHandler:
13
  def __init__(self, path: str = "") -> None:
14
- # path: /repository (bu repo klasörü)
15
- # NOT: Ağırlıkları bu repodan değil, HF Hub’dan alıyoruz
16
- self.tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, use_fast=True, trust_remote_code=True)
17
- self.processor = AutoProcessor.from_pretrained(HF_MODEL_ID, trust_remote_code=True)
18
- self.model = AutoModelForVision2Seq.from_pretrained(
19
- HF_MODEL_ID,
20
- torch_dtype=DT,
21
- device_map="auto", # GPU varsa otomatik yerleşim
22
- trust_remote_code=True,
23
- low_cpu_mem_usage=True,
24
- # attn_implementation="sdpa", # flash-attn yoksa güvenlisi SDPA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
 
27
- def _load_image(self, img_field: str) -> Image.Image:
 
 
 
 
 
 
 
 
 
28
  if img_field.startswith("data:image"):
29
  head, b64 = img_field.split(",", 1)
30
  return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
@@ -34,26 +73,77 @@ class EndpointHandler:
34
  r.raise_for_status()
35
  return Image.open(io.BytesIO(r.content)).convert("RGB")
36
  else:
 
37
  return Image.open(img_field).convert("RGB")
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
40
  inputs = data.get("inputs") or {}
41
  params = data.get("parameters") or {}
42
- query = inputs.get("query", "")
43
- img_f = inputs.get("image", "")
44
- image = self._load_image(img_f) if img_f else None
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Model türüne göre preprocessing (örnek akış)
47
- model_inputs = self.processor(images=image, text=query, return_tensors="pt").to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
48
  gen_kwargs = {
49
- "max_new_tokens": int(params.get("max_new_tokens", 4096)),
50
  "temperature": float(params.get("temperature", 0.0)),
51
- "do_sample": bool(params.get("do_sample", params.get("temperature", 0.0) > 0)),
52
  "top_p": float(params.get("top_p", 1.0)),
53
  "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
 
 
54
  }
55
 
56
  with torch.no_grad():
57
- out_ids = self.model.generate(**model_inputs, **gen_kwargs)
58
- text = self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
59
- return [{"generated_text": text}]
 
 
 
 
 
 
 
1
+ # /repository/handler.py
2
+ import os, io, base64
3
+ from typing import Any, Dict, List, Optional
4
  import torch
5
  from PIL import Image
 
6
 
7
+ # ---- LLaVA demodaki modüller ----
8
+ from llava.model.builder import load_pretrained_model
9
+ from llava.mm_utils import tokenizer_image_token, process_images
10
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
11
+ from llava.conversation import conv_templates, SeparatorStyle
12
+ from llava.utils import disable_torch_init
13
+ from llava.model.builder import get_model_name_from_path
14
+
15
+ # Ortam değişkenleri (modeli nereden alacağımız)
16
+ # 1) Yerel klasörden yüklemek istersen HF_MODEL_LOCAL_DIR kullan
17
+ # 2) HF Hub repo id ile yüklemek istersen HF_MODEL_ID kullan
18
+ HF_MODEL_LOCAL_DIR = os.getenv("HF_MODEL_LOCAL_DIR", "").strip()
19
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "").strip() # ör: "your-org/your-llava-model"
20
+
21
+ DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2") # demo: llava_v2
22
+ MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "4096"))
23
+
24
+ # Flash-Attention yoksa SDPA güvenli yoldur
25
+ os.environ.setdefault("ATTN_IMPLEMENTATION", "sdpa")
26
 
27
  class EndpointHandler:
28
  def __init__(self, path: str = "") -> None:
29
+ """
30
+ path: /repository (endpoint bu klasörü model_dir olarak geçer)
31
+ """
32
+ disable_torch_init()
33
+
34
+ # Model yolunu belirle
35
+ if HF_MODEL_LOCAL_DIR:
36
+ model_path = HF_MODEL_LOCAL_DIR
37
+ elif HF_MODEL_ID:
38
+ model_path = HF_MODEL_ID
39
+ else:
40
+ # Eğer ağırlık/konfig bu repository içindeyse path= "/repository"
41
+ model_path = path
42
+
43
+ # model adı (LLaVA utils)
44
+ self.model_name = get_model_name_from_path(model_path)
45
+
46
+ # LLaVA yükleme — demo ile aynı giriş:
47
+ # Dönüş: tokenizer, model, image_processor, context_len
48
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
49
+ model_path=model_path,
50
+ model_base=None, # LoRA vb. yoksa None
51
+ model_name=self.model_name,
52
+ torch_dtype="auto", # ortam GPU'ya göre seçsin
53
+ attn_implementation=os.getenv("ATTN_IMPLEMENTATION", "sdpa"),
54
+ device_map="auto"
55
  )
56
 
57
+ # Görüntü başlangıç/bitiş tokenları (model sürümüne göre aktif)
58
+ self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
59
+ self.image_token = DEFAULT_IMAGE_TOKEN
60
+ self.im_start = DEFAULT_IM_START_TOKEN
61
+ self.im_end = DEFAULT_IM_END_TOKEN
62
+
63
+ # ---- Yardımcılar ----
64
+ def _load_image(self, img_field: str) -> Optional[Image.Image]:
65
+ if not img_field:
66
+ return None
67
  if img_field.startswith("data:image"):
68
  head, b64 = img_field.split(",", 1)
69
  return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
 
73
  r.raise_for_status()
74
  return Image.open(io.BytesIO(r.content)).convert("RGB")
75
  else:
76
+ # container içinden dosya okunacaksa
77
  return Image.open(img_field).convert("RGB")
78
 
79
+ def _build_prompt(self, user_text: str, conv_mode: str) -> str:
80
+ # Demo: conv_templates ile diyalog kur
81
+ conv = conv_templates[conv_mode].copy()
82
+ if self.use_im_start_end:
83
+ # <im_start> <image> <im_end> + kullanıcı metni
84
+ content = f"{self.im_start}{self.image_token}{self.im_end}\n{user_text}"
85
+ else:
86
+ content = f"{self.image_token}\n{user_text}"
87
+ conv.append_message(conv.roles[0], content)
88
+ conv.append_message(conv.roles[1], None) # assistant boş, model dolduracak
89
+ return conv.get_prompt()
90
+
91
+ # ---- Inference giriş noktası ----
92
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
93
+ """
94
+ Beklenen giriş (demo ile uyumlu):
95
+ {
96
+ "inputs": { "query": "...", "image": "<url|dataurl|path>" },
97
+ "parameters": { "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0, ... },
98
+ "conv_mode": "llava_v2" # opsiyonel; yoksa varsayılanı kullanırız
99
+ }
100
+ """
101
  inputs = data.get("inputs") or {}
102
  params = data.get("parameters") or {}
103
+ conv_mode = data.get("conv_mode") or DEFAULT_CONV_MODE
104
+
105
+ query_text = inputs.get("query", "")
106
+ image_f = inputs.get("image", "")
107
+ pil_img = self._load_image(image_f)
108
+
109
+ # 1) Prompt (conversation şablonu)
110
+ prompt = self._build_prompt(query_text, conv_mode)
111
+
112
+ # 2) Görsel tensörü (demo: process_images)
113
+ image_tensors = None
114
+ if pil_img is not None:
115
+ image_tensors = process_images([pil_img], self.image_processor, self.model.config)
116
 
117
+ # 3) Tokenize (görüntü tokenını metne göm)
118
+ input_ids = tokenizer_image_token(
119
+ prompt,
120
+ self.tokenizer,
121
+ IMAGE_TOKEN_INDEX,
122
+ return_tensors="pt"
123
+ )
124
+
125
+ # 4) Cihaza taşı
126
+ input_ids = input_ids.to(self.model.device, non_blocking=True)
127
+ if image_tensors is not None:
128
+ image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
129
+
130
+ # 5) Generate (demo parametreleri)
131
  gen_kwargs = {
132
+ "max_new_tokens": int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF)),
133
  "temperature": float(params.get("temperature", 0.0)),
 
134
  "top_p": float(params.get("top_p", 1.0)),
135
  "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
136
+ "do_sample": bool(params.get("do_sample", float(params.get("temperature", 0.0)) > 0)),
137
+ "use_cache": bool(params.get("use_cache", True)),
138
  }
139
 
140
  with torch.no_grad():
141
+ output_ids = self.model.generate(
142
+ input_ids,
143
+ images=image_tensors,
144
+ **gen_kwargs
145
+ )
146
+
147
+ # 6) Decode (assistant’ın cevabı)
148
+ outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
149
+ return [{"generated_text": outputs}]