dejanseo commited on
Commit
6f1c63c
·
verified ·
1 Parent(s): 6e375a5

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +61 -0
handler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import os
5
+ import json
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path: str = ""):
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
11
+ self.tokenizer.add_special_tokens({
12
+ "additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"]
13
+ })
14
+ self.model = AutoModel.from_pretrained(path).to(self.device)
15
+
16
+ head_path = os.path.join(path, "classifier_head.json")
17
+ with open(head_path, "r") as f:
18
+ head = json.load(f)
19
+
20
+ self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device)
21
+ self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device)
22
+ self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device)
23
+
24
+ self.model.eval()
25
+
26
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
+ """
28
+ Expected input format:
29
+ {
30
+ "query": "how to sharpen kitchen knives",
31
+ "candidates": [
32
+ {"label": "Tool-Specific", "description": "..."},
33
+ {"label": "Local Intent", "description": "..."}
34
+ ]
35
+ }
36
+ """
37
+ query = data["query"]
38
+ candidates = data["candidates"]
39
+ results = []
40
+
41
+ with torch.no_grad():
42
+ for entry in candidates:
43
+ text = f"[QUERY] {query} [LABEL_NAME] {entry['label']} [LABEL_DESCRIPTION] {entry['description']}"
44
+ tokens = self.tokenizer(
45
+ text,
46
+ return_tensors="pt",
47
+ padding="max_length",
48
+ truncation=True,
49
+ max_length=64
50
+ ).to(self.device)
51
+
52
+ out = self.model(**tokens)
53
+ cls = out.last_hidden_state[:, 0, :]
54
+ score = torch.sigmoid(self.classifier(cls)).item()
55
+ results.append({
56
+ "label": entry["label"],
57
+ "description": entry["description"],
58
+ "score": round(score, 4)
59
+ })
60
+
61
+ return sorted(results, key=lambda x: x["score"], reverse=True)