Update app.py
Browse files
app.py
CHANGED
|
@@ -17,10 +17,8 @@ from PIL import Image, ImageDraw, ImageFont
|
|
| 17 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 18 |
from functools import lru_cache
|
| 19 |
|
| 20 |
-
# 初始化环境
|
| 21 |
load_dotenv()
|
| 22 |
|
| 23 |
-
# 安全检测配置
|
| 24 |
MODEL_URL = "TostAI/nsfw-text-detection-large"
|
| 25 |
CLASS_NAMES = {
|
| 26 |
0: "✅ SAFE",
|
|
@@ -28,11 +26,9 @@ CLASS_NAMES = {
|
|
| 28 |
2: "🚫 UNSAFE"
|
| 29 |
}
|
| 30 |
|
| 31 |
-
# 加载模型
|
| 32 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
|
| 33 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
|
| 34 |
|
| 35 |
-
# 会话管理
|
| 36 |
class SessionManager:
|
| 37 |
_instances = {}
|
| 38 |
_lock = threading.Lock()
|
|
@@ -56,7 +52,6 @@ class SessionManager:
|
|
| 56 |
for k in expired:
|
| 57 |
del cls._instances[k]
|
| 58 |
|
| 59 |
-
# 频率限制
|
| 60 |
class RateLimiter:
|
| 61 |
def __init__(self):
|
| 62 |
self.clients = {}
|
|
@@ -79,11 +74,9 @@ class RateLimiter:
|
|
| 79 |
self.clients[client_id]['count'] += 1
|
| 80 |
return True
|
| 81 |
|
| 82 |
-
# 初始化模块
|
| 83 |
session_manager = SessionManager()
|
| 84 |
rate_limiter = RateLimiter()
|
| 85 |
|
| 86 |
-
# 图像处理函数
|
| 87 |
def image_to_base64(file_path):
|
| 88 |
try:
|
| 89 |
with open(file_path, "rb") as f:
|
|
@@ -112,7 +105,7 @@ def create_error_image(message):
|
|
| 112 |
img.save("error.jpg")
|
| 113 |
return "error.jpg"
|
| 114 |
|
| 115 |
-
|
| 116 |
@lru_cache(maxsize=100)
|
| 117 |
def classify_prompt(prompt):
|
| 118 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
|
@@ -133,27 +126,23 @@ def generate_video(
|
|
| 133 |
size,
|
| 134 |
session_id
|
| 135 |
):
|
| 136 |
-
|
| 137 |
safety_level = classify_prompt(prompt)
|
| 138 |
if safety_level != 0:
|
| 139 |
error_img = create_error_image(CLASS_NAMES[safety_level])
|
| 140 |
yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img
|
| 141 |
return
|
| 142 |
|
| 143 |
-
# 频率检查
|
| 144 |
if not rate_limiter.check(session_id):
|
| 145 |
error_img = create_error_image("Hourly limit exceeded (20 requests)")
|
| 146 |
yield "❌ 请求过于频繁,请稍后再试", error_img
|
| 147 |
return
|
| 148 |
|
| 149 |
-
# 会话更新
|
| 150 |
session = session_manager.get_session(session_id)
|
| 151 |
session['last_active'] = time.time()
|
| 152 |
session['count'] += 1
|
| 153 |
|
| 154 |
-
# API调用
|
| 155 |
try:
|
| 156 |
-
# 准备请求
|
| 157 |
api_key = os.getenv("WAVESPEED_API_KEY")
|
| 158 |
if not api_key:
|
| 159 |
raise ValueError("API key missing")
|
|
@@ -192,7 +181,6 @@ def generate_video(
|
|
| 192 |
yield f"❌ 提交失败: {str(e)}", error_img
|
| 193 |
return
|
| 194 |
|
| 195 |
-
# 轮询结果
|
| 196 |
result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
|
| 197 |
start_time = time.time()
|
| 198 |
|
|
@@ -224,13 +212,11 @@ def generate_video(
|
|
| 224 |
yield f"❌ 生成失败: {str(e)}", error_img
|
| 225 |
return
|
| 226 |
|
| 227 |
-
# 后台清理线程
|
| 228 |
def cleanup_task():
|
| 229 |
while True:
|
| 230 |
session_manager.cleanup_sessions()
|
| 231 |
time.sleep(3600)
|
| 232 |
|
| 233 |
-
# Gradio界面
|
| 234 |
with gr.Blocks(
|
| 235 |
theme=gr.themes.Soft(),
|
| 236 |
css="""
|
|
@@ -267,7 +253,7 @@ with gr.Blocks(
|
|
| 267 |
guidance = gr.Slider(1, 20, value=7, label="Guidance Scale")
|
| 268 |
with gr.Row():
|
| 269 |
seed = gr.Number(-1, label="Seed")
|
| 270 |
-
random_seed_btn = gr.Button("Random
|
| 271 |
with gr.Row():
|
| 272 |
enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker",value=True,interactive=True)
|
| 273 |
flow_shift = gr.Number(3, label="Flow Shift",interactive=False)
|
|
@@ -277,8 +263,8 @@ with gr.Blocks(
|
|
| 277 |
status_output = gr.Textbox(label="System Status", interactive=False, lines=4)
|
| 278 |
generate_btn = gr.Button("Generated", variant="primary")
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
|
| 283 |
with gr.Accordion("Safety Status", open=True):
|
| 284 |
gr.Markdown("""
|
|
|
|
| 17 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 18 |
from functools import lru_cache
|
| 19 |
|
|
|
|
| 20 |
load_dotenv()
|
| 21 |
|
|
|
|
| 22 |
MODEL_URL = "TostAI/nsfw-text-detection-large"
|
| 23 |
CLASS_NAMES = {
|
| 24 |
0: "✅ SAFE",
|
|
|
|
| 26 |
2: "🚫 UNSAFE"
|
| 27 |
}
|
| 28 |
|
|
|
|
| 29 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
|
| 30 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
|
| 31 |
|
|
|
|
| 32 |
class SessionManager:
|
| 33 |
_instances = {}
|
| 34 |
_lock = threading.Lock()
|
|
|
|
| 52 |
for k in expired:
|
| 53 |
del cls._instances[k]
|
| 54 |
|
|
|
|
| 55 |
class RateLimiter:
|
| 56 |
def __init__(self):
|
| 57 |
self.clients = {}
|
|
|
|
| 74 |
self.clients[client_id]['count'] += 1
|
| 75 |
return True
|
| 76 |
|
|
|
|
| 77 |
session_manager = SessionManager()
|
| 78 |
rate_limiter = RateLimiter()
|
| 79 |
|
|
|
|
| 80 |
def image_to_base64(file_path):
|
| 81 |
try:
|
| 82 |
with open(file_path, "rb") as f:
|
|
|
|
| 105 |
img.save("error.jpg")
|
| 106 |
return "error.jpg"
|
| 107 |
|
| 108 |
+
|
| 109 |
@lru_cache(maxsize=100)
|
| 110 |
def classify_prompt(prompt):
|
| 111 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
|
|
|
| 126 |
size,
|
| 127 |
session_id
|
| 128 |
):
|
| 129 |
+
|
| 130 |
safety_level = classify_prompt(prompt)
|
| 131 |
if safety_level != 0:
|
| 132 |
error_img = create_error_image(CLASS_NAMES[safety_level])
|
| 133 |
yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img
|
| 134 |
return
|
| 135 |
|
|
|
|
| 136 |
if not rate_limiter.check(session_id):
|
| 137 |
error_img = create_error_image("Hourly limit exceeded (20 requests)")
|
| 138 |
yield "❌ 请求过于频繁,请稍后再试", error_img
|
| 139 |
return
|
| 140 |
|
|
|
|
| 141 |
session = session_manager.get_session(session_id)
|
| 142 |
session['last_active'] = time.time()
|
| 143 |
session['count'] += 1
|
| 144 |
|
|
|
|
| 145 |
try:
|
|
|
|
| 146 |
api_key = os.getenv("WAVESPEED_API_KEY")
|
| 147 |
if not api_key:
|
| 148 |
raise ValueError("API key missing")
|
|
|
|
| 181 |
yield f"❌ 提交失败: {str(e)}", error_img
|
| 182 |
return
|
| 183 |
|
|
|
|
| 184 |
result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
|
| 185 |
start_time = time.time()
|
| 186 |
|
|
|
|
| 212 |
yield f"❌ 生成失败: {str(e)}", error_img
|
| 213 |
return
|
| 214 |
|
|
|
|
| 215 |
def cleanup_task():
|
| 216 |
while True:
|
| 217 |
session_manager.cleanup_sessions()
|
| 218 |
time.sleep(3600)
|
| 219 |
|
|
|
|
| 220 |
with gr.Blocks(
|
| 221 |
theme=gr.themes.Soft(),
|
| 222 |
css="""
|
|
|
|
| 253 |
guidance = gr.Slider(1, 20, value=7, label="Guidance Scale")
|
| 254 |
with gr.Row():
|
| 255 |
seed = gr.Number(-1, label="Seed")
|
| 256 |
+
random_seed_btn = gr.Button("Random🎲Seed", variant="secondary")
|
| 257 |
with gr.Row():
|
| 258 |
enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker",value=True,interactive=True)
|
| 259 |
flow_shift = gr.Number(3, label="Flow Shift",interactive=False)
|
|
|
|
| 263 |
status_output = gr.Textbox(label="System Status", interactive=False, lines=4)
|
| 264 |
generate_btn = gr.Button("Generated", variant="primary")
|
| 265 |
|
| 266 |
+
# with gr.Accordion("Generation History", open=False):
|
| 267 |
+
# history_gallery = gr.Gallery(label="History", columns=3)
|
| 268 |
|
| 269 |
with gr.Accordion("Safety Status", open=True):
|
| 270 |
gr.Markdown("""
|