convexray commited on
Commit
088dc17
·
verified ·
1 Parent(s): 11bd521

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +62 -0
  2. requirements.txt +4 -0
handler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
4
+ import torch
5
+ import base64
6
+ import io
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path: str = ""):
10
+ """Called when the endpoint starts. Load model and processor."""
11
+ self.processor = Pix2StructProcessor.from_pretrained(path)
12
+ self.model = Pix2StructForConditionalGeneration.from_pretrained(path)
13
+
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model.to(self.device)
16
+ self.model.eval()
17
+
18
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
+ """
20
+ Called on every request.
21
+
22
+ Args:
23
+ data: Dictionary containing:
24
+ - inputs: base64 encoded image string
25
+ - parameters (optional): generation params like max_new_tokens
26
+
27
+ Returns:
28
+ List containing the generated table text
29
+ """
30
+ inputs = data.get("inputs")
31
+ parameters = data.get("parameters", {})
32
+
33
+ # Decode base64 image
34
+ if isinstance(inputs, str):
35
+ image_bytes = base64.b64decode(inputs)
36
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
37
+ else:
38
+ raise ValueError("Expected base64 encoded image string in 'inputs'")
39
+
40
+ # Process image
41
+ model_inputs = self.processor(
42
+ images=image,
43
+ return_tensors="pt"
44
+ ).to(self.device)
45
+
46
+ # Get generation parameters
47
+ max_new_tokens = parameters.get("max_new_tokens", 512)
48
+
49
+ # Generate
50
+ with torch.no_grad():
51
+ predictions = self.model.generate(
52
+ **model_inputs,
53
+ max_new_tokens=max_new_tokens
54
+ )
55
+
56
+ # Decode
57
+ output_text = self.processor.decode(
58
+ predictions[0],
59
+ skip_special_tokens=True
60
+ )
61
+
62
+ return [{"generated_text": output_text}]
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers>=4.30.0
2
+ torch
3
+ Pillow
4
+ sentencepiece