Texttra commited on
Commit
1aabf84
Β·
verified Β·
1 Parent(s): 9e8b405

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -33
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Union
2
  import torch
3
  from diffusers import FluxKontextPipeline
4
  from io import BytesIO
@@ -9,52 +9,63 @@ class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
  print("πŸš€ Initializing Flux Kontext pipeline...")
11
 
 
12
  self.pipe = FluxKontextPipeline.from_pretrained(
13
- "black-forest-labs/FLUX.1-Kontext-dev",
14
  torch_dtype=torch.float16,
15
  )
16
  self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
17
  print("βœ… Model ready.")
18
 
19
- def __call__(self, data: Union[Dict, Image.Image]) -> Dict:
20
  print("πŸ”§ Received data:", data)
21
 
22
- # Handle direct PIL image input
23
- if isinstance(data, Image.Image):
24
- return {"error": "Prompt input missing. Received raw image without prompt."}
25
 
26
- # Handle dict input
27
- inputs = data.get("inputs") if isinstance(data, dict) else None
28
- if inputs is None:
29
- return {"error": "Invalid input format. Expected dict with 'inputs'."}
30
 
31
  prompt = inputs.get("prompt")
32
- image_base64 = inputs.get("image")
33
 
34
- if not prompt or not image_base64:
35
- return {"error": "Both 'prompt' and 'image' inputs are required."}
36
 
37
- # Decode input image from base64
38
- try:
39
- image_bytes = base64.b64decode(image_base64)
40
- image = Image.open(BytesIO(image_bytes)).convert("RGB")
41
- except Exception as e:
42
- return {"error": f"Failed to decode image. Error: {str(e)}"}
 
 
 
 
 
 
 
 
43
 
44
  # Generate edited image with Kontext
45
- output = self.pipe(
46
- prompt=prompt,
47
- image=image,
48
- num_inference_steps=28,
49
- guidance_scale=3.5
50
- ).images[0]
51
-
52
- print("🎨 Image generated.")
 
 
53
 
54
  # Encode output image to base64
55
- buffer = BytesIO()
56
- output.save(buffer, format="PNG")
57
- base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
58
-
59
- print("βœ… Returning image.")
60
- return {"image": base64_image}
 
 
 
1
+ from typing import Dict
2
  import torch
3
  from diffusers import FluxKontextPipeline
4
  from io import BytesIO
 
9
  def __init__(self, path: str = ""):
10
  print("πŸš€ Initializing Flux Kontext pipeline...")
11
 
12
+ # Load Flux Kontext model from Hugging Face Hub
13
  self.pipe = FluxKontextPipeline.from_pretrained(
14
+ "black-forest-labs/FLUX.1-Kontext-dev", # replace with your specific Kontext model if different
15
  torch_dtype=torch.float16,
16
  )
17
  self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
18
  print("βœ… Model ready.")
19
 
20
+ def __call__(self, data: Dict) -> Dict:
21
  print("πŸ”§ Received data:", data)
22
 
23
+ inputs = data.get("inputs")
24
+ if not inputs:
25
+ return {"error": "'inputs' key missing. Payload must include an 'inputs' dictionary."}
26
 
27
+ if not isinstance(inputs, dict):
28
+ return {"error": "'inputs' must be a JSON object with 'prompt' and optionally 'image'."}
 
 
29
 
30
  prompt = inputs.get("prompt")
31
+ image_input = inputs.get("image")
32
 
33
+ if not prompt:
34
+ return {"error": "Prompt is required in 'inputs'."}
35
 
36
+ # Process image input if provided
37
+ image = None
38
+ if image_input:
39
+ if isinstance(image_input, str):
40
+ try:
41
+ # Assume it's base64 encoded
42
+ image_bytes = base64.b64decode(image_input)
43
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
44
+ except Exception as e:
45
+ return {"error": f"Failed to decode base64 image input: {str(e)}"}
46
+ elif isinstance(image_input, Image.Image):
47
+ image = image_input
48
+ else:
49
+ return {"error": "'image' must be a base64 string or a PIL.Image object."}
50
 
51
  # Generate edited image with Kontext
52
+ try:
53
+ output = self.pipe(
54
+ prompt=prompt,
55
+ image=image,
56
+ num_inference_steps=28, # context standard
57
+ guidance_scale=3.5
58
+ ).images[0]
59
+ print("🎨 Image generated.")
60
+ except Exception as e:
61
+ return {"error": f"Model inference failed: {str(e)}"}
62
 
63
  # Encode output image to base64
64
+ try:
65
+ buffer = BytesIO()
66
+ output.save(buffer, format="PNG")
67
+ base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
68
+ print("βœ… Returning image.")
69
+ return {"image": base64_image}
70
+ except Exception as e:
71
+ return {"error": f"Failed to encode output image: {str(e)}"}