File size: 3,526 Bytes
1aabf84
4409dea
 
 
 
905bc0d
4409dea
 
 
 
 
17119de
4409dea
6f92f77
a2bc943
4409dea
17119de
d8adfd2
eaf5244
 
d8adfd2
17119de
d8adfd2
 
 
 
 
17119de
 
 
 
4409dea
17119de
4409dea
1aabf84
6f92f77
 
4409dea
6f92f77
 
 
 
9e8b405
17119de
6f92f77
 
 
 
 
 
 
 
 
4409dea
1aabf84
6f92f77
f08dfbf
6f92f77
4409dea
905bc0d
f08dfbf
 
 
905bc0d
f08dfbf
ec2c1d4
 
 
 
 
4409dea
 
1aabf84
 
 
 
6f92f77
1aabf84
 
 
 
 
4409dea
 
1aabf84
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Dict
import torch
from diffusers import FluxKontextPipeline
from io import BytesIO
import base64
from PIL import Image, ImageOps  # Updated import

class EndpointHandler:
    def __init__(self, path: str = ""):
        print("πŸš€ Initializing Flux Kontext pipeline...")

        # Load base model from Hugging Face
        self.pipe = FluxKontextPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-Kontext-dev",
            torch_dtype=torch.float32,
        )

        # Debug available methods on pipeline
        print("πŸ” Available methods on pipeline:", dir(self.pipe))

        # Load your LoRA weights from your Hugging Face repo
        try:
            self.pipe.load_lora_weights(
                "Texttra/BhoriKontext",
                weight_name="Bh0r1.safetensors"
            )
            print("βœ… LoRA weights loaded from Texttra/BhoriKontext/Bh0r1.safetensors.")
        except Exception as e:
            print(f"⚠️ Failed to load LoRA weights: {str(e)}")

        # Move pipeline to GPU if available
        self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
        print("βœ… Model ready with LoRA applied.")

    def __call__(self, data: Dict) -> Dict:
        print("πŸ”§ Received raw data type:", type(data))
        print("πŸ”§ Received raw data content:", data)

        # Defensive parsing
        if isinstance(data, dict):
            prompt = data.get("prompt")
            image_input = data.get("image")

            # If 'inputs' key is used (HF Inference schema)
            if prompt is None and image_input is None:
                inputs = data.get("inputs")
                if isinstance(inputs, dict):
                    prompt = inputs.get("prompt")
                    image_input = inputs.get("image")
                else:
                    return {"error": "Expected 'inputs' to be a JSON object containing 'prompt' and 'image'."}
        else:
            return {"error": "Input payload must be a JSON object."}

        if not prompt:
            return {"error": "Missing 'prompt' in input data."}
        if not image_input:
            return {"error": "Missing 'image' (base64) in input data."}

        # Decode image from base64 and correct orientation
        try:
            image_bytes = base64.b64decode(image_input)
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
            image = ImageOps.exif_transpose(image)  # Correct EXIF orientation here
        except Exception as e:
            return {"error": f"Failed to decode 'image' as base64: {str(e)}"}

        # Debug prints for prompt and image size
        print(f"πŸ“ Final prompt: {prompt}")
        print(f"πŸ–ΌοΈ Image size: {image.size}")

        # Generate edited image with Kontext
        try:
            output = self.pipe(
                prompt=prompt,
                image=image,
                num_inference_steps=28,
                guidance_scale=3.5
            ).images[0]
            print("🎨 Image generated.")
        except Exception as e:
            return {"error": f"Model inference failed: {str(e)}"}

        # Encode output image to base64
        try:
            buffer = BytesIO()
            output.save(buffer, format="PNG")
            base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
            print("βœ… Returning image.")
            return {"image": base64_image}
        except Exception as e:
            return {"error": f"Failed to encode output image: {str(e)}"}