CanerDedeoglu commited on
Commit
254422d
·
verified ·
1 Parent(s): 1fd6be7

Rename handler (5).py to app.py

Browse files
Files changed (2) hide show
  1. app.py +118 -0
  2. handler (5).py +0 -477
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ FastAPI servis giriş noktası (app.py)
4
+ - Startup'ta modeli yükler (sıcak bekletir).
5
+ - /infer ile tahmin, /health ve /model_info ile kontrol sağlar.
6
+ - handler.py dosyası aynı klasörde olmalıdır.
7
+ """
8
+
9
+ import os
10
+ import asyncio
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Any, Dict, Optional
13
+
14
+ from fastapi import FastAPI, Body, HTTPException
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from pydantic import BaseModel, Field
17
+
18
+ import handler as pulse_handler # AYNI KLASÖR
19
+
20
+ # ---- Ayarlar
21
+ HOST = os.getenv("HOST", "0.0.0.0")
22
+ PORT = int(os.getenv("PORT", "8000"))
23
+ MAX_WORKERS = int(os.getenv("MAX_WORKERS", "4"))
24
+
25
+ # HF model id varsayılanı (senin istediğin)
26
+ os.environ.setdefault("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG")
27
+
28
+ # Tekil EndpointHandler ve thread pool
29
+ executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
30
+ endpoint = None
31
+
32
+ app = FastAPI(title="Rapid ECG Inference API", version="1.0.0")
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=os.getenv("CORS_ALLOW_ORIGINS", "*").split(","),
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ # ---- Şemalar
42
+ class InferenceRequest(BaseModel):
43
+ # HF uyumluluğu: "inputs" veya direkt alanlar
44
+ inputs: Optional[Dict[str, Any]] = None
45
+
46
+ message: Optional[str] = None
47
+ image: Optional[Any] = None
48
+ image_url: Optional[str] = None
49
+ img: Optional[Any] = None
50
+
51
+ temperature: Optional[float] = None
52
+ top_p: Optional[float] = None
53
+ max_new_tokens: Optional[int] = None
54
+ repetition_penalty: Optional[float] = None
55
+ conv_mode: Optional[str] = None
56
+ det_seed: Optional[int] = None
57
+
58
+ def _ensure_initialized():
59
+ """Modeli (bir kere) yükle ve EndpointHandler hazırla."""
60
+ global endpoint
61
+ if pulse_handler.model_initialized and endpoint is not None:
62
+ return
63
+ ok = pulse_handler.initialize_model()
64
+ if not ok:
65
+ raise RuntimeError("Model initialization failed")
66
+ endpoint = pulse_handler.EndpointHandler(
67
+ model_dir=os.getenv("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG")
68
+ )
69
+
70
+ def _merge_payload(req: InferenceRequest) -> Dict[str, Any]:
71
+ """HF 'inputs' ile diğer alanları birleştirir."""
72
+ payload = dict(req.inputs or {})
73
+ for k in ["message","image","image_url","img",
74
+ "temperature","top_p","max_new_tokens",
75
+ "repetition_penalty","conv_mode","det_seed"]:
76
+ v = getattr(req, k)
77
+ if v is not None:
78
+ payload[k] = v
79
+ return payload
80
+
81
+ async def _run_inference(payload: Dict[str, Any]) -> Dict[str, Any]:
82
+ """Blocking handler çağrısını thread pool'da çalıştır."""
83
+ loop = asyncio.get_running_loop()
84
+ def _call():
85
+ return endpoint({"inputs": payload})
86
+ return await loop.run_in_executor(executor, _call)
87
+
88
+ # ---- Lifecycle
89
+ @app.on_event("startup")
90
+ async def on_startup():
91
+ _ensure_initialized()
92
+
93
+ # ---- Routes
94
+ @app.get("/health")
95
+ async def health():
96
+ return pulse_handler.health_check()
97
+
98
+ @app.get("/model_info")
99
+ async def model_info():
100
+ _ensure_initialized()
101
+ return pulse_handler.get_model_info()
102
+
103
+ @app.post("/infer")
104
+ async def infer(req: InferenceRequest = Body(...)):
105
+ _ensure_initialized()
106
+ payload = _merge_payload(req)
107
+ if not payload.get("message"):
108
+ raise HTTPException(400, "Missing 'message'")
109
+ if not (payload.get("image") or payload.get("image_url") or payload.get("img")):
110
+ raise HTTPException(400, "Missing 'image' / 'image_url' / 'img'")
111
+ result = await _run_inference(payload)
112
+ if isinstance(result, dict) and result.get("error"):
113
+ raise HTTPException(500, result["error"])
114
+ return result
115
+
116
+ if __name__ == "__main__":
117
+ import uvicorn
118
+ uvicorn.run("app:app", host=HOST, port=PORT, reload=bool(int(os.getenv("RELOAD","0"))))
handler (5).py DELETED
@@ -1,477 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- PULSE ECG Handler — Demo Parity + Style Hint
4
- - Demo app.py ile aynı üretim ayarları:
5
- do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
- - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
- - Görsel tensörü: .half() ve model cihazında
8
- - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
- - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
- - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) yaklaşmak için
11
- - Post-process: YALNIZCA whitespace/biçim normalizasyonu (yönetim/öneri cümleleri korunur)
12
- """
13
-
14
- import os
15
- import re
16
- import json
17
- import base64
18
- import hashlib
19
- import datetime
20
- from io import BytesIO
21
- from threading import Thread
22
- from typing import Optional, Union
23
-
24
- import torch
25
- from PIL import Image
26
- import requests
27
-
28
- # ====== LLaVA & Transformers ======
29
- try:
30
- from llava.constants import (
31
- IMAGE_TOKEN_INDEX,
32
- DEFAULT_IMAGE_TOKEN,
33
- )
34
- from llava.conversation import conv_templates, SeparatorStyle
35
- from llava.model.builder import load_pretrained_model
36
- from llava.mm_utils import (
37
- tokenizer_image_token,
38
- process_images,
39
- get_model_name_from_path,
40
- )
41
- from llava.utils import disable_torch_init
42
- LLAVA_AVAILABLE = True
43
- except Exception as e:
44
- LLAVA_AVAILABLE = False
45
- print(f"[WARN] LLaVA not available: {e}")
46
-
47
- try:
48
- from transformers import TextIteratorStreamer, StoppingCriteria
49
- TRANSFORMERS_AVAILABLE = True
50
- except Exception as e:
51
- TRANSFORMERS_AVAILABLE = False
52
- print(f"[WARN] transformers not available: {e}")
53
-
54
- # ====== HF Hub logging (opsiyonel) ======
55
- try:
56
- from huggingface_hub import HfApi, login
57
- HF_HUB_AVAILABLE = True
58
- except Exception:
59
- HF_HUB_AVAILABLE = False
60
-
61
- api = None
62
- repo_name = ""
63
- if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
64
- try:
65
- login(token=os.environ["HF_TOKEN"], write_permission=True)
66
- api = HfApi()
67
- repo_name = os.environ.get("LOG_REPO", "")
68
- except Exception as e:
69
- print(f"[HF Hub] init failed: {e}")
70
- api = None
71
- repo_name = ""
72
-
73
- LOGDIR = "./logs"
74
- os.makedirs(LOGDIR, exist_ok=True)
75
-
76
- # ====== Global State ======
77
- tokenizer = None
78
- model = None
79
- image_processor = None
80
- context_len = None
81
- args = None
82
- model_initialized = False
83
-
84
- # ====== Style Hint (demo benzeri üslup) ======
85
- STYLE_HINT = (
86
- "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
87
- "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
88
- "Use neutral, factual cardiology language. Avoid headings and bullet points. "
89
- "Finish with a single final line starting exactly with 'Structured clinical impression:' "
90
- "followed by a succinct, comma-separated summary of the key diagnoses."
91
- )
92
-
93
- # ===================== Utilities =====================
94
-
95
- def _safe_upload(path: str):
96
- if api and repo_name and path and os.path.isfile(path):
97
- try:
98
- api.upload_file(
99
- path_or_fileobj=path,
100
- path_in_repo=path.replace("./logs/", ""),
101
- repo_id=repo_name,
102
- repo_type="dataset",
103
- )
104
- except Exception as e:
105
- print(f"[upload] failed for {path}: {e}")
106
-
107
- def _conv_log_path() -> str:
108
- t = datetime.datetime.now()
109
- return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
110
-
111
- def load_image_any(image_input: Union[str, dict]) -> Image.Image:
112
- """
113
- Desteklenen:
114
- - URL (http/https)
115
- - yerel dosya yolu
116
- - base64 (opsiyonel data URL prefix ile)
117
- - {"image": <base64|dataurl>}
118
- """
119
- if isinstance(image_input, str):
120
- s = image_input.strip()
121
- if s.startswith(("http://", "https://")):
122
- r = requests.get(s, timeout=(5, 20))
123
- r.raise_for_status()
124
- return Image.open(BytesIO(r.content)).convert("RGB")
125
- if os.path.exists(s):
126
- return Image.open(s).convert("RGB")
127
- # base64 (dataurl olabilir)
128
- if s.startswith("data:image"):
129
- s = s.split(",", 1)[1]
130
- raw = base64.b64decode(s)
131
- return Image.open(BytesIO(raw)).convert("RGB")
132
-
133
- if isinstance(image_input, dict) and "image" in image_input:
134
- return load_image_any(image_input["image"])
135
-
136
- raise ValueError("Unsupported image input format")
137
-
138
- def _normalize_whitespace(text: str) -> str:
139
- """
140
- Gereksiz boşluk/boş satırları toparlar:
141
- - Satır başı/sonu boşluklarını siler
142
- - Birden çok boşluğu tek boşluğa indirger
143
- - 3+ boş satırı 1 boş satıra indirger
144
- """
145
- text = text.replace("\r\n", "\n").replace("\r", "\n")
146
- lines = [re.sub(r"[ \t]+", " ", ln.strip()) for ln in text.split("\n")]
147
- text = "\n".join(lines).strip()
148
- text = re.sub(r"\n{3,}", "\n\n", text)
149
- return text
150
-
151
- def _postprocess_min(text: str) -> str:
152
- # Yalnızca whitespace/biçim temizliği
153
- return _normalize_whitespace(text)
154
-
155
- # ====== Güvenli Stop Kriteri (conv separator) ======
156
- class SafeKeywordsStoppingCriteria(StoppingCriteria):
157
- """
158
- conv.sep/sep2 bazlı token eşleşmesi; tensör → bool hatası yok.
159
- """
160
- def __init__(self, keyword: str, tokenizer):
161
- self.tokenizer = tokenizer
162
- tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
163
- self.kw_ids = tok # shape: (n,)
164
-
165
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
166
- if input_ids is None or input_ids.shape[0] == 0:
167
- return False
168
- out = input_ids[0] # assume bsz=1
169
- n = self.kw_ids.shape[0]
170
- if out.shape[0] < n:
171
- return False
172
- tail = out[-n:]
173
- kw = self.kw_ids.to(tail.device)
174
- return torch.equal(tail, kw)
175
-
176
- # ===================== Core Generation =====================
177
-
178
- class InferenceDemo:
179
- def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
180
- if not LLAVA_AVAILABLE:
181
- raise ImportError("LLaVA not available")
182
- disable_torch_init()
183
- self.tokenizer, self.model, self.image_processor, self.context_len = (
184
- tokenizer_, model_, image_processor_, context_len_
185
- )
186
- # Parite için sabit şablon
187
- self.conv_mode = "llava_v1"
188
- self.conversation = conv_templates[self.conv_mode].copy()
189
- self.num_frames = getattr(args, "num_frames", 16)
190
-
191
- class ChatSessionManager:
192
- def __init__(self):
193
- self.chatbot = None
194
- self.args = None
195
- self.model_path = None
196
- def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len):
197
- if self.chatbot is None:
198
- self.args = args
199
- self.model_path = model_path
200
- self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
201
- def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
202
- self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
203
- # Her çağrıda taze template (demo gibi yeni tur)
204
- self.chatbot.conversation = conv_templates[self.chatbot.conv_mode].copy()
205
- return self.chatbot
206
-
207
- chat_manager = ChatSessionManager()
208
-
209
- def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
210
- # DEMO PARİTE: sarım yok, tek görüntü için tek image token
211
- inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
212
- chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
213
- chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
214
- prompt = chatbot.conversation.get_prompt()
215
-
216
- input_ids = tokenizer_image_token(
217
- prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
218
- ).unsqueeze(0).to(device)
219
- return prompt, input_ids
220
-
221
- def generate_response(
222
- message_text: str,
223
- image_input,
224
- *,
225
- temperature: Optional[float] = None,
226
- top_p: Optional[float] = None,
227
- max_new_tokens: Optional[int] = None,
228
- conv_mode_override: Optional[str] = None,
229
- repetition_penalty: Optional[float] = None,
230
- det_seed: Optional[int] = None, # None → stokastik (demo gibi)
231
- ):
232
- if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
233
- return {"error": "Required libraries not available (llava/transformers)"}
234
- if not message_text or image_input is None:
235
- return {"error": "Both 'message' and 'image' are required"}
236
-
237
- # Varsayılanlar → demo
238
- if temperature is None: temperature = 0.05
239
- if top_p is None: top_p = 1.0
240
- if max_new_tokens is None: max_new_tokens = 4096
241
- if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz
242
-
243
- # Chat session
244
- chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
245
- if conv_mode_override and conv_mode_override in conv_templates:
246
- chatbot.conversation = conv_templates[conv_mode_override].copy()
247
-
248
- # Görüntü yükle
249
- try:
250
- pil_img = load_image_any(image_input)
251
- except Exception as e:
252
- return {"error": f"Failed to load image: {e}"}
253
-
254
- # Log için hash+path
255
- img_hash, img_path = "NA", None
256
- try:
257
- buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
258
- img_hash = hashlib.md5(raw).hexdigest()
259
- t = datetime.datetime.now()
260
- img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
261
- os.makedirs(os.path.dirname(img_path), exist_ok=True)
262
- if not os.path.isfile(img_path):
263
- pil_img.save(img_path)
264
- except Exception as e:
265
- print(f"[log] save image failed: {e}")
266
-
267
- # Cihaz/dtype
268
- device = next(chatbot.model.parameters()).device
269
- dtype = torch.float16 # demo: half
270
-
271
- # Görüntü ön-işleme → tensör
272
- try:
273
- processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
274
- if isinstance(processed, (list, tuple)) and len(processed) > 0:
275
- image_tensor = processed[0]
276
- elif isinstance(processed, torch.Tensor):
277
- image_tensor = processed[0] if processed.ndim == 4 else processed
278
- else:
279
- return {"error": "Image processing returned empty"}
280
- if image_tensor.ndim == 3:
281
- image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
282
- image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device
283
- except Exception as e:
284
- return {"error": f"Image processing failed: {e}"}
285
-
286
- # STYLE_HINT ekle ve prompt hazırla
287
- msg = (message_text or "").strip()
288
- msg = f"{msg}\n\n{STYLE_HINT}"
289
- _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
290
-
291
- # Stop string (conv separator) → güvenli kriter
292
- stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
293
- stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
294
-
295
- # Seed (gönderilmediyse stokastik → demo gibi)
296
- if det_seed is not None:
297
- try:
298
- s = int(det_seed)
299
- torch.manual_seed(s)
300
- if torch.cuda.is_available():
301
- torch.cuda.manual_seed(s)
302
- torch.cuda.manual_seed_all(s)
303
- except Exception:
304
- pass
305
-
306
- # Streamer (demo gibi)
307
- streamer = TextIteratorStreamer(
308
- chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
309
- )
310
-
311
- # Generate kwargs — demo ayarları
312
- gen_kwargs = dict(
313
- inputs=input_ids,
314
- images=image_tensor,
315
- streamer=streamer,
316
- do_sample=True, # DEMO
317
- temperature=float(temperature), # DEMO default 0.05
318
- top_p=float(top_p), # DEMO default 1.0
319
- max_new_tokens=int(max_new_tokens), # DEMO slider
320
- repetition_penalty=float(repetition_penalty), # default 1.0 → etkisiz
321
- use_cache=False,
322
- stopping_criteria=[stopping], # DEMO-benzeri durdurma
323
- )
324
-
325
- # Üretim (arka thread) + akışı topla
326
- try:
327
- t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
328
- t.start()
329
- chunks = []
330
- for piece in streamer:
331
- chunks.append(piece)
332
- text = "".join(chunks)
333
- text = _postprocess_min(text) # yalnızca whitespace/format temizliği
334
- chatbot.conversation.messages[-1][-1] = text
335
- except Exception as e:
336
- return {"error": f"Generation failed: {e}"}
337
-
338
- # Log
339
- try:
340
- row = {
341
- "time": datetime.datetime.now().isoformat(),
342
- "type": "chat",
343
- "model": "PULSE-7B",
344
- "state": [(message_text, text)],
345
- "image_hash": img_hash,
346
- "image_path": img_path or "",
347
- }
348
- with open(_conv_log_path(), "a", encoding="utf-8") as f:
349
- f.write(json.dumps(row, ensure_ascii=False) + "\n")
350
- _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
351
- except Exception as e:
352
- print(f"[log] failed: {e}")
353
-
354
- return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
355
-
356
- # ===================== Public API =====================
357
-
358
- def query(payload: dict):
359
- """HF Endpoint entry (demo-like)."""
360
- global model_initialized, tokenizer, model, image_processor, context_len, args
361
- if not model_initialized:
362
- if not initialize_model():
363
- return {"error": "Model initialization failed"}
364
- model_initialized = True
365
-
366
- try:
367
- message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
368
- image = payload.get("image") or payload.get("image_url") or payload.get("img") or None
369
- if not message.strip(): return {"error": "Missing 'message' text"}
370
- if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
371
-
372
- # Demo varsayılanları — payload override edebilir
373
- temperature = float(payload.get("temperature", 0.05))
374
- top_p = float(payload.get("top_p", 1.0))
375
- max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
376
- repetition_penalty = float(payload.get("repetition_penalty", 1.0)) # etkisiz default
377
-
378
- conv_mode_override = payload.get("conv_mode", None)
379
- det_seed = payload.get("det_seed", None)
380
- if det_seed is not None:
381
- try: det_seed = int(det_seed)
382
- except Exception: det_seed = None
383
-
384
- return generate_response(
385
- message_text=message,
386
- image_input=image,
387
- temperature=temperature,
388
- top_p=top_p,
389
- max_new_tokens=max_new_tokens,
390
- conv_mode_override=conv_mode_override,
391
- repetition_penalty=repetition_penalty,
392
- det_seed=det_seed,
393
- )
394
- except Exception as e:
395
- return {"error": f"Query failed: {e}"}
396
-
397
- def health_check():
398
- return {
399
- "status": "healthy",
400
- "model_initialized": model_initialized,
401
- "cuda_available": torch.cuda.is_available(),
402
- "llava_available": LLAVA_AVAILABLE,
403
- "transformers_available": TRANSFORMERS_AVAILABLE,
404
- }
405
-
406
- def get_model_info():
407
- if not model_initialized:
408
- return {"error": "Model not initialized"}
409
- return {
410
- "model_path": args.model_path if args else "Unknown",
411
- "context_len": context_len,
412
- "device": str(next(model.parameters()).device) if model else "Unknown",
413
- }
414
-
415
- # ===================== Init & Session =====================
416
-
417
- class _Args:
418
- def __init__(self):
419
- self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
420
- self.model_base = None
421
- self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
422
- self.conv_mode = "llava_v1" # Parite için sabit
423
- self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
424
- self.num_frames = 16
425
- self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
426
- self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
427
- self.debug = bool(int(os.getenv("DEBUG", "0")))
428
-
429
- def initialize_model():
430
- global tokenizer, model, image_processor, context_len, args
431
- if not LLAVA_AVAILABLE:
432
- print("[init] LLaVA not available; cannot init.")
433
- return False
434
- try:
435
- args = _Args()
436
- model_name = get_model_name_from_path(args.model_path)
437
- tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
438
- args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
439
- )
440
- # demo: model'ı genelde cuda’da çalıştırır
441
- try:
442
- _ = next(model_.parameters()).device
443
- except Exception:
444
- if torch.cuda.is_available():
445
- model_ = model_.to(torch.device("cuda"))
446
- model_.eval()
447
-
448
- globals()["tokenizer"] = tokenizer_
449
- globals()["model"] = model_
450
- globals()["image_processor"] = image_processor_
451
- globals()["context_len"] = context_len_
452
-
453
- chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_)
454
- print("[init] model/tokenizer/image_processor loaded.")
455
- return True
456
- except Exception as e:
457
- print(f"[init] failed: {e}")
458
- return False
459
-
460
- # ===================== HF EndpointHandler =====================
461
-
462
- class EndpointHandler:
463
- """Hugging Face Endpoint uyumlu sınıf"""
464
- def __init__(self, model_dir):
465
- self.model_dir = model_dir
466
- print(f"EndpointHandler initialized with model_dir: {model_dir}")
467
- def __call__(self, payload):
468
- if "inputs" in payload:
469
- return query(payload["inputs"])
470
- return query(payload)
471
- def health_check(self):
472
- return health_check()
473
- def get_model_info(self):
474
- return get_model_info()
475
-
476
- if __name__ == "__main__":
477
- print("Handler ready (Demo Parity + Style Hint + whitespace post-process). Use `EndpointHandler` or `query`.")