mjpsm commited on
Commit
7bafbe7
·
verified ·
1 Parent(s): 2d770d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -22
app.py CHANGED
@@ -1,32 +1,31 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- import joblib
4
  from sentence_transformers import SentenceTransformer
5
  from huggingface_hub import hf_hub_download
 
6
 
7
  # -----------------------------
8
- # Download models from Hugging Face Hub
9
- # (replace repo_ids with your actual repos)
10
  # -----------------------------
11
- griot_path = hf_hub_download(repo_id="mjpsm/Griot-xgb-model", filename="griot_xgb_regression_model_updated_parameters.pkl")
12
- kinara_path = hf_hub_download(repo_id="mjpsm/Kinara-xgb-model", filename="Kinara_xgb_model.pkl")
13
- ubuntu_path = hf_hub_download(repo_id="mjpsm/Ubuntu-xgb-model", filename="Ubuntu_xgb_model.pkl")
14
- jali_path = hf_hub_download(repo_id="mjpsm/Jali-xgb-model", filename="Jali_xgb_model.pkl")
15
- kuumba_path = hf_hub_download(repo_id="mjpsm/Kuumba-xgb-model", filename="Kuumba_xgb_model.pkl")
16
-
17
 
 
 
 
18
  available_models = {
19
- "Griot": joblib.load(griot_path),
20
- "Kinara": joblib.load(kinara_path),
21
- "Ubuntu": joblib.load(ubuntu_path),
22
- "Jali": joblib.load(jali_path),
23
- "Kuumba": joblib.load(kuumba_path)
24
-
25
  }
26
 
27
-
28
-
29
- # Archetype list (15 total, 12 fillers for now)
30
  all_archetypes = [
31
  "Griot", "Kinara", "Ubuntu", "Jali", "Sankofa", "Imani", "Maji",
32
  "Nzinga", "Bisa", "Zamani", "Tamu", "Shujaa", "Ayo", "Ujamaa", "Kuumba"
@@ -43,17 +42,17 @@ class TextInput(BaseModel):
43
 
44
  @app.post("/soulprint_snapshot")
45
  def soulprint_snapshot(input: TextInput):
46
- # Convert text into embedding
47
  embedding = embedder.encode([input.text]).reshape(1, -1)
48
 
49
- # Build snapshot
50
  snapshot = {}
51
  for name in all_archetypes:
52
  if name in available_models:
53
- score = available_models[name].predict(embedding)[0]
 
54
  snapshot[name] = float(score)
55
  else:
56
- snapshot[name] = 0.0 # filler until model is ready
57
 
58
  return {"soulprint_snapshot": snapshot}
59
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  from sentence_transformers import SentenceTransformer
4
  from huggingface_hub import hf_hub_download
5
+ import xgboost as xgb
6
 
7
  # -----------------------------
8
+ # Helper: Load XGBoost Booster (.json)
 
9
  # -----------------------------
10
+ def load_xgb_model(repo_id: str, filename: str):
11
+ path = hf_hub_download(repo_id=repo_id, filename=filename)
12
+ booster = xgb.Booster()
13
+ booster.load_model(path)
14
+ return booster
 
15
 
16
+ # -----------------------------
17
+ # Load Soulprint models (all JSON now)
18
+ # -----------------------------
19
  available_models = {
20
+ "Griot": load_xgb_model("mjpsm/Griot-xgb-model", "Griot_xgb_model.json"),
21
+ "Kinara": load_xgb_model("mjpsm/Kinara-xgb-model", "Kinara_xgb_model.json"),
22
+ "Ubuntu": load_xgb_model("mjpsm/Ubuntu-xgb-model", "Ubuntu_xgb_model.json"),
23
+ "Jali": load_xgb_model("mjpsm/Jali-xgb-model", "Jali_xgb_model.json"),
24
+ "Kuumba": load_xgb_model("mjpsm/Kuumba-xgb-model", "Kuumba_xgb_model.json"),
25
+ "Sankofa": load_xgb_model("mjpsm/Sankofa-xgb-model", "Sankofa_xgb_model.json"),
26
  }
27
 
28
+ # Archetype list (15 total, placeholders for now)
 
 
29
  all_archetypes = [
30
  "Griot", "Kinara", "Ubuntu", "Jali", "Sankofa", "Imani", "Maji",
31
  "Nzinga", "Bisa", "Zamani", "Tamu", "Shujaa", "Ayo", "Ujamaa", "Kuumba"
 
42
 
43
  @app.post("/soulprint_snapshot")
44
  def soulprint_snapshot(input: TextInput):
 
45
  embedding = embedder.encode([input.text]).reshape(1, -1)
46
 
 
47
  snapshot = {}
48
  for name in all_archetypes:
49
  if name in available_models:
50
+ dmatrix = xgb.DMatrix(embedding)
51
+ score = available_models[name].predict(dmatrix)[0]
52
  snapshot[name] = float(score)
53
  else:
54
+ snapshot[name] = 0.0 # placeholder until model is trained
55
 
56
  return {"soulprint_snapshot": snapshot}
57
 
58
+