File size: 5,146 Bytes
94337ad
b4fadea
a472415
79b961c
31460c4
 
 
 
 
79b961c
b4fadea
79b961c
31460c4
 
b4fadea
94337ad
31460c4
 
 
 
 
b4fadea
31460c4
 
 
 
 
 
 
27236c5
94337ad
31460c4
 
94337ad
31460c4
 
 
 
 
 
94337ad
31460c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4fadea
31460c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94337ad
31460c4
94337ad
 
 
 
31460c4
 
79b961c
 
a472415
79b961c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# app/main.py
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import asyncio
import os
import pandas as pd
import random
import json
from datetime import datetime

from app.api.routes import router
from app.api.dashboard_data import router as dashboard_data_router
from app.inference.predictor import Predictor
from app.monitoring.drift import run_drift_check
from app.core.logging import init_db

# ---- Constants ----
PROD_LOG_PATH = "data/production/predictions_log.csv"
REFERENCE_PATH = "models/v1/reference_data.csv"
DASHBOARD_JSON = "reports/evidently/drift_report.json"
SOURCE_DATA = "data/processed/current_data.csv"

# ---- Config ----
STARTUP_DELAY = 5
MIN_SLEEP = 2
MAX_SLEEP = 8
MIN_BATCH = 1
MAX_BATCH = 5
MAX_DRIFT_ROWS = 9000
MAX_DISPLAY = 100  # last N predictions for dashboard

predictor = Predictor()
os.makedirs(os.path.dirname(DASHBOARD_JSON), exist_ok=True)

# ---- Traffic daemon in-process (no HTTP call) ----
async def traffic_loop():
    await asyncio.sleep(STARTUP_DELAY)
    if not os.path.exists(SOURCE_DATA):
        print("Traffic daemon: source data not found, disabled.")
        return

    df_source = pd.read_csv(SOURCE_DATA)
    print("Traffic daemon started (in-process).")

    while True:
        try:
            batch_size = random.randint(MIN_BATCH, MAX_BATCH)
            sample = df_source.sample(batch_size)
            # In-process prediction instead of requests.post
            preds, probas = predictor.predict(sample)
            df_log = sample.copy()
            df_log["model_prediction"] = preds
            df_log["model_probability"] = probas
            df_log["model_risk_level"] = [
                "High" if p >= 0.75 else "Medium" if p >= 0.5 else "Low"
                for p in probas
            ]
            df_log["model_version"] = predictor.model_version
            df_log["timestamp"] = pd.Timestamp.utcnow()
            df_log.to_csv(PROD_LOG_PATH, mode="a", header=not os.path.exists(PROD_LOG_PATH), index=False)

        except Exception as e:
            print("Traffic daemon error:", e)

        await asyncio.sleep(random.uniform(MIN_SLEEP, MAX_SLEEP))


# ---- Drift loop ----
async def drift_loop(interval_seconds: int = 10):
    while True:
        try:
            if not os.path.exists(PROD_LOG_PATH):
                await asyncio.sleep(interval_seconds)
                continue

            prod_df = pd.read_csv(PROD_LOG_PATH)
            if len(prod_df) > MAX_DRIFT_ROWS:
                prod_df = prod_df.tail(MAX_DRIFT_ROWS)
                prod_df.to_csv(PROD_LOG_PATH, index=False)

            missing_features = set(predictor.features) - set(prod_df.columns)
            if missing_features:
                await asyncio.sleep(interval_seconds)
                continue

            prod_df = prod_df.dropna(subset=predictor.features)
            if prod_df.empty:
                await asyncio.sleep(interval_seconds)
                continue

            reference_df = pd.read_csv(REFERENCE_PATH)
            _, drift_dict = run_drift_check(
                prod_df[predictor.features],
                reference_df[predictor.features],
                model_version="v1"
            )

            # Prepare last N predictions for dashboard
            results = []
            log_cols = ["model_prediction", "model_probability", "model_risk_level"]
            if all(c in prod_df.columns for c in log_cols):
                for i, row in prod_df.tail(MAX_DISPLAY).iterrows():
                    results.append({
                        "row": i,
                        "prediction": "Default" if row["model_prediction"] == 1 else "No Default",
                        "probability": round(float(row["model_probability"]), 4),
                        "risk_level": row.get("model_risk_level", "Unknown")
                    })

            dashboard_payload = {
                "n_rows": len(prod_df),
                "results": results,
                "drift": [
                    {"column": col, "score": float(score)}
                    for col, score in drift_dict.items()
                ],
            }

            tmp_path = DASHBOARD_JSON + ".tmp"
            with open(tmp_path, "w") as f:
                json.dump(dashboard_payload, f, indent=2)
            os.replace(tmp_path, DASHBOARD_JSON)

        except Exception as e:
            print("Drift loop error:", e)

        await asyncio.sleep(interval_seconds)


# ---- HF-compatible lifespan ----
from contextlib import asynccontextmanager

@asynccontextmanager
async def lifespan(app: FastAPI):
    tasks = [
        asyncio.create_task(traffic_loop()),
        asyncio.create_task(drift_loop(10))
    ]
    yield
    for t in tasks:
        t.cancel()
        try:
            await t
        except asyncio.CancelledError:
            pass


# ---- FastAPI app ----
app = FastAPI(title="ML Inference Service", lifespan=lifespan)
app.mount("/static", StaticFiles(directory="app/static"), name="static")
app.mount("/reports", StaticFiles(directory="reports"), name="reports")
app.include_router(router)
app.include_router(dashboard_data_router)