| """ |
| Custom handler for ConTeXT ESCO skill mapping on HF Inference Endpoints. |
| Implements the ConTeXT-match mechanism (Decorte et al. 2025): |
| match(x, s) = Σ_j α_j · cos(x_j, s) where α_j = softmax(z_xj · z_s) |
| """ |
| from typing import Dict, List, Any |
| import os |
| import torch |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
| THRESHOLD = 0.48 |
| MAX_SKILLS = 50 |
| INTERNAL_BATCH = 32 |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| self.model = SentenceTransformer(path, device=self.device) |
| if self.device == "cuda": |
| self.model = self.model.half() |
|
|
| esco = torch.load(os.path.join(path, "esco_data.pt"), map_location="cpu", |
| weights_only=False) |
| self.esco_labels = esco["labels"] |
| self.esco_uris = esco["uris"] |
| self.skill_embs = esco["embeddings"].to(self.device) |
| if self.device == "cuda": |
| self.skill_embs = self.skill_embs.half() |
| self.skill_norms = self.skill_embs.norm(dim=-1) |
| print(f"[handler] ConTeXT ready — {len(self.esco_labels)} ESCO skills, " |
| f"device={self.device}, batch={INTERNAL_BATCH}") |
|
|
| def _encode_tokens(self, sentences): |
| features = self.model.tokenize(sentences) |
| features = {k: v.to(self.device) for k, v in features.items()} |
| with torch.no_grad(): |
| out = self.model[0](features) |
| tok_embs = out["token_embeddings"] |
| mask = out["attention_mask"] |
| if self.device == "cuda": |
| tok_embs = tok_embs.half() |
| return tok_embs, mask |
|
|
| def _context_match(self, tok_embs, mask): |
| B, T, d = tok_embs.shape |
| dots = torch.einsum("btd,sd->bts", tok_embs, self.skill_embs) |
| tok_norms = tok_embs.norm(dim=-1, keepdim=True) |
| skill_norms = self.skill_norms.unsqueeze(0).unsqueeze(0) |
| cos_sim = dots / (tok_norms * skill_norms + 1e-8) |
| neg_inf = torch.finfo(dots.dtype).min |
| attn_mask = mask.unsqueeze(-1).bool() |
| dots_masked = dots.masked_fill(~attn_mask, neg_inf) |
| alpha = torch.softmax(dots_masked, dim=1) |
| match_scores = (alpha * cos_sim).sum(dim=1) |
| return match_scores, dots |
|
|
| def _redundancy_filter(self, candidate_indices, dots_row, mask_row): |
| valid_len = int(mask_row.sum().item()) |
| if valid_len <= 2: |
| return candidate_indices |
| content_start, content_end = 1, valid_len - 1 |
| if content_start >= content_end: |
| return candidate_indices |
| content_dots = dots_row[content_start:content_end, :] |
| cand_dots = content_dots[:, candidate_indices] |
| winners = cand_dots.argmax(dim=1) |
| unique_winners = torch.unique(winners) |
| return candidate_indices[unique_winners.cpu()] |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, Any]]]: |
| inputs = data.get("inputs", data.get("input", "")) |
| if isinstance(inputs, str): |
| inputs = [inputs] |
|
|
| params = data.get("parameters", {}) |
| threshold = params.get("threshold", THRESHOLD) |
| max_skills = params.get("max_skills", MAX_SKILLS) |
| do_filter = params.get("redundancy_filter", True) |
|
|
| all_results = [] |
| for i in range(0, len(inputs), INTERNAL_BATCH): |
| batch = inputs[i : i + INTERNAL_BATCH] |
| tok_embs, mask = self._encode_tokens(batch) |
| match_scores, dots = self._context_match(tok_embs, mask) |
|
|
| for j in range(len(batch)): |
| scores = match_scores[j] |
| above = (scores >= threshold).nonzero(as_tuple=True)[0] |
|
|
| if len(above) == 0: |
| all_results.append([]) |
| continue |
|
|
| if do_filter and len(above) > 1: |
| above = self._redundancy_filter(above, dots[j], mask[j]) |
|
|
| skill_scores = scores[above] |
| order = skill_scores.argsort(descending=True)[:max_skills] |
| kept = above[order] |
| kept_scores = scores[kept] |
|
|
| all_results.append([ |
| { |
| "uri": self.esco_uris[int(idx)], |
| "label": self.esco_labels[int(idx)], |
| "score": round(float(sc), 4), |
| } |
| for idx, sc in zip(kept.cpu(), kept_scores.cpu()) |
| ]) |
|
|
| return all_results |
|
|