Files changed (1) hide show
  1. handler.py +35 -22
handler.py CHANGED
@@ -1,5 +1,10 @@
1
- import base64, io, os, logging
2
- import requests, torch, transformers
 
 
 
 
 
3
  from PIL import Image
4
  from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor
5
 
@@ -8,7 +13,7 @@ class EndpointHandler:
8
  logging.warning(f"[INIT] Transformers version: {transformers.__version__}")
9
  self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B")
10
 
11
- # 1) Normal yol: pipeline
12
  try:
13
  self.pipe = pipeline(
14
  task="image-text-to-text",
@@ -43,25 +48,33 @@ class EndpointHandler:
43
  except Exception as e:
44
  logging.warning(f"[INIT] override failed: {e}")
45
 
46
- # 3) Fallback: AutoProcessor + AutoModel
47
- logging.warning("[INIT] Fallback: AutoProcessor/AutoModel")
48
- proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
49
- mdl = AutoModelForCausalLM.from_pretrained(
50
- self.model_id,
51
- device_map="auto",
52
- torch_dtype="auto",
53
- trust_remote_code=True
54
- )
 
 
 
 
 
55
 
56
- def _mini_pipe(msgs, **params):
57
- inputs = proc(msgs, return_tensors="pt").to(mdl.device)
58
- gen_kwargs = {"max_new_tokens": 512, **params}
59
- with torch.inference_mode():
60
- out_ids = mdl.generate(**inputs, **gen_kwargs)
61
- return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True)
62
 
63
- self.pipe = _mini_pipe
64
- logging.warning("[INIT] Fallback loaded")
 
 
 
65
 
66
  # ---- helpers ----
67
  def _ensure_pad_token(self):
@@ -74,7 +87,7 @@ class EndpointHandler:
74
  pass
75
 
76
  def _normalize_inputs(self, data: dict):
77
- # Basit şema
78
  if "image_url" in data or "text" in data:
79
  image_url = data.get("image_url")
80
  text = data.get("text", "Interpret this ECG image.")
@@ -87,7 +100,7 @@ class EndpointHandler:
87
  ]}
88
  ], data.get("parameters", {})
89
 
90
- # Multimodal chat şeması
91
  if "inputs" in data:
92
  return data.get("inputs", []), data.get("parameters", {})
93
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import logging
5
+ import requests
6
+ import torch
7
+ import transformers
8
  from PIL import Image
9
  from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoProcessor
10
 
 
13
  logging.warning(f"[INIT] Transformers version: {transformers.__version__}")
14
  self.model_id = os.getenv("PULSE_MODEL_ID", "PULSE-ECG/PULSE-7B")
15
 
16
+ # 1) Normal path: attempt pipeline directly
17
  try:
18
  self.pipe = pipeline(
19
  task="image-text-to-text",
 
48
  except Exception as e:
49
  logging.warning(f"[INIT] override failed: {e}")
50
 
51
+ # 3) Fallback: AutoProcessor + AutoModel with config override check
52
+ try:
53
+ cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
54
+ if getattr(cfg, "model_type", None) == "llava_llama":
55
+ logging.warning("[INIT] Fallback override: llava_llama -> llava")
56
+ cfg.model_type = "llava"
57
+ proc = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True, config=cfg)
58
+ mdl = AutoModelForCausalLM.from_pretrained(
59
+ self.model_id,
60
+ device_map="auto",
61
+ torch_dtype="auto",
62
+ trust_remote_code=True,
63
+ config=cfg,
64
+ )
65
 
66
+ def _mini_pipe(msgs, **params):
67
+ inputs = proc(msgs, return_tensors="pt").to(mdl.device)
68
+ gen_kwargs = {"max_new_tokens": 512, **params}
69
+ with torch.inference_mode():
70
+ out_ids = mdl.generate(**inputs, **gen_kwargs)
71
+ return proc.tokenizer.batch_decode(out_ids, skip_special_tokens=True)
72
 
73
+ self.pipe = _mini_pipe
74
+ logging.warning("[INIT] Fallback loaded")
75
+ except Exception as e:
76
+ logging.error(f"[INIT] Fallback failed: {e}")
77
+ raise
78
 
79
  # ---- helpers ----
80
  def _ensure_pad_token(self):
 
87
  pass
88
 
89
  def _normalize_inputs(self, data: dict):
90
+ # Simple schema
91
  if "image_url" in data or "text" in data:
92
  image_url = data.get("image_url")
93
  text = data.get("text", "Interpret this ECG image.")
 
100
  ]}
101
  ], data.get("parameters", {})
102
 
103
+ # Multimodal chat schema
104
  if "inputs" in data:
105
  return data.get("inputs", []), data.get("parameters", {})
106