refoundd commited on
Commit
37ecb84
·
verified ·
1 Parent(s): cb76de1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -17
handler.py CHANGED
@@ -4,21 +4,16 @@ from typing import Any, Dict
4
  from PIL import Image
5
  import torch
6
  from diffusers import FluxPipeline
7
- from huggingface_inference_toolkit.logging import logger
8
  import time
9
- IS_TURBO=True
10
- class EndpointHandler:
11
- def __init__(self, path=""):
12
- #dtype = torch.float16 # for older nVidia GPUs
13
-
14
- self.pipeline =FluxPipeline.from_pretrained("NoMoreCopyrightOrg/flux-dev-8step", torch_dtype=torch.bfloat16)
15
-
16
-
17
- self.pipeline.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
18
 
19
- prompt = "A cat holding a sign that says hello world"
20
 
21
- torch.cuda.empty_cache()
 
 
 
 
 
22
 
23
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
24
  logger.info(f"Received incoming request with {data=}")
@@ -35,7 +30,7 @@ class EndpointHandler:
35
 
36
  parameters = data.pop("parameters", {})
37
 
38
- num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
39
  width = parameters.get("width", 1024)
40
  height = parameters.get("height", 1024)
41
  guidance_scale = parameters.get("guidance_scale", 3.5)
@@ -43,8 +38,8 @@ class EndpointHandler:
43
  # seed generator (seed cannot be provided as is but via a generator)
44
  seed = parameters.get("seed", 0)
45
  generator = torch.manual_seed(seed)
46
- start_time=time.time()
47
- result = self.pipeline( # type: ignore
48
  prompt,
49
  height=height,
50
  width=width,
@@ -57,5 +52,3 @@ class EndpointHandler:
57
  time_taken = end_time - start_time
58
  print(f"Time taken: {time_taken:.2f} seconds")
59
  return result
60
-
61
-
 
4
  from PIL import Image
5
  import torch
6
  from diffusers import FluxPipeline
7
+ from transformers import logger
8
  import time
 
 
 
 
 
 
 
 
 
9
 
 
10
 
11
+ class EndpointHandler:
12
+ def __init__(self):
13
+ self.pipe = FluxPipeline.from_pretrained(
14
+ "NoMoreCopyrightOrg/flux-dev",
15
+ torch_dtype=torch.bfloat16,
16
+ ).to("cuda")
17
 
18
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
19
  logger.info(f"Received incoming request with {data=}")
 
30
 
31
  parameters = data.pop("parameters", {})
32
 
33
+ num_inference_steps = parameters.get("num_inference_steps", 28)
34
  width = parameters.get("width", 1024)
35
  height = parameters.get("height", 1024)
36
  guidance_scale = parameters.get("guidance_scale", 3.5)
 
38
  # seed generator (seed cannot be provided as is but via a generator)
39
  seed = parameters.get("seed", 0)
40
  generator = torch.manual_seed(seed)
41
+ start_time = time.time()
42
+ result = self.pipe( # type: ignore
43
  prompt,
44
  height=height,
45
  width=width,
 
52
  time_taken = end_time - start_time
53
  print(f"Time taken: {time_taken:.2f} seconds")
54
  return result