yamildiego commited on
Commit
808dc28
·
1 Parent(s): cef6574

BASE endpoint

Browse files
Files changed (2) hide show
  1. handler.py +29 -228
  2. requirements.txt +2 -0
handler.py CHANGED
@@ -1,228 +1,29 @@
1
- # import cv2
2
- # import torch
3
- # import random
4
- # import numpy as np
5
-
6
- # from PIL import Image
7
- # from pathlib import Path
8
-
9
- # from huggingface_hub import hf_hub_download, snapshot_download
10
- # from ip_adapter.ip_adapter import IPAdapterXL
11
- # from safetensors.torch import load_file
12
- # import os
13
-
14
- # from diffusers import (
15
- # ControlNetModel,
16
- # StableDiffusionXLControlNetPipeline,
17
- # UNet2DConditionModel,
18
- # EulerDiscreteScheduler,
19
- # )
20
-
21
- # # global variable
22
- # MAX_SEED = np.iinfo(np.int32).max
23
- # device = "cuda" if torch.cuda.is_available() else "cpu"
24
- # dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
25
-
26
- # # initialization
27
- # base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
28
- # # image_encoder_path = "sdxl_models/image_encoder"
29
- # # ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
30
- # controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
31
-
32
-
33
-
34
- class EndpointHandler:
35
- def __init__(self, model_dir):
36
-
37
- # repo_id = "h94/IP-Adapter"
38
-
39
- # # Descargar todo el contenido del directorio image_encoder
40
- # local_repo_path = snapshot_download(repo_id=repo_id)
41
- # # image_encoder_local_path = os.path.join(local_repo_path, "image_encoder")
42
- # self.image_encoder_local_path = os.path.join(local_repo_path, "sdxl_models", "image_encoder")
43
- # self.ip_ckpt = os.path.join(local_repo_path, "sdxl_models", "ip-adapter_sdxl.bin")
44
-
45
-
46
- # self.controlnet = ControlNetModel.from_pretrained(
47
- # controlnet_path, use_safetensors=False, torch_dtype=torch.float16
48
- # ).to(device)
49
-
50
- # # load SDXL lightnining
51
-
52
- # self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
53
- # base_model_path,
54
- # controlnet=self.controlnet,
55
- # torch_dtype=torch.float16,
56
- # variant="fp16",
57
- # add_watermarker=False,
58
- # ).to(device)
59
- # self.pipe.set_progress_bar_config(disable=True)
60
- # self.pipe.scheduler = EulerDiscreteScheduler.from_config(
61
- # self.pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
62
- # )
63
- # self.pipe.unet.load_state_dict(
64
- # load_file(
65
- # hf_hub_download(
66
- # "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
67
- # ),
68
- # device="cuda",
69
- # )
70
- # )
71
-
72
- # self.ip_model = IPAdapterXL(
73
- # self.pipe,
74
- # self.image_encoder_local_path,
75
- # self.ip_ckpt,
76
- # device,
77
- # target_blocks=["up_blocks.0.attentions.1"],
78
- # )
79
- print("Model loaded successfully")
80
-
81
- def __call__(self, data):
82
-
83
- # def create_image(
84
- # image_pil,
85
- # input_image,
86
- # prompt,
87
- # n_prompt,
88
- # scale,
89
- # control_scale,
90
- # guidance_scale,
91
- # num_inference_steps,
92
- # seed,
93
- # target="Load only style blocks",
94
- # neg_content_prompt=None,
95
- # neg_content_scale=0,
96
- # ):
97
- # seed = random.randint(0, MAX_SEED) if seed == -1 else seed
98
- # # if target == "Load original IP-Adapter":
99
- # # # target_blocks=["blocks"] for original IP-Adapter
100
- # # ip_model = IPAdapterXL(
101
- # # self.pipe, self.image_encoder_local_path, self.ip_ckpt, device, target_blocks=["blocks"]
102
- # # )
103
- # # elif target == "Load only style blocks":
104
- # # # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
105
- # # ip_model = IPAdapterXL(
106
- # # self.pipe,
107
- # # self.image_encoder_local_path,
108
- # # self.ip_ckpt,
109
- # # device,
110
- # # target_blocks=["up_blocks.0.attentions.1"],
111
- # # )
112
- # # elif target == "Load style+layout block":
113
- # # # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
114
- # # ip_model = IPAdapterXL(
115
- # # self.pipe,
116
- # # self.image_encoder_local_path,
117
- # # self.ip_ckpt,
118
- # # device,
119
- # # target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"],
120
- # # )
121
-
122
- # if input_image is not None:
123
- # input_image = resize_img(input_image, max_side=1024)
124
- # cv_input_image = pil_to_cv2(input_image)
125
- # detected_map = cv2.Canny(cv_input_image, 50, 200)
126
- # canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
127
- # else:
128
- # canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
129
- # control_scale = 0
130
-
131
- # if float(control_scale) == 0:
132
- # canny_map = canny_map.resize((1024, 1024))
133
-
134
- # if len(neg_content_prompt) > 0 and neg_content_scale != 0:
135
- # images = self.ip_model.generate(
136
- # pil_image=image_pil,
137
- # prompt=prompt,
138
- # negative_prompt=n_prompt,
139
- # scale=scale,
140
- # guidance_scale=guidance_scale,
141
- # num_samples=1,
142
- # num_inference_steps=num_inference_steps,
143
- # seed=seed,
144
- # image=canny_map,
145
- # controlnet_conditioning_scale=float(control_scale),
146
- # neg_content_prompt=neg_content_prompt,
147
- # neg_content_scale=neg_content_scale,
148
- # )
149
- # else:
150
- # images = self.ip_model.generate(
151
- # pil_image=image_pil,
152
- # prompt=prompt,
153
- # negative_prompt=n_prompt,
154
- # scale=scale,
155
- # guidance_scale=guidance_scale,
156
- # num_samples=1,
157
- # num_inference_steps=num_inference_steps,
158
- # seed=seed,
159
- # image=canny_map,
160
- # controlnet_conditioning_scale=float(control_scale),
161
- # )
162
- # image = images[0]
163
-
164
- # return image
165
-
166
-
167
- # def pil_to_cv2(image_pil):
168
- # image_np = np.array(image_pil)
169
- # image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
170
- # return image_cv2
171
-
172
- # def resize_img(
173
- # input_image,
174
- # max_side=1280,
175
- # min_side=1024,
176
- # size=None,
177
- # pad_to_max_side=False,
178
- # mode=Image.BILINEAR,
179
- # base_pixel_number=64,
180
- # ):
181
- # w, h = input_image.size
182
- # if size is not None:
183
- # w_resize_new, h_resize_new = size
184
- # else:
185
- # ratio = min_side / min(h, w)
186
- # w, h = round(ratio * w), round(ratio * h)
187
- # ratio = max_side / max(h, w)
188
- # input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
189
- # w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
190
- # h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
191
- # input_image = input_image.resize([w_resize_new, h_resize_new], mode)
192
-
193
- # if pad_to_max_side:
194
- # res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
195
- # offset_x = (max_side - w_resize_new) // 2
196
- # offset_y = (max_side - h_resize_new) // 2
197
- # res[
198
- # offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
199
- # ] = np.array(input_image)
200
- # input_image = Image.fromarray(res)
201
- # return input_image
202
-
203
- # prompte = data.pop("inputs", "a man flying in the sky in Mars")
204
-
205
- # print("Prompt: ", prompte)
206
-
207
- # style_image = "https://huggingface.co/spaces/radames/InstantStyle-SDXL-Lightning/resolve/main/assets/0.jpg"
208
- # source_image =None
209
- # prompt = "a cat, masterpiece, best quality, high quality"
210
- # scale =1.0
211
- # control_scale =0.0
212
-
213
-
214
- # return create_image(
215
- # image_pil=style_image,
216
- # input_image=source_image,
217
- # prompt=prompt,
218
- # n_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
219
- # scale=scale,
220
- # control_scale=control_scale,
221
- # guidance_scale=0.0,
222
- # num_inference_steps=2,
223
- # seed=42,
224
- # target="Load only style blocks",
225
- # neg_content_prompt="",
226
- # neg_content_scale=0,
227
- # )
228
- return "Hello World"
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ import holidays
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ self.pipeline = pipeline("text-classification",model=path)
8
+ self.holidays = holidays.US()
9
+
10
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
11
+ """
12
+ data args:
13
+ inputs (:obj: `str`)
14
+ date (:obj: `str`)
15
+ Return:
16
+ A :obj:`list` | `dict`: will be serialized and returned
17
+ """
18
+ # get inputs
19
+ inputs = data.pop("inputs",data)
20
+ date = data.pop("date", None)
21
+
22
+ # check if date exists and if it is a holiday
23
+ if date is not None and date in self.holidays:
24
+ return [{"label": "happy", "score": 1}]
25
+
26
+
27
+ # run normal prediction
28
+ prediction = self.pipeline(inputs)
29
+ return prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.18.0
2
+ holidays==0.13