ace-1 commited on
Commit
e890fdc
·
verified ·
1 Parent(s): 86c8190

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -15
handler.py CHANGED
@@ -1,20 +1,18 @@
1
  from typing import Dict, List, Any
2
  from transformers import pipeline
3
 
4
- # Load the model and tokenizer at startup
5
- classifier = pipeline("text-classification")
 
 
6
 
7
- def preprocess(inputs: Dict[str, Any]) -> List[str]:
8
- # Hugging Face Inference API sends {"inputs": ...}
9
- if isinstance(inputs["inputs"], list):
10
- return inputs["inputs"]
11
- return [inputs["inputs"]]
12
 
13
- def inference(inputs: Dict[str, Any]) -> List[Dict[str, Any]]:
14
- texts = preprocess(inputs)
15
- results = classifier(texts)
16
- return results
17
-
18
- # The entrypoint for the Inference API
19
- def handle(inputs: Dict[str, Any], context: Dict[str, Any] = None) -> Any:
20
- return inference(inputs)
 
1
  from typing import Dict, List, Any
2
  from transformers import pipeline
3
 
4
+ class EndpointHandler:
5
+ def __init__(self, path=""):
6
+ # Load the model and tokenizer at startup
7
+ self.classifier = pipeline("text-classification", model=path if path else None)
8
 
9
+ def preprocess(self, inputs: Dict[str, Any]) -> List[str]:
10
+ # Hugging Face Inference API sends {"inputs": ...}
11
+ if isinstance(inputs["inputs"], list):
12
+ return inputs["inputs"]
13
+ return [inputs["inputs"]]
14
 
15
+ def __call__(self, inputs: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ texts = self.preprocess(inputs)
17
+ results = self.classifier(texts)
18
+ return results