Rapid_ECG / app.py
CanerDedeoglu's picture
Rename handler (5).py to app.py
254422d verified
# -*- coding: utf-8 -*-
"""
FastAPI servis giriş noktası (app.py)
- Startup'ta modeli yükler (sıcak bekletir).
- /infer ile tahmin, /health ve /model_info ile kontrol sağlar.
- handler.py dosyası aynı klasörde olmalıdır.
"""
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Optional
from fastapi import FastAPI, Body, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import handler as pulse_handler # AYNI KLASÖR
# ---- Ayarlar
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", "8000"))
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "4"))
# HF model id varsayılanı (senin istediğin)
os.environ.setdefault("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG")
# Tekil EndpointHandler ve thread pool
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
endpoint = None
app = FastAPI(title="Rapid ECG Inference API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=os.getenv("CORS_ALLOW_ORIGINS", "*").split(","),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---- Şemalar
class InferenceRequest(BaseModel):
# HF uyumluluğu: "inputs" veya direkt alanlar
inputs: Optional[Dict[str, Any]] = None
message: Optional[str] = None
image: Optional[Any] = None
image_url: Optional[str] = None
img: Optional[Any] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
conv_mode: Optional[str] = None
det_seed: Optional[int] = None
def _ensure_initialized():
"""Modeli (bir kere) yükle ve EndpointHandler hazırla."""
global endpoint
if pulse_handler.model_initialized and endpoint is not None:
return
ok = pulse_handler.initialize_model()
if not ok:
raise RuntimeError("Model initialization failed")
endpoint = pulse_handler.EndpointHandler(
model_dir=os.getenv("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG")
)
def _merge_payload(req: InferenceRequest) -> Dict[str, Any]:
"""HF 'inputs' ile diğer alanları birleştirir."""
payload = dict(req.inputs or {})
for k in ["message","image","image_url","img",
"temperature","top_p","max_new_tokens",
"repetition_penalty","conv_mode","det_seed"]:
v = getattr(req, k)
if v is not None:
payload[k] = v
return payload
async def _run_inference(payload: Dict[str, Any]) -> Dict[str, Any]:
"""Blocking handler çağrısını thread pool'da çalıştır."""
loop = asyncio.get_running_loop()
def _call():
return endpoint({"inputs": payload})
return await loop.run_in_executor(executor, _call)
# ---- Lifecycle
@app.on_event("startup")
async def on_startup():
_ensure_initialized()
# ---- Routes
@app.get("/health")
async def health():
return pulse_handler.health_check()
@app.get("/model_info")
async def model_info():
_ensure_initialized()
return pulse_handler.get_model_info()
@app.post("/infer")
async def infer(req: InferenceRequest = Body(...)):
_ensure_initialized()
payload = _merge_payload(req)
if not payload.get("message"):
raise HTTPException(400, "Missing 'message'")
if not (payload.get("image") or payload.get("image_url") or payload.get("img")):
raise HTTPException(400, "Missing 'image' / 'image_url' / 'img'")
result = await _run_inference(payload)
if isinstance(result, dict) and result.get("error"):
raise HTTPException(500, result["error"])
return result
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host=HOST, port=PORT, reload=bool(int(os.getenv("RELOAD","0"))))