Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import requests
|
| 3 |
import time
|
|
|
|
| 4 |
import threading
|
| 5 |
import uuid
|
| 6 |
import base64
|
|
@@ -18,78 +19,142 @@ API_KEY = os.getenv("WAVESPEED_API_KEY")
|
|
| 18 |
if not API_KEY:
|
| 19 |
raise ValueError("WAVESPEED_API_KEY is not set in environment variables")
|
| 20 |
|
|
|
|
| 21 |
MODEL_URL = "TostAI/nsfw-text-detection-large"
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
except Exception as e:
|
| 28 |
-
raise RuntimeError(f"Failed to load safety model: {str(e)}")
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
_instances = {}
|
| 33 |
_lock = threading.Lock()
|
| 34 |
|
| 35 |
@classmethod
|
| 36 |
-
def
|
|
|
|
|
|
|
|
|
|
| 37 |
with cls._lock:
|
| 38 |
-
if
|
| 39 |
-
cls._instances[
|
| 40 |
-
|
| 41 |
-
'history': [],
|
| 42 |
-
'last_active': time.time()
|
| 43 |
-
}
|
| 44 |
-
return cls._instances[session_id]
|
| 45 |
|
| 46 |
@classmethod
|
| 47 |
-
def
|
|
|
|
| 48 |
with cls._lock:
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
-
class
|
| 59 |
|
| 60 |
def __init__(self):
|
| 61 |
-
self.clients = {}
|
| 62 |
self.lock = threading.Lock()
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
def
|
| 65 |
with self.lock:
|
| 66 |
-
|
| 67 |
-
if client_id not in self.clients:
|
| 68 |
-
self.clients[client_id] = {'count': 1, 'reset': now + 3600}
|
| 69 |
-
return True
|
| 70 |
-
if now > self.clients[client_id]['reset']:
|
| 71 |
-
self.clients[client_id] = {'count': 1, 'reset': now + 3600}
|
| 72 |
-
return True
|
| 73 |
-
if self.clients[client_id]['count'] >= 20:
|
| 74 |
-
return False
|
| 75 |
-
self.clients[client_id]['count'] += 1
|
| 76 |
-
return True
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
@torch.no_grad()
|
|
@@ -112,28 +177,48 @@ def decode_base64_to_image(base64_str):
|
|
| 112 |
return Image.open(io.BytesIO(image_data))
|
| 113 |
|
| 114 |
|
| 115 |
-
def generate_image(
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
| 120 |
try:
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
return
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
error_messages = []
|
| 139 |
if not image_file:
|
|
@@ -143,16 +228,27 @@ def generate_image(image_file,
|
|
| 143 |
if not prompt.strip():
|
| 144 |
error_messages.append("Prompt cannot be empty")
|
| 145 |
if error_messages:
|
| 146 |
-
|
| 147 |
-
|
|
|
|
| 148 |
return
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
try:
|
| 151 |
base64_image = image_to_base64(image_file)
|
| 152 |
input_image = decode_base64_to_image(base64_image)
|
| 153 |
except Exception as e:
|
| 154 |
-
|
| 155 |
-
yield
|
| 156 |
return
|
| 157 |
|
| 158 |
headers = {
|
|
@@ -178,7 +274,7 @@ def generate_image(image_file,
|
|
| 178 |
start_time = time.time()
|
| 179 |
|
| 180 |
for _ in range(60):
|
| 181 |
-
time.sleep(1)
|
| 182 |
resp = requests.get(result_url, headers=headers)
|
| 183 |
resp.raise_for_status()
|
| 184 |
|
|
@@ -188,25 +284,28 @@ def generate_image(image_file,
|
|
| 188 |
if status == "completed":
|
| 189 |
elapsed = time.time() - start_time
|
| 190 |
output_url = data["outputs"][0]
|
| 191 |
-
|
| 192 |
-
|
| 193 |
return
|
| 194 |
elif status == "failed":
|
| 195 |
raise Exception(data.get("error", "Unknown error"))
|
| 196 |
else:
|
| 197 |
-
|
|
|
|
| 198 |
|
| 199 |
raise Exception("Generation timed out")
|
| 200 |
|
| 201 |
except Exception as e:
|
| 202 |
-
|
| 203 |
-
yield
|
| 204 |
|
| 205 |
|
|
|
|
| 206 |
def cleanup_task():
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 210 |
|
| 211 |
|
| 212 |
# Store recent generations
|
|
@@ -238,14 +337,16 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
| 238 |
|
| 239 |
with gr.Row():
|
| 240 |
with gr.Column(scale=1):
|
| 241 |
-
prompt = gr.Textbox(label="Prompt",
|
| 242 |
-
placeholder="Please enter your prompt...",
|
| 243 |
-
lines=3)
|
| 244 |
image_file = gr.Image(label="Upload Image",
|
| 245 |
type="filepath",
|
| 246 |
-
sources=["upload"],
|
| 247 |
interactive=True,
|
| 248 |
-
image_mode="RGB"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
seed = gr.Number(label="seed",
|
| 250 |
value=-1,
|
| 251 |
minimum=-1,
|
|
@@ -256,8 +357,8 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
| 256 |
value=True,
|
| 257 |
interactive=False)
|
| 258 |
with gr.Column(scale=1):
|
| 259 |
-
status = gr.Textbox(label="Status", elem_classes=["status-box"])
|
| 260 |
output_image = gr.Image(label="Generated Result")
|
|
|
|
| 261 |
output_url = gr.Textbox(label="Image URL",
|
| 262 |
interactive=True,
|
| 263 |
visible=False)
|
|
@@ -266,15 +367,15 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
| 266 |
examples=[
|
| 267 |
[
|
| 268 |
"Convert the image into Claymation style.",
|
| 269 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/
|
| 270 |
],
|
| 271 |
[
|
| 272 |
"Convert the image into Ghibli style.",
|
| 273 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/
|
| 274 |
],
|
| 275 |
[
|
| 276 |
-
"Add sunglasses to the face of the
|
| 277 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/
|
| 278 |
],
|
| 279 |
# [
|
| 280 |
# 'Convert the image into an ink sketch style.',
|
|
@@ -319,11 +420,16 @@ with gr.Blocks(theme=gr.themes.Soft(),
|
|
| 319 |
inputs=[image_file, prompt, seed, session_id, enable_safety],
|
| 320 |
outputs=[status, output_image, output_url, recent_gallery],
|
| 321 |
api_name=False,
|
|
|
|
|
|
|
|
|
|
| 322 |
)
|
| 323 |
|
| 324 |
if __name__ == "__main__":
|
| 325 |
-
|
| 326 |
-
|
|
|
|
| 327 |
server_name="0.0.0.0",
|
|
|
|
| 328 |
share=False,
|
| 329 |
)
|
|
|
|
| 1 |
import os
|
| 2 |
import requests
|
| 3 |
import time
|
| 4 |
+
import functools
|
| 5 |
import threading
|
| 6 |
import uuid
|
| 7 |
import base64
|
|
|
|
| 19 |
if not API_KEY:
|
| 20 |
raise ValueError("WAVESPEED_API_KEY is not set in environment variables")
|
| 21 |
|
| 22 |
+
|
| 23 |
MODEL_URL = "TostAI/nsfw-text-detection-large"
|
| 24 |
+
TITLE = "πΌοΈπ Image Prompt Safety Classifier π‘οΈ"
|
| 25 |
+
DESCRIPTION = "β¨ Enter an image generation prompt to classify its safety level! β¨"
|
| 26 |
|
| 27 |
+
# Load model and tokenizer
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
|
| 29 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
# Define class names with emojis and detailed descriptions
|
| 32 |
+
CLASS_NAMES = {
|
| 33 |
+
0: "β
SAFE - This prompt is appropriate and harmless.",
|
| 34 |
+
1: "β οΈ QUESTIONABLE - This prompt may require further review.",
|
| 35 |
+
2: "π« UNSAFE - This prompt is likely to generate inappropriate content."
|
| 36 |
+
}
|
| 37 |
|
| 38 |
+
|
| 39 |
+
@functools.lru_cache(maxsize=128)
|
| 40 |
+
def classify_text(text):
|
| 41 |
+
inputs = tokenizer(text,
|
| 42 |
+
return_tensors="pt",
|
| 43 |
+
truncation=True,
|
| 44 |
+
padding=True,
|
| 45 |
+
max_length=1024)
|
| 46 |
+
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
outputs = model(**inputs)
|
| 49 |
+
|
| 50 |
+
logits = outputs.logits
|
| 51 |
+
predicted_class = torch.argmax(logits, dim=1).item()
|
| 52 |
+
|
| 53 |
+
return predicted_class, CLASS_NAMES[predicted_class]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ClientManager:
|
| 57 |
_instances = {}
|
| 58 |
_lock = threading.Lock()
|
| 59 |
|
| 60 |
@classmethod
|
| 61 |
+
def get_manager(cls, client_id=None):
|
| 62 |
+
if not client_id:
|
| 63 |
+
client_id = str(uuid.uuid4())
|
| 64 |
+
|
| 65 |
with cls._lock:
|
| 66 |
+
if client_id not in cls._instances:
|
| 67 |
+
cls._instances[client_id] = ClientGenerationManager()
|
| 68 |
+
return cls._instances[client_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
@classmethod
|
| 71 |
+
def cleanup_old_clients(cls, max_age=3600): # 1 hour default
|
| 72 |
+
current_time = time.time()
|
| 73 |
with cls._lock:
|
| 74 |
+
to_remove = []
|
| 75 |
+
for client_id, manager in cls._instances.items():
|
| 76 |
+
if (hasattr(manager, "last_activity")
|
| 77 |
+
and current_time - manager.last_activity > max_age):
|
| 78 |
+
to_remove.append(client_id)
|
| 79 |
+
|
| 80 |
+
for client_id in to_remove:
|
| 81 |
+
del cls._instances[client_id]
|
| 82 |
|
| 83 |
|
| 84 |
+
class ClientGenerationManager:
|
| 85 |
|
| 86 |
def __init__(self):
|
|
|
|
| 87 |
self.lock = threading.Lock()
|
| 88 |
+
self.last_activity = time.time()
|
| 89 |
+
self.request_timestamps = [] # Track timestamps of requests
|
| 90 |
|
| 91 |
+
def update_activity(self):
|
| 92 |
with self.lock:
|
| 93 |
+
self.last_activity = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
def add_request_timestamp(self):
|
| 96 |
+
with self.lock:
|
| 97 |
+
self.request_timestamps.append(time.time())
|
| 98 |
|
| 99 |
+
def has_exceeded_limit(self, limit=20):
|
| 100 |
+
with self.lock:
|
| 101 |
+
current_time = time.time()
|
| 102 |
+
# Filter timestamps to only include those within the last hour
|
| 103 |
+
self.request_timestamps = [
|
| 104 |
+
ts for ts in self.request_timestamps
|
| 105 |
+
if current_time - ts <= 3600
|
| 106 |
+
]
|
| 107 |
+
return len(self.request_timestamps) >= limit
|
| 108 |
|
| 109 |
|
| 110 |
+
class SessionManager:
|
| 111 |
+
_instances = {}
|
| 112 |
+
_lock = threading.Lock()
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def get_manager(cls, session_id=None):
|
| 116 |
+
if session_id is None:
|
| 117 |
+
session_id = str(uuid.uuid4())
|
| 118 |
+
|
| 119 |
+
with cls._lock:
|
| 120 |
+
if session_id not in cls._instances:
|
| 121 |
+
cls._instances[session_id] = GenerationManager()
|
| 122 |
+
return session_id, cls._instances[session_id]
|
| 123 |
+
|
| 124 |
+
@classmethod
|
| 125 |
+
def cleanup_old_sessions(cls, max_age=3600): # 1 hour default
|
| 126 |
+
current_time = time.time()
|
| 127 |
+
with cls._lock:
|
| 128 |
+
to_remove = []
|
| 129 |
+
for session_id, manager in cls._instances.items():
|
| 130 |
+
if (hasattr(manager, "last_activity")
|
| 131 |
+
and current_time - manager.last_activity > max_age):
|
| 132 |
+
to_remove.append(session_id)
|
| 133 |
+
|
| 134 |
+
for session_id in to_remove:
|
| 135 |
+
del cls._instances[session_id]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class GenerationManager:
|
| 139 |
+
|
| 140 |
+
def __init__(self):
|
| 141 |
+
self.last_activity = time.time()
|
| 142 |
+
self.request_timestamps = [] # Track timestamps of requests
|
| 143 |
+
|
| 144 |
+
def update_activity(self):
|
| 145 |
+
self.last_activity = time.time()
|
| 146 |
+
|
| 147 |
+
def add_request_timestamp(self):
|
| 148 |
+
self.request_timestamps.append(time.time())
|
| 149 |
+
|
| 150 |
+
def has_exceeded_limit(self,
|
| 151 |
+
limit=10): # Default limit: 10 requests per hour
|
| 152 |
+
current_time = time.time()
|
| 153 |
+
# Filter timestamps to only include those within the last hour
|
| 154 |
+
self.request_timestamps = [
|
| 155 |
+
ts for ts in self.request_timestamps if current_time - ts <= 3600
|
| 156 |
+
]
|
| 157 |
+
return len(self.request_timestamps) >= limit
|
| 158 |
|
| 159 |
|
| 160 |
@torch.no_grad()
|
|
|
|
| 177 |
return Image.open(io.BytesIO(image_data))
|
| 178 |
|
| 179 |
|
| 180 |
+
def generate_image(
|
| 181 |
+
image_file,
|
| 182 |
+
prompt,
|
| 183 |
+
seed,
|
| 184 |
+
session_id,
|
| 185 |
+
enable_safety_checker,
|
| 186 |
+
request: gr.Request,
|
| 187 |
+
):
|
| 188 |
try:
|
| 189 |
+
client_ip = request.client.host
|
| 190 |
+
x_forwarded_for = request.headers.get('x-forwarded-for')
|
| 191 |
+
if x_forwarded_for:
|
| 192 |
+
client_ip = x_forwarded_for
|
| 193 |
+
print(f"Client IP: {client_ip}")
|
| 194 |
+
client_generation_manager = ClientManager.get_manager(client_ip)
|
| 195 |
+
client_generation_manager.update_activity()
|
| 196 |
+
if client_generation_manager.has_exceeded_limit(limit=20):
|
| 197 |
+
error_message = "β Your network has exceeded the limit of 20 requests per hour. Please try again later."
|
| 198 |
+
yield error_message, None, "", None
|
| 199 |
+
return
|
| 200 |
|
| 201 |
+
client_generation_manager.add_request_timestamp()
|
| 202 |
+
"""Generate images with big status box during generation"""
|
| 203 |
+
# Get or create a session manager
|
| 204 |
+
session_id, manager = SessionManager.get_manager(session_id)
|
| 205 |
+
manager.update_activity()
|
| 206 |
+
|
| 207 |
+
# Check if the user has exceeded the request limit
|
| 208 |
+
if manager.has_exceeded_limit(
|
| 209 |
+
limit=10): # Set the limit to 10 requests per hour
|
| 210 |
+
error_message = "β You have exceeded the limit of 10 requests per hour. Please try again later."
|
| 211 |
+
yield error_message, None, "", None
|
| 212 |
return
|
| 213 |
|
| 214 |
+
# Add the current request timestamp
|
| 215 |
+
manager.add_request_timestamp()
|
| 216 |
+
|
| 217 |
+
if not prompt or prompt.strip() == "":
|
| 218 |
+
# Handle empty prompt case
|
| 219 |
+
error_message = "β οΈ Please enter a prompt first"
|
| 220 |
+
yield error_message, None, "", None
|
| 221 |
+
return
|
| 222 |
|
| 223 |
error_messages = []
|
| 224 |
if not image_file:
|
|
|
|
| 228 |
if not prompt.strip():
|
| 229 |
error_messages.append("Prompt cannot be empty")
|
| 230 |
if error_messages:
|
| 231 |
+
error_message = "β Input validation failed: " + ", ".join(
|
| 232 |
+
error_messages)
|
| 233 |
+
yield error_message, None, "", None
|
| 234 |
return
|
| 235 |
|
| 236 |
+
# Check if the prompt is safe
|
| 237 |
+
classification, message = classify_text(prompt)
|
| 238 |
+
if classification == 2: # UNSAFE
|
| 239 |
+
yield "β NSFW prompt detected", None, "", None
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
# Status message
|
| 243 |
+
status_message = f"π PROCESSING: '{prompt}'"
|
| 244 |
+
yield status_message, None, "", None
|
| 245 |
+
|
| 246 |
try:
|
| 247 |
base64_image = image_to_base64(image_file)
|
| 248 |
input_image = decode_base64_to_image(base64_image)
|
| 249 |
except Exception as e:
|
| 250 |
+
error_message = f"β File processing failed: {str(e)}"
|
| 251 |
+
yield error_message, None, "", None
|
| 252 |
return
|
| 253 |
|
| 254 |
headers = {
|
|
|
|
| 274 |
start_time = time.time()
|
| 275 |
|
| 276 |
for _ in range(60):
|
| 277 |
+
time.sleep(1.0)
|
| 278 |
resp = requests.get(result_url, headers=headers)
|
| 279 |
resp.raise_for_status()
|
| 280 |
|
|
|
|
| 284 |
if status == "completed":
|
| 285 |
elapsed = time.time() - start_time
|
| 286 |
output_url = data["outputs"][0]
|
| 287 |
+
yield f"π Generation successful! Time taken {elapsed:.1f}s", output_url, output_url, update_recent_gallery(
|
| 288 |
+
prompt, input_image, output_url)
|
| 289 |
return
|
| 290 |
elif status == "failed":
|
| 291 |
raise Exception(data.get("error", "Unknown error"))
|
| 292 |
else:
|
| 293 |
+
error_message = f"β³ Current status: {status.capitalize()}..."
|
| 294 |
+
yield error_message, None, "", None
|
| 295 |
|
| 296 |
raise Exception("Generation timed out")
|
| 297 |
|
| 298 |
except Exception as e:
|
| 299 |
+
error_message = f"β Generation failed: {str(e)}"
|
| 300 |
+
yield error_message, None, "", None
|
| 301 |
|
| 302 |
|
| 303 |
+
# Schedule periodic cleanup of old sessions
|
| 304 |
def cleanup_task():
|
| 305 |
+
SessionManager.cleanup_old_sessions()
|
| 306 |
+
ClientManager.cleanup_old_clients()
|
| 307 |
+
# Schedule the next cleanup
|
| 308 |
+
threading.Timer(3600, cleanup_task).start() # Run every hour
|
| 309 |
|
| 310 |
|
| 311 |
# Store recent generations
|
|
|
|
| 337 |
|
| 338 |
with gr.Row():
|
| 339 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
| 340 |
image_file = gr.Image(label="Upload Image",
|
| 341 |
type="filepath",
|
| 342 |
+
sources=["upload", "clipboard"],
|
| 343 |
interactive=True,
|
| 344 |
+
image_mode="RGB",
|
| 345 |
+
value="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png")
|
| 346 |
+
prompt = gr.Textbox(label="Prompt",
|
| 347 |
+
placeholder="Please enter your prompt...",
|
| 348 |
+
lines=3,
|
| 349 |
+
value="Convert the image into Claymation style.")
|
| 350 |
seed = gr.Number(label="seed",
|
| 351 |
value=-1,
|
| 352 |
minimum=-1,
|
|
|
|
| 357 |
value=True,
|
| 358 |
interactive=False)
|
| 359 |
with gr.Column(scale=1):
|
|
|
|
| 360 |
output_image = gr.Image(label="Generated Result")
|
| 361 |
+
status = gr.Textbox(label="Status", elem_classes=["status-box"])
|
| 362 |
output_url = gr.Textbox(label="Image URL",
|
| 363 |
interactive=True,
|
| 364 |
visible=False)
|
|
|
|
| 367 |
examples=[
|
| 368 |
[
|
| 369 |
"Convert the image into Claymation style.",
|
| 370 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png"
|
| 371 |
],
|
| 372 |
[
|
| 373 |
"Convert the image into Ghibli style.",
|
| 374 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
|
| 375 |
],
|
| 376 |
[
|
| 377 |
+
"Add sunglasses to the face of the statue.",
|
| 378 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg"
|
| 379 |
],
|
| 380 |
# [
|
| 381 |
# 'Convert the image into an ink sketch style.',
|
|
|
|
| 420 |
inputs=[image_file, prompt, seed, session_id, enable_safety],
|
| 421 |
outputs=[status, output_image, output_url, recent_gallery],
|
| 422 |
api_name=False,
|
| 423 |
+
max_batch_size=10,
|
| 424 |
+
concurrency_limit=20,
|
| 425 |
+
concurrency_id="generation",
|
| 426 |
)
|
| 427 |
|
| 428 |
if __name__ == "__main__":
|
| 429 |
+
# Start the cleanup task
|
| 430 |
+
cleanup_task()
|
| 431 |
+
app.queue(max_size=20).launch(
|
| 432 |
server_name="0.0.0.0",
|
| 433 |
+
max_threads=10,
|
| 434 |
share=False,
|
| 435 |
)
|