Spaces:
Runtime error
Runtime error
| import os | |
| import oss2 | |
| from pymongo import MongoClient | |
| from pymongo.server_api import ServerApi | |
| import bcrypt | |
| from datetime import datetime | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import json | |
| import websocket | |
| import uuid | |
| import urllib.request | |
| import urllib.parse | |
| import requests | |
| from bson.objectid import ObjectId | |
| from dotenv import load_dotenv | |
| # 加载 .env 文件中的环境变量 | |
| load_dotenv() | |
| app = Flask(__name__) | |
| CORS(app) | |
| # ComfyUI 设置 | |
| SERVER_ADDRESS = "paint.aixiao.xyz" | |
| CLIENT_ID = str(uuid.uuid4()) | |
| # 从环境变量中获取阿里云 OSS 配置信息 | |
| access_key_id = os.getenv("OSS_ACCESS_KEY_ID") | |
| access_key_secret = os.getenv("OSS_ACCESS_KEY_SECRET") | |
| bucket_name = os.getenv("OSS_BUCKET_NAME") | |
| endpoint = os.getenv("OSS_ENDPOINT") | |
| # MongoDB configuration | |
| # 从环境变量中获取 MongoDB 配置信息 | |
| uri = os.getenv("MONGO_URI") | |
| client = MongoClient(uri, server_api=ServerApi('1')) | |
| # Create Aliyun OSS bucket object | |
| bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name) | |
| try: | |
| client.admin.command('ping') | |
| print("Successfully connected to MongoDB!") | |
| except Exception as e: | |
| print(f"Failed to connect to MongoDB: {e}") | |
| exit(1) | |
| db = client['ai_image_generator'] | |
| images_collection = db['images'] | |
| # ComfyUI workflow (省略具体内容) | |
| WORKFLOW = { | |
| "3": { | |
| "inputs": { | |
| "seed": 1048756903667323, | |
| "steps": 20, | |
| "cfg": 8, | |
| "sampler_name": "euler", | |
| "scheduler": "normal", | |
| "denoise": 1, | |
| "model": ["4", 0], | |
| "positive": ["6", 0], | |
| "negative": ["7", 0], | |
| "latent_image": ["5", 0] | |
| }, | |
| "class_type": "KSampler" | |
| }, | |
| "4": { | |
| "inputs": { | |
| "ckpt_name": "sd_xl_base_1.0.safetensors" | |
| }, | |
| "class_type": "CheckpointLoaderSimple" | |
| }, | |
| "5": { | |
| "inputs": { | |
| "width": 512, | |
| "height": 512, | |
| "batch_size": 1 | |
| }, | |
| "class_type": "EmptyLatentImage" | |
| }, | |
| "6": { | |
| "inputs": { | |
| "text": "", | |
| "clip": ["4", 1] | |
| }, | |
| "class_type": "CLIPTextEncode" | |
| }, | |
| "7": { | |
| "inputs": { | |
| "text": "text, watermark", | |
| "clip": ["4", 1] | |
| }, | |
| "class_type": "CLIPTextEncode" | |
| }, | |
| "8": { | |
| "inputs": { | |
| "samples": ["3", 0], | |
| "vae": ["4", 2] | |
| }, | |
| "class_type": "VAEDecode" | |
| }, | |
| "9": { | |
| "inputs": { | |
| "filename_prefix": "ComfyUI", | |
| "images": ["8", 0] | |
| }, | |
| "class_type": "SaveImage" | |
| } | |
| } | |
| def queue_prompt(prompt): | |
| p = {"prompt": prompt, "client_id": CLIENT_ID} | |
| data = json.dumps(p).encode('utf-8') | |
| req = urllib.request.Request(f"http://{SERVER_ADDRESS}/prompt", data=data) | |
| return json.loads(urllib.request.urlopen(req).read()) | |
| def get_image(filename, subfolder, folder_type): | |
| data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | |
| url_values = urllib.parse.urlencode(data) | |
| with urllib.request.urlopen(f"http://{SERVER_ADDRESS}/view?{url_values}") as response: | |
| return response.read() | |
| def get_history(prompt_id): | |
| with urllib.request.urlopen(f"http://{SERVER_ADDRESS}/history/{prompt_id}") as response: | |
| return json.loads(response.read()) | |
| def get_images(ws, prompt): | |
| prompt_id = queue_prompt(prompt)['prompt_id'] | |
| print(f'Prompt ID: {prompt_id}') | |
| while True: | |
| out = ws.recv() | |
| if isinstance(out, str): | |
| message = json.loads(out) | |
| if message['type'] == 'executing': | |
| data = message['data'] | |
| if data['node'] is None and data['prompt_id'] == prompt_id: | |
| print('Execution completed') | |
| break | |
| else: | |
| continue # Ignore binary data (previews) | |
| history = get_history(prompt_id)[prompt_id] | |
| output_images = {} | |
| for node_id, node_output in history['outputs'].items(): | |
| if 'images' in node_output: | |
| images_output = [] | |
| for image in node_output['images']: | |
| image_data = get_image(image['filename'], image['subfolder'], image['type']) | |
| images_output.append(image_data) | |
| output_images[node_id] = images_output | |
| return output_images | |
| def translate_to_english(text): | |
| url = os.getenv("DEEP_URI") | |
| payload = json.dumps({ | |
| "text": text, | |
| "source_lang": "auto", | |
| "target_lang": "EN" | |
| }) | |
| headers = { | |
| 'Content-Type': 'application/json' | |
| } | |
| try: | |
| response = requests.post(url, headers=headers, data=payload) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result.get('data', text) | |
| except requests.RequestException as e: | |
| print(f"翻译请求失败: {e}") | |
| return text | |
| def generate_image(): | |
| prompt = request.json['prompt'] | |
| english_prompt = translate_to_english(prompt) | |
| print(f"Original prompt: {prompt}") | |
| print(f"Translated prompt: {english_prompt}") | |
| ws = websocket.create_connection(f"ws://{SERVER_ADDRESS}/ws?clientId={CLIENT_ID}") | |
| workflow = WORKFLOW.copy() | |
| workflow["6"]["inputs"]["text"] = english_prompt | |
| images = get_images(ws, workflow) | |
| ws.close() | |
| if images: | |
| image_data = list(images.values())[0][0] | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| filename = f"generated_image_{timestamp}.png" | |
| oss_path = f"images/{filename}" | |
| # Upload to Aliyun OSS | |
| bucket.put_object(oss_path, image_data) | |
| # Get the public URL | |
| image_url = f"https://{bucket_name}.{endpoint}/{oss_path}" | |
| # Save to MongoDB | |
| image_doc = { | |
| "prompt": prompt, | |
| "english_prompt": english_prompt, | |
| "url": image_url, | |
| "filename": filename, | |
| "created_at": datetime.utcnow(), | |
| "is_public": False | |
| } | |
| result = images_collection.insert_one(image_doc) | |
| return jsonify({ | |
| "status": "success", | |
| "filename": filename, | |
| "url": image_url, | |
| "id": str(result.inserted_id) | |
| }) | |
| else: | |
| return jsonify({"status": "error", "message": "Failed to generate image"}) | |
| def add_to_public_gallery(): | |
| image_id = request.json['image_id'] | |
| # Update the image document in MongoDB | |
| result = images_collection.update_one( | |
| {"_id": ObjectId(image_id)}, | |
| {"$set": {"is_public": True}} | |
| ) | |
| if result.modified_count > 0: | |
| return jsonify({"status": "success", "message": "Image added to public gallery"}) | |
| else: | |
| return jsonify({"status": "error", "message": "Failed to add image to public gallery"}) | |
| def get_gallery_images(): | |
| # Temporarily return all images, regardless of is_public status | |
| all_images = list(images_collection.find().sort("created_at", -1).limit(20)) | |
| # Convert ObjectId to string for JSON serialization | |
| for image in all_images: | |
| image['_id'] = str(image['_id']) | |
| print(f"Returning {len(all_images)} images") # Add this line for debugging | |
| return jsonify(all_images) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860, debug=True) | |