mistpe commited on
Commit
1be96fc
·
verified ·
1 Parent(s): d3e1ba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -277
app.py CHANGED
@@ -1,256 +1,3 @@
1
- # import os
2
- # import oss2
3
- # from pymongo import MongoClient
4
- # from pymongo.server_api import ServerApi
5
- # import bcrypt
6
- # from datetime import datetime
7
- # from flask import Flask, request, jsonify
8
- # from flask_cors import CORS
9
- # import json
10
- # import websocket
11
- # import uuid
12
- # import urllib.request
13
- # import urllib.parse
14
- # import requests
15
- # from bson.objectid import ObjectId
16
- # from dotenv import load_dotenv
17
-
18
- # # 加载 .env 文件中的环境变量
19
- # load_dotenv()
20
-
21
- # app = Flask(__name__)
22
- # CORS(app)
23
-
24
- # # ComfyUI 设置
25
- # SERVER_ADDRESS = "paint.aixiao.xyz"
26
- # CLIENT_ID = str(uuid.uuid4())
27
-
28
- # # 从环境变量中获取阿里云 OSS 配置信息
29
- # access_key_id = os.getenv("OSS_ACCESS_KEY_ID")
30
- # access_key_secret = os.getenv("OSS_ACCESS_KEY_SECRET")
31
- # bucket_name = os.getenv("OSS_BUCKET_NAME")
32
- # endpoint = os.getenv("OSS_ENDPOINT")
33
-
34
- # # MongoDB configuration
35
- # # 从环境变量中获取 MongoDB 配置信息
36
- # uri = os.getenv("MONGO_URI")
37
- # client = MongoClient(uri, server_api=ServerApi('1'))
38
-
39
- # # Create Aliyun OSS bucket object
40
- # bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
41
-
42
- # try:
43
- # client.admin.command('ping')
44
- # print("Successfully connected to MongoDB!")
45
- # except Exception as e:
46
- # print(f"Failed to connect to MongoDB: {e}")
47
- # exit(1)
48
-
49
- # db = client['ai_image_generator']
50
- # images_collection = db['images']
51
-
52
- # # ComfyUI workflow (省略具体内容)
53
- # WORKFLOW = {
54
- # "3": {
55
- # "inputs": {
56
- # "seed": 1048756903667323,
57
- # "steps": 20,
58
- # "cfg": 8,
59
- # "sampler_name": "euler",
60
- # "scheduler": "normal",
61
- # "denoise": 1,
62
- # "model": ["4", 0],
63
- # "positive": ["6", 0],
64
- # "negative": ["7", 0],
65
- # "latent_image": ["5", 0]
66
- # },
67
- # "class_type": "KSampler"
68
- # },
69
- # "4": {
70
- # "inputs": {
71
- # "ckpt_name": "sd_xl_base_1.0.safetensors"
72
- # },
73
- # "class_type": "CheckpointLoaderSimple"
74
- # },
75
- # "5": {
76
- # "inputs": {
77
- # "width": 512,
78
- # "height": 512,
79
- # "batch_size": 1
80
- # },
81
- # "class_type": "EmptyLatentImage"
82
- # },
83
- # "6": {
84
- # "inputs": {
85
- # "text": "",
86
- # "clip": ["4", 1]
87
- # },
88
- # "class_type": "CLIPTextEncode"
89
- # },
90
- # "7": {
91
- # "inputs": {
92
- # "text": "text, watermark",
93
- # "clip": ["4", 1]
94
- # },
95
- # "class_type": "CLIPTextEncode"
96
- # },
97
- # "8": {
98
- # "inputs": {
99
- # "samples": ["3", 0],
100
- # "vae": ["4", 2]
101
- # },
102
- # "class_type": "VAEDecode"
103
- # },
104
- # "9": {
105
- # "inputs": {
106
- # "filename_prefix": "ComfyUI",
107
- # "images": ["8", 0]
108
- # },
109
- # "class_type": "SaveImage"
110
- # }
111
- # }
112
-
113
- # def queue_prompt(prompt):
114
- # p = {"prompt": prompt, "client_id": CLIENT_ID}
115
- # data = json.dumps(p).encode('utf-8')
116
- # req = urllib.request.Request(f"http://{SERVER_ADDRESS}/prompt", data=data)
117
- # return json.loads(urllib.request.urlopen(req).read())
118
-
119
- # def get_image(filename, subfolder, folder_type):
120
- # data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
121
- # url_values = urllib.parse.urlencode(data)
122
- # with urllib.request.urlopen(f"http://{SERVER_ADDRESS}/view?{url_values}") as response:
123
- # return response.read()
124
-
125
- # def get_history(prompt_id):
126
- # with urllib.request.urlopen(f"http://{SERVER_ADDRESS}/history/{prompt_id}") as response:
127
- # return json.loads(response.read())
128
-
129
- # def get_images(ws, prompt):
130
- # prompt_id = queue_prompt(prompt)['prompt_id']
131
- # print(f'Prompt ID: {prompt_id}')
132
-
133
- # while True:
134
- # out = ws.recv()
135
- # if isinstance(out, str):
136
- # message = json.loads(out)
137
- # if message['type'] == 'executing':
138
- # data = message['data']
139
- # if data['node'] is None and data['prompt_id'] == prompt_id:
140
- # print('Execution completed')
141
- # break
142
- # else:
143
- # continue # Ignore binary data (previews)
144
-
145
- # history = get_history(prompt_id)[prompt_id]
146
- # output_images = {}
147
-
148
- # for node_id, node_output in history['outputs'].items():
149
- # if 'images' in node_output:
150
- # images_output = []
151
- # for image in node_output['images']:
152
- # image_data = get_image(image['filename'], image['subfolder'], image['type'])
153
- # images_output.append(image_data)
154
- # output_images[node_id] = images_output
155
-
156
- # return output_images
157
-
158
- # def translate_to_english(text):
159
- # url = os.getenv("DEEP_URI")
160
- # payload = json.dumps({
161
- # "text": text,
162
- # "source_lang": "auto",
163
- # "target_lang": "EN"
164
- # })
165
- # headers = {
166
- # 'Content-Type': 'application/json'
167
- # }
168
- # try:
169
- # response = requests.post(url, headers=headers, data=payload)
170
- # response.raise_for_status()
171
- # result = response.json()
172
- # return result.get('data', text)
173
- # except requests.RequestException as e:
174
- # print(f"翻译请求失败: {e}")
175
- # return text
176
-
177
- # @app.route('/generate', methods=['POST'])
178
- # def generate_image():
179
- # prompt = request.json['prompt']
180
-
181
- # english_prompt = translate_to_english(prompt)
182
- # print(f"Original prompt: {prompt}")
183
- # print(f"Translated prompt: {english_prompt}")
184
-
185
- # ws = websocket.create_connection(f"ws://{SERVER_ADDRESS}/ws?clientId={CLIENT_ID}")
186
-
187
- # workflow = WORKFLOW.copy()
188
- # workflow["6"]["inputs"]["text"] = english_prompt
189
-
190
- # images = get_images(ws, workflow)
191
- # ws.close()
192
-
193
- # if images:
194
- # image_data = list(images.values())[0][0]
195
- # timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
196
- # filename = f"generated_image_{timestamp}.png"
197
- # oss_path = f"images/{filename}"
198
-
199
- # # Upload to Aliyun OSS
200
- # bucket.put_object(oss_path, image_data)
201
-
202
- # # Get the public URL
203
- # image_url = f"https://{bucket_name}.{endpoint}/{oss_path}"
204
-
205
- # # Save to MongoDB
206
- # image_doc = {
207
- # "prompt": prompt,
208
- # "english_prompt": english_prompt,
209
- # "url": image_url,
210
- # "filename": filename,
211
- # "created_at": datetime.utcnow(),
212
- # "is_public": False
213
- # }
214
- # result = images_collection.insert_one(image_doc)
215
-
216
- # return jsonify({
217
- # "status": "success",
218
- # "filename": filename,
219
- # "url": image_url,
220
- # "id": str(result.inserted_id)
221
- # })
222
- # else:
223
- # return jsonify({"status": "error", "message": "Failed to generate image"})
224
-
225
- # @app.route('/add-to-public-gallery', methods=['POST'])
226
- # def add_to_public_gallery():
227
- # image_id = request.json['image_id']
228
-
229
- # # Update the image document in MongoDB
230
- # result = images_collection.update_one(
231
- # {"_id": ObjectId(image_id)},
232
- # {"$set": {"is_public": True}}
233
- # )
234
-
235
- # if result.modified_count > 0:
236
- # return jsonify({"status": "success", "message": "Image added to public gallery"})
237
- # else:
238
- # return jsonify({"status": "error", "message": "Failed to add image to public gallery"})
239
-
240
- # @app.route('/api/gallery', methods=['GET'])
241
- # def get_gallery_images():
242
- # # Temporarily return all images, regardless of is_public status
243
- # all_images = list(images_collection.find().sort("created_at", -1).limit(20))
244
-
245
- # # Convert ObjectId to string for JSON serialization
246
- # for image in all_images:
247
- # image['_id'] = str(image['_id'])
248
-
249
- # print(f"Returning {len(all_images)} images") # Add this line for debugging
250
- # return jsonify(all_images)
251
-
252
- # if __name__ == '__main__':
253
- # app.run(host='0.0.0.0', port=7860, debug=True)
254
  import os
255
  import oss2
256
  from pymongo import MongoClient
@@ -260,7 +7,7 @@ from datetime import datetime
260
  from flask import Flask, request, jsonify
261
  from flask_cors import CORS
262
  import json
263
- from websocket import WebSocketApp
264
  import uuid
265
  import urllib.request
266
  import urllib.parse
@@ -366,43 +113,45 @@ WORKFLOW = {
366
  def queue_prompt(prompt):
367
  p = {"prompt": prompt, "client_id": CLIENT_ID}
368
  data = json.dumps(p).encode('utf-8')
369
- req = urllib.request.Request(f"https://{SERVER_ADDRESS}/prompt", data=data)
370
  return json.loads(urllib.request.urlopen(req).read())
371
 
372
  def get_image(filename, subfolder, folder_type):
373
  data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
374
  url_values = urllib.parse.urlencode(data)
375
- with urllib.request.urlopen(f"https://{SERVER_ADDRESS}/view?{url_values}") as response:
376
  return response.read()
377
 
378
  def get_history(prompt_id):
379
- with urllib.request.urlopen(f"https://{SERVER_ADDRESS}/history/{prompt_id}") as response:
380
  return json.loads(response.read())
381
 
382
  def get_images(ws, prompt):
383
  prompt_id = queue_prompt(prompt)['prompt_id']
384
  print(f'Prompt ID: {prompt_id}')
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  output_images = {}
387
-
388
- def on_message(ws, message):
389
- nonlocal output_images
390
- if isinstance(message, str):
391
- data = json.loads(message)
392
- if data['type'] == 'executing':
393
- if data['data']['node'] is None and data['data']['prompt_id'] == prompt_id:
394
- ws.close()
395
- elif data['type'] == 'executed':
396
- node_id = data['data']['node']
397
- if 'images' in data['data']['output']:
398
- images_output = []
399
- for image in data['data']['output']['images']:
400
- image_data = get_image(image['filename'], image['subfolder'], image['type'])
401
- images_output.append(image_data)
402
- output_images[node_id] = images_output
403
 
404
- ws.on_message = on_message
405
- ws.run_forever()
 
 
 
 
 
406
 
407
  return output_images
408
 
@@ -433,12 +182,13 @@ def generate_image():
433
  print(f"Original prompt: {prompt}")
434
  print(f"Translated prompt: {english_prompt}")
435
 
436
- ws = WebSocketApp(f"wss://{SERVER_ADDRESS}/ws?clientId={CLIENT_ID}")
437
 
438
  workflow = WORKFLOW.copy()
439
  workflow["6"]["inputs"]["text"] = english_prompt
440
 
441
  images = get_images(ws, workflow)
 
442
 
443
  if images:
444
  image_data = list(images.values())[0][0]
@@ -500,4 +250,4 @@ def get_gallery_images():
500
  return jsonify(all_images)
501
 
502
  if __name__ == '__main__':
503
- app.run(host='0.0.0.0', port=7860, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import oss2
3
  from pymongo import MongoClient
 
7
  from flask import Flask, request, jsonify
8
  from flask_cors import CORS
9
  import json
10
+ import websocket
11
  import uuid
12
  import urllib.request
13
  import urllib.parse
 
113
  def queue_prompt(prompt):
114
  p = {"prompt": prompt, "client_id": CLIENT_ID}
115
  data = json.dumps(p).encode('utf-8')
116
+ req = urllib.request.Request(f"http://{SERVER_ADDRESS}/prompt", data=data)
117
  return json.loads(urllib.request.urlopen(req).read())
118
 
119
  def get_image(filename, subfolder, folder_type):
120
  data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
121
  url_values = urllib.parse.urlencode(data)
122
+ with urllib.request.urlopen(f"http://{SERVER_ADDRESS}/view?{url_values}") as response:
123
  return response.read()
124
 
125
  def get_history(prompt_id):
126
+ with urllib.request.urlopen(f"http://{SERVER_ADDRESS}/history/{prompt_id}") as response:
127
  return json.loads(response.read())
128
 
129
  def get_images(ws, prompt):
130
  prompt_id = queue_prompt(prompt)['prompt_id']
131
  print(f'Prompt ID: {prompt_id}')
132
 
133
+ while True:
134
+ out = ws.recv()
135
+ if isinstance(out, str):
136
+ message = json.loads(out)
137
+ if message['type'] == 'executing':
138
+ data = message['data']
139
+ if data['node'] is None and data['prompt_id'] == prompt_id:
140
+ print('Execution completed')
141
+ break
142
+ else:
143
+ continue # Ignore binary data (previews)
144
+
145
+ history = get_history(prompt_id)[prompt_id]
146
  output_images = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ for node_id, node_output in history['outputs'].items():
149
+ if 'images' in node_output:
150
+ images_output = []
151
+ for image in node_output['images']:
152
+ image_data = get_image(image['filename'], image['subfolder'], image['type'])
153
+ images_output.append(image_data)
154
+ output_images[node_id] = images_output
155
 
156
  return output_images
157
 
 
182
  print(f"Original prompt: {prompt}")
183
  print(f"Translated prompt: {english_prompt}")
184
 
185
+ ws = websocket.create_connection(f"ws://{SERVER_ADDRESS}/ws?clientId={CLIENT_ID}")
186
 
187
  workflow = WORKFLOW.copy()
188
  workflow["6"]["inputs"]["text"] = english_prompt
189
 
190
  images = get_images(ws, workflow)
191
+ ws.close()
192
 
193
  if images:
194
  image_data = list(images.values())[0][0]
 
250
  return jsonify(all_images)
251
 
252
  if __name__ == '__main__':
253
+ app.run(host='0.0.0.0', port=7860, debug=True)