Texttra commited on
Commit
0c0b7e8
·
verified ·
1 Parent(s): bf7ff83

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -24
handler.py CHANGED
@@ -11,12 +11,14 @@ class EndpointHandler:
11
  self.pipe = DiffusionPipeline.from_pretrained(
12
  "black-forest-labs/FLUX.1-dev",
13
  torch_dtype=torch.float16,
14
- use_auth_token=True
15
  )
16
 
 
17
  print("Loading LoRA weights from: Texttra/Cityscape_Studio")
18
  self.pipe.load_lora_weights("Texttra/Cityscape_Studio", weight_name="c1t3_v1.safetensors")
19
 
 
20
  if torch.cuda.is_available():
21
  self.pipe.to("cuda")
22
  else:
@@ -24,6 +26,7 @@ class EndpointHandler:
24
 
25
  self.pipe.enable_model_cpu_offload()
26
 
 
27
  self.compel = Compel(
28
  tokenizer=self.pipe.tokenizer,
29
  text_encoder=self.pipe.text_encoder
@@ -33,32 +36,29 @@ class EndpointHandler:
33
  def __call__(self, data: Dict) -> Dict:
34
  print("Received data:", data)
35
 
36
- try:
37
- inputs = data.get("inputs", {})
38
- if isinstance(inputs, str):
39
- # In case the input comes in raw string form (e.g., Postman tests)
40
- prompt = inputs
41
- else:
42
- prompt = inputs.get("prompt", "")
43
 
44
- print("Extracted prompt:", prompt)
 
45
 
46
- if not prompt:
47
- return {"error": "No prompt provided"}
 
48
 
49
- conditioning = self.compel(prompt)
50
- print("Conditioning complete.")
 
 
 
 
51
 
52
- image = self.pipe(prompt_embeds=conditioning).images[0]
53
- print("Image generated.")
 
 
 
54
 
55
- buffer = BytesIO()
56
- image.save(buffer, format="PNG")
57
- base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
58
 
59
- print("Returning image.")
60
- return {"image": base64_image}
61
-
62
- except Exception as e:
63
- print(f"Error occurred: {str(e)}")
64
- return {"error": str(e)}
 
11
  self.pipe = DiffusionPipeline.from_pretrained(
12
  "black-forest-labs/FLUX.1-dev",
13
  torch_dtype=torch.float16,
14
+ use_auth_token=True # Required for gated base model
15
  )
16
 
17
+ # Load LoRA weights from your Hugging Face repo
18
  print("Loading LoRA weights from: Texttra/Cityscape_Studio")
19
  self.pipe.load_lora_weights("Texttra/Cityscape_Studio", weight_name="c1t3_v1.safetensors")
20
 
21
+ # Send to GPU if available
22
  if torch.cuda.is_available():
23
  self.pipe.to("cuda")
24
  else:
 
26
 
27
  self.pipe.enable_model_cpu_offload()
28
 
29
+ # Initialize Compel for prompt conditioning
30
  self.compel = Compel(
31
  tokenizer=self.pipe.tokenizer,
32
  text_encoder=self.pipe.text_encoder
 
36
  def __call__(self, data: Dict) -> Dict:
37
  print("Received data:", data)
38
 
39
+ inputs = data.get("inputs", {})
40
+ prompt = inputs.get("prompt", "")
41
+ print("Extracted prompt:", prompt)
 
 
 
 
42
 
43
+ if not prompt:
44
+ return {"error": "No prompt provided"}
45
 
46
+ # Generate both prompt and pooled embeddings
47
+ conditioning, pooled = self.compel(prompt, return_pooled=True)
48
+ print("Conditioning complete.")
49
 
50
+ # Run the model
51
+ image = self.pipe(
52
+ prompt_embeds=conditioning,
53
+ pooled_prompt_embeds=pooled
54
+ ).images[0]
55
+ print("Image generated.")
56
 
57
+ # Encode image to base64
58
+ buffer = BytesIO()
59
+ image.save(buffer, format="PNG")
60
+ base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
61
+ print("Returning image.")
62
 
63
+ return {"image": base64_image}
 
 
64