jiandan1998's picture
Update app.py
e2f8807 verified
raw
history blame
5.57 kB
import os
import requests
import json
import time
import threading
import shutil
from datetime import datetime
from pathlib import Path
from http.server import HTTPServer, SimpleHTTPRequestHandler
import base64
from dotenv import load_dotenv
import gradio as gr
import random
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from functools import lru_cache
# 加载环境变量
load_dotenv()
# ==== 新增安全检测模块 ====
MODEL_URL = "TostAI/nsfw-text-detection-large"
CLASS_NAMES = {
0: "✅ SAFE",
1: "⚠️ QUESTIONABLE",
2: "🚫 UNSAFE"
}
# 加载模型和tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL)
@lru_cache(maxsize=128)
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
return torch.argmax(outputs.logits, dim=1).item()
# ==== 会话管理模块 ====
class SessionManager:
_instances = {}
_lock = threading.Lock()
@classmethod
def get_session(cls, session_id):
with cls._lock:
if session_id not in cls._instances:
cls._instances[session_id] = {
'request_count': 0,
'last_request': time.time(),
'history': []
}
return cls._instances[session_id]
@classmethod
def cleanup_sessions(cls):
with cls._lock:
now = time.time()
expired = [k for k, v in cls._instances.items() if now - v['last_request'] > 3600]
for k in expired:
del cls._instances[k]
# ==== 频率限制模块 ====
class RateLimiter:
def __init__(self):
self.client_data = {}
def check_limit(self, client_id):
if client_id not in self.client_data:
self.client_data[client_id] = {
'count': 0,
'reset_time': time.time() + 3600
}
if time.time() > self.client_data[client_id]['reset_time']:
self.client_data[client_id] = {
'count': 0,
'reset_time': time.time() + 3600
}
if self.client_data[client_id]['count'] >= 20:
return False
self.client_data[client_id]['count'] += 1
return True
# ==== 错误处理模块 ====
def create_error_image(message):
img = Image.new("RGB", (832, 480), color="#ffdddd")
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 24)
except:
font = ImageFont.load_default()
text_width, text_height = draw.textsize(message, font=font)
x = (832 - text_width) / 2
y = (480 - text_height) / 2
draw.text((x, y), message, fill="#ff4444", font=font)
img.save("error.jpg")
return "error.jpg"
# ==== 核心生成逻辑 ====
def generate_video(/* 保持原有参数 */):
# 新增安全检测
safety_level = classify_text(prompt)
if safety_level != 0:
error_msg = f"Content blocked: {CLASS_NAMES[safety_level]}"
error_img = create_error_image(error_msg)
yield f"❌ {error_msg}", error_img
return
# 新增频率检查
session = SessionManager.get_session(session_id)
if session['request_count'] >= 20:
yield "❌ Hourly limit exceeded (20 requests)", None
return
session['request_count'] += 1
# 原有生成逻辑保持不变,增加状态跟踪
try:
# API调用部分
response = requests.post(/* 保持原有参数 */)
# 轮询部分增加进度跟踪
while True:
# 获取状态
status = get_status(request_id)
# 更新会话最后活动时间
session['last_request'] = time.time()
# 处理不同状态
if status == 'processing':
yield f"⏳ 生成进度: {progress}%", None
elif status == 'completed':
session['history'].append(video_url)
yield f"✅ 生成完成", video_url
return
except Exception as e:
error_img = create_error_image(str(e))
yield f"❌ 生成失败: {str(e)}", error_img
# ==== 新增定时清理任务 ====
def start_cleanup_task():
def cleanup():
while True:
SessionManager.cleanup_sessions()
time.sleep(3600)
thread = threading.Thread(target=cleanup)
thread.daemon = True
thread.start()
# ==== 界面增强 ====
with gr.Blocks(/* 保持原有参数 */) as app:
# 新增状态组件
status_bars = {}
with gr.Row():
for backend in ["WAN-2.1", "FLUX", "TURBO"]: # 示例后端
with gr.Column():
gr.Markdown(f"**{backend}**")
status_bars[backend] = gr.Textbox(label="状态", value="🟢 空闲")
# 在生成按钮点击时更新状态
generate_btn.click(
fn=update_status,
inputs=[...],
outputs=[status_bars]
)
# 新增历史记录模块
history = gr.Gallery(label="生成历史")
# 在生成完成后更新历史
generate_btn.click(
fn=update_history,
inputs=[video_output],
outputs=[history]
)
# ==== 启动时初始化 ====
if __name__ == "__main__":
start_cleanup_task()
app.queue(/* 保持原有参数 */)