slapmack commited on
Commit
20efac6
·
1 Parent(s): d6978ab

attempted batching

Browse files
Files changed (2) hide show
  1. handler.py +27 -13
  2. handler_test.py +14 -5
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Any
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
  from torch.cuda.amp import autocast
@@ -13,7 +13,9 @@ class EndpointHandler:
13
  self.tokenizer = AutoTokenizer.from_pretrained(path)
14
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device).half()
15
 
16
- def process_chunks(self, chunks: list, titles: list, dates: list) -> list:
 
 
17
  """
18
  Process multiple text chunks with the model.
19
 
@@ -59,7 +61,7 @@ class EndpointHandler:
59
 
60
  return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
61
 
62
- def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
63
  """
64
  Handle the inference request.
65
 
@@ -67,15 +69,27 @@ class EndpointHandler:
67
  data (dict): The payload with text inputs.
68
 
69
  Returns:
70
- dict: The processed outputs containing the generated text.
71
  """
72
- inputs = data.pop("inputs", {})
73
- missing_keys = [key for key in ["chunk", "title", "date"] if key not in inputs]
74
- if missing_keys:
75
- raise ValueError(
76
- f"The inputs dictionary is missing required keys: {', '.join(missing_keys)}."
77
- )
78
 
79
- chunk, title, date = inputs["chunk"], inputs["title"], inputs["date"]
80
- prediction = self.process_chunks([chunk], [title], [date])[0]
81
- return {"generated_text": prediction}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
  from torch.cuda.amp import autocast
 
13
  self.tokenizer = AutoTokenizer.from_pretrained(path)
14
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device).half()
15
 
16
+ def process_chunks(
17
+ self, chunks: List[str], titles: List[str], dates: List[str]
18
+ ) -> List[str]:
19
  """
20
  Process multiple text chunks with the model.
21
 
 
61
 
62
  return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
63
 
64
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
65
  """
66
  Handle the inference request.
67
 
 
69
  data (dict): The payload with text inputs.
70
 
71
  Returns:
72
+ dict: The processed outputs containing the generated text for each input along with their IDs.
73
  """
74
+ inputs = data.get("inputs", [])
 
 
 
 
 
75
 
76
+ # Ensure inputs is a list of dictionaries
77
+ if not isinstance(inputs, list) or not all(isinstance(i, dict) for i in inputs):
78
+ raise ValueError("The inputs must be a list of dictionaries.")
79
+
80
+ chunks, titles, dates, ids = [], [], [], []
81
+ for item in inputs:
82
+ for key in ["id", "chunk", "title", "date"]:
83
+ if key not in item:
84
+ raise ValueError(f"Each input must contain the key: {key}.")
85
+ ids.append(item["id"])
86
+ chunks.append(item["chunk"])
87
+ titles.append(item["title"])
88
+ dates.append(item["date"])
89
+
90
+ predictions = self.process_chunks(chunks, titles, dates)
91
+ result = [
92
+ {"id": id_, "generated_text": prediction}
93
+ for id_, prediction in zip(ids, predictions)
94
+ ]
95
+ return {"results": result}
handler_test.py CHANGED
@@ -5,11 +5,20 @@ my_handler = EndpointHandler(path=".")
5
 
6
  # Example payload
7
  data = {
8
- "inputs": {
9
- "chunk": "Prior to restoration work performed between 1990 and 2001, Leaning Tower of Pisa leaned at an angle of 5.5 degrees, but the tower now leans at about 3.99 degrees. This means the top of the tower is displaced horizontally 3.9 meters (12 ft 10 in) from the center.",
10
- "title": "Leaning Tower of Pisa",
11
- "date": "2025-01-15 12:22:44",
12
- },
 
 
 
 
 
 
 
 
 
13
  }
14
 
15
  # Call the handler and print the output
 
5
 
6
  # Example payload
7
  data = {
8
+ "inputs": [
9
+ {
10
+ "id": "1",
11
+ "chunk": "Prior to restoration work performed between 1990 and 2001, Leaning Tower of Pisa leaned at an angle of 5.5 degrees, but the tower now leans at about 3.99 degrees. This means the top of the tower is displaced horizontally 3.9 meters (12 ft 10 in) from the center.",
12
+ "title": "Leaning Tower of Pisa",
13
+ "date": "2023-01-01",
14
+ },
15
+ {
16
+ "id": "2",
17
+ "chunk": "Prior to restoration work performed between 1990 and 2001, Leaning Tower of Pisa leaned at an angle of 5.5 degrees, but the tower now leans at about 3.99 degrees. This means the top of the tower is displaced horizontally 3.9 meters (12 ft 10 in) from the center.",
18
+ "title": "Leaning Tower of Pisa",
19
+ "date": "2023-01-02",
20
+ },
21
+ ]
22
  }
23
 
24
  # Call the handler and print the output