bie-nhd commited on
Commit
a19b0dc
·
1 Parent(s): 65f125e

add all options to handler

Browse files
Files changed (1) hide show
  1. handler.py +21 -10
handler.py CHANGED
@@ -71,11 +71,7 @@ class EndpointHandler:
71
 
72
  return encoding
73
 
74
- def predict(self, data: Dict[str, Any], preprocessed: Dict[str, Any]) -> Dict[str, Any]:
75
- task = data.get('task', None)
76
-
77
- if task is None or task not in self.task_config.keys():
78
- raise ValueError(f"Invalid task: {task}")
79
 
80
  logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :])
81
  config = self.task_config[task]
@@ -89,19 +85,34 @@ class EndpointHandler:
89
  pred_idx = int(np.argmax(probs))
90
  return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])}
91
 
92
- def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]:
93
- return [outputs]
 
 
94
 
95
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
96
  task = data.get('task', None)
97
  print(f"Task: {task}")
98
  if task is None:
99
  raise ValueError("'task' key is required in the input dictionary")
100
- if task not in self.task_config.keys():
 
 
 
 
 
 
 
 
 
 
 
101
  raise ValueError(f"Invalid task: {task}")
102
  preprocessed = self.preprocess(data)
103
- outputs = self.predict(data, preprocessed)
104
- return self.postprocess(outputs)
 
 
105
 
106
  class TaskClassificationHead(torch.nn.Module):
107
  def __init__(self, hidden_size: int, num_labels: int, dropout: float):
 
71
 
72
  return encoding
73
 
74
+ def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
75
 
76
  logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :])
77
  config = self.task_config[task]
 
85
  pred_idx = int(np.argmax(probs))
86
  return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])}
87
 
88
+ # def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]:
89
+ # return [outputs]
90
+
91
+
92
 
93
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
94
  task = data.get('task', None)
95
  print(f"Task: {task}")
96
  if task is None:
97
  raise ValueError("'task' key is required in the input dictionary")
98
+ task = task.lower()
99
+
100
+ if task == "all":
101
+ results = {}
102
+
103
+ for _t in self.task_config.keys():
104
+ data['task'] = _t
105
+ preprocessed = self.preprocess(data)
106
+ outputs = self.predict(_t, preprocessed)
107
+ results[_t] = outputs
108
+ return results
109
+ elif task not in self.task_config.keys():
110
  raise ValueError(f"Invalid task: {task}")
111
  preprocessed = self.preprocess(data)
112
+
113
+ outputs = self.predict(task, preprocessed)
114
+ # return self.postprocess(outputs)
115
+ return outputs
116
 
117
  class TaskClassificationHead(torch.nn.Module):
118
  def __init__(self, hidden_size: int, num_labels: int, dropout: float):