PyTorch
bart
jsoars commited on
Commit
e8d0baa
·
verified ·
1 Parent(s): 45a698d

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +56 -0
handler.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path: str = ""):
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
9
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
10
+
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ self.model.to(self.device)
13
+ self.model.eval()
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
16
+ text = data.get("inputs")
17
+ if text is None:
18
+ return {"error": "Missing required field: inputs"}
19
+
20
+ parameters = data.get("parameters", {})
21
+
22
+ encoded = self.tokenizer(
23
+ text,
24
+ return_tensors="pt",
25
+ truncation=True,
26
+ max_length=int(parameters.get("max_input_length", 1024)),
27
+ )
28
+ encoded = {k: v.to(self.device) for k, v in encoded.items()}
29
+
30
+ with torch.inference_mode():
31
+ output_ids = self.model.generate(
32
+ **encoded,
33
+ max_new_tokens=int(parameters.get("max_new_tokens", 48)),
34
+ num_beams=int(parameters.get("num_beams", 4)),
35
+ do_sample=bool(parameters.get("do_sample", False)),
36
+ temperature=float(parameters.get("temperature", 1.0)),
37
+ no_repeat_ngram_size=int(parameters.get("no_repeat_ngram_size", 3)),
38
+ early_stopping=True,
39
+ )
40
+
41
+ raw_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
42
+
43
+ keywords = [x.strip() for x in raw_text.split(";") if x.strip()]
44
+
45
+ seen = set()
46
+ deduped: List[str] = []
47
+ for kw in keywords:
48
+ k = kw.lower()
49
+ if k not in seen:
50
+ seen.add(k)
51
+ deduped.append(kw)
52
+
53
+ return {
54
+ "generated_text": raw_text,
55
+ "keywords": deduped,
56
+ }