File size: 4,082 Bytes
06e4a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Handler for QwenImageLayeredPipeline.
Decomposes an input RGBA image into semantic layers (foreground, background, objects, etc.)
"""
from typing import Dict, List, Any
import torch
import base64
import io
from PIL import Image

# Try to import the specific pipeline class
try:
    from diffusers import QwenImageLayeredPipeline
except ImportError:
    from diffusers import DiffusionPipeline
    QwenImageLayeredPipeline = None

class EndpointHandler:
    def __init__(self, path=""):
        # The correct model for layered decomposition
        model_id = "Qwen/Qwen-Image-Layered"
        
        print(f"Loading model {model_id}...")
        
        if QwenImageLayeredPipeline:
            print("Using explicit QwenImageLayeredPipeline class.")
            self.pipeline = QwenImageLayeredPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
            )
        else:
            print("Falling back to DiffusionPipeline auto-load.")
            self.pipeline = DiffusionPipeline.from_pretrained(
                model_id,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
            )
            print(f"Loaded pipeline class: {type(self.pipeline).__name__}")
        
        if torch.cuda.is_available():
            self.pipeline.to("cuda")
            
        print("Model ready!")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Expects:
            inputs.image: base64-encoded RGBA image
            parameters.layers: number of layers to decompose into (default: 4)
            parameters.num_inference_steps: inference steps (default: 50)
            parameters.resolution: output resolution (default: 640)
        
        Returns:
            List of base64-encoded layer images
        """
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})
        
        # Parse the input image
        image_data = inputs.get("image")
        if not image_data:
            raise ValueError("Missing 'image' in inputs. Please provide a base64-encoded RGBA image.")
        
        try:
            image_bytes = base64.b64decode(image_data)
            image = Image.open(io.BytesIO(image_bytes)).convert("RGBA")
        except Exception as e:
            raise ValueError(f"Failed to decode image: {e}")
        
        # Get parameters with defaults
        layers = parameters.get("layers", 4)
        num_inference_steps = parameters.get("num_inference_steps", 50)
        resolution = parameters.get("resolution", 640)
        prompt = parameters.get("prompt", "")  # Usually empty for decomposition
        
        print(f"Decomposing image into {layers} layers at resolution {resolution}...")
        
        # Run the pipeline
        with torch.autocast("cuda"):
            output = self.pipeline(
                image,
                prompt,
                num_inference_steps=num_inference_steps,
                layers=layers,
                resolution=resolution,
                true_cfg_scale=4.0,
                cfg_normalize=False,
                use_en_prompt=True,
            )
        
        # Serialize output layers
        images_response = []
        
        if hasattr(output, "images") and output.images:
            # output.images is a list of lists (per batch), we take the first batch
            layer_images = output.images[0] if isinstance(output.images[0], list) else output.images
            
            for i, layer_img in enumerate(layer_images):
                if isinstance(layer_img, Image.Image):
                    buffered = io.BytesIO()
                    layer_img.save(buffered, format="PNG")
                    img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
                    images_response.append({
                        "layer_index": i,
                        "image": img_str
                    })
                    
        print(f"Returned {len(images_response)} layers.")
        return images_response