""" Custom handler for ConTeXT ESCO skill mapping on HF Inference Endpoints. Implements the ConTeXT-match mechanism (Decorte et al. 2025): match(x, s) = Σ_j α_j · cos(x_j, s) where α_j = softmax(z_xj · z_s) """ from typing import Dict, List, Any import os import torch from sentence_transformers import SentenceTransformer THRESHOLD = 0.48 MAX_SKILLS = 50 INTERNAL_BATCH = 32 class EndpointHandler: def __init__(self, path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = SentenceTransformer(path, device=self.device) if self.device == "cuda": self.model = self.model.half() esco = torch.load(os.path.join(path, "esco_data.pt"), map_location="cpu", weights_only=False) self.esco_labels = esco["labels"] self.esco_uris = esco["uris"] self.skill_embs = esco["embeddings"].to(self.device) if self.device == "cuda": self.skill_embs = self.skill_embs.half() self.skill_norms = self.skill_embs.norm(dim=-1) print(f"[handler] ConTeXT ready — {len(self.esco_labels)} ESCO skills, " f"device={self.device}, batch={INTERNAL_BATCH}") def _encode_tokens(self, sentences): features = self.model.tokenize(sentences) features = {k: v.to(self.device) for k, v in features.items()} with torch.no_grad(): out = self.model[0](features) tok_embs = out["token_embeddings"] mask = out["attention_mask"] if self.device == "cuda": tok_embs = tok_embs.half() return tok_embs, mask def _context_match(self, tok_embs, mask): B, T, d = tok_embs.shape dots = torch.einsum("btd,sd->bts", tok_embs, self.skill_embs) tok_norms = tok_embs.norm(dim=-1, keepdim=True) skill_norms = self.skill_norms.unsqueeze(0).unsqueeze(0) cos_sim = dots / (tok_norms * skill_norms + 1e-8) neg_inf = torch.finfo(dots.dtype).min attn_mask = mask.unsqueeze(-1).bool() dots_masked = dots.masked_fill(~attn_mask, neg_inf) alpha = torch.softmax(dots_masked, dim=1) match_scores = (alpha * cos_sim).sum(dim=1) return match_scores, dots def _redundancy_filter(self, candidate_indices, dots_row, mask_row): valid_len = int(mask_row.sum().item()) if valid_len <= 2: return candidate_indices content_start, content_end = 1, valid_len - 1 if content_start >= content_end: return candidate_indices content_dots = dots_row[content_start:content_end, :] cand_dots = content_dots[:, candidate_indices] winners = cand_dots.argmax(dim=1) unique_winners = torch.unique(winners) return candidate_indices[unique_winners.cpu()] def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, Any]]]: inputs = data.get("inputs", data.get("input", "")) if isinstance(inputs, str): inputs = [inputs] params = data.get("parameters", {}) threshold = params.get("threshold", THRESHOLD) max_skills = params.get("max_skills", MAX_SKILLS) do_filter = params.get("redundancy_filter", True) all_results = [] for i in range(0, len(inputs), INTERNAL_BATCH): batch = inputs[i : i + INTERNAL_BATCH] tok_embs, mask = self._encode_tokens(batch) match_scores, dots = self._context_match(tok_embs, mask) for j in range(len(batch)): scores = match_scores[j] above = (scores >= threshold).nonzero(as_tuple=True)[0] if len(above) == 0: all_results.append([]) continue if do_filter and len(above) > 1: above = self._redundancy_filter(above, dots[j], mask[j]) skill_scores = scores[above] order = skill_scores.argsort(descending=True)[:max_skills] kept = above[order] kept_scores = scores[kept] all_results.append([ { "uri": self.esco_uris[int(idx)], "label": self.esco_labels[int(idx)], "score": round(float(sc), 4), } for idx, sc in zip(kept.cpu(), kept_scores.cpu()) ]) return all_results