|
|
import os |
|
|
import requests |
|
|
import json |
|
|
import time |
|
|
import threading |
|
|
import shutil |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from http.server import HTTPServer, SimpleHTTPRequestHandler |
|
|
import base64 |
|
|
from dotenv import load_dotenv |
|
|
import gradio as gr |
|
|
import random |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
from functools import lru_cache |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
MODEL_URL = "TostAI/nsfw-text-detection-large" |
|
|
CLASS_NAMES = { |
|
|
0: "✅ SAFE", |
|
|
1: "⚠️ QUESTIONABLE", |
|
|
2: "🚫 UNSAFE" |
|
|
} |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL) |
|
|
|
|
|
@lru_cache(maxsize=128) |
|
|
def classify_text(text): |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
return torch.argmax(outputs.logits, dim=1).item() |
|
|
|
|
|
|
|
|
class SessionManager: |
|
|
_instances = {} |
|
|
_lock = threading.Lock() |
|
|
|
|
|
@classmethod |
|
|
def get_session(cls, session_id): |
|
|
with cls._lock: |
|
|
if session_id not in cls._instances: |
|
|
cls._instances[session_id] = { |
|
|
'request_count': 0, |
|
|
'last_request': time.time(), |
|
|
'history': [] |
|
|
} |
|
|
return cls._instances[session_id] |
|
|
|
|
|
@classmethod |
|
|
def cleanup_sessions(cls): |
|
|
with cls._lock: |
|
|
now = time.time() |
|
|
expired = [k for k, v in cls._instances.items() if now - v['last_request'] > 3600] |
|
|
for k in expired: |
|
|
del cls._instances[k] |
|
|
|
|
|
|
|
|
class RateLimiter: |
|
|
def __init__(self): |
|
|
self.client_data = {} |
|
|
|
|
|
def check_limit(self, client_id): |
|
|
if client_id not in self.client_data: |
|
|
self.client_data[client_id] = { |
|
|
'count': 0, |
|
|
'reset_time': time.time() + 3600 |
|
|
} |
|
|
|
|
|
if time.time() > self.client_data[client_id]['reset_time']: |
|
|
self.client_data[client_id] = { |
|
|
'count': 0, |
|
|
'reset_time': time.time() + 3600 |
|
|
} |
|
|
|
|
|
if self.client_data[client_id]['count'] >= 20: |
|
|
return False |
|
|
self.client_data[client_id]['count'] += 1 |
|
|
return True |
|
|
|
|
|
|
|
|
def create_error_image(message): |
|
|
img = Image.new("RGB", (832, 480), color="#ffdddd") |
|
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
try: |
|
|
font = ImageFont.truetype("arial.ttf", 24) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
text_width, text_height = draw.textsize(message, font=font) |
|
|
x = (832 - text_width) / 2 |
|
|
y = (480 - text_height) / 2 |
|
|
|
|
|
draw.text((x, y), message, fill="#ff4444", font=font) |
|
|
img.save("error.jpg") |
|
|
return "error.jpg" |
|
|
|
|
|
|
|
|
def generate_video(/* 保持原有参数 */): |
|
|
|
|
|
safety_level = classify_text(prompt) |
|
|
if safety_level != 0: |
|
|
error_msg = f"Content blocked: {CLASS_NAMES[safety_level]}" |
|
|
error_img = create_error_image(error_msg) |
|
|
yield f"❌ {error_msg}", error_img |
|
|
return |
|
|
|
|
|
|
|
|
session = SessionManager.get_session(session_id) |
|
|
if session['request_count'] >= 20: |
|
|
yield "❌ Hourly limit exceeded (20 requests)", None |
|
|
return |
|
|
session['request_count'] += 1 |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
response = requests.post(/* 保持原有参数 */) |
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
status = get_status(request_id) |
|
|
|
|
|
|
|
|
session['last_request'] = time.time() |
|
|
|
|
|
|
|
|
if status == 'processing': |
|
|
yield f"⏳ 生成进度: {progress}%", None |
|
|
elif status == 'completed': |
|
|
session['history'].append(video_url) |
|
|
yield f"✅ 生成完成", video_url |
|
|
return |
|
|
|
|
|
except Exception as e: |
|
|
error_img = create_error_image(str(e)) |
|
|
yield f"❌ 生成失败: {str(e)}", error_img |
|
|
|
|
|
|
|
|
def start_cleanup_task(): |
|
|
def cleanup(): |
|
|
while True: |
|
|
SessionManager.cleanup_sessions() |
|
|
time.sleep(3600) |
|
|
|
|
|
thread = threading.Thread(target=cleanup) |
|
|
thread.daemon = True |
|
|
thread.start() |
|
|
|
|
|
|
|
|
with gr.Blocks(/* 保持原有参数 */) as app: |
|
|
|
|
|
status_bars = {} |
|
|
|
|
|
with gr.Row(): |
|
|
for backend in ["WAN-2.1", "FLUX", "TURBO"]: |
|
|
with gr.Column(): |
|
|
gr.Markdown(f"**{backend}**") |
|
|
status_bars[backend] = gr.Textbox(label="状态", value="🟢 空闲") |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=update_status, |
|
|
inputs=[...], |
|
|
outputs=[status_bars] |
|
|
) |
|
|
|
|
|
|
|
|
history = gr.Gallery(label="生成历史") |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=update_history, |
|
|
inputs=[video_output], |
|
|
outputs=[history] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
start_cleanup_task() |
|
|
app.queue(/* 保持原有参数 */) |