LogAI-Engine / root /api.py
NOT-OMEGA's picture
Rename api.py to root/api.py
2aee6fd verified
raw
history blame
11.8 kB
"""
api.py β€” Async FastAPI Inference Service
Endpoints:
POST /classify β€” Single log
POST /classify/batch β€” Batch of logs (up to 512)
GET /health β€” Liveness check
GET /ready β€” Readiness check (model loaded?)
GET /metrics β€” Request counts, throughput, latency stats
Features:
- Async request handling (non-blocking)
- Worker pool via asyncio semaphore (bounded concurrency)
- Structured JSON logs with request_id
- Rate limiting (configurable)
- Request ID tracing
- Batch queue aggregation for small requests
Run:
uvicorn api:app --host 0.0.0.0 --port 8000 --workers 1
Example:
curl -X POST http://localhost:8000/classify \
-H "Content-Type: application/json" \
-d '{"source": "ModernCRM", "log_message": "User User123 logged in."}'
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
import uuid
import statistics
from collections import deque
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator
# ── Logging setup ─────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format='{"time":"%(asctime)s","level":"%(levelname)s","logger":"%(name)s","msg":"%(message)s"}'
)
logger = logging.getLogger("log-classifier-api")
# ── Config ─────────────────────────────────────────────────────────────────────
MAX_BATCH_SIZE = int(os.getenv("MAX_BATCH_SIZE", "512"))
MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT", "4")) # concurrency cap
RATE_LIMIT_PER_MIN = int(os.getenv("RATE_LIMIT_PER_MIN", "1000"))
LOG_MAX_CHARS = 2048 # truncate huge logs before classify
# ── Global state ───────────────────────────────────────────────────────────────
_semaphore: asyncio.Semaphore = None # type: ignore
_model_ready: bool = False
# Metrics ring buffer (last 1000 requests)
_latencies_ms: deque = deque(maxlen=1000)
_request_count = 0
_error_count = 0
_start_time = time.time()
# Rate limiter (simple sliding window per process)
_rate_window: deque = deque(maxlen=RATE_LIMIT_PER_MIN)
# ── Lifespan: load models on startup ──────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global _semaphore, _model_ready
logger.info("Starting up β€” loading models…")
_semaphore = asyncio.Semaphore(MAX_CONCURRENT)
# Load models in a thread pool (blocking I/O, don't block event loop)
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(None, _load_models_blocking)
_model_ready = True
logger.info("βœ… Models loaded β€” API ready")
except Exception as e:
logger.error(f"❌ Model load failed: {e}")
# Service starts but /ready will return 503
yield
logger.info("Shutting down")
def _load_models_blocking():
"""Load BERT + classifier (blocks β€” run in executor)."""
from processor_bert import classify_batch as _
logger.info("BERT model loaded")
# ── App factory ────────────────────────────────────────────────────────────────
app = FastAPI(
title="Log Classification API",
description="3-tier hybrid pipeline: Regex β†’ BERT β†’ LLM",
version="3.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Request / Response schemas ─────────────────────────────────────────────────
class LogRequest(BaseModel):
source: str = Field(..., example="ModernCRM")
log_message: str = Field(..., example="User User123 logged in.", min_length=1)
@field_validator("log_message")
@classmethod
def truncate_long_logs(cls, v: str) -> str:
return v[:LOG_MAX_CHARS]
class LogResponse(BaseModel):
request_id: str
label: str
tier: str
confidence: Optional[float]
latency_ms: float
cached: bool = False
class BatchRequest(BaseModel):
logs: list[LogRequest] = Field(..., max_length=MAX_BATCH_SIZE)
class BatchResponse(BaseModel):
request_id: str
total: int
elapsed_ms: float
throughput: float
results: list[LogResponse]
class HealthResponse(BaseModel):
status: str
uptime_s: float
class MetricsResponse(BaseModel):
total_requests: int
total_errors: int
uptime_s: float
requests_per_min: float
latency_p50_ms: Optional[float]
latency_p95_ms: Optional[float]
latency_p99_ms: Optional[float]
# ── Rate limiter ───────────────────────────────────────────────────────────────
def _check_rate_limit() -> None:
now = time.time()
_rate_window.append(now)
# Window = last 60 seconds
recent = [t for t in _rate_window if now - t < 60]
if len(recent) > RATE_LIMIT_PER_MIN:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded: {RATE_LIMIT_PER_MIN} req/min",
)
# ── Middleware: request logging ────────────────────────────────────────────────
@app.middleware("http")
async def log_requests(request: Request, call_next):
rid = request.headers.get("X-Request-ID", str(uuid.uuid4())[:8])
request.state.request_id = rid
t0 = time.perf_counter()
response = await call_next(request)
elapsed = (time.perf_counter() - t0) * 1000
logger.info(
f"method={request.method} path={request.url.path} "
f"status={response.status_code} latency={elapsed:.1f}ms rid={rid}"
)
response.headers["X-Request-ID"] = rid
return response
# ── Health & readiness ─────────────────────────────────────────────────────────
@app.get("/health", response_model=HealthResponse, tags=["ops"])
async def health():
return {"status": "ok", "uptime_s": round(time.time() - _start_time, 1)}
@app.get("/ready", tags=["ops"])
async def ready():
if not _model_ready:
raise HTTPException(status_code=503, detail="Models not yet loaded")
return {"status": "ready"}
# ── Metrics ────────────────────────────────────────────────────────────────────
@app.get("/metrics", response_model=MetricsResponse, tags=["ops"])
async def metrics():
uptime = time.time() - _start_time
lats = sorted(_latencies_ms) if _latencies_ms else []
n = len(lats)
def pct(p):
return round(lats[min(int(n * p), n - 1)], 2) if n else None
return {
"total_requests": _request_count,
"total_errors": _error_count,
"uptime_s": round(uptime, 1),
"requests_per_min": round(_request_count / max(uptime / 60, 1), 1),
"latency_p50_ms": pct(0.50),
"latency_p95_ms": pct(0.95),
"latency_p99_ms": pct(0.99),
}
# ── Classify single ────────────────────────────────────────────────────────────
@app.post("/classify", response_model=LogResponse, tags=["inference"])
async def classify_single(req: LogRequest, request: Request):
global _request_count, _error_count
_check_rate_limit()
_request_count += 1
rid = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
async with _semaphore:
loop = asyncio.get_event_loop()
t0 = time.perf_counter()
try:
result = await loop.run_in_executor(
None, _classify_blocking, req.source, req.log_message
)
except Exception as e:
_error_count += 1
logger.error(f"rid={rid} classify error: {e}")
raise HTTPException(status_code=500, detail=str(e))
latency = (time.perf_counter() - t0) * 1000
_latencies_ms.append(latency)
return LogResponse(
request_id = rid,
label = result["label"],
tier = result["tier"],
confidence = result.get("confidence"),
latency_ms = round(latency, 2),
)
def _classify_blocking(source: str, log_message: str) -> dict:
from classify import classify_log
return classify_log(source, log_message)
# ── Classify batch ─────────────────────────────────────────────────────────────
@app.post("/classify/batch", response_model=BatchResponse, tags=["inference"])
async def classify_batch_endpoint(req: BatchRequest, request: Request):
global _request_count, _error_count
_check_rate_limit()
_request_count += 1
rid = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
log_pairs = [(r.source, r.log_message) for r in req.logs]
async with _semaphore:
loop = asyncio.get_event_loop()
t0 = time.perf_counter()
try:
results = await loop.run_in_executor(
None, _classify_batch_blocking, log_pairs
)
except Exception as e:
_error_count += 1
logger.error(f"rid={rid} batch error: {e}")
raise HTTPException(status_code=500, detail=str(e))
elapsed_ms = (time.perf_counter() - t0) * 1000
throughput = round(len(log_pairs) / (elapsed_ms / 1000), 1)
_latencies_ms.extend([elapsed_ms / len(log_pairs)] * len(log_pairs))
return BatchResponse(
request_id = rid,
total = len(log_pairs),
elapsed_ms = round(elapsed_ms, 2),
throughput = throughput,
results = [
LogResponse(
request_id = rid,
label = r["label"],
tier = r["tier"],
confidence = r.get("confidence"),
latency_ms = round(elapsed_ms / len(log_pairs), 2),
)
for r in results
],
)
def _classify_batch_blocking(log_pairs: list[tuple[str, str]]) -> list[dict]:
from classify import classify_logs
return classify_logs(log_pairs)
# ── Dev runner ──────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=False, workers=1)