stuckdavis commited on
Commit
35c43d7
·
verified ·
1 Parent(s): cd4e30a

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +62 -0
handler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LongformerTokenizer, LongformerForSequenceClassification
2
+ import torch
3
+ from ts.torch_handler.base_handler import BaseHandler
4
+ from safetensors.torch import load_file
5
+ import os
6
+
7
+ class LongformerRegressionHandler(BaseHandler):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.initialized = False
11
+
12
+ def initialize(self, ctx):
13
+ """Load model and tokenizer"""
14
+ properties = ctx.system_properties
15
+ model_dir = properties.get("model_dir")
16
+
17
+ # Load tokenizer and config
18
+ self.tokenizer = LongformerTokenizer.from_pretrained(model_dir)
19
+ self.model = LongformerForSequenceClassification.from_pretrained(model_dir)
20
+
21
+ # Load safetensors weights
22
+ weights_path = os.path.join(model_dir, "model.safetensors")
23
+ state_dict = load_file(weights_path)
24
+ self.model.load_state_dict(state_dict)
25
+
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.model.to(self.device)
28
+ self.model.eval()
29
+ self.initialized = True
30
+
31
+ def preprocess(self, requests):
32
+ """Convert raw text into model-ready inputs"""
33
+ inputs = []
34
+ for req in requests:
35
+ text = req.get("data") or req.get("body")
36
+ if isinstance(text, (bytes, bytearray)):
37
+ text = text.decode("utf-8")
38
+ tokens = self.tokenizer(
39
+ text,
40
+ padding="max_length",
41
+ truncation=True,
42
+ max_length=512,
43
+ return_tensors="pt"
44
+ )
45
+ tokens = {k: v.to(self.device) for k, v in tokens.items()}
46
+ inputs.append(tokens)
47
+ return inputs
48
+
49
+ def inference(self, inputs):
50
+ """Run forward pass and return clipped regression output"""
51
+ results = []
52
+ with torch.no_grad():
53
+ for inp in inputs:
54
+ output = self.model(**inp)
55
+ score = output.logits.squeeze().item()
56
+ clipped_score = min(max(score, 0.0), 1.0)
57
+ results.append(clipped_score)
58
+ return results
59
+
60
+ def postprocess(self, inference_output):
61
+ """Convert scores to response-friendly format"""
62
+ return [{"score": float(out)} for out in inference_output]