msgxai commited on
Commit
bfeeedb
·
1 Parent(s): 8b41055

Prepare code for Hugging Face Inference Endpoint deployment

Browse files
Files changed (12) hide show
  1. .gitattributes +0 -35
  2. .python-version +0 -1
  3. Dockerfile +0 -16
  4. README.md +90 -48
  5. app.conf +4 -4
  6. app.py +0 -32
  7. src/handler.py → handler.py +92 -22
  8. requirements.txt +11 -9
  9. src/Procfile +0 -1
  10. src/app.conf +0 -11
  11. src/app.py +0 -145
  12. src/requirements.txt +0 -7
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.python-version DELETED
@@ -1 +0,0 @@
1
- msgxai-hg-api
 
 
Dockerfile DELETED
@@ -1,16 +0,0 @@
1
- # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
- # you will also find guides on how best to write your Dockerfile
3
-
4
- FROM python:3.9
5
-
6
- RUN useradd -m -u 1000 user
7
- USER user
8
- ENV PATH="/home/user/.local/bin:$PATH"
9
-
10
- WORKDIR /app
11
-
12
- COPY --chown=user ./requirements.txt requirements.txt
13
- RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
-
15
- COPY --chown=user . /app
16
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,83 +1,125 @@
1
- ---
2
- title: Msgxai Hugging Face Inference API
3
- emoji: 🖼️
4
- colorFrom: yellow
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: Stable Diffusion XL image generation API
10
- ---
11
 
12
- # Msgxai Hugging Face Inference API
13
 
14
- A custom Hugging Face Inference Endpoint for Stable Diffusion XL image generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  ## Configuration
17
 
18
- The API is configured through the `app.conf` JSON file with the following parameters:
19
 
20
  ```json
21
  {
22
- "model_id": "model-repo-id", // The Hugging Face model repository ID
23
- "name": "your-model-name", // A name for your model (optional)
24
- "prompt": "{prompt}", // Prompt template with {prompt} placeholder
25
- "negative_prompt": "...", // Default negative prompt
26
- "width": 1024, // Default image width
27
- "height": 768, // Default image height
28
- "inference_steps": 30, // Default number of inference steps
29
- "guidance_scale": 7, // Default guidance scale
30
- "use_safetensors": true, // Whether to use safetensors
31
- "clip_skip": 0 // Optional CLIP skip value (0 = disabled)
32
  }
33
  ```
34
 
35
  ## API Usage
36
 
37
- ### Health Check
38
- ```
39
- GET /
40
- ```
41
- Returns: `{"status": "healthy"}`
42
 
43
- ### Generate Image
44
- ```
45
- POST /predict
46
- ```
47
 
48
- Request Body:
49
  ```json
50
  {
51
- "prompt": "your image prompt here",
 
52
  "negative_prompt": "optional negative prompt",
53
- "width": 1024,
54
- "height": 768,
55
  "inference_steps": 30,
56
  "guidance_scale": 7,
57
- "seed": 42
 
 
58
  }
59
  ```
60
 
61
- Response:
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  ```json
63
  {
64
- "image_base64": "base64-encoded-image-data",
65
- "seed": 42
 
 
 
 
 
66
  }
67
  ```
68
 
69
- Note: All parameters except `prompt` are optional and will use defaults from `app.conf` if not provided.
 
 
 
 
 
 
 
 
 
 
70
 
71
- ## Deployment
 
 
 
 
72
 
73
- 1. Configure your `app.conf` file with desired model and parameters
74
- 2. Ensure all dependencies are in `requirements.txt`
75
- 3. Deploy to Hugging Face Inference Endpoints
76
 
77
- ## Content Filtering
 
 
 
78
 
79
- The API includes built-in filtering for child-related content in prompts.
 
 
 
 
 
 
 
 
80
 
81
  ## Environment Variables
82
 
83
- - `USE_TORCH_COMPILE`: Set to "1" to enable torch compilation (default: "0")
 
 
 
 
 
 
1
+ # Hugging Face Inference API for Stable Diffusion XL
 
 
 
 
 
 
 
 
 
2
 
3
+ This repository contains a text-to-image generation API designed to be deployed on Hugging Face Inference Endpoints, using Stable Diffusion XL models for image generation.
4
 
5
+ ## Features
6
+
7
+ - Compatible with Hugging Face Inference Endpoints
8
+ - Stable Diffusion XL (SDXL) model for high-quality image generation
9
+ - Content filtering for safe image generation
10
+ - Configurable image dimensions (default: 1024x768)
11
+ - Base64-encoded image output
12
+ - Performance optimizations (torch.compile, attention processors)
13
+
14
+ ## Project Structure
15
+
16
+ The codebase has been simplified to only use a single file:
17
+
18
+ - `handler.py`: Contains the `EndpointHandler` class that implements the Hugging Face Inference Endpoints interface. This file also includes a built-in FastAPI server for local development.
19
 
20
  ## Configuration
21
 
22
+ The service is configured via the `app.conf` JSON file with the following parameters:
23
 
24
  ```json
25
  {
26
+ "model_id": "your-huggingface-model-id",
27
+ "prompt": "template with {prompt} placeholder",
28
+ "negative_prompt": "default negative prompt",
29
+ "inference_steps": 30,
30
+ "guidance_scale": 7,
31
+ "use_safetensors": true,
32
+ "width": 1024,
33
+ "height": 768
 
 
34
  }
35
  ```
36
 
37
  ## API Usage
38
 
39
+ ### Hugging Face Inference Endpoints Format
 
 
 
 
40
 
41
+ When deployed to Hugging Face Inference Endpoints, the API accepts requests in the following format:
 
 
 
42
 
 
43
  ```json
44
  {
45
+ "inputs": "your prompt here",
46
+ "parameters": {
47
  "negative_prompt": "optional negative prompt",
48
+ "seed": 12345,
 
49
  "inference_steps": 30,
50
  "guidance_scale": 7,
51
+ "width": 1024,
52
+ "height": 768
53
+ }
54
  }
55
  ```
56
 
57
+ Response format:
58
+ ```json
59
+ [
60
+ {
61
+ "generated_image": "base64-encoded-image",
62
+ "seed": 12345
63
+ }
64
+ ]
65
+ ```
66
+
67
+ ### Local Development Format
68
+
69
+ When running locally, you can use the same format as above, or a simplified format:
70
+
71
  ```json
72
  {
73
+ "prompt": "your prompt here",
74
+ "negative_prompt": "optional negative prompt",
75
+ "seed": 12345,
76
+ "inference_steps": 30,
77
+ "guidance_scale": 7,
78
+ "width": 1024,
79
+ "height": 768
80
  }
81
  ```
82
 
83
+ Response format from the local server:
84
+ ```json
85
+ [
86
+ {
87
+ "generated_image": "base64-encoded-image",
88
+ "seed": 12345
89
+ }
90
+ ]
91
+ ```
92
+
93
+ ## Deployment on Hugging Face Inference Endpoints
94
 
95
+ 1. Push this repository to Hugging Face Hub or your Git repository
96
+ 2. Create a new Inference Endpoint on Hugging Face
97
+ 3. Select this repository as the source
98
+ 4. Configure compute resources (recommended: GPU with at least 16GB VRAM)
99
+ 5. Deploy the endpoint
100
 
101
+ ### Required Files
 
 
102
 
103
+ For deployment on Hugging Face Inference Endpoints, you need:
104
+ - `handler.py` - Contains the `EndpointHandler` class implementation
105
+ - `requirements.txt` - Lists the Python dependencies
106
+ - `app.conf` - Contains configuration parameters
107
 
108
+ Note: A `Procfile` is not needed for Hugging Face Inference Endpoints deployment, as the service automatically detects and uses the `EndpointHandler` class.
109
+
110
+ ## Local Development
111
+
112
+ 1. Install dependencies: `pip install -r requirements.txt`
113
+ 2. Run the API locally: `python handler.py [--port PORT] [--host HOST]`
114
+ 3. The API will be available at http://localhost:8000
115
+
116
+ The local server uses the FastAPI implementation included in `handler.py` that provides the same functionality as the Hugging Face Inference Endpoints interface.
117
 
118
  ## Environment Variables
119
 
120
+ - `PORT`: Port to run the server on (default: 8000)
121
+ - `USE_TORCH_COMPILE`: Set to "1" to enable torch.compile for performance (default: "0")
122
+
123
+ ## License
124
+
125
+ This project is licensed under the terms of the MIT license.
app.conf CHANGED
@@ -3,9 +3,9 @@
3
  "name": "hentai-waianiv6-card",
4
  "prompt": "score_9, score_8_up, score_7_up,rating_explicit,BREAK, {prompt}",
5
  "negative_prompt": "source_furry, source_pony, source_cartoon,3d, blurry, incest, beastiality, children, loli, child, kids, teens, text, logo, timestamp, artist name, artist logo, watermark, web address, copyright name, copyright notice, emblem, comic, title, logo, character name, border, patreon username, signature, webpage, company name, caption, labels, comments",
6
- "width": 1024,
7
- "height": 768,
8
  "inference_steps": 30,
9
  "guidance_scale": 7,
10
- "use_safetensors": true
11
- }
 
 
 
3
  "name": "hentai-waianiv6-card",
4
  "prompt": "score_9, score_8_up, score_7_up,rating_explicit,BREAK, {prompt}",
5
  "negative_prompt": "source_furry, source_pony, source_cartoon,3d, blurry, incest, beastiality, children, loli, child, kids, teens, text, logo, timestamp, artist name, artist logo, watermark, web address, copyright name, copyright notice, emblem, comic, title, logo, character name, border, patreon username, signature, webpage, company name, caption, labels, comments",
 
 
6
  "inference_steps": 30,
7
  "guidance_scale": 7,
8
+ "use_safetensors": true,
9
+ "width": 1024,
10
+ "height": 768
11
+ }
app.py DELETED
@@ -1,32 +0,0 @@
1
- from fastapi import FastAPI, Request
2
- from src.handler import EndpointHandler
3
- import json
4
-
5
- app = FastAPI()
6
-
7
- # Initialize the handler
8
- handler = EndpointHandler()
9
-
10
- @app.get("/")
11
- def greet_json():
12
- """Simple health check endpoint."""
13
- return {"status": "healthy"}
14
-
15
- @app.post("/predict")
16
- async def predict(request: Request):
17
- """
18
- Main prediction endpoint that processes image generation requests.
19
-
20
- Args:
21
- request (Request): The FastAPI request object
22
-
23
- Returns:
24
- dict: The generated image as base64 and other metadata
25
- """
26
- # Parse the request data
27
- data = await request.json()
28
-
29
- # Process the request using our handler
30
- result = handler(data)
31
-
32
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/handler.py → handler.py RENAMED
@@ -53,6 +53,10 @@ class EndpointHandler:
53
  """
54
  Custom handler for Hugging Face Inference Endpoints.
55
  This class follows the HF Inference Endpoints specification.
 
 
 
 
56
  """
57
 
58
  def __init__(self, path="", config=None):
@@ -60,13 +64,20 @@ class EndpointHandler:
60
  Initialize the handler with model path and configurations.
61
 
62
  Args:
63
- path (str): Path to the model. Not used for this implementation.
64
- config (dict, optional): Configuration for the handler. Not used for this implementation.
65
  """
66
- # Load configuration from app.conf
67
  try:
68
- with open("app.conf", "r") as f:
69
- self.cfg = json.load(f)
 
 
 
 
 
 
 
70
  except Exception as e:
71
  print(f"Error loading configuration: {e}")
72
  self.cfg = {}
@@ -122,16 +133,18 @@ class EndpointHandler:
122
  def __call__(self, data):
123
  """
124
  Process the inference request.
125
- This is called for each inference request.
126
 
127
  Args:
128
  data: The input data for the inference request
129
-
 
130
  Returns:
131
- dict: The result of the inference
 
132
  """
133
  # Validate that the model is loaded
134
- if not self.pipe:
135
  return {"error": "Model not loaded. Please check initialization logs."}
136
 
137
  # Parse the request payload
@@ -144,10 +157,20 @@ class EndpointHandler:
144
  except Exception as e:
145
  return {"error": f"Failed to parse request data: {str(e)}"}
146
 
 
 
 
 
 
 
147
  # Get the prompt from the payload
148
- prompt_text = payload.get("prompt", "")
149
  if not prompt_text:
150
- return {"error": "No prompt provided"}
 
 
 
 
151
 
152
  # Apply child-content filtering to the prompt
153
  if contains_child_related_content(prompt_text):
@@ -155,17 +178,19 @@ class EndpointHandler:
155
 
156
  # Replace placeholder in the prompt template from config
157
  combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text)
158
- # Use negative_prompt if provided; otherwise, default to config
159
- negative_prompt = payload.get("negative_prompt", self.cfg.get("negative_prompt", ""))
 
 
 
 
160
 
161
- # Get parameters from config or override with request params
162
- width = int(payload.get("width", self.cfg.get("width", 1024)))
163
- height = int(payload.get("height", self.cfg.get("height", 768)))
164
- inference_steps = int(payload.get("inference_steps", self.cfg.get("inference_steps", 30)))
165
- guidance_scale = float(payload.get("guidance_scale", self.cfg.get("guidance_scale", 7)))
166
 
167
  # Use provided seed or generate a random one
168
- seed = int(payload.get("seed", random.randint(0, MAX_SEED)))
169
  generator = torch.Generator(self.pipe.device).manual_seed(seed)
170
 
171
  try:
@@ -184,11 +209,56 @@ class EndpointHandler:
184
  # Convert the first generated image to base64
185
  img_base64 = pil_image_to_base64(outputs.images[0])
186
 
187
- # Return the response
188
- return {"image_base64": img_base64, "seed": seed}
189
 
190
  except Exception as e:
191
  # Log the error and return an error response
192
  error_message = f"Image generation failed: {str(e)}"
193
  print(error_message)
194
- return {"error": error_message}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  """
54
  Custom handler for Hugging Face Inference Endpoints.
55
  This class follows the HF Inference Endpoints specification.
56
+
57
+ For Hugging Face Inference Endpoints, only this class is needed.
58
+ It provides both the initialization (__init__) and inference (__call__) methods
59
+ required by the Hugging Face Inference API.
60
  """
61
 
62
  def __init__(self, path="", config=None):
 
64
  Initialize the handler with model path and configurations.
65
 
66
  Args:
67
+ path (str): Path to the model directory (used by HF Inference Endpoints).
68
+ config (dict, optional): Configuration for the handler, passed by HF Inference Endpoints.
69
  """
70
+ # Load configuration from app.conf or use provided config
71
  try:
72
+ if config:
73
+ # Use config provided by HF Inference Endpoints
74
+ self.cfg = config
75
+ else:
76
+ # Try to load from app.conf as fallback
77
+ config_path = os.path.join(path, "app.conf") if path else "app.conf"
78
+ with open(config_path, "r") as f:
79
+ self.cfg = json.load(f)
80
+ print("Configuration loaded successfully")
81
  except Exception as e:
82
  print(f"Error loading configuration: {e}")
83
  self.cfg = {}
 
133
  def __call__(self, data):
134
  """
135
  Process the inference request.
136
+ This is called for each inference request by the Hugging Face Inference API.
137
 
138
  Args:
139
  data: The input data for the inference request
140
+ For HF Inference Endpoints, this is typically a dict with "inputs" field
141
+
142
  Returns:
143
+ list: A list containing the generated image as base64 string and seed
144
+ This follows the HF Inference Endpoints output format
145
  """
146
  # Validate that the model is loaded
147
+ if not hasattr(self, 'pipe') or self.pipe is None:
148
  return {"error": "Model not loaded. Please check initialization logs."}
149
 
150
  # Parse the request payload
 
157
  except Exception as e:
158
  return {"error": f"Failed to parse request data: {str(e)}"}
159
 
160
+ # Extract parameters from the payload
161
+ parameters = {}
162
+ if "parameters" in payload and isinstance(payload["parameters"], dict):
163
+ # HF Inference Endpoints format: {"inputs": "prompt", "parameters": {...}}
164
+ parameters = payload["parameters"]
165
+
166
  # Get the prompt from the payload
167
+ prompt_text = payload.get("inputs", "")
168
  if not prompt_text:
169
+ # Try to get prompt from different fields for compatibility
170
+ prompt_text = payload.get("prompt", "")
171
+
172
+ if not prompt_text:
173
+ return {"error": "No prompt provided. Please include 'inputs' or 'prompt' field."}
174
 
175
  # Apply child-content filtering to the prompt
176
  if contains_child_related_content(prompt_text):
 
178
 
179
  # Replace placeholder in the prompt template from config
180
  combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text)
181
+ # Use negative_prompt from parameters or payload, fall back to config
182
+ negative_prompt = parameters.get("negative_prompt", payload.get("negative_prompt", self.cfg.get("negative_prompt", "")))
183
+
184
+ # Get dimensions from config (default to 1024x768 if not specified)
185
+ width = int(self.cfg.get("width", 1024))
186
+ height = int(self.cfg.get("height", 768))
187
 
188
+ # Other generation parameters
189
+ inference_steps = int(parameters.get("inference_steps", payload.get("inference_steps", self.cfg.get("inference_steps", 30))))
190
+ guidance_scale = float(parameters.get("guidance_scale", payload.get("guidance_scale", self.cfg.get("guidance_scale", 7))))
 
 
191
 
192
  # Use provided seed or generate a random one
193
+ seed = int(parameters.get("seed", payload.get("seed", random.randint(0, MAX_SEED))))
194
  generator = torch.Generator(self.pipe.device).manual_seed(seed)
195
 
196
  try:
 
209
  # Convert the first generated image to base64
210
  img_base64 = pil_image_to_base64(outputs.images[0])
211
 
212
+ # Return the response formatted for Hugging Face Inference Endpoints
213
+ return [{"generated_image": img_base64, "seed": seed}]
214
 
215
  except Exception as e:
216
  # Log the error and return an error response
217
  error_message = f"Image generation failed: {str(e)}"
218
  print(error_message)
219
+ return {"error": error_message}
220
+
221
+ # For local testing without HF Inference Endpoints
222
+ if __name__ == "__main__":
223
+ import argparse
224
+ import uvicorn
225
+ from fastapi import FastAPI, Request
226
+ from fastapi.responses import JSONResponse
227
+
228
+ # Parse command-line arguments
229
+ parser = argparse.ArgumentParser(description="Run the text-to-image API locally")
230
+ parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
231
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
232
+ args = parser.parse_args()
233
+
234
+ # Create FastAPI app
235
+ app = FastAPI(title="Text-to-Image API with Content Filtering")
236
+
237
+ # Initialize the handler
238
+ handler = EndpointHandler()
239
+
240
+ @app.get("/")
241
+ async def read_root():
242
+ """Health check endpoint."""
243
+ return {"status": "ok", "message": "Text-to-Image API is running"}
244
+
245
+ @app.post("/")
246
+ async def generate_image(request: Request):
247
+ """Main inference endpoint."""
248
+ try:
249
+ body = await request.json()
250
+ result = handler(body)
251
+
252
+ if "error" in result:
253
+ return JSONResponse(status_code=500, content={"error": result["error"]})
254
+
255
+ return result
256
+ except Exception as e:
257
+ return JSONResponse(
258
+ status_code=500,
259
+ content={"error": f"Failed to process request: {str(e)}"}
260
+ )
261
+
262
+ # Run the server
263
+ print(f"Starting server on http://{args.host}:{args.port}")
264
+ uvicorn.run(app, host=args.host, port=args.port)
requirements.txt CHANGED
@@ -1,9 +1,11 @@
1
- fastapi
2
- uvicorn[standard]
3
- diffusers
4
- transformers
5
- torch
6
- accelerate
7
- huggingface_hub
8
- pillow
9
- safetensors
 
 
 
1
+ fastapi>=0.95.0
2
+ uvicorn>=0.22.0
3
+ torch>=2.0.0
4
+ diffusers>=0.19.0
5
+ transformers>=4.30.0
6
+ accelerate>=0.20.0
7
+ huggingface_hub>=0.16.0
8
+ pydantic>=1.10.0
9
+ Pillow>=9.0.0
10
+ scipy>=1.10.0
11
+ safetensors>=0.3.0
src/Procfile DELETED
@@ -1 +0,0 @@
1
- web: uvicorn app:app --host 0.0.0.0 --port $PORT
 
 
src/app.conf DELETED
@@ -1,11 +0,0 @@
1
- {
2
- "model_id": "John6666/wai-ani-hentai-pony-v3-sdxl",
3
- "name": "hentai-waianiv6-card",
4
- "prompt": "score_9, score_8_up, score_7_up,rating_explicit,BREAK, {prompt}",
5
- "negative_prompt": "source_furry, source_pony, source_cartoon,3d, blurry, incest, beastiality, children, loli, child, kids, teens, text, logo, timestamp, artist name, artist logo, watermark, web address, copyright name, copyright notice, emblem, comic, title, logo, character name, border, patreon username, signature, webpage, company name, caption, labels, comments",
6
- "width": 1024,
7
- "height": 768,
8
- "inference_steps": 30,
9
- "guidance_scale": 7,
10
- "use_safetensors": true
11
- }
 
 
 
 
 
 
 
 
 
 
 
 
src/app.py DELETED
@@ -1,145 +0,0 @@
1
- import os
2
- import json
3
- import random
4
- import re
5
- import base64
6
- from io import BytesIO
7
-
8
- from fastapi import FastAPI, HTTPException
9
- from pydantic import BaseModel
10
- from PIL import Image
11
-
12
- import torch
13
- from huggingface_hub import snapshot_download
14
- from diffusers import (
15
- AutoencoderKL,
16
- StableDiffusionXLPipeline,
17
- EulerAncestralDiscreteScheduler,
18
- DPMSolverSDEScheduler
19
- )
20
- from diffusers.models.attention_processor import AttnProcessor2_0
21
-
22
- # Global constants
23
- MAX_SEED = 12211231
24
- NUM_IMAGES_PER_PROMPT = 1
25
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
26
-
27
- # Load configuration from app.conf
28
- with open("app.conf", "r") as f:
29
- cfg = json.load(f)
30
-
31
- # --- Child-Content Filtering Functions ---
32
- child_related_regex = re.compile(
33
- r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|'
34
- r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|'
35
- r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))',
36
- re.IGNORECASE
37
- )
38
-
39
- def remove_child_related_content(prompt: str) -> str:
40
- """Remove any child-related references from the prompt."""
41
- cleaned_prompt = re.sub(child_related_regex, '', prompt)
42
- return cleaned_prompt.strip()
43
-
44
- def contains_child_related_content(prompt: str) -> bool:
45
- """Check if the prompt contains child-related content."""
46
- return bool(child_related_regex.search(prompt))
47
-
48
- # --- Model Pipeline Loading ---
49
- def load_pipeline_and_scheduler():
50
- clip_skip = cfg.get("clip_skip", 0)
51
-
52
- # Download model files from Hugging Face Hub
53
- ckpt_dir = snapshot_download(repo_id=cfg["model_id"])
54
-
55
- # Load the VAE model (for decoding latents)
56
- vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
57
-
58
- # Load the Stable Diffusion XL pipeline
59
- pipe = StableDiffusionXLPipeline.from_pretrained(
60
- ckpt_dir,
61
- vae=vae,
62
- torch_dtype=torch.float16,
63
- use_safetensors=cfg.get("use_safetensors", True),
64
- variant="fp16"
65
- )
66
- pipe = pipe.to("cuda")
67
- pipe.unet.set_attn_processor(AttnProcessor2_0())
68
-
69
- # Set up samplers/schedulers based on configuration
70
- samplers = {
71
- "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
72
- "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
73
- }
74
- # Default to "DPM++ SDE Karras" if not specified
75
- pipe.scheduler = samplers.get(cfg.get("sampler", "DPM++ SDE Karras"))
76
-
77
- # Adjust the text encoder layers if needed using clip_skip
78
- pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
79
-
80
- if USE_TORCH_COMPILE:
81
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
82
- print("Model Compiled!")
83
- return pipe
84
-
85
- # Load the model pipeline once at startup
86
- pipe = load_pipeline_and_scheduler()
87
-
88
- # --- Utility Function: Convert PIL Image to Base64 ---
89
- def pil_image_to_base64(img: Image.Image) -> str:
90
- buffered = BytesIO()
91
- img.convert("RGB").save(buffered, format="WEBP", quality=90)
92
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
93
-
94
- # --- FastAPI Application Setup ---
95
- app = FastAPI(title="Text-to-Image API with Content Filtering")
96
-
97
- class GenerateRequest(BaseModel):
98
- prompt: str
99
-
100
- @app.get("/")
101
- async def read_root():
102
- return {"message": "Text-to-Image API with content filtering is running."}
103
-
104
- @app.post("/generate")
105
- async def generate(req: GenerateRequest):
106
- # Apply child-content filtering to the prompt
107
- prompt_text = req.prompt
108
- if contains_child_related_content(prompt_text):
109
- prompt_text = remove_child_related_content(prompt_text)
110
-
111
- # Replace placeholder in the prompt template from config
112
- combined_prompt = cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text)
113
- # Use negative_prompt if provided; otherwise, default to empty string
114
- negative_prompt = cfg.get("negative_prompt", "")
115
- width = cfg.get("width", 1024)
116
- height = cfg.get("height", 768)
117
- inference_steps = cfg.get("inference_steps", 30)
118
- guidance_scale = cfg.get("guidance_scale", 7)
119
-
120
- # Randomize the seed for generation
121
- seed = random.randint(0, MAX_SEED)
122
- generator = torch.Generator(pipe.device).manual_seed(seed)
123
-
124
- try:
125
- outputs = pipe(
126
- prompt=combined_prompt,
127
- negative_prompt=negative_prompt,
128
- width=width,
129
- height=height,
130
- guidance_scale=guidance_scale,
131
- num_inference_steps=inference_steps,
132
- generator=generator,
133
- num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
134
- output_type="pil"
135
- )
136
- # Convert the first generated image to base64
137
- img_base64 = pil_image_to_base64(outputs.images[0])
138
- except Exception as e:
139
- raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
140
-
141
- return {"image_base64": img_base64, "seed": seed}
142
-
143
- if __name__ == "__main__":
144
- import uvicorn
145
- uvicorn.run("app:app", host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/requirements.txt DELETED
@@ -1,7 +0,0 @@
1
- fastapi
2
- uvicorn
3
- torch
4
- diffusers
5
- huggingface_hub
6
- pydantic
7
- Pillow