wealthcoders commited on
Commit
5053334
·
verified ·
1 Parent(s): bcf25d2

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +86 -0
handler.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ from typing import Dict, List, Any
3
+ import torch
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ import os
8
+ import tempfile
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, model_dir = 'scb10x/typhoon-ocr1.5-2b'):
12
+ model_path = model_dir
13
+
14
+ self.model = AutoModelForImageTextToText.from_pretrained(model_path, dtype="auto", device_map="auto")
15
+ selfprocessor = AutoProcessor.from_pretrained(model_path)
16
+
17
+ def __call__(self, data: Dict[str, Any]) -> str:
18
+ try:
19
+ base64_string = None
20
+ if "inputs" in data and isinstance(data["inputs"], str):
21
+ base64_string = data["inputs"]
22
+
23
+ # Case 2: Base64 in nested inputs dictionary
24
+ elif "inputs" in data and isinstance(data["inputs"], dict):
25
+ base64_string = data["inputs"].get("base64")
26
+
27
+ # Case 3: Direct base64 at root level
28
+ elif "base64" in data:
29
+ base64_string = data["base64"]
30
+
31
+ # Case 4: Try raw data as base64
32
+ elif isinstance(data, str):
33
+ base64_string = data
34
+
35
+ if not base64_string:
36
+ return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())}
37
+
38
+ print("Found base64 string, length:", len(base64_string))
39
+
40
+ # Remove data URL prefix if present
41
+ if ',' in base64_string:
42
+ base64_string = base64_string.split(',')[1]
43
+
44
+ # Decode base64 to image
45
+ image_data = base64.b64decode(base64_string)
46
+
47
+ messages = [
48
+ {
49
+ "role": "user",
50
+ "content": [
51
+ {
52
+ "type": "image",
53
+ "image": image_data,
54
+ },
55
+ {
56
+ "type": "text",
57
+ "text": prompt
58
+ }
59
+ ],
60
+ }
61
+ ]
62
+ # Preparation for inference
63
+ inputs = self.processor.apply_chat_template(
64
+ messages,
65
+ tokenize=True,
66
+ add_generation_prompt=True,
67
+ return_dict=True,
68
+ return_tensors="pt"
69
+ )
70
+ inputs = inputs.to(self.model.device)
71
+
72
+ # Inference: Generation of the output
73
+ generated_ids = self.model.generate(**inputs, max_new_tokens=10000)
74
+ generated_ids_trimmed = [
75
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
76
+ ]
77
+ output_text = self.processor.batch_decode(
78
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
79
+ )
80
+ print(output_text[0])
81
+
82
+ return output_text[0]
83
+
84
+ except Exception as e:
85
+ print(f"Error processing image: {e}")
86
+ return str(e)