File size: 10,971 Bytes
8b41055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfeeedb
 
 
 
8b41055
 
 
 
 
 
 
bfeeedb
 
8b41055
bfeeedb
8b41055
bfeeedb
 
 
 
 
 
 
 
 
8b41055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4302ebf
8b41055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfeeedb
8b41055
 
 
bfeeedb
 
8b41055
bfeeedb
 
8b41055
 
bfeeedb
8b41055
 
 
 
 
 
 
 
 
 
 
 
bfeeedb
 
 
 
 
 
8b41055
bfeeedb
8b41055
bfeeedb
 
 
 
 
8b41055
 
 
 
 
 
 
bfeeedb
 
 
 
 
 
8b41055
bfeeedb
 
 
8b41055
 
bfeeedb
8b41055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfeeedb
 
8b41055
 
 
 
 
bfeeedb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import os
import json
import random
import re
import base64
from io import BytesIO

import torch
from huggingface_hub import snapshot_download
from diffusers import (
    AutoencoderKL,
    StableDiffusionXLPipeline,
    EulerAncestralDiscreteScheduler,
    DPMSolverSDEScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from PIL import Image

# Global constants
MAX_SEED = 12211231  # Maximum seed value for random generator
NUM_IMAGES_PER_PROMPT = 1  # Number of images to generate per prompt
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"  # Flag to enable torch compilation

# --- Child-Content Filtering Functions ---
child_related_regex = re.compile(
    r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|'
    r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|'
    r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))',
    re.IGNORECASE
)

def remove_child_related_content(prompt: str) -> str:
    """Remove any child-related references from the prompt."""
    # Filter out child-related words/phrases using regex
    cleaned_prompt = re.sub(child_related_regex, '', prompt)
    return cleaned_prompt.strip()

def contains_child_related_content(prompt: str) -> bool:
    """Check if the prompt contains child-related content."""
    # Use regex to determine if prompt has child-related terms
    return bool(child_related_regex.search(prompt))

# --- Utility Function: Convert PIL Image to Base64 ---
def pil_image_to_base64(img: Image.Image) -> str:
    """Convert a PIL Image to base64 encoded string."""
    # Create a BytesIO buffer and save the image to it
    buffered = BytesIO()
    img.convert("RGB").save(buffered, format="WEBP", quality=90)
    # Convert buffer to base64 string
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

class EndpointHandler:
    """
    Custom handler for Hugging Face Inference Endpoints.
    This class follows the HF Inference Endpoints specification.
    
    For Hugging Face Inference Endpoints, only this class is needed.
    It provides both the initialization (__init__) and inference (__call__) methods
    required by the Hugging Face Inference API.
    """
    
    def __init__(self, path="", config=None):
        """
        Initialize the handler with model path and configurations.
        
        Args:
            path (str): Path to the model directory (used by HF Inference Endpoints).
            config (dict, optional): Configuration for the handler, passed by HF Inference Endpoints.
        """
        # Load configuration from app.conf or use provided config
        try:
            if config:
                # Use config provided by HF Inference Endpoints
                self.cfg = config
            else:
                # Try to load from app.conf as fallback
                config_path = os.path.join(path, "app.conf") if path else "app.conf"
                with open(config_path, "r") as f:
                    self.cfg = json.load(f)
            print("Configuration loaded successfully")
        except Exception as e:
            print(f"Error loading configuration: {e}")
            self.cfg = {}
            
        # Load the model pipeline
        print("Loading the model pipeline...")
        self.pipe = self._load_pipeline_and_scheduler()
        print("Model loaded successfully!")
    
    def _load_pipeline_and_scheduler(self):
        """Load the Stable Diffusion pipeline and scheduler."""
        # Get clip_skip from configuration, default to 0
        clip_skip = self.cfg.get("clip_skip", 0)
        
        # Download model files from Hugging Face Hub
        ckpt_dir = snapshot_download(repo_id=self.cfg["model_id"])
        
        # Load the VAE model (for decoding latents)
        vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
        
        # Load the Stable Diffusion XL pipeline
        pipe = StableDiffusionXLPipeline.from_pretrained(
            ckpt_dir,
            vae=vae,
            torch_dtype=torch.float16,
            use_safetensors=self.cfg.get("use_safetensors", True)
        )
        # Move model to GPU
        pipe = pipe.to("cuda")
        # Use efficient attention processor
        pipe.unet.set_attn_processor(AttnProcessor2_0())
        
        # Set up samplers/schedulers based on configuration
        samplers = {
            "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
            "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
        }
        # Default to "DPM++ SDE Karras" if not specified
        pipe.scheduler = samplers.get(self.cfg.get("sampler", "DPM++ SDE Karras"))
        
        # Adjust the text encoder layers if needed using clip_skip
        if clip_skip > 0:
            pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
        
        # Compile model if environment variable is set
        if USE_TORCH_COMPILE:
            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
            print("Model Compiled!")
        
        return pipe
    
    def __call__(self, data):
        """
        Process the inference request.
        This is called for each inference request by the Hugging Face Inference API.
        
        Args:
            data: The input data for the inference request
                 For HF Inference Endpoints, this is typically a dict with "inputs" field
                 
        Returns:
            list: A list containing the generated image as base64 string and seed
                 This follows the HF Inference Endpoints output format
        """
        # Validate that the model is loaded
        if not hasattr(self, 'pipe') or self.pipe is None:
            return {"error": "Model not loaded. Please check initialization logs."}
        
        # Parse the request payload
        try:
            if isinstance(data, dict):
                payload = data
            else:
                # Assuming the request is a JSON string
                payload = json.loads(data)
        except Exception as e:
            return {"error": f"Failed to parse request data: {str(e)}"}
        
        # Extract parameters from the payload
        parameters = {}
        if "parameters" in payload and isinstance(payload["parameters"], dict):
            # HF Inference Endpoints format: {"inputs": "prompt", "parameters": {...}}
            parameters = payload["parameters"]
        
        # Get the prompt from the payload
        prompt_text = payload.get("inputs", "")
        if not prompt_text:
            # Try to get prompt from different fields for compatibility
            prompt_text = payload.get("prompt", "")
            
        if not prompt_text:
            return {"error": "No prompt provided. Please include 'inputs' or 'prompt' field."}
        
        # Apply child-content filtering to the prompt
        if contains_child_related_content(prompt_text):
            prompt_text = remove_child_related_content(prompt_text)
        
        # Replace placeholder in the prompt template from config
        combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text)
        # Use negative_prompt from parameters or payload, fall back to config
        negative_prompt = parameters.get("negative_prompt", payload.get("negative_prompt", self.cfg.get("negative_prompt", "")))
        
        # Get dimensions from config (default to 1024x768 if not specified)
        width = int(self.cfg.get("width", 1024))
        height = int(self.cfg.get("height", 768))
        
        # Other generation parameters
        inference_steps = int(parameters.get("inference_steps", payload.get("inference_steps", self.cfg.get("inference_steps", 30))))
        guidance_scale = float(parameters.get("guidance_scale", payload.get("guidance_scale", self.cfg.get("guidance_scale", 7))))
        
        # Use provided seed or generate a random one
        seed = int(parameters.get("seed", payload.get("seed", random.randint(0, MAX_SEED))))
        generator = torch.Generator(self.pipe.device).manual_seed(seed)
        
        try:
            # Generate the image using the pipeline
            outputs = self.pipe(
                prompt=combined_prompt,
                negative_prompt=negative_prompt,
                width=width,
                height=height,
                guidance_scale=guidance_scale,
                num_inference_steps=inference_steps,
                generator=generator,
                num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
                output_type="pil"
            )
            # Convert the first generated image to base64
            img_base64 = pil_image_to_base64(outputs.images[0])
            
            # Return the response formatted for Hugging Face Inference Endpoints
            return [{"generated_image": img_base64, "seed": seed}]
        
        except Exception as e:
            # Log the error and return an error response
            error_message = f"Image generation failed: {str(e)}"
            print(error_message)
            return {"error": error_message}

# For local testing without HF Inference Endpoints
if __name__ == "__main__":
    import argparse
    import uvicorn
    from fastapi import FastAPI, Request
    from fastapi.responses import JSONResponse
    
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Run the text-to-image API locally")
    parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
    args = parser.parse_args()
    
    # Create FastAPI app
    app = FastAPI(title="Text-to-Image API with Content Filtering")
    
    # Initialize the handler
    handler = EndpointHandler()
    
    @app.get("/")
    async def read_root():
        """Health check endpoint."""
        return {"status": "ok", "message": "Text-to-Image API is running"}
    
    @app.post("/")
    async def generate_image(request: Request):
        """Main inference endpoint."""
        try:
            body = await request.json()
            result = handler(body)
            
            if "error" in result:
                return JSONResponse(status_code=500, content={"error": result["error"]})
                
            return result
        except Exception as e:
            return JSONResponse(
                status_code=500,
                content={"error": f"Failed to process request: {str(e)}"}
            )
    
    # Run the server
    print(f"Starting server on http://{args.host}:{args.port}")
    uvicorn.run(app, host=args.host, port=args.port)