update handler class name
Browse files- handler.py +2 -2
handler.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import Dict, List, Any
|
|
| 2 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
| 3 |
|
| 4 |
|
| 5 |
-
class
|
| 6 |
def __init__(self, path=""):
|
| 7 |
self.processor = TrOCRProcessor.from_pretrained(path)
|
| 8 |
self.model = VisionEncoderDecoderModel.from_pretrained(path)
|
|
@@ -18,4 +18,4 @@ class PreTrainedPipeline():
|
|
| 18 |
|
| 19 |
# decode output
|
| 20 |
prediction = generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 21 |
-
return {"text":prediction[0]}
|
|
|
|
| 2 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
| 3 |
|
| 4 |
|
| 5 |
+
class EndpointHandler():
|
| 6 |
def __init__(self, path=""):
|
| 7 |
self.processor = TrOCRProcessor.from_pretrained(path)
|
| 8 |
self.model = VisionEncoderDecoderModel.from_pretrained(path)
|
|
|
|
| 18 |
|
| 19 |
# decode output
|
| 20 |
prediction = generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 21 |
+
return {"text":prediction[0]}
|