soulprint_API / app.py
mjpsm's picture
Update app.py
732258d verified
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
import xgboost as xgb
# -----------------------------
# Helper: Load XGBoost Booster (.json)
# -----------------------------
def load_xgb_model(repo_id: str, filename: str):
path = hf_hub_download(repo_id=repo_id, filename=filename)
booster = xgb.Booster()
booster.load_model(path)
return booster
# -----------------------------
# Load Soulprint models (all JSON now)
# -----------------------------
available_models = {
"Griot": load_xgb_model("mjpsm/Griot-xgb-model", "Griot_xgb_model.json"),
"Kinara": load_xgb_model("mjpsm/Kinara-xgb-model", "Kinara_xgb_model.json"),
"Ubuntu": load_xgb_model("mjpsm/Ubuntu-xgb-model", "Ubuntu_xgb_model.json"),
"Jali": load_xgb_model("mjpsm/Jali-xgb-model", "Jali_xgb_model.json"),
"Kuumba": load_xgb_model("mjpsm/Kuumba-xgb-model", "Kuumba_xgb_model.json"),
"Sankofa": load_xgb_model("mjpsm/Sankofa-xgb-model", "Sankofa_xgb_model.json"),
"Imani": load_xgb_model("mjpsm/Imani-xgb-model", "Imani_xgb_model.json"),
"Maji": load_xgb_model("mjpsm/Maji-xgb-model", "Maji_xgb_model.json"),
"Nzinga": load_xgb_model("mjpsm/Nzinga-xgb-model", "Nzinga_xgb_model.json"),
"Bisa": load_xgb_model("mjpsm/Bisa-xgb-model", "Bisa_xgb_model.json"),
"Zamani": load_xgb_model("mjpsm/Zamani-xgb-model", "Zamani_xgb_model.json"),
"Tamu": load_xgb_model("mjpsm/Tamu-xgb-model", "Tamu_xgb_model.json"),
"Shujaa": load_xgb_model("mjpsm/Shujaa-xgb-model", "Shujaa_xgb_model.json"),
"Ayo": load_xgb_model("mjpsm/Ayo-xgb-model", "Ayo_xgb_model.json"),
"Ujamaa": load_xgb_model("mjpsm/Ujamaa-xgb-model", "Ujamaa_xgb_model.json")
}
# Archetype list (15 total, placeholders for now)
all_archetypes = [
"Griot", "Kinara", "Ubuntu", "Jali", "Sankofa", "Imani", "Maji",
"Nzinga", "Bisa", "Zamani", "Tamu", "Shujaa", "Ayo", "Ujamaa", "Kuumba"
]
# Shared embedder
embedder = SentenceTransformer("all-mpnet-base-v2")
# FastAPI app
app = FastAPI()
class TextInput(BaseModel):
text: str
@app.post("/soulprint_snapshot")
def soulprint_snapshot(input: TextInput):
embedding = embedder.encode([input.text]).reshape(1, -1)
snapshot = {}
for name in all_archetypes:
if name in available_models:
dmatrix = xgb.DMatrix(embedding)
score = available_models[name].predict(dmatrix)[0]
snapshot[name] = float(score)
else:
snapshot[name] = 0.0 # placeholder until model is trained
return {"soulprint_snapshot": snapshot}