File size: 4,468 Bytes
f8291b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Custom handler for ConTeXT ESCO skill mapping on HF Inference Endpoints.
Implements the ConTeXT-match mechanism (Decorte et al. 2025):
  match(x, s) = Σ_j α_j · cos(x_j, s)   where α_j = softmax(z_xj · z_s)
"""
from typing import Dict, List, Any
import os
import torch
from sentence_transformers import SentenceTransformer


THRESHOLD = 0.48
MAX_SKILLS = 50
INTERNAL_BATCH = 32


class EndpointHandler:
    def __init__(self, path=""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.model = SentenceTransformer(path, device=self.device)
        if self.device == "cuda":
            self.model = self.model.half()

        esco = torch.load(os.path.join(path, "esco_data.pt"), map_location="cpu",
                          weights_only=False)
        self.esco_labels = esco["labels"]
        self.esco_uris = esco["uris"]
        self.skill_embs = esco["embeddings"].to(self.device)
        if self.device == "cuda":
            self.skill_embs = self.skill_embs.half()
        self.skill_norms = self.skill_embs.norm(dim=-1)
        print(f"[handler] ConTeXT ready — {len(self.esco_labels)} ESCO skills, "
              f"device={self.device}, batch={INTERNAL_BATCH}")

    def _encode_tokens(self, sentences):
        features = self.model.tokenize(sentences)
        features = {k: v.to(self.device) for k, v in features.items()}
        with torch.no_grad():
            out = self.model[0](features)
        tok_embs = out["token_embeddings"]
        mask = out["attention_mask"]
        if self.device == "cuda":
            tok_embs = tok_embs.half()
        return tok_embs, mask

    def _context_match(self, tok_embs, mask):
        B, T, d = tok_embs.shape
        dots = torch.einsum("btd,sd->bts", tok_embs, self.skill_embs)
        tok_norms = tok_embs.norm(dim=-1, keepdim=True)
        skill_norms = self.skill_norms.unsqueeze(0).unsqueeze(0)
        cos_sim = dots / (tok_norms * skill_norms + 1e-8)
        neg_inf = torch.finfo(dots.dtype).min
        attn_mask = mask.unsqueeze(-1).bool()
        dots_masked = dots.masked_fill(~attn_mask, neg_inf)
        alpha = torch.softmax(dots_masked, dim=1)
        match_scores = (alpha * cos_sim).sum(dim=1)
        return match_scores, dots

    def _redundancy_filter(self, candidate_indices, dots_row, mask_row):
        valid_len = int(mask_row.sum().item())
        if valid_len <= 2:
            return candidate_indices
        content_start, content_end = 1, valid_len - 1
        if content_start >= content_end:
            return candidate_indices
        content_dots = dots_row[content_start:content_end, :]
        cand_dots = content_dots[:, candidate_indices]
        winners = cand_dots.argmax(dim=1)
        unique_winners = torch.unique(winners)
        return candidate_indices[unique_winners.cpu()]

    def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
        inputs = data.get("inputs", data.get("input", ""))
        if isinstance(inputs, str):
            inputs = [inputs]

        params = data.get("parameters", {})
        threshold = params.get("threshold", THRESHOLD)
        max_skills = params.get("max_skills", MAX_SKILLS)
        do_filter = params.get("redundancy_filter", True)

        all_results = []
        for i in range(0, len(inputs), INTERNAL_BATCH):
            batch = inputs[i : i + INTERNAL_BATCH]
            tok_embs, mask = self._encode_tokens(batch)
            match_scores, dots = self._context_match(tok_embs, mask)

            for j in range(len(batch)):
                scores = match_scores[j]
                above = (scores >= threshold).nonzero(as_tuple=True)[0]

                if len(above) == 0:
                    all_results.append([])
                    continue

                if do_filter and len(above) > 1:
                    above = self._redundancy_filter(above, dots[j], mask[j])

                skill_scores = scores[above]
                order = skill_scores.argsort(descending=True)[:max_skills]
                kept = above[order]
                kept_scores = scores[kept]

                all_results.append([
                    {
                        "uri": self.esco_uris[int(idx)],
                        "label": self.esco_labels[int(idx)],
                        "score": round(float(sc), 4),
                    }
                    for idx, sc in zip(kept.cpu(), kept_scores.cpu())
                ])

        return all_results