Texttra commited on
Commit
d22fba2
·
verified ·
1 Parent(s): 110a12d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +49 -16
handler.py CHANGED
@@ -1,23 +1,56 @@
1
- def __call__(self, data: Dict) -> Dict:
2
- print("Received data:", data)
 
 
 
 
 
3
 
4
- inputs = data.get("inputs", {})
5
- prompt = inputs.get("prompt", "")
 
 
 
 
 
 
6
 
7
- print("Extracted prompt:", prompt)
 
 
8
 
9
- if not prompt:
10
- return {"error": "No prompt provided."}
 
 
11
 
12
- conditioning = self.compel(prompt)
13
- print("Conditioning complete.")
14
 
15
- image = self.pipe(prompt_embeds=conditioning).images[0]
16
- print("Image generated.")
 
 
 
17
 
18
- buffer = BytesIO()
19
- image.save(buffer, format="PNG")
20
- base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
21
 
22
- print("Returning image.")
23
- return {"image": base64_image}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import torch
3
+ from diffusers import DiffusionPipeline
4
+ from compel import Compel
5
+ from io import BytesIO
6
+ import base64
7
+ import os
8
 
9
+ class EndpointHandler:
10
+ def __init__(self, path: str = ""):
11
+ print(f"Initializing model from: {path}")
12
+ self.pipe = DiffusionPipeline.from_pretrained(
13
+ "black-forest-labs/FLUX.1-dev",
14
+ torch_dtype=torch.float16,
15
+ use_auth_token=True
16
+ )
17
 
18
+ lora_path = os.path.join(path, "c1t3_v1.safetensors")
19
+ print(f"Loading LoRA weights from: {lora_path}")
20
+ self.pipe.load_lora_weights(lora_path)
21
 
22
+ if torch.cuda.is_available():
23
+ self.pipe.to("cuda")
24
+ else:
25
+ self.pipe.to("cpu")
26
 
27
+ self.pipe.enable_model_cpu_offload()
 
28
 
29
+ self.compel = Compel(
30
+ tokenizer=self.pipe.tokenizer,
31
+ text_encoder=self.pipe.text_encoder
32
+ )
33
+ print("Model initialized successfully.")
34
 
35
+ def __call__(self, data: Dict) -> Dict:
36
+ print("Received data:", data)
 
37
 
38
+ inputs = data.get("inputs", {})
39
+ prompt = inputs.get("prompt", "")
40
+ print("Extracted prompt:", prompt)
41
+
42
+ if not prompt:
43
+ return {"error": "No prompt provided."}
44
+
45
+ conditioning = self.compel(prompt)
46
+ print("Conditioning complete.")
47
+
48
+ image = self.pipe(prompt_embeds=conditioning).images[0]
49
+ print("Image generated.")
50
+
51
+ buffer = BytesIO()
52
+ image.save(buffer, format="PNG")
53
+ base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
54
+ print("Returning image.")
55
+
56
+ return {"image": base64_image}