convexray commited on
Commit
bbddc92
·
verified ·
1 Parent(s): 088dc17

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -6
handler.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import base64
6
  import io
7
 
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
  """Called when the endpoint starts. Load model and processor."""
@@ -14,7 +15,10 @@ class EndpointHandler:
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  self.model.to(self.device)
16
  self.model.eval()
17
-
 
 
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
  """
20
  Called on every request.
@@ -22,24 +26,41 @@ class EndpointHandler:
22
  Args:
23
  data: Dictionary containing:
24
  - inputs: base64 encoded image string
25
- - parameters (optional): generation params like max_new_tokens
26
-
 
 
27
  Returns:
28
  List containing the generated table text
29
  """
30
  inputs = data.get("inputs")
31
  parameters = data.get("parameters", {})
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Decode base64 image
34
  if isinstance(inputs, str):
35
- image_bytes = base64.b64decode(inputs)
36
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
 
37
  else:
38
  raise ValueError("Expected base64 encoded image string in 'inputs'")
39
 
40
- # Process image
41
  model_inputs = self.processor(
42
  images=image,
 
43
  return_tensors="pt"
44
  ).to(self.device)
45
 
 
5
  import base64
6
  import io
7
 
8
+
9
  class EndpointHandler:
10
  def __init__(self, path: str = ""):
11
  """Called when the endpoint starts. Load model and processor."""
 
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  self.model.to(self.device)
17
  self.model.eval()
18
+
19
+ # Default prompt for DePlot
20
+ self.default_header = "Generate underlying data table of the figure below:"
21
+
22
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
23
  """
24
  Called on every request.
 
26
  Args:
27
  data: Dictionary containing:
28
  - inputs: base64 encoded image string
29
+ - parameters (optional): dict with:
30
+ - header: text prompt for the model (default: DePlot prompt)
31
+ - max_new_tokens: max generation length (default: 512)
32
+
33
  Returns:
34
  List containing the generated table text
35
  """
36
  inputs = data.get("inputs")
37
  parameters = data.get("parameters", {})
38
 
39
+ # Get header text - check multiple possible keys
40
+ header_text = (
41
+ parameters.get("header") or
42
+ parameters.get("text") or
43
+ parameters.get("prompt") or
44
+ data.get("header") or
45
+ data.get("text") or
46
+ data.get("prompt") or
47
+ self.default_header
48
+ )
49
+
50
  # Decode base64 image
51
  if isinstance(inputs, str):
52
+ try:
53
+ image_bytes = base64.b64decode(inputs)
54
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
55
+ except Exception as e:
56
+ raise ValueError(f"Failed to decode base64 image: {e}")
57
  else:
58
  raise ValueError("Expected base64 encoded image string in 'inputs'")
59
 
60
+ # Process image WITH header text (required for Pix2Struct!)
61
  model_inputs = self.processor(
62
  images=image,
63
+ text=header_text, # <-- THIS WAS MISSING
64
  return_tensors="pt"
65
  ).to(self.device)
66