Spaces:
Paused
Paused
| import os | |
| import requests | |
| import json | |
| import time | |
| import threading | |
| import uuid | |
| import base64 | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| import random | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # 环境变量加载 | |
| load_dotenv() | |
| API_KEY = os.getenv("WAVESPEED_API_KEY") | |
| if not API_KEY: | |
| raise ValueError("WAVESPEED_API_KEY 未在环境变量中设置") | |
| # 安全分类配置 | |
| MODEL_URL = "TostAI/nsfw-text-detection-large" | |
| CLASS_NAMES = {0: "✅ SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"} | |
| # 加载安全模型 | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL) | |
| except Exception as e: | |
| raise RuntimeError(f"安全模型加载失败: {str(e)}") | |
| # 会话管理 | |
| class SessionManager: | |
| _instances = {} | |
| _lock = threading.Lock() | |
| def get_session(cls, session_id): | |
| with cls._lock: | |
| if session_id not in cls._instances: | |
| cls._instances[session_id] = { | |
| 'count': 0, | |
| 'history': [], | |
| 'last_active': time.time() | |
| } | |
| return cls._instances[session_id] | |
| def cleanup_sessions(cls): | |
| with cls._lock: | |
| now = time.time() | |
| expired = [ | |
| k for k, v in cls._instances.items() | |
| if now - v['last_active'] > 3600 | |
| ] | |
| for k in expired: | |
| del cls._instances[k] | |
| # 速率限制 | |
| class RateLimiter: | |
| def __init__(self): | |
| self.clients = {} | |
| self.lock = threading.Lock() | |
| def check(self, client_id): | |
| with self.lock: | |
| now = time.time() | |
| if client_id not in self.clients: | |
| self.clients[client_id] = {'count': 1, 'reset': now + 3600} | |
| return True | |
| if now > self.clients[client_id]['reset']: | |
| self.clients[client_id] = {'count': 1, 'reset': now + 3600} | |
| return True | |
| if self.clients[client_id]['count'] >= 20: | |
| return False | |
| self.clients[client_id]['count'] += 1 | |
| return True | |
| session_manager = SessionManager() | |
| rate_limiter = RateLimiter() | |
| # 工具函数 | |
| def create_error_image(message): | |
| """生成错误提示图片""" | |
| img = Image.new("RGB", (512, 512), "#ffdddd") | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 24) | |
| except: | |
| font = ImageFont.load_default() | |
| draw = ImageDraw.Draw(img) | |
| text = f"Error: {message[:60]}..." if len(message) > 60 else message | |
| draw.text((50, 200), text, fill="#ff0000", font=font) | |
| return img | |
| def classify_prompt(prompt): | |
| """安全分类""" | |
| inputs = tokenizer(prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512) | |
| outputs = model(**inputs) | |
| return torch.argmax(outputs.logits).item() | |
| def image_to_base64(file_path): | |
| """将图片转换为Base64格式""" | |
| with open(file_path, "rb") as f: | |
| file_ext = Path(file_path).suffix.lower()[1:] | |
| mime_type = f"image/{file_ext}" if file_ext in ["jpeg", "jpg", "png" | |
| ] else "image/jpeg" | |
| return f"data:{mime_type};base64,{base64.b64encode(f.read()).decode()}" | |
| # 核心生成逻辑 | |
| def generate_image(image_file, prompt, seed, session_id, enable_safety=True): | |
| try: | |
| # 安全检查 | |
| if enable_safety: | |
| safety_level = classify_prompt(prompt) | |
| if safety_level != 0: | |
| error_img = create_error_image(CLASS_NAMES[safety_level]) | |
| yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img, "" | |
| return | |
| # 速率限制 | |
| if not rate_limiter.check(session_id): | |
| error_img = create_error_image( | |
| "Hourly limit exceeded (20 requests)") | |
| yield "❌ 请求过于频繁,请稍后再试", error_img, "" | |
| return | |
| # 会话更新 | |
| session = session_manager.get_session(session_id) | |
| session['last_active'] = time.time() | |
| session['count'] += 1 | |
| # 输入验证 | |
| error_messages = [] | |
| if not image_file: | |
| error_messages.append("请上传图片文件") | |
| elif not Path(image_file).exists(): | |
| error_messages.append("文件不存在") | |
| elif Path(image_file).suffix.lower()[1:] not in ["jpg", "jpeg", "png"]: | |
| error_messages.append("仅支持JPG/PNG格式") | |
| if not prompt.strip(): | |
| error_messages.append("提示语不能为空") | |
| if error_messages: | |
| error_img = create_error_image(" | ".join(error_messages)) | |
| yield "❌ 输入验证失败", error_img, "" | |
| return | |
| # 转换为Base64 | |
| try: | |
| base64_image = image_to_base64(image_file) | |
| except Exception as e: | |
| error_img = create_error_image(f"文件处理失败: {str(e)}") | |
| yield "❌ 文件处理失败", error_img, "" | |
| return | |
| # 构造请求 | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {API_KEY}", | |
| } | |
| payload = { | |
| "enable_base64_output": True, | |
| "enable_safety_checker": enable_safety, | |
| "image": base64_image, | |
| "prompt": prompt, | |
| "seed": int(seed) if seed != -1 else random.randint(0, 999999) | |
| } | |
| # 提交请求 | |
| response = requests.post( | |
| "https://api.wavespeed.ai/api/v2/wavespeed-ai/hidream-e1-full", | |
| headers=headers, | |
| json=payload, | |
| timeout=30) | |
| response.raise_for_status() | |
| # 处理响应 | |
| request_id = response.json()["data"]["id"] | |
| result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result" | |
| start_time = time.time() | |
| # 轮询结果 | |
| for _ in range(60): | |
| time.sleep(1) | |
| resp = requests.get(result_url, headers=headers) | |
| resp.raise_for_status() | |
| data = resp.json()["data"] | |
| status = data["status"] | |
| if status == "completed": | |
| elapsed = time.time() - start_time | |
| image_url = data["outputs"][0] | |
| session["history"].append(image_url) | |
| yield f"🎉 生成成功! 耗时 {elapsed:.1f}s", image_url, image_url | |
| return | |
| elif status == "failed": | |
| raise Exception(data.get("error", "Unknown error")) | |
| else: | |
| yield f"⏳ 当前状态: {status.capitalize()}...", None, None | |
| raise Exception("生成超时") | |
| except Exception as e: | |
| error_img = create_error_image(str(e)) | |
| yield f"❌ 生成失败: {str(e)}", error_img, "" | |
| # 后台清理任务 | |
| def cleanup_task(): | |
| while True: | |
| session_manager.cleanup_sessions() | |
| time.sleep(3600) | |
| # 界面构建 | |
| with gr.Blocks(theme=gr.themes.Soft(), | |
| css=""" | |
| .status-box { padding: 10px; border-radius: 5px; margin: 5px; } | |
| .safe { background: #e8f5e9; border: 1px solid #a5d6a7; } | |
| .warning { background: #fff3e0; border: 1px solid #ffcc80; } | |
| .error { background: #ffebee; border: 1px solid #ef9a9a; } | |
| """) as app: | |
| session_id = gr.State(str(uuid.uuid4())) | |
| gr.Markdown("# 🖼️Hidream-E1-Full Live On Wavespeed Ai") | |
| gr.Markdown("HiDream-E1 is an image editing model built on HiDream-I1.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_file = gr.Image(label="Upload Image", | |
| type="filepath", | |
| sources=["upload"], | |
| interactive=True, | |
| image_mode="RGB") | |
| prompt = gr.Textbox( | |
| label="prompt", | |
| placeholder="Please enter an English prompt...", | |
| lines=3) | |
| seed = gr.Number(label="seed", | |
| value=-1, | |
| minimum=-1, | |
| maximum=999999, | |
| step=1) | |
| random_btn = gr.Button("random🎲seed", variant="secondary") | |
| enable_safety = gr.Checkbox(label="🔒 Enable Safety Checker", | |
| value=True, | |
| interactive=False) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Result") | |
| output_url = gr.Textbox(label="image url", | |
| interactive=True, | |
| visible=False) | |
| status = gr.Textbox(label="Status", elem_classes=["status-box"]) | |
| submit_btn = gr.Button("开始生成", variant="primary") | |
| gr.Examples(examples=[ | |
| [ | |
| "Convert the image into Claymation style.", | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" | |
| ], | |
| [ | |
| "Convert the image into a Ghibli style.", | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg" | |
| ], | |
| [ | |
| "Add sunglasses to the face of the girl.", | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png" | |
| ], | |
| [ | |
| 'Convert the image into an ink sketch style.', | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" | |
| ], | |
| [ | |
| 'Add a butterfly to the scene.', | |
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_result.png" | |
| ] | |
| ], | |
| inputs=[prompt, image_file], | |
| label="Examples") | |
| random_btn.click(fn=lambda: random.randint(0, 999999), outputs=seed) | |
| submit_btn.click( | |
| generate_image, | |
| inputs=[image_file, prompt, seed, session_id, enable_safety], | |
| outputs=[status, output_image, output_url]) | |
| if __name__ == "__main__": | |
| threading.Thread(target=cleanup_task, daemon=True).start() | |
| app.queue(max_size=4).launch(server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False) | |