|
|
|
|
|
""" |
|
|
FastAPI servis giriş noktası (app.py) |
|
|
- Startup'ta modeli yükler (sıcak bekletir). |
|
|
- /infer ile tahmin, /health ve /model_info ile kontrol sağlar. |
|
|
- handler.py dosyası aynı klasörde olmalıdır. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import asyncio |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from fastapi import FastAPI, Body, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
import handler as pulse_handler |
|
|
|
|
|
|
|
|
HOST = os.getenv("HOST", "0.0.0.0") |
|
|
PORT = int(os.getenv("PORT", "8000")) |
|
|
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "4")) |
|
|
|
|
|
|
|
|
os.environ.setdefault("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG") |
|
|
|
|
|
|
|
|
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) |
|
|
endpoint = None |
|
|
|
|
|
app = FastAPI(title="Rapid ECG Inference API", version="1.0.0") |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=os.getenv("CORS_ALLOW_ORIGINS", "*").split(","), |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class InferenceRequest(BaseModel): |
|
|
|
|
|
inputs: Optional[Dict[str, Any]] = None |
|
|
|
|
|
message: Optional[str] = None |
|
|
image: Optional[Any] = None |
|
|
image_url: Optional[str] = None |
|
|
img: Optional[Any] = None |
|
|
|
|
|
temperature: Optional[float] = None |
|
|
top_p: Optional[float] = None |
|
|
max_new_tokens: Optional[int] = None |
|
|
repetition_penalty: Optional[float] = None |
|
|
conv_mode: Optional[str] = None |
|
|
det_seed: Optional[int] = None |
|
|
|
|
|
def _ensure_initialized(): |
|
|
"""Modeli (bir kere) yükle ve EndpointHandler hazırla.""" |
|
|
global endpoint |
|
|
if pulse_handler.model_initialized and endpoint is not None: |
|
|
return |
|
|
ok = pulse_handler.initialize_model() |
|
|
if not ok: |
|
|
raise RuntimeError("Model initialization failed") |
|
|
endpoint = pulse_handler.EndpointHandler( |
|
|
model_dir=os.getenv("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG") |
|
|
) |
|
|
|
|
|
def _merge_payload(req: InferenceRequest) -> Dict[str, Any]: |
|
|
"""HF 'inputs' ile diğer alanları birleştirir.""" |
|
|
payload = dict(req.inputs or {}) |
|
|
for k in ["message","image","image_url","img", |
|
|
"temperature","top_p","max_new_tokens", |
|
|
"repetition_penalty","conv_mode","det_seed"]: |
|
|
v = getattr(req, k) |
|
|
if v is not None: |
|
|
payload[k] = v |
|
|
return payload |
|
|
|
|
|
async def _run_inference(payload: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Blocking handler çağrısını thread pool'da çalıştır.""" |
|
|
loop = asyncio.get_running_loop() |
|
|
def _call(): |
|
|
return endpoint({"inputs": payload}) |
|
|
return await loop.run_in_executor(executor, _call) |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def on_startup(): |
|
|
_ensure_initialized() |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return pulse_handler.health_check() |
|
|
|
|
|
@app.get("/model_info") |
|
|
async def model_info(): |
|
|
_ensure_initialized() |
|
|
return pulse_handler.get_model_info() |
|
|
|
|
|
@app.post("/infer") |
|
|
async def infer(req: InferenceRequest = Body(...)): |
|
|
_ensure_initialized() |
|
|
payload = _merge_payload(req) |
|
|
if not payload.get("message"): |
|
|
raise HTTPException(400, "Missing 'message'") |
|
|
if not (payload.get("image") or payload.get("image_url") or payload.get("img")): |
|
|
raise HTTPException(400, "Missing 'image' / 'image_url' / 'img'") |
|
|
result = await _run_inference(payload) |
|
|
if isinstance(result, dict) and result.get("error"): |
|
|
raise HTTPException(500, result["error"]) |
|
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run("app:app", host=HOST, port=PORT, reload=bool(int(os.getenv("RELOAD","0")))) |
|
|
|