dreyyyy commited on
Commit
4dd3e86
·
verified ·
1 Parent(s): 8c34feb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -20
handler.py CHANGED
@@ -1,31 +1,35 @@
1
- import os
2
  import easyocr
 
 
 
 
 
3
 
4
  class EndpointHandler:
5
- def __init__(self, model_dir):
6
- """
7
- Initialize the handler and load the model.
8
- """
9
- self.model_dir = model_dir
10
- # Load the EasyOCR reader
11
- self.reader = easyocr.Reader(['en'], gpu=False) # Use GPU=True if GPU is available
12
 
13
- def __call__(self, data):
14
  """
15
- Handle the inference request.
16
  Args:
17
- data (dict): Input data for the model.
 
18
  Returns:
19
- dict: Model predictions.
20
  """
21
- # Extract the image path or image content from the data
22
- image_path = data.get("inputs")
23
- if not image_path:
24
  return {"error": "No input image provided"}
25
 
 
 
 
26
  # Perform OCR
27
- try:
28
- results = self.reader.readtext(image_path, detail=0)
29
- return {"predictions": results}
30
- except Exception as e:
31
- return {"error": str(e)}
 
 
1
  import easyocr
2
+ import torch
3
+ from typing import Dict
4
+ from PIL import Image
5
+ import io
6
+ import json
7
 
8
  class EndpointHandler:
9
+ def __init__(self, model_dir: str):
10
+ # Path to your model file
11
+ model_path = f"{model_dir}/easyocr_reader.pkl"
12
+ self.reader = torch.load(model_path) # Load the EasyOCR model
 
 
 
13
 
14
+ def __call__(self, data: Dict):
15
  """
16
+ Perform inference on the input image.
17
  Args:
18
+ data (Dict): Input dictionary with keys:
19
+ - `inputs` containing image bytes or a file.
20
  Returns:
21
+ str: Extracted text.
22
  """
23
+ # Check if the input is an image
24
+ if "inputs" not in data:
 
25
  return {"error": "No input image provided"}
26
 
27
+ image_bytes = data["inputs"]
28
+ image = Image.open(io.BytesIO(image_bytes))
29
+
30
  # Perform OCR
31
+ results = self.reader.readtext(image)
32
+
33
+ # Extract and return the detected text
34
+ extracted_text = " ".join([text for (_, text, _) in results])
35
+ return {"text": extracted_text}