import os import json import random import re import base64 from io import BytesIO import torch from huggingface_hub import snapshot_download from diffusers import ( AutoencoderKL, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, DPMSolverSDEScheduler ) from diffusers.models.attention_processor import AttnProcessor2_0 from PIL import Image # Global constants MAX_SEED = 12211231 # Maximum seed value for random generator NUM_IMAGES_PER_PROMPT = 1 # Number of images to generate per prompt USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" # Flag to enable torch compilation # --- Child-Content Filtering Functions --- child_related_regex = re.compile( r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|' r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|' r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))', re.IGNORECASE ) def remove_child_related_content(prompt: str) -> str: """Remove any child-related references from the prompt.""" # Filter out child-related words/phrases using regex cleaned_prompt = re.sub(child_related_regex, '', prompt) return cleaned_prompt.strip() def contains_child_related_content(prompt: str) -> bool: """Check if the prompt contains child-related content.""" # Use regex to determine if prompt has child-related terms return bool(child_related_regex.search(prompt)) # --- Utility Function: Convert PIL Image to Base64 --- def pil_image_to_base64(img: Image.Image) -> str: """Convert a PIL Image to base64 encoded string.""" # Create a BytesIO buffer and save the image to it buffered = BytesIO() img.convert("RGB").save(buffered, format="WEBP", quality=90) # Convert buffer to base64 string return base64.b64encode(buffered.getvalue()).decode("utf-8") class EndpointHandler: """ Custom handler for Hugging Face Inference Endpoints. This class follows the HF Inference Endpoints specification. For Hugging Face Inference Endpoints, only this class is needed. It provides both the initialization (__init__) and inference (__call__) methods required by the Hugging Face Inference API. """ def __init__(self, path="", config=None): """ Initialize the handler with model path and configurations. Args: path (str): Path to the model directory (used by HF Inference Endpoints). config (dict, optional): Configuration for the handler, passed by HF Inference Endpoints. """ # Load configuration from app.conf or use provided config try: if config: # Use config provided by HF Inference Endpoints self.cfg = config else: # Try to load from app.conf as fallback config_path = os.path.join(path, "app.conf") if path else "app.conf" with open(config_path, "r") as f: self.cfg = json.load(f) print("Configuration loaded successfully") except Exception as e: print(f"Error loading configuration: {e}") self.cfg = {} # Load the model pipeline print("Loading the model pipeline...") self.pipe = self._load_pipeline_and_scheduler() print("Model loaded successfully!") def _load_pipeline_and_scheduler(self): """Load the Stable Diffusion pipeline and scheduler.""" # Get clip_skip from configuration, default to 0 clip_skip = self.cfg.get("clip_skip", 0) # Download model files from Hugging Face Hub ckpt_dir = snapshot_download(repo_id=self.cfg["model_id"]) # Load the VAE model (for decoding latents) vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16) # Load the Stable Diffusion XL pipeline pipe = StableDiffusionXLPipeline.from_pretrained( ckpt_dir, vae=vae, torch_dtype=torch.float16, use_safetensors=self.cfg.get("use_safetensors", True) ) # Move model to GPU pipe = pipe.to("cuda") # Use efficient attention processor pipe.unet.set_attn_processor(AttnProcessor2_0()) # Set up samplers/schedulers based on configuration samplers = { "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config), "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) } # Default to "DPM++ SDE Karras" if not specified pipe.scheduler = samplers.get(self.cfg.get("sampler", "DPM++ SDE Karras")) # Adjust the text encoder layers if needed using clip_skip if clip_skip > 0: pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1) # Compile model if environment variable is set if USE_TORCH_COMPILE: pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) print("Model Compiled!") return pipe def __call__(self, data): """ Process the inference request. This is called for each inference request by the Hugging Face Inference API. Args: data: The input data for the inference request For HF Inference Endpoints, this is typically a dict with "inputs" field Returns: list: A list containing the generated image as base64 string and seed This follows the HF Inference Endpoints output format """ # Validate that the model is loaded if not hasattr(self, 'pipe') or self.pipe is None: return {"error": "Model not loaded. Please check initialization logs."} # Parse the request payload try: if isinstance(data, dict): payload = data else: # Assuming the request is a JSON string payload = json.loads(data) except Exception as e: return {"error": f"Failed to parse request data: {str(e)}"} # Extract parameters from the payload parameters = {} if "parameters" in payload and isinstance(payload["parameters"], dict): # HF Inference Endpoints format: {"inputs": "prompt", "parameters": {...}} parameters = payload["parameters"] # Get the prompt from the payload prompt_text = payload.get("inputs", "") if not prompt_text: # Try to get prompt from different fields for compatibility prompt_text = payload.get("prompt", "") if not prompt_text: return {"error": "No prompt provided. Please include 'inputs' or 'prompt' field."} # Apply child-content filtering to the prompt if contains_child_related_content(prompt_text): prompt_text = remove_child_related_content(prompt_text) # Replace placeholder in the prompt template from config combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text) # Use negative_prompt from parameters or payload, fall back to config negative_prompt = parameters.get("negative_prompt", payload.get("negative_prompt", self.cfg.get("negative_prompt", ""))) # Get dimensions from config (default to 1024x768 if not specified) width = int(self.cfg.get("width", 1024)) height = int(self.cfg.get("height", 768)) # Other generation parameters inference_steps = int(parameters.get("inference_steps", payload.get("inference_steps", self.cfg.get("inference_steps", 30)))) guidance_scale = float(parameters.get("guidance_scale", payload.get("guidance_scale", self.cfg.get("guidance_scale", 7)))) # Use provided seed or generate a random one seed = int(parameters.get("seed", payload.get("seed", random.randint(0, MAX_SEED)))) generator = torch.Generator(self.pipe.device).manual_seed(seed) try: # Generate the image using the pipeline outputs = self.pipe( prompt=combined_prompt, negative_prompt=negative_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=inference_steps, generator=generator, num_images_per_prompt=NUM_IMAGES_PER_PROMPT, output_type="pil" ) # Convert the first generated image to base64 img_base64 = pil_image_to_base64(outputs.images[0]) # Return the response formatted for Hugging Face Inference Endpoints return [{"generated_image": img_base64, "seed": seed}] except Exception as e: # Log the error and return an error response error_message = f"Image generation failed: {str(e)}" print(error_message) return {"error": error_message} # For local testing without HF Inference Endpoints if __name__ == "__main__": import argparse import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse # Parse command-line arguments parser = argparse.ArgumentParser(description="Run the text-to-image API locally") parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") args = parser.parse_args() # Create FastAPI app app = FastAPI(title="Text-to-Image API with Content Filtering") # Initialize the handler handler = EndpointHandler() @app.get("/") async def read_root(): """Health check endpoint.""" return {"status": "ok", "message": "Text-to-Image API is running"} @app.post("/") async def generate_image(request: Request): """Main inference endpoint.""" try: body = await request.json() result = handler(body) if "error" in result: return JSONResponse(status_code=500, content={"error": result["error"]}) return result except Exception as e: return JSONResponse( status_code=500, content={"error": f"Failed to process request: {str(e)}"} ) # Run the server print(f"Starting server on http://{args.host}:{args.port}") uvicorn.run(app, host=args.host, port=args.port)