jiandan1998 commited on
Commit
b3e34bb
·
verified ·
1 Parent(s): c74f0f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -19
app.py CHANGED
@@ -17,10 +17,8 @@ from PIL import Image, ImageDraw, ImageFont
17
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
18
  from functools import lru_cache
19
 
20
- # 初始化环境
21
  load_dotenv()
22
 
23
- # 安全检测配置
24
  MODEL_URL = "TostAI/nsfw-text-detection-large"
25
  CLASS_NAMES = {
26
  0: "✅ SAFE",
@@ -28,11 +26,9 @@ CLASS_NAMES = {
28
  2: "🚫 UNSAFE"
29
  }
30
 
31
- # 加载模型
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
33
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
34
 
35
- # 会话管理
36
  class SessionManager:
37
  _instances = {}
38
  _lock = threading.Lock()
@@ -56,7 +52,6 @@ class SessionManager:
56
  for k in expired:
57
  del cls._instances[k]
58
 
59
- # 频率限制
60
  class RateLimiter:
61
  def __init__(self):
62
  self.clients = {}
@@ -79,11 +74,9 @@ class RateLimiter:
79
  self.clients[client_id]['count'] += 1
80
  return True
81
 
82
- # 初始化模块
83
  session_manager = SessionManager()
84
  rate_limiter = RateLimiter()
85
 
86
- # 图像处理函数
87
  def image_to_base64(file_path):
88
  try:
89
  with open(file_path, "rb") as f:
@@ -112,7 +105,7 @@ def create_error_image(message):
112
  img.save("error.jpg")
113
  return "error.jpg"
114
 
115
- # 核心生成逻辑
116
  @lru_cache(maxsize=100)
117
  def classify_prompt(prompt):
118
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
@@ -133,27 +126,23 @@ def generate_video(
133
  size,
134
  session_id
135
  ):
136
- # 安全检查
137
  safety_level = classify_prompt(prompt)
138
  if safety_level != 0:
139
  error_img = create_error_image(CLASS_NAMES[safety_level])
140
  yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img
141
  return
142
 
143
- # 频率检查
144
  if not rate_limiter.check(session_id):
145
  error_img = create_error_image("Hourly limit exceeded (20 requests)")
146
  yield "❌ 请求过于频繁,请稍后再试", error_img
147
  return
148
 
149
- # 会话更新
150
  session = session_manager.get_session(session_id)
151
  session['last_active'] = time.time()
152
  session['count'] += 1
153
 
154
- # API调用
155
  try:
156
- # 准备请求
157
  api_key = os.getenv("WAVESPEED_API_KEY")
158
  if not api_key:
159
  raise ValueError("API key missing")
@@ -192,7 +181,6 @@ def generate_video(
192
  yield f"❌ 提交失败: {str(e)}", error_img
193
  return
194
 
195
- # 轮询结果
196
  result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
197
  start_time = time.time()
198
 
@@ -224,13 +212,11 @@ def generate_video(
224
  yield f"❌ 生成失败: {str(e)}", error_img
225
  return
226
 
227
- # 后台清理线程
228
  def cleanup_task():
229
  while True:
230
  session_manager.cleanup_sessions()
231
  time.sleep(3600)
232
 
233
- # Gradio界面
234
  with gr.Blocks(
235
  theme=gr.themes.Soft(),
236
  css="""
@@ -267,7 +253,7 @@ with gr.Blocks(
267
  guidance = gr.Slider(1, 20, value=7, label="Guidance Scale")
268
  with gr.Row():
269
  seed = gr.Number(-1, label="Seed")
270
- random_seed_btn = gr.Button("Random Seed", variant="secondary")
271
  with gr.Row():
272
  enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker",value=True,interactive=True)
273
  flow_shift = gr.Number(3, label="Flow Shift",interactive=False)
@@ -277,8 +263,8 @@ with gr.Blocks(
277
  status_output = gr.Textbox(label="System Status", interactive=False, lines=4)
278
  generate_btn = gr.Button("Generated", variant="primary")
279
 
280
- with gr.Accordion("Generation History", open=False):
281
- history_gallery = gr.Gallery(label="History", columns=3)
282
 
283
  with gr.Accordion("Safety Status", open=True):
284
  gr.Markdown("""
 
17
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
18
  from functools import lru_cache
19
 
 
20
  load_dotenv()
21
 
 
22
  MODEL_URL = "TostAI/nsfw-text-detection-large"
23
  CLASS_NAMES = {
24
  0: "✅ SAFE",
 
26
  2: "🚫 UNSAFE"
27
  }
28
 
 
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
30
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
31
 
 
32
  class SessionManager:
33
  _instances = {}
34
  _lock = threading.Lock()
 
52
  for k in expired:
53
  del cls._instances[k]
54
 
 
55
  class RateLimiter:
56
  def __init__(self):
57
  self.clients = {}
 
74
  self.clients[client_id]['count'] += 1
75
  return True
76
 
 
77
  session_manager = SessionManager()
78
  rate_limiter = RateLimiter()
79
 
 
80
  def image_to_base64(file_path):
81
  try:
82
  with open(file_path, "rb") as f:
 
105
  img.save("error.jpg")
106
  return "error.jpg"
107
 
108
+
109
  @lru_cache(maxsize=100)
110
  def classify_prompt(prompt):
111
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
126
  size,
127
  session_id
128
  ):
129
+
130
  safety_level = classify_prompt(prompt)
131
  if safety_level != 0:
132
  error_img = create_error_image(CLASS_NAMES[safety_level])
133
  yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img
134
  return
135
 
 
136
  if not rate_limiter.check(session_id):
137
  error_img = create_error_image("Hourly limit exceeded (20 requests)")
138
  yield "❌ 请求过于频繁,请稍后再试", error_img
139
  return
140
 
 
141
  session = session_manager.get_session(session_id)
142
  session['last_active'] = time.time()
143
  session['count'] += 1
144
 
 
145
  try:
 
146
  api_key = os.getenv("WAVESPEED_API_KEY")
147
  if not api_key:
148
  raise ValueError("API key missing")
 
181
  yield f"❌ 提交失败: {str(e)}", error_img
182
  return
183
 
 
184
  result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
185
  start_time = time.time()
186
 
 
212
  yield f"❌ 生成失败: {str(e)}", error_img
213
  return
214
 
 
215
  def cleanup_task():
216
  while True:
217
  session_manager.cleanup_sessions()
218
  time.sleep(3600)
219
 
 
220
  with gr.Blocks(
221
  theme=gr.themes.Soft(),
222
  css="""
 
253
  guidance = gr.Slider(1, 20, value=7, label="Guidance Scale")
254
  with gr.Row():
255
  seed = gr.Number(-1, label="Seed")
256
+ random_seed_btn = gr.Button("Random🎲Seed", variant="secondary")
257
  with gr.Row():
258
  enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker",value=True,interactive=True)
259
  flow_shift = gr.Number(3, label="Flow Shift",interactive=False)
 
263
  status_output = gr.Textbox(label="System Status", interactive=False, lines=4)
264
  generate_btn = gr.Button("Generated", variant="primary")
265
 
266
+ # with gr.Accordion("Generation History", open=False):
267
+ # history_gallery = gr.Gallery(label="History", columns=3)
268
 
269
  with gr.Accordion("Safety Status", open=True):
270
  gr.Markdown("""