File size: 3,805 Bytes
254422d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-
"""
FastAPI servis giriş noktası (app.py)
- Startup'ta modeli yükler (sıcak bekletir).
- /infer ile tahmin, /health ve /model_info ile kontrol sağlar.
- handler.py dosyası aynı klasörde olmalıdır.
"""

import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Optional

from fastapi import FastAPI, Body, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field

import handler as pulse_handler  # AYNI KLASÖR

# ---- Ayarlar
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", "8000"))
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "4"))

# HF model id varsayılanı (senin istediğin)
os.environ.setdefault("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG")

# Tekil EndpointHandler ve thread pool
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
endpoint = None

app = FastAPI(title="Rapid ECG Inference API", version="1.0.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=os.getenv("CORS_ALLOW_ORIGINS", "*").split(","),
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---- Şemalar
class InferenceRequest(BaseModel):
    # HF uyumluluğu: "inputs" veya direkt alanlar
    inputs: Optional[Dict[str, Any]] = None

    message: Optional[str] = None
    image: Optional[Any] = None
    image_url: Optional[str] = None
    img: Optional[Any] = None

    temperature: Optional[float] = None
    top_p: Optional[float] = None
    max_new_tokens: Optional[int] = None
    repetition_penalty: Optional[float] = None
    conv_mode: Optional[str] = None
    det_seed: Optional[int] = None

def _ensure_initialized():
    """Modeli (bir kere) yükle ve EndpointHandler hazırla."""
    global endpoint
    if pulse_handler.model_initialized and endpoint is not None:
        return
    ok = pulse_handler.initialize_model()
    if not ok:
        raise RuntimeError("Model initialization failed")
    endpoint = pulse_handler.EndpointHandler(
        model_dir=os.getenv("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG")
    )

def _merge_payload(req: InferenceRequest) -> Dict[str, Any]:
    """HF 'inputs' ile diğer alanları birleştirir."""
    payload = dict(req.inputs or {})
    for k in ["message","image","image_url","img",
              "temperature","top_p","max_new_tokens",
              "repetition_penalty","conv_mode","det_seed"]:
        v = getattr(req, k)
        if v is not None:
            payload[k] = v
    return payload

async def _run_inference(payload: Dict[str, Any]) -> Dict[str, Any]:
    """Blocking handler çağrısını thread pool'da çalıştır."""
    loop = asyncio.get_running_loop()
    def _call():
        return endpoint({"inputs": payload})
    return await loop.run_in_executor(executor, _call)

# ---- Lifecycle
@app.on_event("startup")
async def on_startup():
    _ensure_initialized()

# ---- Routes
@app.get("/health")
async def health():
    return pulse_handler.health_check()

@app.get("/model_info")
async def model_info():
    _ensure_initialized()
    return pulse_handler.get_model_info()

@app.post("/infer")
async def infer(req: InferenceRequest = Body(...)):
    _ensure_initialized()
    payload = _merge_payload(req)
    if not payload.get("message"):
        raise HTTPException(400, "Missing 'message'")
    if not (payload.get("image") or payload.get("image_url") or payload.get("img")):
        raise HTTPException(400, "Missing 'image' / 'image_url' / 'img'")
    result = await _run_inference(payload)
    if isinstance(result, dict) and result.get("error"):
        raise HTTPException(500, result["error"])
    return result

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host=HOST, port=PORT, reload=bool(int(os.getenv("RELOAD","0"))))