msgxai-hg-api / handler.py
msgxai's picture
chore: fix code
4302ebf
import os
import json
import random
import re
import base64
from io import BytesIO
import torch
from huggingface_hub import snapshot_download
from diffusers import (
AutoencoderKL,
StableDiffusionXLPipeline,
EulerAncestralDiscreteScheduler,
DPMSolverSDEScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from PIL import Image
# Global constants
MAX_SEED = 12211231 # Maximum seed value for random generator
NUM_IMAGES_PER_PROMPT = 1 # Number of images to generate per prompt
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" # Flag to enable torch compilation
# --- Child-Content Filtering Functions ---
child_related_regex = re.compile(
r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|'
r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|'
r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))',
re.IGNORECASE
)
def remove_child_related_content(prompt: str) -> str:
"""Remove any child-related references from the prompt."""
# Filter out child-related words/phrases using regex
cleaned_prompt = re.sub(child_related_regex, '', prompt)
return cleaned_prompt.strip()
def contains_child_related_content(prompt: str) -> bool:
"""Check if the prompt contains child-related content."""
# Use regex to determine if prompt has child-related terms
return bool(child_related_regex.search(prompt))
# --- Utility Function: Convert PIL Image to Base64 ---
def pil_image_to_base64(img: Image.Image) -> str:
"""Convert a PIL Image to base64 encoded string."""
# Create a BytesIO buffer and save the image to it
buffered = BytesIO()
img.convert("RGB").save(buffered, format="WEBP", quality=90)
# Convert buffer to base64 string
return base64.b64encode(buffered.getvalue()).decode("utf-8")
class EndpointHandler:
"""
Custom handler for Hugging Face Inference Endpoints.
This class follows the HF Inference Endpoints specification.
For Hugging Face Inference Endpoints, only this class is needed.
It provides both the initialization (__init__) and inference (__call__) methods
required by the Hugging Face Inference API.
"""
def __init__(self, path="", config=None):
"""
Initialize the handler with model path and configurations.
Args:
path (str): Path to the model directory (used by HF Inference Endpoints).
config (dict, optional): Configuration for the handler, passed by HF Inference Endpoints.
"""
# Load configuration from app.conf or use provided config
try:
if config:
# Use config provided by HF Inference Endpoints
self.cfg = config
else:
# Try to load from app.conf as fallback
config_path = os.path.join(path, "app.conf") if path else "app.conf"
with open(config_path, "r") as f:
self.cfg = json.load(f)
print("Configuration loaded successfully")
except Exception as e:
print(f"Error loading configuration: {e}")
self.cfg = {}
# Load the model pipeline
print("Loading the model pipeline...")
self.pipe = self._load_pipeline_and_scheduler()
print("Model loaded successfully!")
def _load_pipeline_and_scheduler(self):
"""Load the Stable Diffusion pipeline and scheduler."""
# Get clip_skip from configuration, default to 0
clip_skip = self.cfg.get("clip_skip", 0)
# Download model files from Hugging Face Hub
ckpt_dir = snapshot_download(repo_id=self.cfg["model_id"])
# Load the VAE model (for decoding latents)
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
# Load the Stable Diffusion XL pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
ckpt_dir,
vae=vae,
torch_dtype=torch.float16,
use_safetensors=self.cfg.get("use_safetensors", True)
)
# Move model to GPU
pipe = pipe.to("cuda")
# Use efficient attention processor
pipe.unet.set_attn_processor(AttnProcessor2_0())
# Set up samplers/schedulers based on configuration
samplers = {
"Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
}
# Default to "DPM++ SDE Karras" if not specified
pipe.scheduler = samplers.get(self.cfg.get("sampler", "DPM++ SDE Karras"))
# Adjust the text encoder layers if needed using clip_skip
if clip_skip > 0:
pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
# Compile model if environment variable is set
if USE_TORCH_COMPILE:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
print("Model Compiled!")
return pipe
def __call__(self, data):
"""
Process the inference request.
This is called for each inference request by the Hugging Face Inference API.
Args:
data: The input data for the inference request
For HF Inference Endpoints, this is typically a dict with "inputs" field
Returns:
list: A list containing the generated image as base64 string and seed
This follows the HF Inference Endpoints output format
"""
# Validate that the model is loaded
if not hasattr(self, 'pipe') or self.pipe is None:
return {"error": "Model not loaded. Please check initialization logs."}
# Parse the request payload
try:
if isinstance(data, dict):
payload = data
else:
# Assuming the request is a JSON string
payload = json.loads(data)
except Exception as e:
return {"error": f"Failed to parse request data: {str(e)}"}
# Extract parameters from the payload
parameters = {}
if "parameters" in payload and isinstance(payload["parameters"], dict):
# HF Inference Endpoints format: {"inputs": "prompt", "parameters": {...}}
parameters = payload["parameters"]
# Get the prompt from the payload
prompt_text = payload.get("inputs", "")
if not prompt_text:
# Try to get prompt from different fields for compatibility
prompt_text = payload.get("prompt", "")
if not prompt_text:
return {"error": "No prompt provided. Please include 'inputs' or 'prompt' field."}
# Apply child-content filtering to the prompt
if contains_child_related_content(prompt_text):
prompt_text = remove_child_related_content(prompt_text)
# Replace placeholder in the prompt template from config
combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text)
# Use negative_prompt from parameters or payload, fall back to config
negative_prompt = parameters.get("negative_prompt", payload.get("negative_prompt", self.cfg.get("negative_prompt", "")))
# Get dimensions from config (default to 1024x768 if not specified)
width = int(self.cfg.get("width", 1024))
height = int(self.cfg.get("height", 768))
# Other generation parameters
inference_steps = int(parameters.get("inference_steps", payload.get("inference_steps", self.cfg.get("inference_steps", 30))))
guidance_scale = float(parameters.get("guidance_scale", payload.get("guidance_scale", self.cfg.get("guidance_scale", 7))))
# Use provided seed or generate a random one
seed = int(parameters.get("seed", payload.get("seed", random.randint(0, MAX_SEED))))
generator = torch.Generator(self.pipe.device).manual_seed(seed)
try:
# Generate the image using the pipeline
outputs = self.pipe(
prompt=combined_prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
generator=generator,
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
output_type="pil"
)
# Convert the first generated image to base64
img_base64 = pil_image_to_base64(outputs.images[0])
# Return the response formatted for Hugging Face Inference Endpoints
return [{"generated_image": img_base64, "seed": seed}]
except Exception as e:
# Log the error and return an error response
error_message = f"Image generation failed: {str(e)}"
print(error_message)
return {"error": error_message}
# For local testing without HF Inference Endpoints
if __name__ == "__main__":
import argparse
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Run the text-to-image API locally")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
args = parser.parse_args()
# Create FastAPI app
app = FastAPI(title="Text-to-Image API with Content Filtering")
# Initialize the handler
handler = EndpointHandler()
@app.get("/")
async def read_root():
"""Health check endpoint."""
return {"status": "ok", "message": "Text-to-Image API is running"}
@app.post("/")
async def generate_image(request: Request):
"""Main inference endpoint."""
try:
body = await request.json()
result = handler(body)
if "error" in result:
return JSONResponse(status_code=500, content={"error": result["error"]})
return result
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"Failed to process request: {str(e)}"}
)
# Run the server
print(f"Starting server on http://{args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)