stuckdavis commited on
Commit
d67623d
·
verified ·
1 Parent(s): 1ef7ea2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -40
handler.py CHANGED
@@ -1,47 +1,45 @@
1
- from typing import Dict, List, Any
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
4
 
5
- class EndpointHandler():
6
  def __init__(self, path=""):
7
- # Load the model and tokenizer
8
- self.model = AutoModelForSequenceClassification.from_pretrained(
9
- path if path else ".",
10
- num_labels=1, # Regression task
11
- problem_type="regression"
12
- )
13
- self.tokenizer = AutoTokenizer.from_pretrained(
14
- path if path else ".",
15
- use_fast=False # Use the slow tokenizer
16
- )
17
  self.model.eval()
 
 
18
 
19
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
  """
21
- data args:
22
- inputs (:obj: `str`): The longform text input to analyze
23
- Return:
24
- A :obj:`dict`: containing the regression prediction
25
  """
26
- # Get the input text
27
- inputs = data.pop("inputs", data)
28
-
29
- # Tokenize the input
30
- tokenized = self.tokenizer(
31
- inputs,
32
- padding=True,
33
- truncation=True,
34
- max_length=4096, # Longformer's max length
35
- return_tensors="pt"
36
- )
37
-
38
- # Get predictions
39
- with torch.no_grad():
40
- outputs = self.model(**tokenized)
41
- prediction = outputs.logits.item() # Single regression value
42
-
43
- return {
44
- "prediction": prediction,
45
- "confidence": 1.0, # Not applicable for regression
46
- "raw_scores": [prediction] # Just the regression score
47
- }
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
2
  import torch
3
 
4
+ class EndpointHandler:
5
  def __init__(self, path=""):
6
+ # Load model and tokenizer from the repo path
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForSequenceClassification.from_pretrained(path)
 
 
 
 
 
 
 
9
  self.model.eval()
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ self.model.to(self.device)
12
 
13
+ def __call__(self, data):
14
  """
15
+ This method is called when the endpoint receives a request.
16
+ Expected input: { "inputs": "some string" } or { "inputs": ["a", "b", ...] }
 
 
17
  """
18
+ inputs = data.get("inputs", None)
19
+
20
+ if inputs is None:
21
+ return {"error": "No input provided"}
22
+
23
+ if isinstance(inputs, str):
24
+ inputs = [inputs]
25
+
26
+ results = []
27
+ for text in inputs:
28
+ encoded = self.tokenizer(
29
+ text,
30
+ return_tensors="pt",
31
+ truncation=True,
32
+ padding="max_length",
33
+ max_length=4096,
34
+ )
35
+ encoded = {k: v.to(self.device) for k, v in encoded.items()}
36
+
37
+ with torch.no_grad():
38
+ outputs = self.model(**encoded)
39
+
40
+ raw_score = outputs.logits.squeeze().item()
41
+ clipped_score = min(max(raw_score, 0.0), 1.0)
42
+
43
+ results.append({"score": round(clipped_score, 4)})
44
+
45
+ return results