|
|
|
|
|
""" |
|
|
PULSE ECG Handler — Demo Parity + Style Hint + Robust Fallbacks + Debug |
|
|
- Demo app.py ile aynı üretim ayarları: |
|
|
do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096 |
|
|
- Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter |
|
|
- Görsel tensörü: .half() ve model cihazında |
|
|
- Streamer: TextIteratorStreamer (demo gibi), thread ile generate |
|
|
- Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik |
|
|
- STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) |
|
|
- Post-process: yalnızca whitespace/biçim temizliği |
|
|
- Ekler: |
|
|
* DEBUG yardımcıları (ENV: DEBUG=1) |
|
|
* image_processor fallback (AutoProcessor → CLIPImageProcessor) |
|
|
* process_images fallback (torchvision + CLIP norm) |
|
|
* FastAPI wrapper: /health, /info, /query, /debug |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import base64 |
|
|
import hashlib |
|
|
import datetime |
|
|
from io import BytesIO |
|
|
from threading import Thread |
|
|
from typing import Optional, Union, Any, Dict |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
import requests |
|
|
|
|
|
|
|
|
def _env_bool(name: str, default: bool = False) -> bool: |
|
|
v = os.getenv(name) |
|
|
if v is None: |
|
|
return default |
|
|
return str(v).strip().lower() in {"1", "true", "yes", "y", "on"} |
|
|
|
|
|
DEBUG = _env_bool("DEBUG", False) |
|
|
|
|
|
def dbg(*args, **kwargs): |
|
|
if DEBUG: |
|
|
print("[DEBUG]", *args, **kwargs) |
|
|
|
|
|
def warn(*args, **kwargs): |
|
|
print("[WARN]", *args, **kwargs) |
|
|
|
|
|
|
|
|
try: |
|
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
|
from llava.conversation import conv_templates, SeparatorStyle |
|
|
from llava.model.builder import load_pretrained_model |
|
|
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path |
|
|
from llava.utils import disable_torch_init |
|
|
LLAVA_AVAILABLE = True |
|
|
except Exception as e: |
|
|
LLAVA_AVAILABLE = False |
|
|
warn(f"LLaVA not available: {e}") |
|
|
|
|
|
try: |
|
|
from transformers import TextIteratorStreamer, StoppingCriteria |
|
|
TRANSFORMERS_AVAILABLE = True |
|
|
except Exception as e: |
|
|
TRANSFORMERS_AVAILABLE = False |
|
|
warn(f"transformers not available: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi, login |
|
|
HF_HUB_AVAILABLE = True |
|
|
except Exception: |
|
|
HF_HUB_AVAILABLE = False |
|
|
|
|
|
api = None |
|
|
repo_name = "" |
|
|
if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: |
|
|
try: |
|
|
login(token=os.environ["HF_TOKEN"], write_permission=True) |
|
|
api = HfApi() |
|
|
repo_name = os.environ.get("LOG_REPO", "") |
|
|
except Exception as e: |
|
|
warn(f"[HF Hub] init failed: {e}") |
|
|
api = None |
|
|
repo_name = "" |
|
|
|
|
|
LOGDIR = "./logs" |
|
|
os.makedirs(LOGDIR, exist_ok=True) |
|
|
|
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
image_processor = None |
|
|
context_len = None |
|
|
args = None |
|
|
model_initialized = False |
|
|
|
|
|
|
|
|
STYLE_HINT = ( |
|
|
"Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, " |
|
|
"P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. " |
|
|
"Use neutral, factual cardiology language. Avoid headings and bullet points. " |
|
|
"Finish with a single final line starting exactly with 'Structured clinical impression:' " |
|
|
"followed by a succinct, comma-separated summary of the key diagnoses." |
|
|
) |
|
|
|
|
|
|
|
|
def _safe_upload(path: str): |
|
|
if api and repo_name and path and os.path.isfile(path): |
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=path, |
|
|
path_in_repo=path.replace("./logs/", ""), |
|
|
repo_id=repo_name, |
|
|
repo_type="dataset", |
|
|
) |
|
|
except Exception as e: |
|
|
warn(f"[upload] failed for {path}: {e}") |
|
|
|
|
|
def _conv_log_path() -> str: |
|
|
t = datetime.datetime.now() |
|
|
return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json") |
|
|
|
|
|
def load_image_any(image_input: Union[str, dict]) -> Image.Image: |
|
|
""" |
|
|
Desteklenen: |
|
|
- URL (http/https) |
|
|
- yerel dosya yolu |
|
|
- base64 (opsiyonel data URL prefix ile) |
|
|
- {"image": <base64|dataurl>} |
|
|
""" |
|
|
if isinstance(image_input, str): |
|
|
s = image_input.strip() |
|
|
if s.startswith(("http://", "https://")): |
|
|
r = requests.get(s, timeout=(5, 20)) |
|
|
r.raise_for_status() |
|
|
return Image.open(BytesIO(r.content)).convert("RGB") |
|
|
if os.path.exists(s): |
|
|
return Image.open(s).convert("RGB") |
|
|
|
|
|
if s.startswith("data:image"): |
|
|
s = s.split(",", 1)[1] |
|
|
raw = base64.b64decode(s) |
|
|
return Image.open(BytesIO(raw)).convert("RGB") |
|
|
|
|
|
if isinstance(image_input, dict) and "image" in image_input: |
|
|
return load_image_any(image_input["image"]) |
|
|
|
|
|
raise ValueError("Unsupported image input format") |
|
|
|
|
|
def _normalize_whitespace(text: str) -> str: |
|
|
text = text.replace("\r\n", "\n").replace("\r", "\n") |
|
|
lines = [re.sub(r"[ \t]+", " ", ln.strip()) for ln in text.split("\n")] |
|
|
text = "\n".join(lines).strip() |
|
|
text = re.sub(r"\n{3,}", "\n\n", text) |
|
|
return text |
|
|
|
|
|
def _postprocess_min(text: str) -> str: |
|
|
return _normalize_whitespace(text) |
|
|
|
|
|
|
|
|
class SafeKeywordsStoppingCriteria(StoppingCriteria): |
|
|
def __init__(self, keyword: str, tokenizer): |
|
|
self.tokenizer = tokenizer |
|
|
tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0] |
|
|
self.kw_ids = tok |
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
|
if input_ids is None or input_ids.shape[0] == 0: |
|
|
return False |
|
|
out = input_ids[0] |
|
|
n = self.kw_ids.shape[0] |
|
|
if out.shape[0] < n: |
|
|
return False |
|
|
tail = out[-n:] |
|
|
kw = self.kw_ids.to(tail.device) |
|
|
return torch.equal(tail, kw) |
|
|
|
|
|
|
|
|
class InferenceDemo: |
|
|
def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_): |
|
|
if not LLAVA_AVAILABLE: |
|
|
raise ImportError("LLaVA not available") |
|
|
disable_torch_init() |
|
|
self.tokenizer, self.model, self.image_processor, self.context_len = ( |
|
|
tokenizer_, model_, image_processor_, context_len_ |
|
|
) |
|
|
self.conv_mode = "llava_v1" |
|
|
self.conversation = conv_templates[self.conv_mode].copy() |
|
|
self.num_frames = getattr(args, "num_frames", 16) |
|
|
|
|
|
class ChatSessionManager: |
|
|
def __init__(self): |
|
|
self.chatbot = None |
|
|
self.args = None |
|
|
self.model_path = None |
|
|
def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len): |
|
|
if self.chatbot is None: |
|
|
self.args = args |
|
|
self.model_path = model_path |
|
|
self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) |
|
|
def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): |
|
|
self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len) |
|
|
self.chatbot.conversation = conv_templates[self.chatbot.conv_mode].copy() |
|
|
return self.chatbot |
|
|
|
|
|
chat_manager = ChatSessionManager() |
|
|
|
|
|
def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device): |
|
|
inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}" |
|
|
chatbot.conversation.append_message(chatbot.conversation.roles[0], inp) |
|
|
chatbot.conversation.append_message(chatbot.conversation.roles[1], None) |
|
|
prompt = chatbot.conversation.get_prompt() |
|
|
input_ids = tokenizer_image_token( |
|
|
prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" |
|
|
).unsqueeze(0).to(device) |
|
|
return prompt, input_ids |
|
|
|
|
|
def generate_response( |
|
|
message_text: str, |
|
|
image_input, |
|
|
*, |
|
|
temperature: Optional[float] = None, |
|
|
top_p: Optional[float] = None, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
conv_mode_override: Optional[str] = None, |
|
|
repetition_penalty: Optional[float] = None, |
|
|
det_seed: Optional[int] = None, |
|
|
): |
|
|
if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE): |
|
|
return {"error": "Required libraries not available (llava/transformers)"} |
|
|
if not message_text or image_input is None: |
|
|
return {"error": "Both 'message' and 'image' are required"} |
|
|
|
|
|
if temperature is None: temperature = 0.05 |
|
|
if top_p is None: top_p = 1.0 |
|
|
if max_new_tokens is None: max_new_tokens = 4096 |
|
|
if repetition_penalty is None: repetition_penalty = 1.0 |
|
|
|
|
|
dbg(f"[gen] temperature={temperature} top_p={top_p} max_new_tokens={max_new_tokens} rep={repetition_penalty} seed={det_seed}") |
|
|
|
|
|
chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len) |
|
|
if conv_mode_override and conv_mode_override in conv_templates: |
|
|
chatbot.conversation = conv_templates[conv_mode_override].copy() |
|
|
|
|
|
try: |
|
|
pil_img = load_image_any(image_input) |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to load image: {e}"} |
|
|
|
|
|
img_hash, img_path = "NA", None |
|
|
try: |
|
|
buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue() |
|
|
img_hash = hashlib.md5(raw).hexdigest() |
|
|
t = datetime.datetime.now() |
|
|
img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg") |
|
|
os.makedirs(os.path.dirname(img_path), exist_ok=True) |
|
|
if not os.path.isfile(img_path): |
|
|
pil_img.save(img_path) |
|
|
except Exception as e: |
|
|
warn(f"[log] save image failed: {e}") |
|
|
|
|
|
device = next(chatbot.model.parameters()).device |
|
|
dtype = torch.float16 |
|
|
|
|
|
|
|
|
try: |
|
|
dbg(f"[pre] PIL image size={pil_img.size}, mode={pil_img.mode}, processor={type(chatbot.image_processor)}") |
|
|
processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config) |
|
|
dbg("[pre] process_images ok") |
|
|
|
|
|
if isinstance(processed, (list, tuple)) and len(processed) > 0: |
|
|
image_tensor = processed[0] |
|
|
elif isinstance(processed, torch.Tensor): |
|
|
image_tensor = processed[0] if processed.ndim == 4 else processed |
|
|
else: |
|
|
raise ValueError("Image processing returned empty") |
|
|
|
|
|
if image_tensor.ndim == 3: |
|
|
image_tensor = image_tensor.unsqueeze(0) |
|
|
image_tensor = image_tensor.to(device=device, dtype=dtype) |
|
|
dbg(f"[pre] tensor shape={tuple(image_tensor.shape)} dtype={image_tensor.dtype} device={image_tensor.device}") |
|
|
except Exception as e: |
|
|
warn(f"[pre] process_images failed: {e} → manual CLIP preprocess fallback kullanılacak.") |
|
|
try: |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms import InterpolationMode |
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize(224, interpolation=InterpolationMode.BICUBIC), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.48145466, 0.4578275, 0.40821073], |
|
|
std=[0.26862954, 0.26130258, 0.27577711] |
|
|
), |
|
|
]) |
|
|
image_tensor = preprocess(pil_img).unsqueeze(0).to(device=device, dtype=dtype) |
|
|
dbg("[pre] manual CLIP preprocess fallback ok → tensor shape=" + str(tuple(image_tensor.shape))) |
|
|
except Exception as ee: |
|
|
return {"error": f"Image processing failed (and fallback failed): {ee}"} |
|
|
|
|
|
msg = (message_text or "").strip() |
|
|
msg = f"{msg}\n\n{STYLE_HINT}" |
|
|
dbg(f"[prompt] conv_sep_style={chatbot.conversation.sep_style} sep_len={len(chatbot.conversation.sep)}") |
|
|
_, input_ids = _build_prompt_and_ids(chatbot, msg, device) |
|
|
|
|
|
stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2 |
|
|
stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer) |
|
|
|
|
|
if det_seed is not None: |
|
|
try: |
|
|
s = int(det_seed) |
|
|
torch.manual_seed(s) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(s) |
|
|
torch.cuda.manual_seed_all(s) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
streamer = TextIteratorStreamer(chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
inputs=input_ids, |
|
|
images=image_tensor, |
|
|
streamer=streamer, |
|
|
do_sample=True, |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
repetition_penalty=float(repetition_penalty), |
|
|
use_cache=False, |
|
|
stopping_criteria=[stopping], |
|
|
) |
|
|
|
|
|
try: |
|
|
t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs) |
|
|
t.start() |
|
|
chunks = [] |
|
|
for piece in streamer: |
|
|
chunks.append(piece) |
|
|
text = "".join(chunks) |
|
|
text = _postprocess_min(text) |
|
|
chatbot.conversation.messages[-1][-1] = text |
|
|
except Exception as e: |
|
|
return {"error": f"Generation failed: {e}"} |
|
|
|
|
|
try: |
|
|
row = { |
|
|
"time": datetime.datetime.now().isoformat(), |
|
|
"type": "chat", |
|
|
"model": "PULSE-7B", |
|
|
"state": [(message_text, text)], |
|
|
"image_hash": img_hash, |
|
|
"image_path": img_path or "", |
|
|
} |
|
|
with open(_conv_log_path(), "a", encoding="utf-8") as f: |
|
|
f.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
_safe_upload(_conv_log_path()); _safe_upload(img_path or "") |
|
|
except Exception as e: |
|
|
warn(f"[log] failed: {e}") |
|
|
|
|
|
return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)} |
|
|
|
|
|
|
|
|
def query(payload: dict): |
|
|
"""HF Endpoint entry (demo-like).""" |
|
|
global model_initialized, tokenizer, model, image_processor, context_len, args |
|
|
if not model_initialized: |
|
|
if not initialize_model(): |
|
|
return {"error": "Model initialization failed"} |
|
|
model_initialized = True |
|
|
|
|
|
try: |
|
|
message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "" |
|
|
image = payload.get("image") or payload.get("image_url") or payload.get("img") or None |
|
|
if not message.strip(): return {"error": "Missing 'message' text"} |
|
|
if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."} |
|
|
|
|
|
temperature = float(payload.get("temperature", 0.05)) |
|
|
top_p = float(payload.get("top_p", 1.0)) |
|
|
max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096)))) |
|
|
repetition_penalty = float(payload.get("repetition_penalty", 1.0)) |
|
|
|
|
|
conv_mode_override = payload.get("conv_mode", None) |
|
|
det_seed = payload.get("det_seed", None) |
|
|
if det_seed is not None: |
|
|
try: det_seed = int(det_seed) |
|
|
except Exception: det_seed = None |
|
|
|
|
|
return generate_response( |
|
|
message_text=message, |
|
|
image_input=image, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_new_tokens=max_new_tokens, |
|
|
conv_mode_override=conv_mode_override, |
|
|
repetition_penalty=repetition_penalty, |
|
|
det_seed=det_seed, |
|
|
) |
|
|
except Exception as e: |
|
|
return {"error": f"Query failed: {e}"} |
|
|
|
|
|
def health_check(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_initialized": model_initialized, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"llava_available": LLAVA_AVAILABLE, |
|
|
"transformers_available": TRANSFORMERS_AVAILABLE, |
|
|
} |
|
|
|
|
|
def get_model_info(): |
|
|
if not model_initialized: |
|
|
return {"error": "Model not initialized"} |
|
|
return { |
|
|
"model_path": args.model_path if args else "Unknown", |
|
|
"context_len": context_len, |
|
|
"device": str(next(model.parameters()).device) if model else "Unknown", |
|
|
} |
|
|
|
|
|
|
|
|
class _Args: |
|
|
def __init__(self): |
|
|
self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B") |
|
|
self.model_base = None |
|
|
self.num_gpus = int(os.getenv("NUM_GPUS", "1")) |
|
|
self.conv_mode = "llava_v1" |
|
|
self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096")) |
|
|
self.num_frames = 16 |
|
|
self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0"))) |
|
|
self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0"))) |
|
|
self.debug = bool(int(os.getenv("DEBUG", "0"))) |
|
|
|
|
|
def initialize_model(): |
|
|
global tokenizer, model, image_processor, context_len, args |
|
|
if not LLAVA_AVAILABLE: |
|
|
warn("[init] LLaVA not available; cannot init.") |
|
|
return False |
|
|
try: |
|
|
args = _Args() |
|
|
dbg(f"[init] HF_MODEL_ID={args.model_path} | LOAD_8BIT={args.load_8bit} | LOAD_4BIT={args.load_4bit}") |
|
|
model_name = get_model_name_from_path(args.model_path) |
|
|
|
|
|
tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model( |
|
|
args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit |
|
|
) |
|
|
dbg(f"[init] load_pretrained_model ok | tokenizer={type(tokenizer_)} | model={type(model_)} | image_processor={type(image_processor_)} | context_len={context_len_}") |
|
|
|
|
|
try: |
|
|
_ = next(model_.parameters()).device |
|
|
except Exception: |
|
|
if torch.cuda.is_available(): |
|
|
model_ = model_.to(torch.device("cuda")) |
|
|
model_.eval() |
|
|
dbg(f"[init] device={next(model_.parameters()).device}, cuda_available={torch.cuda.is_available()}") |
|
|
|
|
|
|
|
|
try: |
|
|
if image_processor_ is None: |
|
|
dbg("[init] image_processor None → AutoProcessor fallback deneniyor…") |
|
|
try: |
|
|
from transformers import AutoProcessor |
|
|
image_processor_ = AutoProcessor.from_pretrained(args.model_path) |
|
|
dbg("[init] image_processor: AutoProcessor.from_pretrained(model_path) ile yüklendi.") |
|
|
except Exception as _e1: |
|
|
dbg(f"[init] AutoProcessor failed: {_e1} → CLIPImageProcessor fallback deneniyor…") |
|
|
from transformers import CLIPImageProcessor |
|
|
image_processor_ = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
|
warn("[init] image_processor: CLIPImageProcessor(openai/clip-vit-large-patch14) fallback kullanılıyor.") |
|
|
except Exception as _e: |
|
|
warn(f"[init] image_processor fallback failed: {_e}") |
|
|
|
|
|
|
|
|
try: |
|
|
ip = image_processor_ |
|
|
if ip is not None: |
|
|
crop_sz = getattr(getattr(ip, "crop_size", None), "height", None) or getattr(ip, "crop_size", None) |
|
|
size_sz = getattr(getattr(ip, "size", None), "height", None) or getattr(ip, "size", None) |
|
|
dbg(f"[init] image_processor crop_size={crop_sz} size={size_sz} class={ip.__class__.__name__}") |
|
|
else: |
|
|
warn("[init] image_processor yine None (fallback da başarısız).") |
|
|
except Exception as e_ip: |
|
|
warn(f"[init] image_processor inspect error: {e_ip}") |
|
|
|
|
|
globals()["tokenizer"] = tokenizer_ |
|
|
globals()["model"] = model_ |
|
|
globals()["image_processor"] = image_processor_ |
|
|
globals()["context_len"] = context_len_ |
|
|
|
|
|
chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_) |
|
|
print("[init] model/tokenizer/image_processor loaded.") |
|
|
return True |
|
|
except Exception as e: |
|
|
warn(f"[init] failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
"""Hugging Face Endpoint uyumlu sınıf""" |
|
|
def __init__(self, model_dir): |
|
|
self.model_dir = model_dir |
|
|
print(f"EndpointHandler initialized with model_dir: {model_dir}") |
|
|
def __call__(self, payload): |
|
|
if "inputs" in payload: |
|
|
return query(payload["inputs"]) |
|
|
return query(payload) |
|
|
def health_check(self): |
|
|
return health_check() |
|
|
def get_model_info(self): |
|
|
return get_model_info() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Handler ready (Demo Parity + Style Hint + whitespace post-process + fallbacks + debug). Use `EndpointHandler` or `query`.") |
|
|
|
|
|
|
|
|
try: |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
FASTAPI_AVAILABLE = True |
|
|
except Exception as e: |
|
|
FASTAPI_AVAILABLE = False |
|
|
warn(f"fastapi/pydantic not available: {e}") |
|
|
|
|
|
if FASTAPI_AVAILABLE: |
|
|
app = FastAPI(title="PULSE ECG Handler API", version="1.0.0") |
|
|
|
|
|
class QueryIn(BaseModel): |
|
|
message: str | None = None |
|
|
query: str | None = None |
|
|
prompt: str | None = None |
|
|
istem: str | None = None |
|
|
image: str | Dict[str, Any] | None = None |
|
|
image_url: str | None = None |
|
|
img: str | None = None |
|
|
temperature: float | None = None |
|
|
top_p: float | None = None |
|
|
max_output_tokens: int | None = None |
|
|
max_new_tokens: int | None = None |
|
|
max_tokens: int | None = None |
|
|
repetition_penalty: float | None = None |
|
|
conv_mode: str | None = None |
|
|
det_seed: int | None = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def _startup(): |
|
|
global model_initialized |
|
|
if not model_initialized: |
|
|
model_initialized = initialize_model() |
|
|
print(f"[startup] model_initialized={model_initialized}") |
|
|
|
|
|
@app.get("/health") |
|
|
async def _health(): |
|
|
return health_check() |
|
|
|
|
|
@app.get("/info") |
|
|
async def _info(): |
|
|
return get_model_info() |
|
|
|
|
|
@app.get("/debug") |
|
|
async def _debug(): |
|
|
try: |
|
|
dev = str(next(model.parameters()).device) if model else "Unknown" |
|
|
except Exception: |
|
|
dev = "Unknown" |
|
|
|
|
|
try: |
|
|
ip = image_processor |
|
|
ip_cls = ip.__class__.__name__ if ip else None |
|
|
crop_sz = getattr(getattr(ip, "crop_size", None), "height", None) or getattr(ip, "crop_size", None) |
|
|
size_sz = getattr(getattr(ip, "size", None), "height", None) or getattr(ip, "size", None) |
|
|
except Exception: |
|
|
ip_cls, crop_sz, size_sz = None, None, None |
|
|
|
|
|
return { |
|
|
"debug": bool(DEBUG), |
|
|
"llava_available": LLAVA_AVAILABLE, |
|
|
"transformers_available": TRANSFORMERS_AVAILABLE, |
|
|
"device": dev, |
|
|
"context_len": context_len, |
|
|
"image_processor_class": ip_cls, |
|
|
"image_processor_crop_size": crop_sz, |
|
|
"image_processor_size": size_sz, |
|
|
"model_path": args.model_path if args else None, |
|
|
} |
|
|
|
|
|
@app.post("/query") |
|
|
async def _query(payload: QueryIn): |
|
|
return query({k: v for k, v in payload.dict().items() if v is not None}) |
|
|
else: |
|
|
app = None |
|
|
|