ConTeXT-ESCO-Matcher / handler.py
mpalinski's picture
Upload folder using huggingface_hub
f8291b5 verified
"""
Custom handler for ConTeXT ESCO skill mapping on HF Inference Endpoints.
Implements the ConTeXT-match mechanism (Decorte et al. 2025):
match(x, s) = Σ_j α_j · cos(x_j, s) where α_j = softmax(z_xj · z_s)
"""
from typing import Dict, List, Any
import os
import torch
from sentence_transformers import SentenceTransformer
THRESHOLD = 0.48
MAX_SKILLS = 50
INTERNAL_BATCH = 32
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer(path, device=self.device)
if self.device == "cuda":
self.model = self.model.half()
esco = torch.load(os.path.join(path, "esco_data.pt"), map_location="cpu",
weights_only=False)
self.esco_labels = esco["labels"]
self.esco_uris = esco["uris"]
self.skill_embs = esco["embeddings"].to(self.device)
if self.device == "cuda":
self.skill_embs = self.skill_embs.half()
self.skill_norms = self.skill_embs.norm(dim=-1)
print(f"[handler] ConTeXT ready — {len(self.esco_labels)} ESCO skills, "
f"device={self.device}, batch={INTERNAL_BATCH}")
def _encode_tokens(self, sentences):
features = self.model.tokenize(sentences)
features = {k: v.to(self.device) for k, v in features.items()}
with torch.no_grad():
out = self.model[0](features)
tok_embs = out["token_embeddings"]
mask = out["attention_mask"]
if self.device == "cuda":
tok_embs = tok_embs.half()
return tok_embs, mask
def _context_match(self, tok_embs, mask):
B, T, d = tok_embs.shape
dots = torch.einsum("btd,sd->bts", tok_embs, self.skill_embs)
tok_norms = tok_embs.norm(dim=-1, keepdim=True)
skill_norms = self.skill_norms.unsqueeze(0).unsqueeze(0)
cos_sim = dots / (tok_norms * skill_norms + 1e-8)
neg_inf = torch.finfo(dots.dtype).min
attn_mask = mask.unsqueeze(-1).bool()
dots_masked = dots.masked_fill(~attn_mask, neg_inf)
alpha = torch.softmax(dots_masked, dim=1)
match_scores = (alpha * cos_sim).sum(dim=1)
return match_scores, dots
def _redundancy_filter(self, candidate_indices, dots_row, mask_row):
valid_len = int(mask_row.sum().item())
if valid_len <= 2:
return candidate_indices
content_start, content_end = 1, valid_len - 1
if content_start >= content_end:
return candidate_indices
content_dots = dots_row[content_start:content_end, :]
cand_dots = content_dots[:, candidate_indices]
winners = cand_dots.argmax(dim=1)
unique_winners = torch.unique(winners)
return candidate_indices[unique_winners.cpu()]
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
inputs = data.get("inputs", data.get("input", ""))
if isinstance(inputs, str):
inputs = [inputs]
params = data.get("parameters", {})
threshold = params.get("threshold", THRESHOLD)
max_skills = params.get("max_skills", MAX_SKILLS)
do_filter = params.get("redundancy_filter", True)
all_results = []
for i in range(0, len(inputs), INTERNAL_BATCH):
batch = inputs[i : i + INTERNAL_BATCH]
tok_embs, mask = self._encode_tokens(batch)
match_scores, dots = self._context_match(tok_embs, mask)
for j in range(len(batch)):
scores = match_scores[j]
above = (scores >= threshold).nonzero(as_tuple=True)[0]
if len(above) == 0:
all_results.append([])
continue
if do_filter and len(above) > 1:
above = self._redundancy_filter(above, dots[j], mask[j])
skill_scores = scores[above]
order = skill_scores.argsort(descending=True)[:max_skills]
kept = above[order]
kept_scores = scores[kept]
all_results.append([
{
"uri": self.esco_uris[int(idx)],
"label": self.esco_labels[int(idx)],
"score": round(float(sc), 4),
}
for idx, sc in zip(kept.cpu(), kept_scores.cpu())
])
return all_results