refoundd commited on
Commit
2ede802
·
verified ·
1 Parent(s): c826456

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +58 -10
handler.py CHANGED
@@ -1,5 +1,3 @@
1
- # https://github.com/sayakpaul/diffusers-torchao
2
- #6.22s
3
  import os
4
  from typing import Any, Dict
5
  from PIL import Image
@@ -7,26 +5,71 @@ import torch
7
  from diffusers import FluxPipeline
8
  from huggingface_inference_toolkit.logging import logger
9
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
10
- # from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
11
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  class EndpointHandler:
14
- def __init__(self,path=""):
15
  self.pipe = FluxPipeline.from_pretrained(
16
  "NoMoreCopyrightOrg/flux-dev",
17
  torch_dtype=torch.bfloat16,
18
  ).to("cuda")
19
  apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
20
- # quantize_(self.pipe.text_encoder, float8_weight_only())
21
- # quantize_(self.pipe.transformer, float8_dynamic_activation_float8_weight())
22
  self.pipe.transformer = torch.compile(
23
  self.pipe.transformer, mode="max-autotune-no-cudagraphs",
24
  )
25
  self.pipe.vae = torch.compile(
26
  self.pipe.vae, mode="max-autotune-no-cudagraphs",
27
  )
28
-
29
- def __call__(self, data: Dict[str, Any]) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
30
  logger.info(f"Received incoming request with {data=}")
31
 
32
  if "inputs" in data and isinstance(data["inputs"], str):
@@ -57,9 +100,14 @@ class EndpointHandler:
57
  guidance_scale=guidance_scale,
58
  num_inference_steps=num_inference_steps,
59
  generator=generator,
60
- # output_type="pil" if dist.get_rank() == 0 else "pt",
61
  ).images[0]
62
  end_time = time.time()
63
  time_taken = end_time - start_time
64
  print(f"Time taken: {time_taken:.2f} seconds")
65
- return result
 
 
 
 
 
 
 
 
 
1
  import os
2
  from typing import Any, Dict
3
  from PIL import Image
 
5
  from diffusers import FluxPipeline
6
  from huggingface_inference_toolkit.logging import logger
7
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
 
8
  import time
9
+ import uuid
10
+ from huggingface_hub import HfApi
11
+
12
+ from pyngrok import ngrok
13
+ import subprocess
14
+ from fastapi import FastAPI
15
+ from fastapi.responses import FileResponse
16
+ import uvicorn
17
+ # Flask
18
+ image_directory='./images'
19
+ if not os.path.exists(image_directory):
20
+ os.makedirs(image_directory)
21
+
22
+ app = FastAPI()
23
+
24
+ @app.get("/images/{image_name}")
25
+ async def get_image(image_name: str):
26
+ image_path = os.path.join(image_directory, image_name)
27
+
28
+ if os.path.exists(image_path):
29
+ return FileResponse(image_path)
30
+ else:
31
+ return {"error": "Image not found"}
32
+
33
+ authtoken = "2cvqFKWc1eb9b0aN7pRLDUBfEtC_2FUehxFL8CAKXRkW3Hfjo"
34
+ commands = [
35
+ "curl -sSL https://ngrok-agent.s3.amazonaws.com/ngrok.asc | sudo tee /etc/apt/trusted.gpg.d/ngrok.asc >/dev/null",
36
+ 'echo "deb https://ngrok-agent.s3.amazonaws.com buster main" | sudo tee /etc/apt/sources.list.d/ngrok.list',
37
+ "sudo apt update",
38
+ "sudo apt install -y ngrok",
39
+ f"ngrok config add-authtoken {authtoken}",
40
+ ]
41
+ for command in commands:
42
+ try:
43
+ subprocess.run(command, shell=True, check=True)
44
+ logger.info(f"SUCCESS CMD: {command}")
45
+ except subprocess.CalledProcessError as e:
46
+ logger.info(f"Failed CMD: {e}")
47
 
48
  class EndpointHandler:
49
+ def __init__(self, path=""):
50
  self.pipe = FluxPipeline.from_pretrained(
51
  "NoMoreCopyrightOrg/flux-dev",
52
  torch_dtype=torch.bfloat16,
53
  ).to("cuda")
54
  apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
 
 
55
  self.pipe.transformer = torch.compile(
56
  self.pipe.transformer, mode="max-autotune-no-cudagraphs",
57
  )
58
  self.pipe.vae = torch.compile(
59
  self.pipe.vae, mode="max-autotune-no-cudagraphs",
60
  )
61
+ # Fastapi Run
62
+ uvicorn.run(app, host="127.0.0.1", port=5000)
63
+ # ngrok
64
+ self.public_url = ngrok.connect(5000).public_url
65
+ command = "ngrok http 5000"
66
+ try:
67
+ subprocess.run(command, shell=True, check=True)
68
+ print("ngrok HTTP run sucessfully")
69
+ except subprocess.CalledProcessError as e:
70
+ print(f"Falied ngrok: {e}")
71
+
72
+ def __call__(self, data: Dict[str, Any]) -> str:
73
  logger.info(f"Received incoming request with {data=}")
74
 
75
  if "inputs" in data and isinstance(data["inputs"], str):
 
100
  guidance_scale=guidance_scale,
101
  num_inference_steps=num_inference_steps,
102
  generator=generator,
 
103
  ).images[0]
104
  end_time = time.time()
105
  time_taken = end_time - start_time
106
  print(f"Time taken: {time_taken:.2f} seconds")
107
+ filename = f"{uuid.uuid4()}.png"
108
+ image_path = f"/images/{filename}"
109
+
110
+ result.save(image_path)
111
+ image_url = f"{self.public_url+image_path}"
112
+
113
+ return image_url