marutitecblic commited on
Commit
aa16163
·
verified ·
1 Parent(s): 7eab698

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +78 -0
handler.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoModelForCausalLM, AutoProcessor
4
+ from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
5
+ from transformers.image_transforms import resize, to_channel_dimension_format
6
+ import os
7
+
8
+ # Constants
9
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # HF_TASK = os.getenv('HF_TASK')
12
+
13
+ API_TOKEN = os.getenv['API_TOKEN'] # Ensure you replace this with your actual API token
14
+
15
+ # Load processor and model
16
+ PROCESSOR = AutoProcessor.from_pretrained(
17
+ "marutitecblic/HtmlTocode",
18
+ trust_remote_code=True,
19
+ # token=API_TOKEN,
20
+ )
21
+ MODEL = AutoModelForCausalLM.from_pretrained(
22
+ "marutitecblic/HtmlTocode",
23
+ # token=API_TOKEN,
24
+ trust_remote_code=True,
25
+ torch_dtype=torch.bfloat16,
26
+ ).to(DEVICE)
27
+ image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
28
+ BOS_TOKEN = PROCESSOR.tokenizer.bos_token
29
+ BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
30
+
31
+ def convert_to_rgb(image):
32
+ if image.mode == "RGB":
33
+ return image
34
+ image_rgba = image.convert("RGBA")
35
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
36
+ alpha_composite = Image.alpha_composite(background, image_rgba)
37
+ alpha_composite = alpha_composite.convert("RGB")
38
+ return alpha_composite
39
+
40
+ def custom_transform(x):
41
+ x = convert_to_rgb(x)
42
+ x = to_numpy_array(x)
43
+ x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
44
+ x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
45
+ x = PROCESSOR.image_processor.normalize(
46
+ x,
47
+ mean=PROCESSOR.image_processor.image_mean,
48
+ std=PROCESSOR.image_processor.image_std
49
+ )
50
+ x = to_channel_dimension_format(x, ChannelDimension.FIRST)
51
+ x = torch.tensor(x)
52
+ return x
53
+
54
+ def preprocess(event):
55
+ image = Image.open(event["file"]).convert("RGB")
56
+ inputs = PROCESSOR.tokenizer(
57
+ f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
58
+ return_tensors="pt",
59
+ add_special_tokens=False,
60
+ )
61
+ inputs["pixel_values"] = PROCESSOR.image_processor([image], transform=custom_transform)
62
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
63
+ return inputs
64
+
65
+ def inference(model_inputs):
66
+ inputs = preprocess(model_inputs)
67
+ generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096)
68
+ generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
69
+ return {"generated_text": generated_text}
70
+
71
+ def postprocess(model_outputs):
72
+ return model_outputs
73
+
74
+ def handle(event, context):
75
+ model_inputs = event
76
+ model_outputs = inference(model_inputs)
77
+ response = postprocess(model_outputs)
78
+ return response