msgxai commited on
Commit
8b41055
·
1 Parent(s): 60b3314

chore: main-backup

Browse files
Files changed (5) hide show
  1. README.md +75 -4
  2. app.conf +11 -0
  3. app.py +27 -2
  4. requirements.txt +7 -0
  5. src/handler.py +194 -0
README.md CHANGED
@@ -1,12 +1,83 @@
1
  ---
2
- title: Msgxai Hg Api
3
- emoji: 🦀
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
- short_description: msgxai backend api
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Msgxai Hugging Face Inference API
3
+ emoji: 🖼️
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
+ short_description: Stable Diffusion XL image generation API
10
  ---
11
 
12
+ # Msgxai Hugging Face Inference API
13
+
14
+ A custom Hugging Face Inference Endpoint for Stable Diffusion XL image generation.
15
+
16
+ ## Configuration
17
+
18
+ The API is configured through the `app.conf` JSON file with the following parameters:
19
+
20
+ ```json
21
+ {
22
+ "model_id": "model-repo-id", // The Hugging Face model repository ID
23
+ "name": "your-model-name", // A name for your model (optional)
24
+ "prompt": "{prompt}", // Prompt template with {prompt} placeholder
25
+ "negative_prompt": "...", // Default negative prompt
26
+ "width": 1024, // Default image width
27
+ "height": 768, // Default image height
28
+ "inference_steps": 30, // Default number of inference steps
29
+ "guidance_scale": 7, // Default guidance scale
30
+ "use_safetensors": true, // Whether to use safetensors
31
+ "clip_skip": 0 // Optional CLIP skip value (0 = disabled)
32
+ }
33
+ ```
34
+
35
+ ## API Usage
36
+
37
+ ### Health Check
38
+ ```
39
+ GET /
40
+ ```
41
+ Returns: `{"status": "healthy"}`
42
+
43
+ ### Generate Image
44
+ ```
45
+ POST /predict
46
+ ```
47
+
48
+ Request Body:
49
+ ```json
50
+ {
51
+ "prompt": "your image prompt here",
52
+ "negative_prompt": "optional negative prompt",
53
+ "width": 1024,
54
+ "height": 768,
55
+ "inference_steps": 30,
56
+ "guidance_scale": 7,
57
+ "seed": 42
58
+ }
59
+ ```
60
+
61
+ Response:
62
+ ```json
63
+ {
64
+ "image_base64": "base64-encoded-image-data",
65
+ "seed": 42
66
+ }
67
+ ```
68
+
69
+ Note: All parameters except `prompt` are optional and will use defaults from `app.conf` if not provided.
70
+
71
+ ## Deployment
72
+
73
+ 1. Configure your `app.conf` file with desired model and parameters
74
+ 2. Ensure all dependencies are in `requirements.txt`
75
+ 3. Deploy to Hugging Face Inference Endpoints
76
+
77
+ ## Content Filtering
78
+
79
+ The API includes built-in filtering for child-related content in prompts.
80
+
81
+ ## Environment Variables
82
+
83
+ - `USE_TORCH_COMPILE`: Set to "1" to enable torch compilation (default: "0")
app.conf ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_id": "John6666/wai-ani-hentai-pony-v3-sdxl",
3
+ "name": "hentai-waianiv6-card",
4
+ "prompt": "score_9, score_8_up, score_7_up,rating_explicit,BREAK, {prompt}",
5
+ "negative_prompt": "source_furry, source_pony, source_cartoon,3d, blurry, incest, beastiality, children, loli, child, kids, teens, text, logo, timestamp, artist name, artist logo, watermark, web address, copyright name, copyright notice, emblem, comic, title, logo, character name, border, patreon username, signature, webpage, company name, caption, labels, comments",
6
+ "width": 1024,
7
+ "height": 768,
8
+ "inference_steps": 30,
9
+ "guidance_scale": 7,
10
+ "use_safetensors": true
11
+ }
app.py CHANGED
@@ -1,7 +1,32 @@
1
- from fastapi import FastAPI
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
5
  @app.get("/")
6
  def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from src.handler import EndpointHandler
3
+ import json
4
 
5
  app = FastAPI()
6
 
7
+ # Initialize the handler
8
+ handler = EndpointHandler()
9
+
10
  @app.get("/")
11
  def greet_json():
12
+ """Simple health check endpoint."""
13
+ return {"status": "healthy"}
14
+
15
+ @app.post("/predict")
16
+ async def predict(request: Request):
17
+ """
18
+ Main prediction endpoint that processes image generation requests.
19
+
20
+ Args:
21
+ request (Request): The FastAPI request object
22
+
23
+ Returns:
24
+ dict: The generated image as base64 and other metadata
25
+ """
26
+ # Parse the request data
27
+ data = await request.json()
28
+
29
+ # Process the request using our handler
30
+ result = handler(data)
31
+
32
+ return result
requirements.txt CHANGED
@@ -1,2 +1,9 @@
1
  fastapi
2
  uvicorn[standard]
 
 
 
 
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ diffusers
4
+ transformers
5
+ torch
6
+ accelerate
7
+ huggingface_hub
8
+ pillow
9
+ safetensors
src/handler.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import re
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ import torch
9
+ from huggingface_hub import snapshot_download
10
+ from diffusers import (
11
+ AutoencoderKL,
12
+ StableDiffusionXLPipeline,
13
+ EulerAncestralDiscreteScheduler,
14
+ DPMSolverSDEScheduler
15
+ )
16
+ from diffusers.models.attention_processor import AttnProcessor2_0
17
+ from PIL import Image
18
+
19
+ # Global constants
20
+ MAX_SEED = 12211231 # Maximum seed value for random generator
21
+ NUM_IMAGES_PER_PROMPT = 1 # Number of images to generate per prompt
22
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" # Flag to enable torch compilation
23
+
24
+ # --- Child-Content Filtering Functions ---
25
+ child_related_regex = re.compile(
26
+ r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|'
27
+ r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|'
28
+ r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))',
29
+ re.IGNORECASE
30
+ )
31
+
32
+ def remove_child_related_content(prompt: str) -> str:
33
+ """Remove any child-related references from the prompt."""
34
+ # Filter out child-related words/phrases using regex
35
+ cleaned_prompt = re.sub(child_related_regex, '', prompt)
36
+ return cleaned_prompt.strip()
37
+
38
+ def contains_child_related_content(prompt: str) -> bool:
39
+ """Check if the prompt contains child-related content."""
40
+ # Use regex to determine if prompt has child-related terms
41
+ return bool(child_related_regex.search(prompt))
42
+
43
+ # --- Utility Function: Convert PIL Image to Base64 ---
44
+ def pil_image_to_base64(img: Image.Image) -> str:
45
+ """Convert a PIL Image to base64 encoded string."""
46
+ # Create a BytesIO buffer and save the image to it
47
+ buffered = BytesIO()
48
+ img.convert("RGB").save(buffered, format="WEBP", quality=90)
49
+ # Convert buffer to base64 string
50
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
51
+
52
+ class EndpointHandler:
53
+ """
54
+ Custom handler for Hugging Face Inference Endpoints.
55
+ This class follows the HF Inference Endpoints specification.
56
+ """
57
+
58
+ def __init__(self, path="", config=None):
59
+ """
60
+ Initialize the handler with model path and configurations.
61
+
62
+ Args:
63
+ path (str): Path to the model. Not used for this implementation.
64
+ config (dict, optional): Configuration for the handler. Not used for this implementation.
65
+ """
66
+ # Load configuration from app.conf
67
+ try:
68
+ with open("app.conf", "r") as f:
69
+ self.cfg = json.load(f)
70
+ except Exception as e:
71
+ print(f"Error loading configuration: {e}")
72
+ self.cfg = {}
73
+
74
+ # Load the model pipeline
75
+ print("Loading the model pipeline...")
76
+ self.pipe = self._load_pipeline_and_scheduler()
77
+ print("Model loaded successfully!")
78
+
79
+ def _load_pipeline_and_scheduler(self):
80
+ """Load the Stable Diffusion pipeline and scheduler."""
81
+ # Get clip_skip from configuration, default to 0
82
+ clip_skip = self.cfg.get("clip_skip", 0)
83
+
84
+ # Download model files from Hugging Face Hub
85
+ ckpt_dir = snapshot_download(repo_id=self.cfg["model_id"])
86
+
87
+ # Load the VAE model (for decoding latents)
88
+ vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
89
+
90
+ # Load the Stable Diffusion XL pipeline
91
+ pipe = StableDiffusionXLPipeline.from_pretrained(
92
+ ckpt_dir,
93
+ vae=vae,
94
+ torch_dtype=torch.float16,
95
+ use_safetensors=self.cfg.get("use_safetensors", True),
96
+ variant="fp16"
97
+ )
98
+ # Move model to GPU
99
+ pipe = pipe.to("cuda")
100
+ # Use efficient attention processor
101
+ pipe.unet.set_attn_processor(AttnProcessor2_0())
102
+
103
+ # Set up samplers/schedulers based on configuration
104
+ samplers = {
105
+ "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
106
+ "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
107
+ }
108
+ # Default to "DPM++ SDE Karras" if not specified
109
+ pipe.scheduler = samplers.get(self.cfg.get("sampler", "DPM++ SDE Karras"))
110
+
111
+ # Adjust the text encoder layers if needed using clip_skip
112
+ if clip_skip > 0:
113
+ pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
114
+
115
+ # Compile model if environment variable is set
116
+ if USE_TORCH_COMPILE:
117
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
118
+ print("Model Compiled!")
119
+
120
+ return pipe
121
+
122
+ def __call__(self, data):
123
+ """
124
+ Process the inference request.
125
+ This is called for each inference request.
126
+
127
+ Args:
128
+ data: The input data for the inference request
129
+
130
+ Returns:
131
+ dict: The result of the inference
132
+ """
133
+ # Validate that the model is loaded
134
+ if not self.pipe:
135
+ return {"error": "Model not loaded. Please check initialization logs."}
136
+
137
+ # Parse the request payload
138
+ try:
139
+ if isinstance(data, dict):
140
+ payload = data
141
+ else:
142
+ # Assuming the request is a JSON string
143
+ payload = json.loads(data)
144
+ except Exception as e:
145
+ return {"error": f"Failed to parse request data: {str(e)}"}
146
+
147
+ # Get the prompt from the payload
148
+ prompt_text = payload.get("prompt", "")
149
+ if not prompt_text:
150
+ return {"error": "No prompt provided"}
151
+
152
+ # Apply child-content filtering to the prompt
153
+ if contains_child_related_content(prompt_text):
154
+ prompt_text = remove_child_related_content(prompt_text)
155
+
156
+ # Replace placeholder in the prompt template from config
157
+ combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text)
158
+ # Use negative_prompt if provided; otherwise, default to config
159
+ negative_prompt = payload.get("negative_prompt", self.cfg.get("negative_prompt", ""))
160
+
161
+ # Get parameters from config or override with request params
162
+ width = int(payload.get("width", self.cfg.get("width", 1024)))
163
+ height = int(payload.get("height", self.cfg.get("height", 768)))
164
+ inference_steps = int(payload.get("inference_steps", self.cfg.get("inference_steps", 30)))
165
+ guidance_scale = float(payload.get("guidance_scale", self.cfg.get("guidance_scale", 7)))
166
+
167
+ # Use provided seed or generate a random one
168
+ seed = int(payload.get("seed", random.randint(0, MAX_SEED)))
169
+ generator = torch.Generator(self.pipe.device).manual_seed(seed)
170
+
171
+ try:
172
+ # Generate the image using the pipeline
173
+ outputs = self.pipe(
174
+ prompt=combined_prompt,
175
+ negative_prompt=negative_prompt,
176
+ width=width,
177
+ height=height,
178
+ guidance_scale=guidance_scale,
179
+ num_inference_steps=inference_steps,
180
+ generator=generator,
181
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
182
+ output_type="pil"
183
+ )
184
+ # Convert the first generated image to base64
185
+ img_base64 = pil_image_to_base64(outputs.images[0])
186
+
187
+ # Return the response
188
+ return {"image_base64": img_base64, "seed": seed}
189
+
190
+ except Exception as e:
191
+ # Log the error and return an error response
192
+ error_message = f"Image generation failed: {str(e)}"
193
+ print(error_message)
194
+ return {"error": error_message}