# -*- coding: utf-8 -*- """ FastAPI servis giriş noktası (app.py) - Startup'ta modeli yükler (sıcak bekletir). - /infer ile tahmin, /health ve /model_info ile kontrol sağlar. - handler.py dosyası aynı klasörde olmalıdır. """ import os import asyncio from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, Optional from fastapi import FastAPI, Body, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import handler as pulse_handler # AYNI KLASÖR # ---- Ayarlar HOST = os.getenv("HOST", "0.0.0.0") PORT = int(os.getenv("PORT", "8000")) MAX_WORKERS = int(os.getenv("MAX_WORKERS", "4")) # HF model id varsayılanı (senin istediğin) os.environ.setdefault("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG") # Tekil EndpointHandler ve thread pool executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) endpoint = None app = FastAPI(title="Rapid ECG Inference API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=os.getenv("CORS_ALLOW_ORIGINS", "*").split(","), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ---- Şemalar class InferenceRequest(BaseModel): # HF uyumluluğu: "inputs" veya direkt alanlar inputs: Optional[Dict[str, Any]] = None message: Optional[str] = None image: Optional[Any] = None image_url: Optional[str] = None img: Optional[Any] = None temperature: Optional[float] = None top_p: Optional[float] = None max_new_tokens: Optional[int] = None repetition_penalty: Optional[float] = None conv_mode: Optional[str] = None det_seed: Optional[int] = None def _ensure_initialized(): """Modeli (bir kere) yükle ve EndpointHandler hazırla.""" global endpoint if pulse_handler.model_initialized and endpoint is not None: return ok = pulse_handler.initialize_model() if not ok: raise RuntimeError("Model initialization failed") endpoint = pulse_handler.EndpointHandler( model_dir=os.getenv("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG") ) def _merge_payload(req: InferenceRequest) -> Dict[str, Any]: """HF 'inputs' ile diğer alanları birleştirir.""" payload = dict(req.inputs or {}) for k in ["message","image","image_url","img", "temperature","top_p","max_new_tokens", "repetition_penalty","conv_mode","det_seed"]: v = getattr(req, k) if v is not None: payload[k] = v return payload async def _run_inference(payload: Dict[str, Any]) -> Dict[str, Any]: """Blocking handler çağrısını thread pool'da çalıştır.""" loop = asyncio.get_running_loop() def _call(): return endpoint({"inputs": payload}) return await loop.run_in_executor(executor, _call) # ---- Lifecycle @app.on_event("startup") async def on_startup(): _ensure_initialized() # ---- Routes @app.get("/health") async def health(): return pulse_handler.health_check() @app.get("/model_info") async def model_info(): _ensure_initialized() return pulse_handler.get_model_info() @app.post("/infer") async def infer(req: InferenceRequest = Body(...)): _ensure_initialized() payload = _merge_payload(req) if not payload.get("message"): raise HTTPException(400, "Missing 'message'") if not (payload.get("image") or payload.get("image_url") or payload.get("img")): raise HTTPException(400, "Missing 'image' / 'image_url' / 'img'") result = await _run_inference(payload) if isinstance(result, dict) and result.get("error"): raise HTTPException(500, result["error"]) return result if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host=HOST, port=PORT, reload=bool(int(os.getenv("RELOAD","0"))))