import cProfile import pstats import time import psutil from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from sqlalchemy.orm import Session from src.config.db import SessionLocal from src.models.profiling import ProfilingLog class ProfilingMiddleware(BaseHTTPMiddleware): def __init__(self, app, enabled: bool = False): super().__init__(app) self.enabled = enabled self.process = psutil.Process() async def dispatch(self, request: Request, call_next): if not self.enabled: return await call_next(request) cpu_before = self.process.cpu_percent() mem_before = self.process.memory_info().rss / 1024 / 1024 profiler = cProfile.Profile() profiler.enable() start_time = time.perf_counter() response = await call_next(request) total_time = (time.perf_counter() - start_time) * 1000 profiler.disable() cpu_after = self.process.cpu_percent() mem_after = self.process.memory_info().rss / 1024 / 1024 stats = pstats.Stats(profiler) top_functions = self._extract_top_functions(stats, limit=10) timings = self._extract_specific_timings(stats) ncalls_total, ncalls_pandas, ncalls_db = self._count_calls_by_category(stats) self._save_to_database( endpoint=request.url.path, method=request.method, total_time_ms=total_time, top_functions=top_functions, timings=timings, ncalls_total=ncalls_total, ncalls_pandas=ncalls_pandas, ncalls_database=ncalls_db, cpu_percent=(cpu_after - cpu_before), memory_mb=(mem_after - mem_before), ) return response def _extract_top_functions(self, stats: pstats.Stats, limit: int = 10) -> list: stats.sort_stats(pstats.SortKey.CUMULATIVE) top_funcs = [] for func, data in list(stats.stats.items())[:limit]: cc, nc, tt, ct, callers = data filename, line, func_name = func top_funcs.append({ "name": func_name, "file": filename.split("/")[-1], "line": line, "time_ms": ct * 1000, "calls": nc, }) return top_funcs def _extract_specific_timings(self, stats: pstats.Stats) -> dict: timings = { "preprocessing": 0.0, "inference": 0.0, "database": 0.0, "serialization": 0.0, } for func, data in stats.stats.items(): cc, nc, tt, ct, callers = data filename, line, func_name = func time_ms = ct * 1000 func_name_lower = func_name.lower() file_name_lower = filename.lower() # Preprocessing if "compute_features" in func_name_lower or "features.py" in file_name_lower: timings["preprocessing"] += time_ms # Inference elif "predict_proba" in func_name_lower: timings["inference"] += time_ms # Database elif "psycopg" in file_name_lower or "sqlalchemy" in file_name_lower: if any(kw in func_name_lower for kw in ["wait", "execute", "flush", "commit"]): timings["database"] += time_ms # Serialization elif any(kw in func_name_lower for kw in ["json", "dumps", "serialize"]): timings["serialization"] += time_ms return timings def _count_calls_by_category(self, stats: pstats.Stats) -> tuple[int, int, int]: ncalls_total = 0 ncalls_pandas = 0 ncalls_db = 0 for func, data in stats.stats.items(): cc, nc, tt, ct, callers = data filename, line, func_name = func ncalls_total += nc if "pandas" in filename: ncalls_pandas += nc elif "sqlalchemy" in filename or "psycopg" in filename: ncalls_db += nc return ncalls_total, ncalls_pandas, ncalls_db def _save_to_database( self, endpoint: str, method: str, total_time_ms: float, top_functions: list, timings: dict, ncalls_total: int, ncalls_pandas: int, ncalls_database: int, cpu_percent: float, memory_mb: float, ): db: Session = SessionLocal() try: time_preprocessing = timings.get("preprocessing") or None time_inference = timings.get("inference") or None time_database = timings.get("database") or None time_serialization = timings.get("serialization") or None log = ProfilingLog( endpoint=endpoint, method=method, total_time_ms=total_time_ms, time_preprocessing_ms=time_preprocessing, time_inference_ms=time_inference, time_database_ms=time_database, time_serialization_ms=time_serialization, top_functions=top_functions, ncalls_total=ncalls_total, ncalls_pandas=ncalls_pandas, ncalls_database=ncalls_database, cpu_percent=cpu_percent, memory_mb=memory_mb, ) db.add(log) db.commit() except Exception: import traceback traceback.print_exc() db.rollback() finally: db.close()