File size: 1,345 Bytes
0f166dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import pandas as pd
import joblib
from huggingface_hub import hf_hub_download
from ..utils.config import AppConfig
from ..utils.tracing import Tracer

class PredictTool:
    def __init__(self, cfg: AppConfig, tracer: Tracer):
        self.cfg = cfg
        self.tracer = tracer
        self._model = None
        self._feature_meta = None

    def _ensure_loaded(self):
        if self._model is None:
            path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
            self._model = joblib.load(path)
            meta = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="feature_metadata.json", token=os.getenv("HF_TOKEN"))
            import json
            with open(meta, "r") as f:
                self._feature_meta = json.load(f)

    def run(self, df: pd.DataFrame) -> pd.DataFrame:
        self._ensure_loaded()
        use_cols = self._feature_meta.get("feature_order", list(df.columns))
        X = df[use_cols].copy()
        preds = self._model.predict_proba(X)[:, 1] if hasattr(self._model, "predict_proba") else self._model.predict(X)
        out = df.copy()
        out[self._feature_meta.get("prediction_column", "prediction")] = preds
        self.tracer.trace_event("predict", {"rows": len(out)})
        return out