Graf-J commited on
Commit
a99a61d
·
verified ·
1 Parent(s): a9d6fd8

Upload Handler for Widget

Browse files
Files changed (1) hide show
  1. handler.py +46 -0
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoModel, AutoProcessor
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ # Load the processor and model from the local path
9
+ # This uses your custom code in the repo via trust_remote_code
10
+ self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
11
+ self.model = AutoModel.from_pretrained(path, trust_remote_code=True)
12
+
13
+ # Move to GPU if available, otherwise CPU
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model.to(self.device)
16
+ self.model.eval()
17
+
18
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
+ """
20
+ Args:
21
+ data (:obj:`Dict[str, Any]`):
22
+ Includes the deserialized image input under the "inputs" key.
23
+ """
24
+ # The Hub's image-to-text widget sends a PIL Image in the "inputs" key
25
+ inputs_data = data.pop("inputs", data)
26
+
27
+ # Ensure it's a PIL Image (handling both URL strings or raw bytes if necessary)
28
+ if not isinstance(inputs_data, Image.Image):
29
+ # If for some reason it's not a PIL image, you'd handle conversion here
30
+ pass
31
+
32
+ # 1. Preprocess the image using your custom processor
33
+ processed_inputs = self.processor(inputs_data)
34
+ pixel_values = processed_inputs["pixel_values"].to(self.device)
35
+
36
+ # 2. Run Inference
37
+ with torch.no_grad():
38
+ outputs = self.model(pixel_values)
39
+ logits = outputs.logits
40
+
41
+ # 3. Decode the prediction using your CTC logic
42
+ prediction = self.processor.batch_decode(logits)[0]
43
+
44
+ # The widget expects a list of dicts for image-to-text
45
+ # 'generated_text' is the standard key for the widget to display the result
46
+ return [{"generated_text": prediction}]