| | from typing import Dict, List, Any |
| | import torch |
| | from PIL import Image |
| | from transformers import AutoModel, AutoProcessor |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | |
| | |
| | self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) |
| | self.model = AutoModel.from_pretrained(path, trust_remote_code=True) |
| | |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Args: |
| | data (:obj:`Dict[str, Any]`): |
| | Includes the deserialized image input under the "inputs" key. |
| | """ |
| | |
| | inputs_data = data.pop("inputs", data) |
| | |
| | |
| | if not isinstance(inputs_data, Image.Image): |
| | |
| | pass |
| |
|
| | |
| | processed_inputs = self.processor(inputs_data) |
| | pixel_values = processed_inputs["pixel_values"].to(self.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.model(pixel_values) |
| | logits = outputs.logits |
| |
|
| | |
| | prediction = self.processor.batch_decode(logits)[0] |
| |
|
| | |
| | |
| | return [{"generated_text": prediction}] |