jiandan1998 commited on
Commit
7909335
·
verified ·
1 Parent(s): 70f6ec8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -109
app.py CHANGED
@@ -12,13 +12,14 @@ from dotenv import load_dotenv
12
  import gradio as gr
13
  import random
14
  import torch
 
15
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
16
  from functools import lru_cache
17
 
18
- # 加载环境变量
19
  load_dotenv()
20
 
21
- # ==== 新增安全检测模块 ====
22
  MODEL_URL = "TostAI/nsfw-text-detection-large"
23
  CLASS_NAMES = {
24
  0: "✅ SAFE",
@@ -26,18 +27,11 @@ CLASS_NAMES = {
26
  2: "🚫 UNSAFE"
27
  }
28
 
29
- # 加载模型和tokenizer
30
  tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
31
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
32
 
33
- @lru_cache(maxsize=128)
34
- def classify_text(text):
35
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
36
- with torch.no_grad():
37
- outputs = model(**inputs)
38
- return torch.argmax(outputs.logits, dim=1).item()
39
-
40
- # ==== 会话管理模块 ====
41
  class SessionManager:
42
  _instances = {}
43
  _lock = threading.Lock()
@@ -47,9 +41,9 @@ class SessionManager:
47
  with cls._lock:
48
  if session_id not in cls._instances:
49
  cls._instances[session_id] = {
50
- 'request_count': 0,
51
- 'last_request': time.time(),
52
- 'history': []
53
  }
54
  return cls._instances[session_id]
55
 
@@ -57,133 +51,256 @@ class SessionManager:
57
  def cleanup_sessions(cls):
58
  with cls._lock:
59
  now = time.time()
60
- expired = [k for k, v in cls._instances.items() if now - v['last_request'] > 3600]
61
  for k in expired:
62
  del cls._instances[k]
63
 
64
- # ==== 频率限制模块 ====
65
  class RateLimiter:
66
  def __init__(self):
67
- self.client_data = {}
68
-
69
- def check_limit(self, client_id):
70
- if client_id not in self.client_data:
71
- self.client_data[client_id] = {
72
- 'count': 0,
73
- 'reset_time': time.time() + 3600
74
- }
75
-
76
- if time.time() > self.client_data[client_id]['reset_time']:
77
- self.client_data[client_id] = {
78
- 'count': 0,
79
- 'reset_time': time.time() + 3600
80
- }
81
-
82
- if self.client_data[client_id]['count'] >= 20:
83
- return False
84
- self.client_data[client_id]['count'] += 1
85
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # ==== 错误处理模块 ====
88
  def create_error_image(message):
89
- img = Image.new("RGB", (832, 480), color="#ffdddd")
90
- draw = ImageDraw.Draw(img)
91
-
92
  try:
93
  font = ImageFont.truetype("arial.ttf", 24)
94
  except:
95
  font = ImageFont.load_default()
96
-
97
- text_width, text_height = draw.textsize(message, font=font)
98
- x = (832 - text_width) / 2
99
- y = (480 - text_height) / 2
100
-
101
- draw.text((x, y), message, fill="#ff4444", font=font)
102
  img.save("error.jpg")
103
  return "error.jpg"
104
 
105
- # ==== 核心生成逻辑 ====
106
- def generate_video(/* 保持原有参数 */):
107
- # 新增安全检测
108
- safety_level = classify_text(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if safety_level != 0:
110
- error_msg = f"Content blocked: {CLASS_NAMES[safety_level]}"
111
- error_img = create_error_image(error_msg)
112
- yield f"❌ {error_msg}", error_img
113
  return
114
 
115
- # 新增频率检查
116
- session = SessionManager.get_session(session_id)
117
- if session['request_count'] >= 20:
118
- yield "❌ Hourly limit exceeded (20 requests)", None
119
  return
120
- session['request_count'] += 1
121
 
122
- # 原有生成逻辑保持不变,增加状态跟踪
 
 
 
 
 
123
  try:
124
- # API调用部分
125
- response = requests.post(/* 保持原有参数 */)
126
-
127
- # 轮询部分增加进度跟踪
128
- while True:
129
- # 获取状态
130
- status = get_status(request_id)
131
 
132
- # 更新会话最后活动时间
133
- session['last_request'] = time.time()
134
-
135
- # 处理不同状态
136
- if status == 'processing':
137
- yield f" 生成进度: {progress}%", None
138
- elif status == 'completed':
139
- session['history'].append(video_url)
140
- yield f" 生成完成", video_url
141
- return
 
 
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  except Exception as e:
144
  error_img = create_error_image(str(e))
145
- yield f"❌ 生成失败: {str(e)}", error_img
146
-
147
- # ==== 新增定时清理任务 ====
148
- def start_cleanup_task():
149
- def cleanup():
150
- while True:
151
- SessionManager.cleanup_sessions()
152
- time.sleep(3600)
 
 
 
 
 
 
 
 
153
 
154
- thread = threading.Thread(target=cleanup)
155
- thread.daemon = True
156
- thread.start()
157
-
158
- # ==== 界面增强 ====
159
- with gr.Blocks(/* 保持原有参数 */) as app:
160
- # 新增状态组件
161
- status_bars = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  with gr.Row():
164
- for backend in ["WAN-2.1", "FLUX", "TURBO"]: # 示例后端
165
- with gr.Column():
166
- gr.Markdown(f"**{backend}**")
167
- status_bars[backend] = gr.Textbox(label="状态", value="🟢 空闲")
 
 
 
 
 
 
 
 
 
168
 
169
- # 在生成按钮点击时更新状态
170
- generate_btn.click(
171
- fn=update_status,
172
- inputs=[...],
173
- outputs=[status_bars]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  )
175
 
176
- # 新增历史记录模块
177
- history = gr.Gallery(label="生成历史")
 
 
 
178
 
179
- # 在生成完成后更新历史
180
  generate_btn.click(
181
- fn=update_history,
182
- inputs=[video_output],
183
- outputs=[history]
 
184
  )
185
 
186
- # ==== 启动时初始化 ====
187
  if __name__ == "__main__":
188
- start_cleanup_task()
189
- app.queue(/* 保持原有参数 */)
 
 
 
 
 
12
  import gradio as gr
13
  import random
14
  import torch
15
+ from PIL import Image, ImageDraw, ImageFont
16
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
17
  from functools import lru_cache
18
 
19
+ # 初始化环境
20
  load_dotenv()
21
 
22
+ # 安全检测配置
23
  MODEL_URL = "TostAI/nsfw-text-detection-large"
24
  CLASS_NAMES = {
25
  0: "✅ SAFE",
 
27
  2: "🚫 UNSAFE"
28
  }
29
 
30
+ # 加载模型
31
  tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
32
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
33
 
34
+ # 会话管理
 
 
 
 
 
 
 
35
  class SessionManager:
36
  _instances = {}
37
  _lock = threading.Lock()
 
41
  with cls._lock:
42
  if session_id not in cls._instances:
43
  cls._instances[session_id] = {
44
+ 'count': 0,
45
+ 'history': [],
46
+ 'last_active': time.time()
47
  }
48
  return cls._instances[session_id]
49
 
 
51
  def cleanup_sessions(cls):
52
  with cls._lock:
53
  now = time.time()
54
+ expired = [k for k, v in cls._instances.items() if now - v['last_active'] > 3600]
55
  for k in expired:
56
  del cls._instances[k]
57
 
58
+ # 频率限制
59
  class RateLimiter:
60
  def __init__(self):
61
+ self.clients = {}
62
+ self.lock = threading.Lock()
63
+
64
+ def check(self, client_id):
65
+ with self.lock:
66
+ now = time.time()
67
+ if client_id not in self.clients:
68
+ self.clients[client_id] = {'count': 1, 'reset': now + 3600}
69
+ return True
70
+
71
+ if now > self.clients[client_id]['reset']:
72
+ self.clients[client_id] = {'count': 1, 'reset': now + 3600}
73
+ return True
74
+
75
+ if self.clients[client_id]['count'] >= 20:
76
+ return False
77
+
78
+ self.clients[client_id]['count'] += 1
79
+ return True
80
+
81
+ # 初始化模块
82
+ session_manager = SessionManager()
83
+ rate_limiter = RateLimiter()
84
+
85
+ # 图像处理函数
86
+ def image_to_base64(file_path):
87
+ try:
88
+ with open(file_path, "rb") as f:
89
+ ext = Path(file_path).suffix.lower()[1:]
90
+ mime_map = {'jpg':'jpeg','jpeg':'jpeg','png':'png','webp':'webp','gif':'gif'}
91
+ mime = mime_map.get(ext, 'jpeg')
92
+
93
+ encoded = base64.b64encode(f.read())
94
+ if len(encoded) % 4:
95
+ encoded += b'=' * (4 - len(encoded) % 4)
96
+
97
+ return f"data:image/{mime};base64,{encoded.decode()}"
98
+ except Exception as e:
99
+ raise ValueError(f"Base64 Error: {str(e)}")
100
 
 
101
  def create_error_image(message):
102
+ img = Image.new("RGB", (832, 480), "#ffdddd")
 
 
103
  try:
104
  font = ImageFont.truetype("arial.ttf", 24)
105
  except:
106
  font = ImageFont.load_default()
107
+
108
+ draw = ImageDraw.Draw(img)
109
+ text = f"Error: {message[:60]}..." if len(message) > 60 else message
110
+ draw.text((50, 200), text, fill="#ff0000", font=font)
 
 
111
  img.save("error.jpg")
112
  return "error.jpg"
113
 
114
+ # 核心生成逻辑
115
+ @lru_cache(maxsize=100)
116
+ def classify_prompt(prompt):
117
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
118
+ with torch.no_grad():
119
+ outputs = model(**inputs)
120
+ return torch.argmax(outputs.logits).item()
121
+
122
+ def generate_video(
123
+ image,
124
+ prompt,
125
+ duration,
126
+ enable_safety,
127
+ flow_shift,
128
+ guidance_scale,
129
+ negative_prompt,
130
+ inference_steps,
131
+ seed,
132
+ size,
133
+ session_id
134
+ ):
135
+ # 安全检查
136
+ safety_level = classify_prompt(prompt)
137
  if safety_level != 0:
138
+ error_img = create_error_image(CLASS_NAMES[safety_level])
139
+ yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img
 
140
  return
141
 
142
+ # 频率检查
143
+ if not rate_limiter.check(session_id):
144
+ error_img = create_error_image("Hourly limit exceeded (20 requests)")
145
+ yield "❌ 请求过于频繁,请稍后再试", error_img
146
  return
 
147
 
148
+ # 会话更新
149
+ session = session_manager.get_session(session_id)
150
+ session['last_active'] = time.time()
151
+ session['count'] += 1
152
+
153
+ # API调用
154
  try:
155
+ # 准备请求
156
+ api_key = os.getenv("WAVESPEED_API_KEY")
157
+ if not api_key:
158
+ raise ValueError("API key missing")
 
 
 
159
 
160
+ base64_img = image_to_base64(image)
161
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
162
+
163
+ payload = {
164
+ "image": base64_img,
165
+ "prompt": prompt,
166
+ "duration": duration,
167
+ "guidance_scale": guidance_scale,
168
+ "negative_prompt": negative_prompt,
169
+ "num_inference_steps": inference_steps,
170
+ "seed": seed if seed != -1 else random.randint(0, 999999),
171
+ "size": size
172
+ }
173
 
174
+ # 提交任务
175
+ response = requests.post(
176
+ "https://api.wavespeed.ai/api/v2/wavespeed-ai/wan-2.1/i2v-480p-ultra-fast",
177
+ headers=headers,
178
+ json=payload
179
+ )
180
+
181
+ if response.status_code != 200:
182
+ raise Exception(f"API Error {response.status_code}: {response.text}")
183
+
184
+ request_id = response.json()["data"]["id"]
185
+ yield f"✅ 任务已提交 (ID: {request_id})", None
186
+
187
  except Exception as e:
188
  error_img = create_error_image(str(e))
189
+ yield f"❌ 提交失败: {str(e)}", error_img
190
+ return
191
+
192
+ # 轮询结果
193
+ result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result"
194
+ start_time = time.time()
195
+
196
+ while True:
197
+ time.sleep(1)
198
+ try:
199
+ resp = requests.get(result_url, headers=headers)
200
+ if resp.status_code != 200:
201
+ raise Exception(f"状态查询失败: {resp.text}")
202
+
203
+ data = resp.json()["data"]
204
+ status = data["status"]
205
 
206
+ if status == "completed":
207
+ elapsed = time.time() - start_time
208
+ video_url = data["outputs"][0]
209
+ session["history"].append(video_url)
210
+ yield f"🎉 生成成功! 耗时 {elapsed:.1f}s", video_url
211
+ return
212
+
213
+ elif status == "failed":
214
+ raise Exception(data.get("error", "Unknown error"))
215
+
216
+ else:
217
+ yield f"⏳ 当前状态: {status.capitalize()}...", None
218
+
219
+ except Exception as e:
220
+ error_img = create_error_image(str(e))
221
+ yield f"❌ 生成失败: {str(e)}", error_img
222
+ return
223
+
224
+ # 后台清理线程
225
+ def cleanup_task():
226
+ while True:
227
+ session_manager.cleanup_sessions()
228
+ time.sleep(3600)
229
+
230
+ # Gradio界面
231
+ with gr.Blocks(
232
+ theme=gr.themes.Soft(),
233
+ css="""
234
+ .video-preview { max-width: 600px !important; }
235
+ .status-box { padding: 10px; border-radius: 5px; margin: 5px; }
236
+ .safe { background: #e8f5e9; border: 1px solid #a5d6a7; }
237
+ .warning { background: #fff3e0; border: 1px solid #ffcc80; }
238
+ .error { background: #ffebee; border: 1px solid #ef9a9a; }
239
+ """
240
+ ) as app:
241
+
242
+ session_id = gr.State(str(uuid.uuid4()))
243
+
244
+ gr.Markdown("# 🌊 视频生成系统 - WaveSpeedAI")
245
 
246
  with gr.Row():
247
+ with gr.Column(scale=1):
248
+ img_input = gr.Image(type="filepath", label="上传图片")
249
+ prompt = gr.Textbox(label="描述文本", lines=3, placeholder="请输入画面描述...")
250
+ negative_prompt = gr.Textbox(label="排除内容", lines=2)
251
+ with gr.Row():
252
+ size = gr.Dropdown(["832 * 480"], label="分辨率")
253
+ steps = gr.Slider(1, 50, value=30, label="推理步数")
254
+ with gr.Row():
255
+ duration = gr.Slider(1, 10, value=5, step=1, label="时长(秒)")
256
+ guidance = gr.Slider(1, 20, value=7, label="引导强度")
257
+ with gr.Row():
258
+ seed = gr.Number(-1, label="随机种子")
259
+ random_seed_btn = gr.Button("随机生成", variant="secondary")
260
 
261
+ with gr.Column(scale=1):
262
+ video_output = gr.Video(label="生成结果", format="mp4", elem_classes=["video-preview"])
263
+ status_output = gr.Textbox(label="系统状态", interactive=False, lines=4)
264
+ generate_btn = gr.Button("开始生成", variant="primary")
265
+
266
+ with gr.Accordion("生成历史", open=False):
267
+ history_gallery = gr.Gallery(label="历史记录", columns=3)
268
+
269
+ with gr.Accordion("安全状态", open=True):
270
+ gr.Markdown("""
271
+ <div class="status-box safe">
272
+ ✅ 当前内容安全检测通过
273
+ </div>
274
+ """)
275
+
276
+ # 示例区
277
+ gr.Examples(
278
+ examples=[
279
+ ["19世纪绅士在石板街", "example1.jpg"],
280
+ ["赛博朋克女战士在雨夜", "example2.jpg"]
281
+ ],
282
+ inputs=[prompt, img_input],
283
+ label="示例输入"
284
  )
285
 
286
+ # 事件处理
287
+ random_seed_btn.click(
288
+ fn=lambda: random.randint(0, 999999),
289
+ outputs=seed
290
+ )
291
 
 
292
  generate_btn.click(
293
+ generate_video,
294
+ inputs=[img_input, prompt, duration, gr.State(True), gr.State(3),
295
+ guidance, negative_prompt, steps, seed, size, session_id],
296
+ outputs=[status_output, video_output]
297
  )
298
 
299
+ # 启动系统
300
  if __name__ == "__main__":
301
+ threading.Thread(target=cleanup_task, daemon=True).start()
302
+ app.queue(max_size=4).launch(
303
+ server_name="0.0.0.0",
304
+ max_threads=16,
305
+ share=False
306
+ )