Texttra commited on
Commit
7e29046
·
verified ·
1 Parent(s): f6f3cdf

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -6
handler.py CHANGED
@@ -4,16 +4,19 @@ from diffusers import DiffusionPipeline
4
  from compel import Compel
5
  from io import BytesIO
6
  import base64
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
 
10
  self.pipe = DiffusionPipeline.from_pretrained(
11
  "black-forest-labs/FLUX.1-dev",
12
  torch_dtype=torch.float16,
13
  use_auth_token=True
14
  )
15
-
16
- self.pipe.load_lora_weights(path, weight_name="c1t3_v1.safetensors")
 
17
 
18
  if torch.cuda.is_available():
19
  self.pipe.to("cuda")
@@ -26,20 +29,27 @@ class EndpointHandler:
26
  tokenizer=self.pipe.tokenizer,
27
  text_encoder=self.pipe.text_encoder
28
  )
 
 
 
 
29
 
30
- def __call__(self, data: Dict[str, Dict[str, str]]) -> Dict:
31
  inputs = data.get("inputs", {})
32
  prompt = inputs.get("prompt", "")
 
 
33
  if not prompt:
34
  return {"error": "No prompt provided."}
35
 
36
- print(f"Received prompt: {prompt}")
37
-
38
  conditioning = self.compel(prompt)
 
 
39
  image = self.pipe(prompt_embeds=conditioning).images[0]
 
40
 
41
  buffer = BytesIO()
42
  image.save(buffer, format="PNG")
43
  base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
 
44
 
45
- return {"image": base64_image}
 
4
  from compel import Compel
5
  from io import BytesIO
6
  import base64
7
+ import os
8
 
9
  class EndpointHandler:
10
  def __init__(self, path: str = ""):
11
+ print(f"Initializing model from: {path}")
12
  self.pipe = DiffusionPipeline.from_pretrained(
13
  "black-forest-labs/FLUX.1-dev",
14
  torch_dtype=torch.float16,
15
  use_auth_token=True
16
  )
17
+ lora_path = os.path.join(path, "c1t3_v1.safetensors")
18
+ print(f"Loading LoRA weights from: {lora_path}")
19
+ self.pipe.load_lora_weights(lora_path)
20
 
21
  if torch.cuda.is_available():
22
  self.pipe.to("cuda")
 
29
  tokenizer=self.pipe.tokenizer,
30
  text_encoder=self.pipe.text_encoder
31
  )
32
+ print("Model initialized.")
33
+
34
+ def __call__(self, data: Dict) -> Dict:
35
+ print("Received data:", data)
36
 
 
37
  inputs = data.get("inputs", {})
38
  prompt = inputs.get("prompt", "")
39
+ print("Extracted prompt:", prompt)
40
+
41
  if not prompt:
42
  return {"error": "No prompt provided."}
43
 
 
 
44
  conditioning = self.compel(prompt)
45
+ print("Conditioning complete.")
46
+
47
  image = self.pipe(prompt_embeds=conditioning).images[0]
48
+ print("Image generated.")
49
 
50
  buffer = BytesIO()
51
  image.save(buffer, format="PNG")
52
  base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
53
+ print("Returning image.")
54
 
55
+ return {"image": base64_image}