naics_embeddings / api /predict.py
Joseph Warth
updated to flat embedding
a6067aa
import torch
from load_artifacts import load_artifacts
def predict_one(company_description):
artifacts = load_artifacts()
embedder = artifacts["embedder"]
model = artifacts["model"]
label_maps = artifacts["label_maps"]
device = artifacts["device"]
emb = embedder.encode(
[company_description],
convert_to_tensor=False,
normalize_embeddings=True,
)
x = torch.tensor(emb, dtype=torch.float32).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)
probs_np = probs[0].cpu().numpy()
pred_idx = int(probs.argmax(dim=1)[0].cpu().item())
pred_prob_y6 = float(torch.max(probs, dim=1).values[0].cpu().item())
pred_y6 = label_maps["y6"]["to_value"][pred_idx]
pred_y2 = pred_y6[:2]
pred_y3 = pred_y6[:3]
pred_y4 = pred_y6[:4]
pred_y5 = pred_y6[:5]
y6_title_lookup = artifacts["y6_title_lookup"]
pred_y6_title = y6_title_lookup.get(pred_y6, "")
top_idx = probs_np.argsort()[::-1]
pred_top5_y6 = []
for i in top_idx:
prob = float(probs_np[i])
if prob < 1e-6:
continue
code = label_maps["y6"]["to_value"][int(i)]
pred_top5_y6.append({
"code": code,
"title": y6_title_lookup.get(code, ""),
"prob": prob,
})
if len(pred_top5_y6) == 5:
break
return {
"pred_y2": pred_y2,
"pred_y3": pred_y3,
"pred_y4": pred_y4,
"pred_y5": pred_y5,
"pred_y6": pred_y6,
"pred_y6_title": pred_y6_title,
"pred_prob_y6": pred_prob_y6,
"pred_top5_y6": pred_top5_y6,
}