|
|
import os |
|
|
import json |
|
|
import random |
|
|
import re |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
|
|
|
import torch |
|
|
from huggingface_hub import snapshot_download |
|
|
from diffusers import ( |
|
|
AutoencoderKL, |
|
|
StableDiffusionXLPipeline, |
|
|
EulerAncestralDiscreteScheduler, |
|
|
DPMSolverSDEScheduler |
|
|
) |
|
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
MAX_SEED = 12211231 |
|
|
NUM_IMAGES_PER_PROMPT = 1 |
|
|
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" |
|
|
|
|
|
|
|
|
child_related_regex = re.compile( |
|
|
r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|' |
|
|
r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|' |
|
|
r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))', |
|
|
re.IGNORECASE |
|
|
) |
|
|
|
|
|
def remove_child_related_content(prompt: str) -> str: |
|
|
"""Remove any child-related references from the prompt.""" |
|
|
|
|
|
cleaned_prompt = re.sub(child_related_regex, '', prompt) |
|
|
return cleaned_prompt.strip() |
|
|
|
|
|
def contains_child_related_content(prompt: str) -> bool: |
|
|
"""Check if the prompt contains child-related content.""" |
|
|
|
|
|
return bool(child_related_regex.search(prompt)) |
|
|
|
|
|
|
|
|
def pil_image_to_base64(img: Image.Image) -> str: |
|
|
"""Convert a PIL Image to base64 encoded string.""" |
|
|
|
|
|
buffered = BytesIO() |
|
|
img.convert("RGB").save(buffered, format="WEBP", quality=90) |
|
|
|
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom handler for Hugging Face Inference Endpoints. |
|
|
This class follows the HF Inference Endpoints specification. |
|
|
|
|
|
For Hugging Face Inference Endpoints, only this class is needed. |
|
|
It provides both the initialization (__init__) and inference (__call__) methods |
|
|
required by the Hugging Face Inference API. |
|
|
""" |
|
|
|
|
|
def __init__(self, path="", config=None): |
|
|
""" |
|
|
Initialize the handler with model path and configurations. |
|
|
|
|
|
Args: |
|
|
path (str): Path to the model directory (used by HF Inference Endpoints). |
|
|
config (dict, optional): Configuration for the handler, passed by HF Inference Endpoints. |
|
|
""" |
|
|
|
|
|
try: |
|
|
if config: |
|
|
|
|
|
self.cfg = config |
|
|
else: |
|
|
|
|
|
config_path = os.path.join(path, "app.conf") if path else "app.conf" |
|
|
with open(config_path, "r") as f: |
|
|
self.cfg = json.load(f) |
|
|
print("Configuration loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading configuration: {e}") |
|
|
self.cfg = {} |
|
|
|
|
|
|
|
|
print("Loading the model pipeline...") |
|
|
self.pipe = self._load_pipeline_and_scheduler() |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
def _load_pipeline_and_scheduler(self): |
|
|
"""Load the Stable Diffusion pipeline and scheduler.""" |
|
|
|
|
|
clip_skip = self.cfg.get("clip_skip", 0) |
|
|
|
|
|
|
|
|
ckpt_dir = snapshot_download(repo_id=self.cfg["model_id"]) |
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16) |
|
|
|
|
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
|
ckpt_dir, |
|
|
vae=vae, |
|
|
torch_dtype=torch.float16, |
|
|
use_safetensors=self.cfg.get("use_safetensors", True) |
|
|
) |
|
|
|
|
|
pipe = pipe.to("cuda") |
|
|
|
|
|
pipe.unet.set_attn_processor(AttnProcessor2_0()) |
|
|
|
|
|
|
|
|
samplers = { |
|
|
"Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config), |
|
|
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) |
|
|
} |
|
|
|
|
|
pipe.scheduler = samplers.get(self.cfg.get("sampler", "DPM++ SDE Karras")) |
|
|
|
|
|
|
|
|
if clip_skip > 0: |
|
|
pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1) |
|
|
|
|
|
|
|
|
if USE_TORCH_COMPILE: |
|
|
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) |
|
|
print("Model Compiled!") |
|
|
|
|
|
return pipe |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Process the inference request. |
|
|
This is called for each inference request by the Hugging Face Inference API. |
|
|
|
|
|
Args: |
|
|
data: The input data for the inference request |
|
|
For HF Inference Endpoints, this is typically a dict with "inputs" field |
|
|
|
|
|
Returns: |
|
|
list: A list containing the generated image as base64 string and seed |
|
|
This follows the HF Inference Endpoints output format |
|
|
""" |
|
|
|
|
|
if not hasattr(self, 'pipe') or self.pipe is None: |
|
|
return {"error": "Model not loaded. Please check initialization logs."} |
|
|
|
|
|
|
|
|
try: |
|
|
if isinstance(data, dict): |
|
|
payload = data |
|
|
else: |
|
|
|
|
|
payload = json.loads(data) |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to parse request data: {str(e)}"} |
|
|
|
|
|
|
|
|
parameters = {} |
|
|
if "parameters" in payload and isinstance(payload["parameters"], dict): |
|
|
|
|
|
parameters = payload["parameters"] |
|
|
|
|
|
|
|
|
prompt_text = payload.get("inputs", "") |
|
|
if not prompt_text: |
|
|
|
|
|
prompt_text = payload.get("prompt", "") |
|
|
|
|
|
if not prompt_text: |
|
|
return {"error": "No prompt provided. Please include 'inputs' or 'prompt' field."} |
|
|
|
|
|
|
|
|
if contains_child_related_content(prompt_text): |
|
|
prompt_text = remove_child_related_content(prompt_text) |
|
|
|
|
|
|
|
|
combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text) |
|
|
|
|
|
negative_prompt = parameters.get("negative_prompt", payload.get("negative_prompt", self.cfg.get("negative_prompt", ""))) |
|
|
|
|
|
|
|
|
width = int(self.cfg.get("width", 1024)) |
|
|
height = int(self.cfg.get("height", 768)) |
|
|
|
|
|
|
|
|
inference_steps = int(parameters.get("inference_steps", payload.get("inference_steps", self.cfg.get("inference_steps", 30)))) |
|
|
guidance_scale = float(parameters.get("guidance_scale", payload.get("guidance_scale", self.cfg.get("guidance_scale", 7)))) |
|
|
|
|
|
|
|
|
seed = int(parameters.get("seed", payload.get("seed", random.randint(0, MAX_SEED)))) |
|
|
generator = torch.Generator(self.pipe.device).manual_seed(seed) |
|
|
|
|
|
try: |
|
|
|
|
|
outputs = self.pipe( |
|
|
prompt=combined_prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=inference_steps, |
|
|
generator=generator, |
|
|
num_images_per_prompt=NUM_IMAGES_PER_PROMPT, |
|
|
output_type="pil" |
|
|
) |
|
|
|
|
|
img_base64 = pil_image_to_base64(outputs.images[0]) |
|
|
|
|
|
|
|
|
return [{"generated_image": img_base64, "seed": seed}] |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
error_message = f"Image generation failed: {str(e)}" |
|
|
print(error_message) |
|
|
return {"error": error_message} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
import uvicorn |
|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run the text-to-image API locally") |
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
app = FastAPI(title="Text-to-Image API with Content Filtering") |
|
|
|
|
|
|
|
|
handler = EndpointHandler() |
|
|
|
|
|
@app.get("/") |
|
|
async def read_root(): |
|
|
"""Health check endpoint.""" |
|
|
return {"status": "ok", "message": "Text-to-Image API is running"} |
|
|
|
|
|
@app.post("/") |
|
|
async def generate_image(request: Request): |
|
|
"""Main inference endpoint.""" |
|
|
try: |
|
|
body = await request.json() |
|
|
result = handler(body) |
|
|
|
|
|
if "error" in result: |
|
|
return JSONResponse(status_code=500, content={"error": result["error"]}) |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"error": f"Failed to process request: {str(e)}"} |
|
|
) |
|
|
|
|
|
|
|
|
print(f"Starting server on http://{args.host}:{args.port}") |
|
|
uvicorn.run(app, host=args.host, port=args.port) |