harikrishna1985 commited on
Commit
fcb06ad
·
verified ·
1 Parent(s): 41b88b9

Upload src/predict.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/predict.py +35 -0
src/predict.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import joblib
3
+ import pandas as pd
4
+ import yaml
5
+
6
+
7
+ def load_config(config_path: str = "config/config.yaml") -> dict:
8
+ with open(config_path, "r", encoding="utf-8") as f:
9
+ return yaml.safe_load(f)
10
+
11
+
12
+ def load_model():
13
+ config = load_config()
14
+ repo_id = config["model"]["repo_id"]
15
+ filename = config["model"]["filename"]
16
+
17
+ model_path = hf_hub_download(
18
+ repo_id=repo_id,
19
+ filename=filename
20
+ )
21
+ model = joblib.load(model_path)
22
+ return model
23
+
24
+
25
+ def predict_input(input_df: pd.DataFrame):
26
+ model = load_model()
27
+ prediction = model.predict(input_df)
28
+
29
+ result = {"prediction": prediction[0]}
30
+
31
+ if hasattr(model, "predict_proba"):
32
+ proba = model.predict_proba(input_df)
33
+ result["probabilities"] = proba[0].tolist()
34
+
35
+ return result