| import base64 |
| import io |
| import json |
| import logging |
| import pathlib |
| import time |
| import tempfile |
| import os |
|
|
| from datetime import datetime |
|
|
| import requests |
| import tiktoken |
| from PIL import Image |
|
|
| from modules.config import retrieve_proxy |
| from modules.models.models import XMChat |
|
|
| mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE") |
| mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL") |
| mj_temp_folder = os.getenv("MIDJOURNEY_TEMP_FOLDER") |
|
|
|
|
| class Midjourney_Client(XMChat): |
|
|
| class FetchDataPack: |
| """ |
| A class to store data for current fetching data from Midjourney API |
| """ |
|
|
| action: str |
| prefix_content: str |
| task_id: str |
| start_time: float |
| timeout: int |
| finished: bool |
| prompt: str |
|
|
| def __init__(self, action, prefix_content, task_id, timeout=900): |
| self.action = action |
| self.prefix_content = prefix_content |
| self.task_id = task_id |
| self.start_time = time.time() |
| self.timeout = timeout |
| self.finished = False |
|
|
| def __init__(self, model_name, api_key, user_name=""): |
| super().__init__(api_key, user_name) |
| self.model_name = model_name |
| self.history = [] |
| self.api_key = api_key |
| self.headers = { |
| "Content-Type": "application/json", |
| "mj-api-secret": f"{api_key}" |
| } |
| self.proxy_url = mj_proxy_api_base |
| self.command_splitter = "::" |
|
|
| if mj_temp_folder: |
| temp = "./tmp" |
| if user_name: |
| temp = os.path.join(temp, user_name) |
| if not os.path.exists(temp): |
| os.makedirs(temp) |
| self.temp_path = tempfile.mkdtemp(dir=temp) |
| logging.info("mj temp folder: " + self.temp_path) |
| else: |
| self.temp_path = None |
|
|
| def use_mj_self_proxy_url(self, img_url): |
| """ |
| replace discord cdn url with mj self proxy url |
| """ |
| return img_url.replace( |
| "https://cdn.discordapp.com/", |
| mj_discord_proxy_url and mj_discord_proxy_url or "https://cdn.discordapp.com/" |
| ) |
|
|
| def split_image(self, image_url): |
| """ |
| when enabling temp dir, split image into 4 parts |
| """ |
| with retrieve_proxy(): |
| image_bytes = requests.get(image_url).content |
| img = Image.open(io.BytesIO(image_bytes)) |
| width, height = img.size |
| |
| half_width = width // 2 |
| half_height = height // 2 |
| |
| coordinates = [(0, 0, half_width, half_height), |
| (half_width, 0, width, half_height), |
| (0, half_height, half_width, height), |
| (half_width, half_height, width, height)] |
|
|
| images = [img.crop(c) for c in coordinates] |
| return images |
|
|
| def auth_mj(self): |
| """ |
| auth midjourney api |
| """ |
| |
| return {'status': 'ok'} |
|
|
| def request_mj(self, path: str, action: str, data: str, retries=3): |
| """ |
| request midjourney api |
| """ |
| mj_proxy_url = self.proxy_url |
| if mj_proxy_url is None or not (mj_proxy_url.startswith("http://") or mj_proxy_url.startswith("https://")): |
| raise Exception('please set MIDJOURNEY_PROXY_API_BASE in ENV or in config.json') |
|
|
| auth_ = self.auth_mj() |
| if auth_.get('error'): |
| raise Exception('auth not set') |
|
|
| fetch_url = f"{mj_proxy_url}/{path}" |
| |
|
|
| for _ in range(retries): |
| try: |
| with retrieve_proxy(): |
| res = requests.request(method=action, url=fetch_url, headers=self.headers, data=data) |
| break |
| except Exception as e: |
| print(e) |
|
|
| if res.status_code != 200: |
| raise Exception(f'{res.status_code} - {res.content}') |
|
|
| return res |
|
|
| def fetch_status(self, fetch_data: FetchDataPack): |
| """ |
| fetch status of current task |
| """ |
| if fetch_data.start_time + fetch_data.timeout < time.time(): |
| fetch_data.finished = True |
| return "任务超时,请检查 dc 输出。描述:" + fetch_data.prompt |
|
|
| time.sleep(3) |
| status_res = self.request_mj(f"task/{fetch_data.task_id}/fetch", "GET", '') |
| status_res_json = status_res.json() |
| if not (200 <= status_res.status_code < 300): |
| raise Exception("任务状态获取失败:" + status_res_json.get( |
| 'error') or status_res_json.get('description') or '未知错误') |
| else: |
| fetch_data.finished = False |
| if status_res_json['status'] == "SUCCESS": |
| content = status_res_json['imageUrl'] |
| fetch_data.finished = True |
| elif status_res_json['status'] == "FAILED": |
| content = status_res_json['failReason'] or '未知原因' |
| fetch_data.finished = True |
| elif status_res_json['status'] == "NOT_START": |
| content = f'任务未开始,已等待 {time.time() - fetch_data.start_time:.2f} 秒' |
| elif status_res_json['status'] == "IN_PROGRESS": |
| content = '任务正在运行' |
| if status_res_json.get('progress'): |
| content += f",进度:{status_res_json['progress']}" |
| elif status_res_json['status'] == "SUBMITTED": |
| content = '任务已提交处理' |
| elif status_res_json['status'] == "FAILURE": |
| fetch_data.finished = True |
| return "任务处理失败,原因:" + status_res_json['failReason'] or '未知原因' |
| else: |
| content = status_res_json['status'] |
| if fetch_data.finished: |
| img_url = self.use_mj_self_proxy_url(status_res_json['imageUrl']) |
| if fetch_data.action == "DESCRIBE": |
| return f"\n{status_res_json['prompt']}" |
| time_cost_str = f"\n\n{fetch_data.action} 花费时间:{time.time() - fetch_data.start_time:.2f} 秒" |
| upscale_str = "" |
| variation_str = "" |
| if fetch_data.action in ["IMAGINE", "UPSCALE", "VARIATION"]: |
| upscale = [f'/mj UPSCALE{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}' |
| for i in range(4)] |
| upscale_str = '\n放大图片:\n\n' + '\n\n'.join(upscale) |
| variation = [f'/mj VARIATION{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}' |
| for i in range(4)] |
| variation_str = '\n图片变体:\n\n' + '\n\n'.join(variation) |
| if self.temp_path and fetch_data.action in ["IMAGINE", "VARIATION"]: |
| try: |
| images = self.split_image(img_url) |
| |
| for i in range(4): |
| images[i].save(pathlib.Path(self.temp_path) / f"{fetch_data.task_id}_{i}.png") |
| img_str = '\n'.join( |
| [f"" |
| for i in range(4)]) |
| return fetch_data.prefix_content + f"{time_cost_str}\n\n{img_str}{upscale_str}{variation_str}" |
| except Exception as e: |
| logging.error(e) |
| return fetch_data.prefix_content + \ |
| f"{time_cost_str}[]({img_url}){upscale_str}{variation_str}" |
| else: |
| content = f"**任务状态:** [{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - {content}" |
| content += f"\n\n花费时间:{time.time() - fetch_data.start_time:.2f} 秒" |
| if status_res_json['status'] == 'IN_PROGRESS' and status_res_json.get('imageUrl'): |
| img_url = status_res_json.get('imageUrl') |
| return f"{content}\n[]({img_url})" |
| return content |
| return None |
|
|
| def handle_file_upload(self, files, chatbot, language): |
| """ |
| handle file upload |
| """ |
| if files: |
| for file in files: |
| if file.name: |
| logging.info(f"尝试读取图像: {file.name}") |
| self.try_read_image(file.name) |
| if self.image_path is not None: |
| chatbot = chatbot + [((self.image_path,), None)] |
| if self.image_bytes is not None: |
| logging.info("使用图片作为输入") |
| return None, chatbot, None |
|
|
| def reset(self): |
| self.image_bytes = None |
| self.image_path = None |
| return [], "已重置" |
|
|
| def get_answer_at_once(self): |
| content = self.history[-1]['content'] |
| answer = self.get_help() |
|
|
| if not content.lower().startswith("/mj"): |
| return answer, len(content) |
|
|
| prompt = content[3:].strip() |
| action = "IMAGINE" |
| first_split_index = prompt.find(self.command_splitter) |
| if first_split_index > 0: |
| action = prompt[:first_split_index] |
| if action not in ["IMAGINE", "DESCRIBE", "UPSCALE", |
| |
| ]: |
| raise Exception("任务提交失败:未知的任务类型") |
| else: |
| action_index = None |
| action_use_task_id = None |
| if action in ["VARIATION", "UPSCALE", "REROLL"]: |
| action_index = int(prompt[first_split_index + 2:first_split_index + 3]) |
| action_use_task_id = prompt[first_split_index + 5:] |
|
|
| try: |
| res = None |
| if action == "IMAGINE": |
| data = { |
| "prompt": prompt |
| } |
| if self.image_bytes is not None: |
| data["base64"] = 'data:image/png;base64,' + self.image_bytes |
| res = self.request_mj("submit/imagine", "POST", |
| json.dumps(data)) |
| elif action == "DESCRIBE": |
| res = self.request_mj("submit/describe", "POST", |
| json.dumps({"base64": 'data:image/png;base64,' + self.image_bytes})) |
| elif action == "BLEND": |
| res = self.request_mj("submit/blend", "POST", json.dumps( |
| {"base64Array": [self.image_bytes, self.image_bytes]})) |
| elif action in ["UPSCALE", "VARIATION", "REROLL"]: |
| res = self.request_mj( |
| "submit/change", "POST", |
| json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id})) |
| res_json = res.json() |
| if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]): |
| answer = "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误')) |
| else: |
| task_id = res_json['result'] |
| prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n" |
|
|
| fetch_data = Midjourney_Client.FetchDataPack( |
| action=action, |
| prefix_content=prefix_content, |
| task_id=task_id, |
| ) |
| fetch_data.prompt = prompt |
| while not fetch_data.finished: |
| answer = self.fetch_status(fetch_data) |
| except Exception as e: |
| logging.error("submit failed", e) |
| answer = "任务提交错误:" + str(e.args[0]) if e.args else '未知错误' |
|
|
| return answer, tiktoken.get_encoding("cl100k_base").encode(content) |
|
|
| def get_answer_stream_iter(self): |
| content = self.history[-1]['content'] |
| answer = self.get_help() |
|
|
| if not content.lower().startswith("/mj"): |
| yield answer |
| return |
|
|
| prompt = content[3:].strip() |
| action = "IMAGINE" |
| first_split_index = prompt.find(self.command_splitter) |
| if first_split_index > 0: |
| action = prompt[:first_split_index] |
| if action not in ["IMAGINE", "DESCRIBE", "UPSCALE", |
| "VARIATION", "BLEND", "REROLL" |
| ]: |
| yield "任务提交失败:未知的任务类型" |
| return |
|
|
| action_index = None |
| action_use_task_id = None |
| if action in ["VARIATION", "UPSCALE", "REROLL"]: |
| action_index = int(prompt[first_split_index + 2:first_split_index + 3]) |
| action_use_task_id = prompt[first_split_index + 5:] |
|
|
| try: |
| res = None |
| if action == "IMAGINE": |
| data = { |
| "prompt": prompt |
| } |
| if self.image_bytes is not None: |
| data["base64"] = 'data:image/png;base64,' + self.image_bytes |
| res = self.request_mj("submit/imagine", "POST", |
| json.dumps(data)) |
| elif action == "DESCRIBE": |
| res = self.request_mj("submit/describe", "POST", json.dumps( |
| {"base64": 'data:image/png;base64,' + self.image_bytes})) |
| elif action == "BLEND": |
| res = self.request_mj("submit/blend", "POST", json.dumps( |
| {"base64Array": [self.image_bytes, self.image_bytes]})) |
| elif action in ["UPSCALE", "VARIATION", "REROLL"]: |
| res = self.request_mj( |
| "submit/change", "POST", |
| json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id})) |
| res_json = res.json() |
| if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]): |
| yield "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误')) |
| else: |
| task_id = res_json['result'] |
| prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n" |
| content = f"[{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - 任务提交成功:" + \ |
| res_json.get('description') or '请稍等片刻' |
| yield content |
|
|
| fetch_data = Midjourney_Client.FetchDataPack( |
| action=action, |
| prefix_content=prefix_content, |
| task_id=task_id, |
| ) |
| while not fetch_data.finished: |
| yield self.fetch_status(fetch_data) |
| except Exception as e: |
| logging.error('submit failed', e) |
| yield "任务提交错误:" + str(e.args[0]) if e.args else '未知错误' |
|
|
| def get_help(self): |
| return """``` |
| 【绘图帮助】 |
| 所有命令都需要以 /mj 开头,如:/mj a dog |
| IMAGINE - 绘图,可以省略该命令,后面跟上绘图内容 |
| /mj a dog |
| /mj IMAGINE::a cat |
| DESCRIBE - 描述图片,需要在右下角上传需要描述的图片内容 |
| /mj DESCRIBE:: |
| UPSCALE - 确认后放大图片,第一个数值为需要放大的图片(1~4),第二参数为任务ID |
| /mj UPSCALE::1::123456789 |
| 请使用SD进行UPSCALE |
| VARIATION - 图片变体,第一个数值为需要放大的图片(1~4),第二参数为任务ID |
| /mj VARIATION::1::123456789 |
| |
| 【绘图参数】 |
| 所有命令默认会带上参数--v 5.2 |
| 其他参数参照 https://docs.midjourney.com/docs/parameter-list |
| 长宽比 --aspect/--ar |
| --ar 1:2 |
| --ar 16:9 |
| 负面tag --no |
| --no plants |
| --no hands |
| 随机种子 --seed |
| --seed 1 |
| 生成动漫风格(NijiJourney) --niji |
| --niji |
| ``` |
| """ |
|
|