Texttra commited on
Commit
e369958
·
verified ·
1 Parent(s): 10d876d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -28
handler.py CHANGED
@@ -1,36 +1,23 @@
1
  from typing import Dict
2
  import torch
3
- from diffusers import DiffusionPipeline
4
- from compel import Compel
5
  from io import BytesIO
6
  import base64
7
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
  print(f"Initializing model from: {path}")
11
- self.pipe = DiffusionPipeline.from_pretrained(
 
12
  "black-forest-labs/FLUX.1-dev",
13
- torch_dtype=torch.float16,
14
- use_auth_token=True # Required for gated base model
15
  )
16
 
17
- # Load LoRA weights from your Hugging Face repo
18
  print("Loading LoRA weights from: Texttra/Cityscape_Studio")
19
  self.pipe.load_lora_weights("Texttra/Cityscape_Studio", weight_name="c1t3_v1.safetensors")
 
20
 
21
- # Send to GPU if available
22
- if torch.cuda.is_available():
23
- self.pipe.to("cuda")
24
- else:
25
- self.pipe.to("cpu")
26
-
27
- self.pipe.enable_model_cpu_offload()
28
-
29
- # Initialize Compel for prompt conditioning
30
- self.compel = Compel(
31
- tokenizer=self.pipe.tokenizer,
32
- text_encoder=self.pipe.text_encoder
33
- )
34
  print("Model initialized successfully.")
35
 
36
  def __call__(self, data: Dict) -> Dict:
@@ -41,20 +28,15 @@ class EndpointHandler:
41
  print("Extracted prompt:", prompt)
42
 
43
  if not prompt:
44
- return {"error": "No prompt provided"}
45
-
46
- # Generate both prompt and pooled embeddings
47
- conditioning, pooled = self.compel(prompt, return_pooled=True)
48
- print("Conditioning complete.")
49
 
50
- # Run the model
51
  image = self.pipe(
52
- prompt_embeds=conditioning,
53
- pooled_prompt_embeds=pooled
 
54
  ).images[0]
55
  print("Image generated.")
56
 
57
- # Encode image to base64
58
  buffer = BytesIO()
59
  image.save(buffer, format="PNG")
60
  base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
 
1
  from typing import Dict
2
  import torch
3
+ from diffusers import FluxPipeline
 
4
  from io import BytesIO
5
  import base64
6
 
7
  class EndpointHandler:
8
  def __init__(self, path: str = ""):
9
  print(f"Initializing model from: {path}")
10
+
11
+ self.pipe = FluxPipeline.from_pretrained(
12
  "black-forest-labs/FLUX.1-dev",
13
+ torch_dtype=torch.float16
 
14
  )
15
 
 
16
  print("Loading LoRA weights from: Texttra/Cityscape_Studio")
17
  self.pipe.load_lora_weights("Texttra/Cityscape_Studio", weight_name="c1t3_v1.safetensors")
18
+ self.pipe.fuse_lora(lora_scale=1.0)
19
 
20
+ self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
21
  print("Model initialized successfully.")
22
 
23
  def __call__(self, data: Dict) -> Dict:
 
28
  print("Extracted prompt:", prompt)
29
 
30
  if not prompt:
31
+ return {"error": "No prompt provided."}
 
 
 
 
32
 
 
33
  image = self.pipe(
34
+ prompt,
35
+ num_inference_steps=28,
36
+ guidance_scale=4.5
37
  ).images[0]
38
  print("Image generated.")
39
 
 
40
  buffer = BytesIO()
41
  image.save(buffer, format="PNG")
42
  base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")