ALM_LLM / tools /predict_tool.py
AshenH's picture
Update tools/predict_tool.py
91c65e4 verified
raw
history blame
1.31 kB
import os
import pandas as pd
import joblib
from huggingface_hub import hf_hub_download
from utils.config import AppConfig
from utils.tracing import Tracer
class PredictTool:
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self._model = None
self._feature_meta = None
def _ensure_loaded(self):
if self._model is None:
path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
self._model = joblib.load(path)
meta = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="feature_metadata.json", token=os.getenv("HF_TOKEN"))
import json
with open(meta, "r") as f:
self._feature_meta = json.load(f)
def run(self, df: pd.DataFrame) -> pd.DataFrame:
self._ensure_loaded()
use_cols = self._feature_meta.get("feature_order", list(df.columns))
X = df[use_cols].copy()
preds = self._model.predict_proba(X)[:, 1] if hasattr(self._model, "predict_proba") else self._model.predict(X)
out = df.copy()
out[self._feature_meta.get("prediction_column", "prediction")] = preds
self.tracer.trace_event("predict", {"rows": len(out)})
return out