CanerDedeoglu commited on
Commit
9ed768e
·
verified ·
1 Parent(s): 69f64da

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -29
handler.py CHANGED
@@ -1,49 +1,59 @@
1
- # /repository/handler.py
2
- import base64, io, os, json
3
  from typing import Any, Dict, List
 
4
  from PIL import Image
 
5
 
6
- # (Gerekiyorsa: from transformers import ... # model yükleme burada olur)
 
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = "") -> None:
10
- # Burada model/processor/tokenizer'ı yükleyin
11
- # ör: self.model = ...
12
- # self.processor = ...
13
- pass
 
 
 
 
 
 
 
 
14
 
15
  def _load_image(self, img_field: str) -> Image.Image:
16
  if img_field.startswith("data:image"):
17
- # data URL -> bytes
18
- header, b64data = img_field.split(",", 1)
19
- img_bytes = base64.b64decode(b64data)
20
- return Image.open(io.BytesIO(img_bytes)).convert("RGB")
21
  elif img_field.startswith("http://") or img_field.startswith("https://"):
22
  import requests
23
- resp = requests.get(img_field, timeout=20)
24
- resp.raise_for_status()
25
- return Image.open(io.BytesIO(resp.content)).convert("RGB")
26
  else:
27
- # Yerel yol (container içinden)
28
  return Image.open(img_field).convert("RGB")
29
 
30
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
31
- """
32
- Hugging Face Inference Toolkit burayı çağırır.
33
- Beklenen dönüş genelde: [{"generated_text": "..."}]
34
- """
35
  inputs = data.get("inputs") or {}
36
  params = data.get("parameters") or {}
37
  query = inputs.get("query", "")
38
- img_field = inputs.get("image", "")
 
39
 
40
- # Görseli hazırla (opsiyonel modeliniz görsel kullanıyorsa)
41
- image = None
42
- if img_field:
43
- image = self._load_image(img_field)
 
 
 
 
 
44
 
45
- # Burada kendi inference kodunuzu çağırın:
46
- # out_text = run_model(self.model, self.processor, query, image, **params)
47
- out_text = f"(demo) prompt='{query[:50]}...' image={'yes' if image else 'no'}"
48
-
49
- return [{"generated_text": out_text}]
 
1
+ # handler.py (örnek iskelet)
2
+ import base64, io, os
3
  from typing import Any, Dict, List
4
+ import torch
5
  from PIL import Image
6
+ from transformers import AutoTokenizer, AutoProcessor, AutoModelForVision2Seq # model tipinize göre
7
 
8
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B") # ağırlıkların olduğu repo id
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ DT = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # bfloat16 GPU varsa
11
 
12
  class EndpointHandler:
13
  def __init__(self, path: str = "") -> None:
14
+ # path: /repository (bu repo klasörü)
15
+ # NOT: Ağırlıkları bu repodan değil, HF Hub’dan alıyoruz
16
+ self.tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, use_fast=True, trust_remote_code=True)
17
+ self.processor = AutoProcessor.from_pretrained(HF_MODEL_ID, trust_remote_code=True)
18
+ self.model = AutoModelForVision2Seq.from_pretrained(
19
+ HF_MODEL_ID,
20
+ torch_dtype=DT,
21
+ device_map="auto", # GPU varsa otomatik yerleşim
22
+ trust_remote_code=True,
23
+ low_cpu_mem_usage=True,
24
+ # attn_implementation="sdpa", # flash-attn yoksa güvenlisi SDPA
25
+ )
26
 
27
  def _load_image(self, img_field: str) -> Image.Image:
28
  if img_field.startswith("data:image"):
29
+ head, b64 = img_field.split(",", 1)
30
+ return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
 
 
31
  elif img_field.startswith("http://") or img_field.startswith("https://"):
32
  import requests
33
+ r = requests.get(img_field, timeout=20)
34
+ r.raise_for_status()
35
+ return Image.open(io.BytesIO(r.content)).convert("RGB")
36
  else:
 
37
  return Image.open(img_field).convert("RGB")
38
 
39
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
40
  inputs = data.get("inputs") or {}
41
  params = data.get("parameters") or {}
42
  query = inputs.get("query", "")
43
+ img_f = inputs.get("image", "")
44
+ image = self._load_image(img_f) if img_f else None
45
 
46
+ # Model türüne göre preprocessing (örnek akış)
47
+ model_inputs = self.processor(images=image, text=query, return_tensors="pt").to(self.model.device)
48
+ gen_kwargs = {
49
+ "max_new_tokens": int(params.get("max_new_tokens", 256)),
50
+ "temperature": float(params.get("temperature", 0.0)),
51
+ "do_sample": bool(params.get("do_sample", params.get("temperature", 0.0) > 0)),
52
+ "top_p": float(params.get("top_p", 1.0)),
53
+ "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
54
+ }
55
 
56
+ with torch.no_grad():
57
+ out_ids = self.model.generate(**model_inputs, **gen_kwargs)
58
+ text = self.tokenizer.decode(out_ids[0], skip_special_tokens=True)
59
+ return [{"generated_text": text}]