deplot / handler.py
convexray's picture
Update handler.py
bbddc92 verified
raw
history blame
2.91 kB
from typing import Dict, List, Any
from PIL import Image
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
import torch
import base64
import io
class EndpointHandler:
def __init__(self, path: str = ""):
"""Called when the endpoint starts. Load model and processor."""
self.processor = Pix2StructProcessor.from_pretrained(path)
self.model = Pix2StructForConditionalGeneration.from_pretrained(path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
# Default prompt for DePlot
self.default_header = "Generate underlying data table of the figure below:"
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Called on every request.
Args:
data: Dictionary containing:
- inputs: base64 encoded image string
- parameters (optional): dict with:
- header: text prompt for the model (default: DePlot prompt)
- max_new_tokens: max generation length (default: 512)
Returns:
List containing the generated table text
"""
inputs = data.get("inputs")
parameters = data.get("parameters", {})
# Get header text - check multiple possible keys
header_text = (
parameters.get("header") or
parameters.get("text") or
parameters.get("prompt") or
data.get("header") or
data.get("text") or
data.get("prompt") or
self.default_header
)
# Decode base64 image
if isinstance(inputs, str):
try:
image_bytes = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
raise ValueError(f"Failed to decode base64 image: {e}")
else:
raise ValueError("Expected base64 encoded image string in 'inputs'")
# Process image WITH header text (required for Pix2Struct!)
model_inputs = self.processor(
images=image,
text=header_text, # <-- THIS WAS MISSING
return_tensors="pt"
).to(self.device)
# Get generation parameters
max_new_tokens = parameters.get("max_new_tokens", 512)
# Generate
with torch.no_grad():
predictions = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens
)
# Decode
output_text = self.processor.decode(
predictions[0],
skip_special_tokens=True
)
return [{"generated_text": output_text}]