chore: main-backup
Browse files- README.md +75 -4
- app.conf +11 -0
- app.py +27 -2
- requirements.txt +7 -0
- src/handler.py +194 -0
README.md
CHANGED
|
@@ -1,12 +1,83 @@
|
|
| 1 |
---
|
| 2 |
-
title: Msgxai
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
-
short_description:
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Msgxai Hugging Face Inference API
|
| 3 |
+
emoji: 🖼️
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
+
short_description: Stable Diffusion XL image generation API
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Msgxai Hugging Face Inference API
|
| 13 |
+
|
| 14 |
+
A custom Hugging Face Inference Endpoint for Stable Diffusion XL image generation.
|
| 15 |
+
|
| 16 |
+
## Configuration
|
| 17 |
+
|
| 18 |
+
The API is configured through the `app.conf` JSON file with the following parameters:
|
| 19 |
+
|
| 20 |
+
```json
|
| 21 |
+
{
|
| 22 |
+
"model_id": "model-repo-id", // The Hugging Face model repository ID
|
| 23 |
+
"name": "your-model-name", // A name for your model (optional)
|
| 24 |
+
"prompt": "{prompt}", // Prompt template with {prompt} placeholder
|
| 25 |
+
"negative_prompt": "...", // Default negative prompt
|
| 26 |
+
"width": 1024, // Default image width
|
| 27 |
+
"height": 768, // Default image height
|
| 28 |
+
"inference_steps": 30, // Default number of inference steps
|
| 29 |
+
"guidance_scale": 7, // Default guidance scale
|
| 30 |
+
"use_safetensors": true, // Whether to use safetensors
|
| 31 |
+
"clip_skip": 0 // Optional CLIP skip value (0 = disabled)
|
| 32 |
+
}
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## API Usage
|
| 36 |
+
|
| 37 |
+
### Health Check
|
| 38 |
+
```
|
| 39 |
+
GET /
|
| 40 |
+
```
|
| 41 |
+
Returns: `{"status": "healthy"}`
|
| 42 |
+
|
| 43 |
+
### Generate Image
|
| 44 |
+
```
|
| 45 |
+
POST /predict
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Request Body:
|
| 49 |
+
```json
|
| 50 |
+
{
|
| 51 |
+
"prompt": "your image prompt here",
|
| 52 |
+
"negative_prompt": "optional negative prompt",
|
| 53 |
+
"width": 1024,
|
| 54 |
+
"height": 768,
|
| 55 |
+
"inference_steps": 30,
|
| 56 |
+
"guidance_scale": 7,
|
| 57 |
+
"seed": 42
|
| 58 |
+
}
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Response:
|
| 62 |
+
```json
|
| 63 |
+
{
|
| 64 |
+
"image_base64": "base64-encoded-image-data",
|
| 65 |
+
"seed": 42
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Note: All parameters except `prompt` are optional and will use defaults from `app.conf` if not provided.
|
| 70 |
+
|
| 71 |
+
## Deployment
|
| 72 |
+
|
| 73 |
+
1. Configure your `app.conf` file with desired model and parameters
|
| 74 |
+
2. Ensure all dependencies are in `requirements.txt`
|
| 75 |
+
3. Deploy to Hugging Face Inference Endpoints
|
| 76 |
+
|
| 77 |
+
## Content Filtering
|
| 78 |
+
|
| 79 |
+
The API includes built-in filtering for child-related content in prompts.
|
| 80 |
+
|
| 81 |
+
## Environment Variables
|
| 82 |
+
|
| 83 |
+
- `USE_TORCH_COMPILE`: Set to "1" to enable torch compilation (default: "0")
|
app.conf
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_id": "John6666/wai-ani-hentai-pony-v3-sdxl",
|
| 3 |
+
"name": "hentai-waianiv6-card",
|
| 4 |
+
"prompt": "score_9, score_8_up, score_7_up,rating_explicit,BREAK, {prompt}",
|
| 5 |
+
"negative_prompt": "source_furry, source_pony, source_cartoon,3d, blurry, incest, beastiality, children, loli, child, kids, teens, text, logo, timestamp, artist name, artist logo, watermark, web address, copyright name, copyright notice, emblem, comic, title, logo, character name, border, patreon username, signature, webpage, company name, caption, labels, comments",
|
| 6 |
+
"width": 1024,
|
| 7 |
+
"height": 768,
|
| 8 |
+
"inference_steps": 30,
|
| 9 |
+
"guidance_scale": 7,
|
| 10 |
+
"use_safetensors": true
|
| 11 |
+
}
|
app.py
CHANGED
|
@@ -1,7 +1,32 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
|
|
|
|
|
|
| 2 |
|
| 3 |
app = FastAPI()
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
@app.get("/")
|
| 6 |
def greet_json():
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request
|
| 2 |
+
from src.handler import EndpointHandler
|
| 3 |
+
import json
|
| 4 |
|
| 5 |
app = FastAPI()
|
| 6 |
|
| 7 |
+
# Initialize the handler
|
| 8 |
+
handler = EndpointHandler()
|
| 9 |
+
|
| 10 |
@app.get("/")
|
| 11 |
def greet_json():
|
| 12 |
+
"""Simple health check endpoint."""
|
| 13 |
+
return {"status": "healthy"}
|
| 14 |
+
|
| 15 |
+
@app.post("/predict")
|
| 16 |
+
async def predict(request: Request):
|
| 17 |
+
"""
|
| 18 |
+
Main prediction endpoint that processes image generation requests.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
request (Request): The FastAPI request object
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
dict: The generated image as base64 and other metadata
|
| 25 |
+
"""
|
| 26 |
+
# Parse the request data
|
| 27 |
+
data = await request.json()
|
| 28 |
+
|
| 29 |
+
# Process the request using our handler
|
| 30 |
+
result = handler(data)
|
| 31 |
+
|
| 32 |
+
return result
|
requirements.txt
CHANGED
|
@@ -1,2 +1,9 @@
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
| 3 |
+
diffusers
|
| 4 |
+
transformers
|
| 5 |
+
torch
|
| 6 |
+
accelerate
|
| 7 |
+
huggingface_hub
|
| 8 |
+
pillow
|
| 9 |
+
safetensors
|
src/handler.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import base64
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from huggingface_hub import snapshot_download
|
| 10 |
+
from diffusers import (
|
| 11 |
+
AutoencoderKL,
|
| 12 |
+
StableDiffusionXLPipeline,
|
| 13 |
+
EulerAncestralDiscreteScheduler,
|
| 14 |
+
DPMSolverSDEScheduler
|
| 15 |
+
)
|
| 16 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 17 |
+
from PIL import Image
|
| 18 |
+
|
| 19 |
+
# Global constants
|
| 20 |
+
MAX_SEED = 12211231 # Maximum seed value for random generator
|
| 21 |
+
NUM_IMAGES_PER_PROMPT = 1 # Number of images to generate per prompt
|
| 22 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" # Flag to enable torch compilation
|
| 23 |
+
|
| 24 |
+
# --- Child-Content Filtering Functions ---
|
| 25 |
+
child_related_regex = re.compile(
|
| 26 |
+
r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|'
|
| 27 |
+
r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|'
|
| 28 |
+
r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))',
|
| 29 |
+
re.IGNORECASE
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def remove_child_related_content(prompt: str) -> str:
|
| 33 |
+
"""Remove any child-related references from the prompt."""
|
| 34 |
+
# Filter out child-related words/phrases using regex
|
| 35 |
+
cleaned_prompt = re.sub(child_related_regex, '', prompt)
|
| 36 |
+
return cleaned_prompt.strip()
|
| 37 |
+
|
| 38 |
+
def contains_child_related_content(prompt: str) -> bool:
|
| 39 |
+
"""Check if the prompt contains child-related content."""
|
| 40 |
+
# Use regex to determine if prompt has child-related terms
|
| 41 |
+
return bool(child_related_regex.search(prompt))
|
| 42 |
+
|
| 43 |
+
# --- Utility Function: Convert PIL Image to Base64 ---
|
| 44 |
+
def pil_image_to_base64(img: Image.Image) -> str:
|
| 45 |
+
"""Convert a PIL Image to base64 encoded string."""
|
| 46 |
+
# Create a BytesIO buffer and save the image to it
|
| 47 |
+
buffered = BytesIO()
|
| 48 |
+
img.convert("RGB").save(buffered, format="WEBP", quality=90)
|
| 49 |
+
# Convert buffer to base64 string
|
| 50 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 51 |
+
|
| 52 |
+
class EndpointHandler:
|
| 53 |
+
"""
|
| 54 |
+
Custom handler for Hugging Face Inference Endpoints.
|
| 55 |
+
This class follows the HF Inference Endpoints specification.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, path="", config=None):
|
| 59 |
+
"""
|
| 60 |
+
Initialize the handler with model path and configurations.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
path (str): Path to the model. Not used for this implementation.
|
| 64 |
+
config (dict, optional): Configuration for the handler. Not used for this implementation.
|
| 65 |
+
"""
|
| 66 |
+
# Load configuration from app.conf
|
| 67 |
+
try:
|
| 68 |
+
with open("app.conf", "r") as f:
|
| 69 |
+
self.cfg = json.load(f)
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Error loading configuration: {e}")
|
| 72 |
+
self.cfg = {}
|
| 73 |
+
|
| 74 |
+
# Load the model pipeline
|
| 75 |
+
print("Loading the model pipeline...")
|
| 76 |
+
self.pipe = self._load_pipeline_and_scheduler()
|
| 77 |
+
print("Model loaded successfully!")
|
| 78 |
+
|
| 79 |
+
def _load_pipeline_and_scheduler(self):
|
| 80 |
+
"""Load the Stable Diffusion pipeline and scheduler."""
|
| 81 |
+
# Get clip_skip from configuration, default to 0
|
| 82 |
+
clip_skip = self.cfg.get("clip_skip", 0)
|
| 83 |
+
|
| 84 |
+
# Download model files from Hugging Face Hub
|
| 85 |
+
ckpt_dir = snapshot_download(repo_id=self.cfg["model_id"])
|
| 86 |
+
|
| 87 |
+
# Load the VAE model (for decoding latents)
|
| 88 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
|
| 89 |
+
|
| 90 |
+
# Load the Stable Diffusion XL pipeline
|
| 91 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 92 |
+
ckpt_dir,
|
| 93 |
+
vae=vae,
|
| 94 |
+
torch_dtype=torch.float16,
|
| 95 |
+
use_safetensors=self.cfg.get("use_safetensors", True),
|
| 96 |
+
variant="fp16"
|
| 97 |
+
)
|
| 98 |
+
# Move model to GPU
|
| 99 |
+
pipe = pipe.to("cuda")
|
| 100 |
+
# Use efficient attention processor
|
| 101 |
+
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
| 102 |
+
|
| 103 |
+
# Set up samplers/schedulers based on configuration
|
| 104 |
+
samplers = {
|
| 105 |
+
"Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
|
| 106 |
+
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
|
| 107 |
+
}
|
| 108 |
+
# Default to "DPM++ SDE Karras" if not specified
|
| 109 |
+
pipe.scheduler = samplers.get(self.cfg.get("sampler", "DPM++ SDE Karras"))
|
| 110 |
+
|
| 111 |
+
# Adjust the text encoder layers if needed using clip_skip
|
| 112 |
+
if clip_skip > 0:
|
| 113 |
+
pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
|
| 114 |
+
|
| 115 |
+
# Compile model if environment variable is set
|
| 116 |
+
if USE_TORCH_COMPILE:
|
| 117 |
+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 118 |
+
print("Model Compiled!")
|
| 119 |
+
|
| 120 |
+
return pipe
|
| 121 |
+
|
| 122 |
+
def __call__(self, data):
|
| 123 |
+
"""
|
| 124 |
+
Process the inference request.
|
| 125 |
+
This is called for each inference request.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
data: The input data for the inference request
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
dict: The result of the inference
|
| 132 |
+
"""
|
| 133 |
+
# Validate that the model is loaded
|
| 134 |
+
if not self.pipe:
|
| 135 |
+
return {"error": "Model not loaded. Please check initialization logs."}
|
| 136 |
+
|
| 137 |
+
# Parse the request payload
|
| 138 |
+
try:
|
| 139 |
+
if isinstance(data, dict):
|
| 140 |
+
payload = data
|
| 141 |
+
else:
|
| 142 |
+
# Assuming the request is a JSON string
|
| 143 |
+
payload = json.loads(data)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
return {"error": f"Failed to parse request data: {str(e)}"}
|
| 146 |
+
|
| 147 |
+
# Get the prompt from the payload
|
| 148 |
+
prompt_text = payload.get("prompt", "")
|
| 149 |
+
if not prompt_text:
|
| 150 |
+
return {"error": "No prompt provided"}
|
| 151 |
+
|
| 152 |
+
# Apply child-content filtering to the prompt
|
| 153 |
+
if contains_child_related_content(prompt_text):
|
| 154 |
+
prompt_text = remove_child_related_content(prompt_text)
|
| 155 |
+
|
| 156 |
+
# Replace placeholder in the prompt template from config
|
| 157 |
+
combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text)
|
| 158 |
+
# Use negative_prompt if provided; otherwise, default to config
|
| 159 |
+
negative_prompt = payload.get("negative_prompt", self.cfg.get("negative_prompt", ""))
|
| 160 |
+
|
| 161 |
+
# Get parameters from config or override with request params
|
| 162 |
+
width = int(payload.get("width", self.cfg.get("width", 1024)))
|
| 163 |
+
height = int(payload.get("height", self.cfg.get("height", 768)))
|
| 164 |
+
inference_steps = int(payload.get("inference_steps", self.cfg.get("inference_steps", 30)))
|
| 165 |
+
guidance_scale = float(payload.get("guidance_scale", self.cfg.get("guidance_scale", 7)))
|
| 166 |
+
|
| 167 |
+
# Use provided seed or generate a random one
|
| 168 |
+
seed = int(payload.get("seed", random.randint(0, MAX_SEED)))
|
| 169 |
+
generator = torch.Generator(self.pipe.device).manual_seed(seed)
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
# Generate the image using the pipeline
|
| 173 |
+
outputs = self.pipe(
|
| 174 |
+
prompt=combined_prompt,
|
| 175 |
+
negative_prompt=negative_prompt,
|
| 176 |
+
width=width,
|
| 177 |
+
height=height,
|
| 178 |
+
guidance_scale=guidance_scale,
|
| 179 |
+
num_inference_steps=inference_steps,
|
| 180 |
+
generator=generator,
|
| 181 |
+
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
|
| 182 |
+
output_type="pil"
|
| 183 |
+
)
|
| 184 |
+
# Convert the first generated image to base64
|
| 185 |
+
img_base64 = pil_image_to_base64(outputs.images[0])
|
| 186 |
+
|
| 187 |
+
# Return the response
|
| 188 |
+
return {"image_base64": img_base64, "seed": seed}
|
| 189 |
+
|
| 190 |
+
except Exception as e:
|
| 191 |
+
# Log the error and return an error response
|
| 192 |
+
error_message = f"Image generation failed: {str(e)}"
|
| 193 |
+
print(error_message)
|
| 194 |
+
return {"error": error_message}
|