diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..edb45cf22255b187f0d56ca3f2664b06b4fe9278 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM maven:3.8.5-openjdk-17 + +ARG user=spring +ARG group=spring + +ENV SPRING_HOME=/home/spring + +RUN groupadd -g 1000 ${group} \ + && useradd -d "$SPRING_HOME" -u 1000 -g 1000 -m -s /bin/bash ${user} \ + && mkdir -p $SPRING_HOME/config \ + && mkdir -p $SPRING_HOME/logs \ + && chown -R ${user}:${group} $SPRING_HOME/config $SPRING_HOME/logs + +# Railway 不支持使用 VOLUME, 本地需要构建时,取消下一行的注释 +# VOLUME ["$SPRING_HOME/config", "$SPRING_HOME/logs"] + +USER ${user} +WORKDIR $SPRING_HOME + +COPY . . + +RUN mvn clean package \ + && mv target/midjourney-proxy-*.jar ./app.jar \ + && rm -rf target + +EXPOSE 8080 9876 + +ENV JAVA_OPTS -XX:MaxRAMPercentage=85 -Djava.awt.headless=true -XX:+HeapDumpOnOutOfMemoryError \ + -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -Xlog:gc:file=/home/spring/logs/gc.log \ + -Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.port=9876 -Dcom.sun.management.jmxremote.ssl=false \ + -Dcom.sun.management.jmxremote.authenticate=false -Dlogging.file.path=/home/spring/logs \ + -Dserver.port=8080 -Duser.timezone=Asia/Shanghai + +ENTRYPOINT ["bash","-c","java $JAVA_OPTS -jar app.jar"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README3333.md b/README3333.md new file mode 100644 index 0000000000000000000000000000000000000000..a36f76c58afa091bf8df39d87064e7105df2107f --- /dev/null +++ b/README3333.md @@ -0,0 +1,79 @@ +# midjourney-proxy + +代理 MidJourney 的discord频道,实现api形式调用AI绘图 + +[![GitHub release](https://img.shields.io/static/v1?label=release&message=v2.5&color=blue)](https://www.github.com/novicezk/midjourney-proxy) +[![License](https://img.shields.io/badge/license-Apache%202-4EB1BA.svg)](https://www.apache.org/licenses/LICENSE-2.0.html) + +## 主要功能 +- [x] 支持 Imagine 指令和相关动作 +- [x] Imagine 时支持添加图片base64,作为垫图 +- [x] 支持 Blend(图片混合)、Describe(图生文) 指令 +- [x] 支持任务实时进度 +- [x] 支持中英文翻译,需配置百度翻译或gpt +- [x] prompt 敏感词判断,支持覆盖调整 +- [x] user-token 连接 wss,可以获取错误信息和完整功能 +- [x] 支持 discord域名(server、cdn、wss)反代,配置 mj.ng-discord +- [x] 支持多账号配置,每个账号可设置对应的任务队列 + +**🚀 更多功能请查看 [midjourney-proxy-plus](https://github.com/litter-coder/midjourney-proxy-plus)** +> - [x] 支持开源版的所有功能 +> - [x] 支持 Shorten(prompt分析) 指令 +> - [x] 支持焦点移动: Pan ⬅️ ➡️ ⬆️ ⬇️ +> - [x] 支持图片变焦: Zoom 🔍 +> - [x] 支持局部重绘: Vary (Region) 🖌 +> - [x] 支持几乎所有的关联按钮动作和🎛️ Remix模式 +> - [x] 支持获取图片的seed值 +> - [x] 中英文翻译额外支持deepl +> - [x] 账号池持久化,动态维护 +> - [x] 支持获取账号/info、/settings信息 +> - [x] 内嵌管理后台页面 + +## 使用前提 +1. 注册并订阅 MidJourney,创建自己的频道,参考 https://docs.midjourney.com/docs/quick-start +2. 获取用户Token、服务器ID、频道ID:[获取方式](./docs/discord-params.md) + +## 快速启动 +1. `Railway`: 基于Railway平台,不需要自己的服务器: [部署方式](./docs/railway-start.md);若Railway不能使用,可使用Zeabur启动 +2. `Zeabur`: 基于Zeabur平台,不需要自己的服务器: [部署方式](./docs/zeabur-start.md) +3. `Docker`: 在服务器或本地使用Docker启动: [部署方式](./docs/docker-start.md) + +## 本地开发 +- 依赖java17和maven +- 更改配置项: 修改src/main/application.yml +- 项目运行: 启动ProxyApplication的main函数 +- 更改代码后,构建镜像: Dockerfile取消VOLUME的注释,执行 `docker build . -t midjourney-proxy` + +## 配置项 +- mj.accounts: 参考 [账号池配置](./docs/config.md#%E8%B4%A6%E5%8F%B7%E6%B1%A0%E9%85%8D%E7%BD%AE%E5%8F%82%E8%80%83) +- mj.task-store.type: 任务存储方式,默认in_memory(内存\重启后丢失),可选redis +- mj.task-store.timeout: 任务存储过期时间,过期后删除,默认30天 +- mj.api-secret: 接口密钥,为空不启用鉴权;调用接口时需要加请求头 mj-api-secret +- mj.translate-way: 中文prompt翻译成英文的方式,可选null(默认)、baidu、gpt、deepl +- 更多配置查看 [配置项](./docs/config.md) + +## 相关文档 +1. [API接口说明](./docs/api.md) +2. [版本更新记录](https://github.com/novicezk/midjourney-proxy/wiki/%E6%9B%B4%E6%96%B0%E8%AE%B0%E5%BD%95) + +## 注意事项 +1. 作图频繁等行为,可能会触发midjourney账号警告,请谨慎使用 +2. 常见问题及解决办法见 [Wiki / FAQ](https://github.com/novicezk/midjourney-proxy/wiki/FAQ) +3. 在 [Issues](https://github.com/novicezk/midjourney-proxy/issues) 中提出其他问题或建议 +4. 感兴趣的朋友也欢迎加入交流群讨论一下,扫码进群名额已满,加管理员微信邀请进群 + + 微信二维码 + +## 应用项目 +依赖此项目且开源的,欢迎联系作者,加到此处展示 +- [wechat-midjourney](https://github.com/novicezk/wechat-midjourney) : 代理微信客户端,接入MidJourney,仅示例应用场景,不再更新 +- [stable-diffusion-mobileui](https://github.com/yuanyuekeji/stable-diffusion-mobileui) : SDUI,基于本接口和SD,可一键打包生成H5和小程序 +- [ChatGPT-Midjourney](https://github.com/Licoy/ChatGPT-Midjourney) : 一键拥有你自己的 ChatGPT+Midjourney 网页服务 +- [MidJourney-Web](https://github.com/ConnectAI-E/MidJourney-Web) : 🍎 Supercharged Experience For MidJourney On Web UI + +## 其它 +如果觉得这个项目对你有所帮助,请帮忙点个star;也可以请作者喝杯茶~ + + 二维码 + +[![Star History Chart](https://api.star-history.com/svg?repos=novicezk/midjourney-proxy&type=Date)](https://star-history.com/#novicezk/midjourney-proxy&Date) diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..d6112a611ad9cf345b44777080c9f54cb02caa54 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,35 @@ +FROM openjdk:17.0 + +ARG user=spring +ARG group=spring + +ENV SPRING_HOME=/home/spring +ENV APP_HOME=$SPRING_HOME/app + +ENV JAVA_OPTS -XX:MaxRAMPercentage=85 -Djava.awt.headless=true -XX:+HeapDumpOnOutOfMemoryError \ + -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -Xlog:gc:file=/home/spring/logs/gc.log \ + -Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.port=9876 -Dcom.sun.management.jmxremote.ssl=false \ + -Dcom.sun.management.jmxremote.authenticate=false -Dlogging.file.path=/home/spring/logs \ + -Dserver.port=8080 -Duser.timezone=Asia/Shanghai + +RUN groupadd -g 1000 ${group} \ + && useradd -d "$SPRING_HOME" -u 1000 -g 1000 -m -s /bin/bash ${user} \ + && mkdir -p $SPRING_HOME/config \ + && mkdir -p $SPRING_HOME/logs \ + && mkdir -p $APP_HOME \ + && chown -R ${user}:${group} $SPRING_HOME/config $SPRING_HOME/logs $APP_HOME + +VOLUME ["$SPRING_HOME/config", "$SPRING_HOME/logs"] + +USER ${user} + +WORKDIR $SPRING_HOME + +EXPOSE 8080 9876 + +ENTRYPOINT ["bash","-c","java $JAVA_OPTS -cp ./app org.springframework.boot.loader.JarLauncher"] + +COPY --chown=${user}:${group} dependencies $APP_HOME/ +COPY --chown=${user}:${group} spring-boot-loader $APP_HOME/ +COPY --chown=${user}:${group} snapshot-dependencies $APP_HOME/ +COPY --chown=${user}:${group} application $APP_HOME/ diff --git a/docker/build-image.sh b/docker/build-image.sh new file mode 100644 index 0000000000000000000000000000000000000000..2499e9a1d634fd8db11293fe981da064f638edaa --- /dev/null +++ b/docker/build-image.sh @@ -0,0 +1,34 @@ +#!/bin/bash +set -e -u -o pipefail + +if [ $# -lt 1 ]; then + echo 'version is required' + exit 1 +fi + +VERSION=$1 +ARCH=amd64 + +if [ $# -ge 2 ]; then + ARCH=$2 +fi + +JAR_FILE_COUNT=$(find "../target/" -maxdepth 1 -name '*.jar' | wc -l) +if [ $JAR_FILE_COUNT == 0 ]; then + echo "jar file not found, please execute: mvn clean package" + exit 1 +fi + +JAR_FILE_NAME=$(ls ../target/*.jar|grep -v source) +echo ${JAR_FILE_NAME} + +cp ${JAR_FILE_NAME} ./app.jar + +java -Djarmode=layertools -jar app.jar extract + +docker build . -t midjourney-proxy:${VERSION} + +rm -rf application dependencies snapshot-dependencies spring-boot-loader app.jar + +docker tag midjourney-proxy:${VERSION} novicezk/midjourney-proxy-${ARCH}:${VERSION} +docker push novicezk/midjourney-proxy-${ARCH}:${VERSION} \ No newline at end of file diff --git a/docker/build-manifest.sh b/docker/build-manifest.sh new file mode 100644 index 0000000000000000000000000000000000000000..ce04849e83b20ff2570a795cbe7fdec1be300876 --- /dev/null +++ b/docker/build-manifest.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e -u -o pipefail + +if [ $# -lt 1 ]; then + echo 'version is required' + exit 1 +fi + +VERSION=$1 + +echo "create manifest..." +docker manifest create novicezk/midjourney-proxy:${VERSION} novicezk/midjourney-proxy-amd64:${VERSION} novicezk/midjourney-proxy-arm64v8:${VERSION} + +echo "annotate amd64..." +docker manifest annotate novicezk/midjourney-proxy:${VERSION} novicezk/midjourney-proxy-amd64:${VERSION} --os linux --arch amd64 + +echo "annotate arm64v8..." +docker manifest annotate novicezk/midjourney-proxy:${VERSION} novicezk/midjourney-proxy-arm64v8:${VERSION} --os linux --arch arm64 --variant v8 + +echo "push manifest..." +docker manifest push novicezk/midjourney-proxy:${VERSION} \ No newline at end of file diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000000000000000000000000000000000000..6d38c1e59dbf13e30a540b5216e9e54dc2d59149 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,139 @@ +# API接口说明 + +`http://ip:port/mj` 已有api文档,此处仅作补充 + +## 1. 数据结构 + +### 任务 +| 字段 | 类型 | 示例 | 描述 | +|:-----:|:----:|:----|:----| +| id | string | 1689231405853400 | 任务ID | +| action | string | IMAGINE | 任务类型: IMAGINE(绘图)、UPSCALE(选中放大)、VARIATION(选中变换)、REROLL(重新执行)、DESCRIBE(图生文)、BLEAND(图片混合) | +| status | string | SUCCESS | 任务状态: NOT_START(未启动)、SUBMITTED(已提交处理)、IN_PROGRESS(执行中)、FAILURE(失败)、SUCCESS(成功) | +| prompt | string | 猫猫 | 提示词 | +| promptEn | string | Cat | 英文提示词 | +| description | string | /imagine 猫猫 | 任务描述 | +| submitTime | number | 1689231405854 | 提交时间 | +| startTime | number | 1689231442755 | 开始执行时间 | +| finishTime | number | 1689231544312 | 结束时间 | +| progress | string | 100% | 任务进度 | +| imageUrl | string | https://cdn.discordapp.com/attachments/xxx/xxx/xxxx.png | 生成图片的url, 成功或执行中时有值,可能为png或webp | +| failReason | string | [Invalid parameter] Invalid value | 失败原因, 失败时有值 | +| properties | object | {"finalPrompt": "Cat"} | 任务的扩展属性,系统内部使用 | + + +## 2. 任务提交返回 +- code=1: 提交成功,result为任务ID + ```json + { + "code": 1, + "description": "成功", + "result": "8498455807619990", + "properties": { + "discordInstanceId": "1118138338562560102" + } + } + ``` +- code=21: 任务已存在,U时可能发生 + ```json + { + "code": 21, + "description": "任务已存在", + "result": "0741798445574458", + "properties": { + "status": "SUCCESS", + "imageUrl": "https://xxxx" + } + } + ``` +- code=22: 提交成功,进入队列等待 + ```json + { + "code": 22, + "description": "排队中,前面还有1个任务", + "result": "0741798445574458", + "properties": { + "numberOfQueues": 1, + "discordInstanceId": "1118138338562560102" + } + } + ``` +- code=23: 队列已满,请稍后尝试 + ```json + { + "code": 23, + "description": "队列已满,请稍后尝试", + "result": "14001929738841620", + "properties": { + "discordInstanceId": "1118138338562560102" + } + } + ``` +- code=24: prompt包含敏感词 + ```json + { + "code": 24, + "description": "可能包含敏感词", + "properties": { + "promptEn": "nude body", + "bannedWord": "nude" + } + } + ``` +- other: 提交错误,description为错误描述 + +## 3. `/mj/submit/simple-change` 绘图变化-simple +接口作用同 `/mj/submit/change`(绘图变化),传参方式不同,该接口接收content,格式为`ID 操作`,例如:1320098173412546 U2 + +- 放大 U1~U4 +- 变换 V1~V4 +- 重新执行 R + +## 4. `/mj/submit/describe` 图生文 +```json +{ + // 图片的base64字符串 + "base64": "data:image/png;base64,xxx" +} +``` + +后续任务完成后,properties中finalPrompt即为图片生成的prompt +```json +{ + "id":"14001929738841620", + "action":"DESCRIBE", + "status": "SUCCESS", + "description":"/describe 14001929738841620.png", + "imageUrl":"https://cdn.discordapp.com/attachments/xxx/xxx/14001929738841620.png", + "properties": { + "finalPrompt": "1️⃣ Cat --ar 5:4\n\n2️⃣ Cat2 --ar 5:4\n\n3️⃣ Cat3 --ar 5:4\n\n4️⃣ Cat4 --ar 5:4" + } + // ... +} +``` + +## 5. 任务变更回调 +任务状态变化或进度改变时,会调用业务系统的接口 +- 接口地址为配置的 mj.notify-hook,任务提交时支持传`notifyHook`以改变此任务的回调地址 +- 两者都为空时,不触发回调 + +POST application/json +```json +{ + "id": "14001929738841620", + "action": "IMAGINE", + "status": "SUCCESS", + "prompt": "猫猫", + "promptEn": "Cat", + "description": "/imagine 猫猫", + "submitTime": 1689231405854, + "startTime": 1689231442755, + "finishTime": 1689231544312, + "progress": "100%", + "imageUrl": "https://cdn.discordapp.com/attachments/xxx/xxx/xxxx.png", + "failReason": null, + "properties": { + "finalPrompt": "Cat" + } +} +``` diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 0000000000000000000000000000000000000000..af449b6465a7782cd91d0c9761aeb334f508b8b8 --- /dev/null +++ b/docs/config.md @@ -0,0 +1,68 @@ +## 配置项 + +| 变量名 | 非空 | 描述 | +|:------------------------------|:--:|:----------------------------------------------| +| mj.accounts | 是 | [账号池配置](./config.md#%E8%B4%A6%E5%8F%B7%E6%B1%A0%E9%85%8D%E7%BD%AE%E5%8F%82%E8%80%83),配置后不需要额外设置mj.discord | +| mj.discord.guild-id | 是 | discord服务器ID | +| mj.discord.channel-id | 是 | discord频道ID | +| mj.discord.user-token | 是 | discord用户Token | +| mj.discord.user-agent | 否 | 调用discord接口、连接wss时的user-agent,建议从浏览器network复制 | +| mj.discord.core-size | 否 | 并发数,默认为3 | +| mj.discord.queue-size | 否 | 等待队列,默认长度10 | +| mj.discord.timeout-minutes | 否 | 任务超时时间,默认为5分钟 | +| mj.api-secret | 否 | 接口密钥,为空不启用鉴权;调用接口时需要加请求头 mj-api-secret | +| mj.notify-hook | 否 | 全局的任务状态变更回调地址 | +| mj.notify-notify-pool-size | 否 | 通知回调线程池大小,默认10 | +| mj.task-store.type | 否 | 任务存储方式,默认in_memory(内存\重启后丢失),可选redis | +| mj.task-store.timeout | 否 | 任务过期时间,过期后删除,默认30天 | +| mj.proxy.host | 否 | 代理host,全局代理不生效时设置 | +| mj.proxy.port | 否 | 代理port,全局代理不生效时设置 | +| mj.ng-discord.server | 否 | https://discord.com 反代地址 | +| mj.ng-discord.cdn | 否 | https://cdn.discordapp.com 反代地址 | +| mj.ng-discord.wss | 否 | wss://gateway.discord.gg 反代地址 | +| mj.translate-way | 否 | 中文prompt翻译成英文的方式,可选null(默认)、baidu、gpt | +| mj.baidu-translate.appid | 否 | 百度翻译的appid | +| mj.baidu-translate.app-secret | 否 | 百度翻译的app-secret | +| mj.openai.gpt-api-url | 否 | 自定义gpt的接口地址,默认不需要配置 | +| mj.openai.gpt-api-key | 否 | gpt的api-key | +| mj.openai.timeout | 否 | openai调用的超时时间,默认30秒 | +| mj.openai.model | 否 | openai的模型,默认gpt-3.5-turbo | +| mj.openai.max-tokens | 否 | 返回结果的最大分词数,默认2048 | +| mj.openai.temperature | 否 | 相似度(0-2.0),默认0 | +| spring.redis | 否 | 任务存储方式设置为redis,需配置redis相关属性 | + +### 账号池配置参考 +```yaml +mj: + accounts: + - guild-id: xxx + channel-id: xxx + user-token: xxxx + user-agent: xxxx + - guild-id: xxx + channel-id: xxx + user-token: xxxx + user-agent: xxxx +``` + +账号字段说明 + +| 名称 | 非空 | 描述 | +|:------------------| :----: |:--------------------------------------------------------------------| +| guild-id | 是 | discord服务器ID | +| channel-id | 是 | discord频道ID | +| user-token | 是 | discord用户Token | +| user-agent | 否 | 调用discord接口、连接wss时的user-agent,建议从浏览器network复制 | +| enable | 否 | 是否可用,默认true | +| core-size | 否 | 并发数,默认3 | +| queue-size | 否 | 等待队列长度,默认10 | +| timeout-minutes | 否 | 任务超时时间(分钟),默认5 | + +### spring.redis配置参考 +```yaml +spring: + redis: + host: 10.107.xxx.xxx + port: 6379 + password: xxx +``` \ No newline at end of file diff --git a/docs/discord-params.md b/docs/discord-params.md new file mode 100644 index 0000000000000000000000000000000000000000..b826cc18cacce241dfc4c4b85a1b19ff789d0d11 --- /dev/null +++ b/docs/discord-params.md @@ -0,0 +1,11 @@ +## 获取discord配置参数 + +### 1. 获取用户Token +进入频道,打开network,刷新页面,找到 `messages` 的请求,这里的 authorization 即用户Token,后续设置到 `mj.discord.user-token` + +![User Token](img_8.png) + +### 2. 获取服务器ID、频道ID + +频道的url里取出 服务器ID、频道ID,后续设置到配置项 +![Guild Channel ID](img_9.png) diff --git a/docs/docker-start.md b/docs/docker-start.md new file mode 100644 index 0000000000000000000000000000000000000000..4c6da8b687650f3b59b462af477a60a5f8c42dab --- /dev/null +++ b/docs/docker-start.md @@ -0,0 +1,21 @@ +## Docker 部署教程 + +1. /xxx/xxx/config目录下创建 application.yml(mj配置项)、banned-words.txt(可选,覆盖默认的敏感词文件);参考src/main/resources下的文件 +2. 启动容器,映射config目录 +```shell +docker run -d --name midjourney-proxy \ + -p 8080:8080 \ + -v /xxx/xxx/config:/home/spring/config \ + novicezk/midjourney-proxy:2.5 +``` +3. 访问 `http://ip:port/mj` 查看API文档 + +附: 不映射config目录方式,直接在启动命令中设置参数 +```shell +docker run -d --name midjourney-proxy \ + -p 8080:8080 \ + -e mj.discord.guild-id=xxx \ + -e mj.discord.channel-id=xxx \ + -e mj.discord.user-token=xxx \ + novicezk/midjourney-proxy:2.5 +``` diff --git a/docs/img_10.png b/docs/img_10.png new file mode 100644 index 0000000000000000000000000000000000000000..4e6185d5059d31a9e7aafbc61230b9b6388e32ec Binary files /dev/null and b/docs/img_10.png differ diff --git a/docs/img_8.png b/docs/img_8.png new file mode 100644 index 0000000000000000000000000000000000000000..33127d2842fc76b9d70ae764e70041a5cbe65053 Binary files /dev/null and b/docs/img_8.png differ diff --git a/docs/img_9.png b/docs/img_9.png new file mode 100644 index 0000000000000000000000000000000000000000..1e7bf5be281a5a58abc1303cd465c895a37d5b9f Binary files /dev/null and b/docs/img_9.png differ diff --git a/docs/manager-qrcode.png b/docs/manager-qrcode.png new file mode 100644 index 0000000000000000000000000000000000000000..2594f9a405a3219498f0cbfe1639cda670bdac4a Binary files /dev/null and b/docs/manager-qrcode.png differ diff --git a/docs/params_user.png b/docs/params_user.png new file mode 100644 index 0000000000000000000000000000000000000000..3603b199e8bcf16302362fce66658970853ddcd9 Binary files /dev/null and b/docs/params_user.png differ diff --git a/docs/railway-start.md b/docs/railway-start.md new file mode 100644 index 0000000000000000000000000000000000000000..630f6131b6fb886d89462c28876aa4fc8a0e2ac2 --- /dev/null +++ b/docs/railway-start.md @@ -0,0 +1,36 @@ +## Railway 部署教程 + +Railway是一个提供弹性部署方案的平台,服务器在海外,方便MidJourney的调用。 + +**Railway 提供 5 美元,500 个小时/月的免费额度** + +### 1. Fork本仓库 +### 2. Railway使用github账号登录 +进入 [railway官网](https://railway.app) 选择 `Login` -> `Github`,登录github账号 + +### 3. [New Project](https://railway.app/new) 添加对fork仓库的授权 +![railway_img_1](./railway_img_1.png) +![railway_img_2](./railway_img_2.png) +![railway_img_3](./railway_img_3.png) + +### 4. 选择该fork仓库,新建项目,设置环境变量 +![railway_img_4](./railway_img_4.png) +![railway_img_5](./railway_img_5.png) +![railway_img_6](./railway_img_6.png) +![railway_img_7](./railway_img_7.png) +此处配置项参考 [Wiki / 配置项](https://github.com/novicezk/midjourney-proxy/wiki/%E9%85%8D%E7%BD%AE%E9%A1%B9) ,建议配置api密钥启用鉴权,接口调用时需添加请求头 `mj-api-secret` + +### 5. 启动服务 +进入刚才的Project,它应该已经在自动部署了,后续更新配置之后会自动重新部署 +![railway_img_8](./railway_img_8.png) + +若部署启动失败请查看日志,检查配置项 +![railway_img_9](./railway_img_9.png) +![railway_img_10](./railway_img_10.png) + +### 6. 开始使用 +等待部署成功后,生成随机域名 +![railway_img_11](./railway_img_11.png) +![railway_img_12](./railway_img_12.png) + +访问 `https://midjourney-proxy-***.app/mj` diff --git a/docs/railway_img_1.png b/docs/railway_img_1.png new file mode 100644 index 0000000000000000000000000000000000000000..1fa8f0577f20974a8a36e0082c5012f7589dfdf4 Binary files /dev/null and b/docs/railway_img_1.png differ diff --git a/docs/railway_img_10.png b/docs/railway_img_10.png new file mode 100644 index 0000000000000000000000000000000000000000..4bc7fee8221c4e4c00333837083552d26c4db1b4 Binary files /dev/null and b/docs/railway_img_10.png differ diff --git a/docs/railway_img_11.png b/docs/railway_img_11.png new file mode 100644 index 0000000000000000000000000000000000000000..f6864099a023cca4e66dbb56455c33ab3afcf725 Binary files /dev/null and b/docs/railway_img_11.png differ diff --git a/docs/railway_img_12.png b/docs/railway_img_12.png new file mode 100644 index 0000000000000000000000000000000000000000..35d107c76ad19a690ab90d38bcb264d5eca15541 Binary files /dev/null and b/docs/railway_img_12.png differ diff --git a/docs/railway_img_2.png b/docs/railway_img_2.png new file mode 100644 index 0000000000000000000000000000000000000000..41385d6bce71b861ed6934443c021336179fb789 Binary files /dev/null and b/docs/railway_img_2.png differ diff --git a/docs/railway_img_3.png b/docs/railway_img_3.png new file mode 100644 index 0000000000000000000000000000000000000000..4d1bfcbe9b57a28530e8e3d37fa07ac7d6522cc6 Binary files /dev/null and b/docs/railway_img_3.png differ diff --git a/docs/railway_img_4.png b/docs/railway_img_4.png new file mode 100644 index 0000000000000000000000000000000000000000..be1fec5ae230a4f8ec4597d29d773a342979e7f7 Binary files /dev/null and b/docs/railway_img_4.png differ diff --git a/docs/railway_img_5.png b/docs/railway_img_5.png new file mode 100644 index 0000000000000000000000000000000000000000..8a1e88bc278258817794fcc4fa4dc6c09d484878 Binary files /dev/null and b/docs/railway_img_5.png differ diff --git a/docs/railway_img_6.png b/docs/railway_img_6.png new file mode 100644 index 0000000000000000000000000000000000000000..f98c1969330a1e316c028e1693a97f9eb62a3113 Binary files /dev/null and b/docs/railway_img_6.png differ diff --git a/docs/railway_img_7.png b/docs/railway_img_7.png new file mode 100644 index 0000000000000000000000000000000000000000..3359bec9ff9e5ff782025183963803d7fbabc3ed Binary files /dev/null and b/docs/railway_img_7.png differ diff --git a/docs/railway_img_8.png b/docs/railway_img_8.png new file mode 100644 index 0000000000000000000000000000000000000000..51180fdc79b1ac4bcb2299c697f203538e32c035 Binary files /dev/null and b/docs/railway_img_8.png differ diff --git a/docs/railway_img_9.png b/docs/railway_img_9.png new file mode 100644 index 0000000000000000000000000000000000000000..5e07f10ec3cca077507f94909fa9ad5a27868cc0 Binary files /dev/null and b/docs/railway_img_9.png differ diff --git a/docs/receipt-code.png b/docs/receipt-code.png new file mode 100644 index 0000000000000000000000000000000000000000..0746aea2b3837b6bd5e7293bdeaf904d589e292e Binary files /dev/null and b/docs/receipt-code.png differ diff --git a/docs/zeabur-start.md b/docs/zeabur-start.md new file mode 100644 index 0000000000000000000000000000000000000000..26c7b9bb5d47b0402c04c47a6dac50004ffcdd43 --- /dev/null +++ b/docs/zeabur-start.md @@ -0,0 +1,33 @@ +## Zeabur 部署教程 + +### Zeabur 优势 +1. 新注册的 `Github` 账号可能无法使用 `Railway`,但是能用 `Zeabur` +2. 通过 `Railway` 部署的项目会自动生成一个域名,然而因为某些原因,形如 `*.up.railway.app` 的域名在国内无法访问 +3. `Zeabur` 服务器运行在国外,但是其生成的域名 `*.zeabur.app` 没有被污染,国内可直接访问 + +### 开始部署 + +1. 打开网址 https://zeabur.com/zh-CN +2. 点击现在开始 +3. 点击 `Sign in with GitHub` +4. 登陆你的 `Github` 账号 +5. 点击 `Authorize zeabur` 授权 +6. 点击 `创建项目` 并输入一个项目名称,点击 `创建` +7. 点击 `+` 添加服务,选择 `Git-Deploy service from source code in GitHub repository.` +8. 点击 `Configure GitHub` 根据需要选择 `All repositories` 或者 `Only select repositories` +9. 点击 `install`,之后自动跳转,最好再刷新一下页面 +10. 点击 你 fork 的 `midjourney-proxy` 项目 +11. 点击环境变量,点击编辑原始环境变量,添加你需要的环境变量 +12. 关于环境变量,与 `Railway` 稍有不同,需要把 `.` 和 `-` 全部换成 `_`,例如如下格式 + ```properties + PORT=8080 + mj_discord_guild_id=xxx + mj_discord_channel_id=xxx + mj_discord_user_token=xxx + mj_api_secret=*** + ``` + 此处配置项参考 [Wiki / 配置项](https://github.com/novicezk/midjourney-proxy/wiki/%E9%85%8D%E7%BD%AE%E9%A1%B9) ,建议配置api密钥启用鉴权,接口调用时需添加请求头 `mj-api-secret` +13. 然后取消 `Building`,点击 `Redeploy` (此做法是为了让环境变量生效) +14. 部署 `midjourney-proxy` 大概需要 `2` 分钟,此时你可以做的是:配置域名 +15. 点击下方的域名,点击生成域名,输入前缀,例如 `midjourney-proxy-demo`,点击保存;或者添加自定义域名,之后加上 `CNAME` 解析 +16. 等待部署成功,访问 `https://midjourney-proxy-demo.zeabur.app/mj` \ No newline at end of file diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..43f2f031b22f058b4fd5e983935fe9c92efc9194 --- /dev/null +++ b/pom.xml @@ -0,0 +1,120 @@ + + + 4.0.0 + + + org.springframework.boot + spring-boot-starter-parent + 2.6.14 + + + com.github.novicezk + midjourney-proxy + 2.5 + + + 5.8.18 + 20220924 + 5.0.0-beta.9 + 1.0.14-beta1 + 2.0.0 + 4.1.0 + 1.21 + 4.5.14 + 17 + ${java.version} + ${java.version} + + + + + org.springframework.boot + spring-boot-starter-web + + + org.springframework.boot + spring-boot-starter-data-redis + + + + cn.hutool + hutool-core + ${hutool.version} + + + cn.hutool + hutool-cache + ${hutool.version} + + + cn.hutool + hutool-crypto + ${hutool.version} + + + org.json + json + ${org-json.version} + + + net.dv8tion + JDA + ${jda.version} + + + club.minnced + opus-java + + + + + com.unfbx + chatgpt-java + ${chatgpt-java.version} + + + slf4j-simple + org.slf4j + + + + + eu.maxschuster + dataurl + ${dataurl.version} + + + com.github.xiaoymin + knife4j-openapi2-spring-boot-starter + ${knife4j.verison} + + + eu.bitwalker + UserAgentUtils + ${user-agent-utils.verison} + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + org.projectlombok + lombok + true + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + + diff --git a/src/main/java/com/github/novicezk/midjourney/Constants.java b/src/main/java/com/github/novicezk/midjourney/Constants.java new file mode 100644 index 0000000000000000000000000000000000000000..0bd8dcb2544971f0d6bc768d045dc98602d7e4bc --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/Constants.java @@ -0,0 +1,20 @@ +package com.github.novicezk.midjourney; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public final class Constants { + // 任务扩展属性 start + public static final String TASK_PROPERTY_NOTIFY_HOOK = "notifyHook"; + public static final String TASK_PROPERTY_FINAL_PROMPT = "finalPrompt"; + public static final String TASK_PROPERTY_MESSAGE_ID = "messageId"; + public static final String TASK_PROPERTY_MESSAGE_HASH = "messageHash"; + public static final String TASK_PROPERTY_PROGRESS_MESSAGE_ID = "progressMessageId"; + public static final String TASK_PROPERTY_FLAGS = "flags"; + public static final String TASK_PROPERTY_NONCE = "nonce"; + public static final String TASK_PROPERTY_DISCORD_INSTANCE_ID = "discordInstanceId"; + // 任务扩展属性 end + + public static final String API_SECRET_HEADER_NAME = "mj-api-secret"; + public static final String DEFAULT_DISCORD_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36"; +} diff --git a/src/main/java/com/github/novicezk/midjourney/ProxyApplication.java b/src/main/java/com/github/novicezk/midjourney/ProxyApplication.java new file mode 100644 index 0000000000000000000000000000000000000000..f46750039f3f633d4486bd517aeaba1cb0c9d80f --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/ProxyApplication.java @@ -0,0 +1,19 @@ +package com.github.novicezk.midjourney; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Import; +import org.springframework.scheduling.annotation.EnableScheduling; +import spring.config.BeanConfig; +import spring.config.WebMvcConfig; + +@EnableScheduling +@SpringBootApplication +@Import({BeanConfig.class, WebMvcConfig.class}) +public class ProxyApplication { + + public static void main(String[] args) { + SpringApplication.run(ProxyApplication.class, args); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/ProxyProperties.java b/src/main/java/com/github/novicezk/midjourney/ProxyProperties.java new file mode 100644 index 0000000000000000000000000000000000000000..ae4ae4880e2eba3dbde44d575ed298acbfa6e32f --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/ProxyProperties.java @@ -0,0 +1,207 @@ +package com.github.novicezk.midjourney; + +import com.github.novicezk.midjourney.enums.TranslateWay; +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +@Data +@Component +@ConfigurationProperties(prefix = "mj") +public class ProxyProperties { + /** + * task存储配置. + */ + private final TaskStore taskStore = new TaskStore(); + /** + * discord账号选择规则. + */ + private String accountChooseRule = "BestWaitIdleRule"; + /** + * discord单账号配置. + */ + private final DiscordAccountConfig discord = new DiscordAccountConfig(); + /** + * discord账号池配置. + */ + private final List accounts = new ArrayList<>(); + /** + * 代理配置. + */ + private final ProxyConfig proxy = new ProxyConfig(); + /** + * 反代配置. + */ + private final NgDiscordConfig ngDiscord = new NgDiscordConfig(); + /** + * 百度翻译配置. + */ + private final BaiduTranslateConfig baiduTranslate = new BaiduTranslateConfig(); + /** + * openai配置. + */ + private final OpenaiConfig openai = new OpenaiConfig(); + /** + * 中文prompt翻译方式. + */ + private TranslateWay translateWay = TranslateWay.NULL; + /** + * 接口密钥,为空不启用鉴权;调用接口时需要加请求头 mj-api-secret. + */ + private String apiSecret; + /** + * 任务状态变更回调地址. + */ + private String notifyHook; + /** + * 通知回调线程池大小. + */ + private int notifyPoolSize = 10; + + @Data + public static class DiscordAccountConfig { + /** + * 服务器ID. + */ + private String guildId; + /** + * 频道ID. + */ + private String channelId; + /** + * 用户Token. + */ + private String userToken; + /** + * 用户UserAgent. + */ + private String userAgent = Constants.DEFAULT_DISCORD_USER_AGENT; + /** + * 是否可用. + */ + private boolean enable = true; + /** + * 并发数. + */ + private int coreSize = 3; + /** + * 等待队列长度. + */ + private int queueSize = 10; + /** + * 任务超时时间(分钟). + */ + private int timeoutMinutes = 5; + } + + @Data + public static class BaiduTranslateConfig { + /** + * 百度翻译的APP_ID. + */ + private String appid; + /** + * 百度翻译的密钥. + */ + private String appSecret; + } + + @Data + public static class OpenaiConfig { + /** + * 自定义gpt的api-url. + */ + private String gptApiUrl; + /** + * gpt的api-key. + */ + private String gptApiKey; + /** + * 超时时间. + */ + private Duration timeout = Duration.ofSeconds(30); + /** + * 使用的模型. + */ + private String model = "gpt-3.5-turbo"; + /** + * 返回结果的最大分词数. + */ + private int maxTokens = 2048; + /** + * 相似度,取值 0-2. + */ + private double temperature = 0; + } + + @Data + public static class TaskStore { + /** + * 任务过期时间,默认30天. + */ + private Duration timeout = Duration.ofDays(30); + /** + * 任务存储方式: redis(默认)、in_memory. + */ + private Type type = Type.IN_MEMORY; + + public enum Type { + /** + * redis. + */ + REDIS, + /** + * in_memory. + */ + IN_MEMORY + } + } + + @Data + public static class ProxyConfig { + /** + * 代理host. + */ + private String host; + /** + * 代理端口. + */ + private Integer port; + } + + @Data + public static class NgDiscordConfig { + /** + * https://discord.com 反代. + */ + private String server; + /** + * https://cdn.discordapp.com 反代. + */ + private String cdn; + /** + * wss://gateway.discord.gg 反代. + */ + private String wss; + } + + @Data + public static class TaskQueueConfig { + /** + * 并发数. + */ + private int coreSize = 3; + /** + * 等待队列长度. + */ + private int queueSize = 10; + /** + * 任务超时时间(分钟). + */ + private int timeoutMinutes = 5; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/ReturnCode.java b/src/main/java/com/github/novicezk/midjourney/ReturnCode.java new file mode 100644 index 0000000000000000000000000000000000000000..ec60264215181d9aaea3c4781d0cd7c1ccf1ea23 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/ReturnCode.java @@ -0,0 +1,42 @@ +package com.github.novicezk.midjourney; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public final class ReturnCode { + /** + * 成功. + */ + public static final int SUCCESS = 1; + /** + * 数据未找到. + */ + public static final int NOT_FOUND = 3; + /** + * 校验错误. + */ + public static final int VALIDATION_ERROR = 4; + /** + * 系统异常. + */ + public static final int FAILURE = 9; + + /** + * 已存在. + */ + public static final int EXISTED = 21; + /** + * 排队中. + */ + public static final int IN_QUEUE = 22; + /** + * 队列已满. + */ + public static final int QUEUE_REJECTED = 23; + /** + * prompt包含敏感词. + */ + public static final int BANNED_PROMPT = 24; + + +} \ No newline at end of file diff --git a/src/main/java/com/github/novicezk/midjourney/controller/AccountController.java b/src/main/java/com/github/novicezk/midjourney/controller/AccountController.java new file mode 100644 index 0000000000000000000000000000000000000000..784a114939f9079371795595f2f4f06bd317b43a --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/controller/AccountController.java @@ -0,0 +1,36 @@ +package com.github.novicezk.midjourney.controller; + +import com.github.novicezk.midjourney.domain.DiscordAccount; +import com.github.novicezk.midjourney.loadbalancer.DiscordInstance; +import com.github.novicezk.midjourney.loadbalancer.DiscordLoadBalancer; +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import io.swagger.annotations.ApiParam; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.List; + +@Api(tags = "账号查询") +@RestController +@RequestMapping("/account") +@RequiredArgsConstructor +public class AccountController { + private final DiscordLoadBalancer loadBalancer; + + @ApiOperation(value = "指定ID获取账号") + @GetMapping("/{id}/fetch") + public DiscordAccount fetch(@ApiParam(value = "账号ID") @PathVariable String id) { + DiscordInstance instance = this.loadBalancer.getDiscordInstance(id); + return instance == null ? null : instance.account(); + } + + @ApiOperation(value = "查询所有账号") + @GetMapping("/list") + public List list() { + return this.loadBalancer.getAllInstances().stream().map(DiscordInstance::account).toList(); + } +} \ No newline at end of file diff --git a/src/main/java/com/github/novicezk/midjourney/controller/SubmitController.java b/src/main/java/com/github/novicezk/midjourney/controller/SubmitController.java new file mode 100644 index 0000000000000000000000000000000000000000..4823decf3ed9549877173a556fafd2d7577a09f2 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/controller/SubmitController.java @@ -0,0 +1,228 @@ +package com.github.novicezk.midjourney.controller; + +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.util.RandomUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.dto.BaseSubmitDTO; +import com.github.novicezk.midjourney.dto.SubmitBlendDTO; +import com.github.novicezk.midjourney.dto.SubmitChangeDTO; +import com.github.novicezk.midjourney.dto.SubmitDescribeDTO; +import com.github.novicezk.midjourney.dto.SubmitImagineDTO; +import com.github.novicezk.midjourney.dto.SubmitSimpleChangeDTO; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.exception.BannedPromptException; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.service.TaskService; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.service.TranslateService; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.BannedPromptUtils; +import com.github.novicezk.midjourney.util.ConvertUtils; +import com.github.novicezk.midjourney.util.MimeTypeUtils; +import com.github.novicezk.midjourney.util.SnowFlake; +import com.github.novicezk.midjourney.util.TaskChangeParams; +import eu.maxschuster.dataurl.DataUrl; +import eu.maxschuster.dataurl.DataUrlSerializer; +import eu.maxschuster.dataurl.IDataUrlSerializer; +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.net.MalformedURLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +@Api(tags = "任务提交") +@RestController +@RequestMapping("/submit") +@RequiredArgsConstructor +public class SubmitController { + private final TranslateService translateService; + private final TaskStoreService taskStoreService; + private final ProxyProperties properties; + private final TaskService taskService; + + @ApiOperation(value = "提交Imagine任务") + @PostMapping("/imagine") + public SubmitResultVO imagine(@RequestBody SubmitImagineDTO imagineDTO) { + String prompt = imagineDTO.getPrompt(); + if (CharSequenceUtil.isBlank(prompt)) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "prompt不能为空"); + } + prompt = prompt.trim(); + Task task = newTask(imagineDTO); + task.setAction(TaskAction.IMAGINE); + task.setPrompt(prompt); + String promptEn = translatePrompt(prompt); + try { + BannedPromptUtils.checkBanned(promptEn); + } catch (BannedPromptException e) { + return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词") + .setProperty("promptEn", promptEn).setProperty("bannedWord", e.getMessage()); + } + List base64Array = Optional.ofNullable(imagineDTO.getBase64Array()).orElse(new ArrayList<>()); + if (CharSequenceUtil.isNotBlank(imagineDTO.getBase64())) { + base64Array.add(imagineDTO.getBase64()); + } + List dataUrls; + try { + dataUrls = ConvertUtils.convertBase64Array(base64Array); + } catch (MalformedURLException e) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误"); + } + task.setPromptEn(promptEn); + task.setDescription("/imagine " + prompt); + return this.taskService.submitImagine(task, dataUrls); + } + + @ApiOperation(value = "绘图变化-simple") + @PostMapping("/simple-change") + public SubmitResultVO simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) { + TaskChangeParams changeParams = ConvertUtils.convertChangeParams(simpleChangeDTO.getContent()); + if (changeParams == null) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "content参数错误"); + } + SubmitChangeDTO changeDTO = new SubmitChangeDTO(); + changeDTO.setAction(changeParams.getAction()); + changeDTO.setTaskId(changeParams.getId()); + changeDTO.setIndex(changeParams.getIndex()); + changeDTO.setState(simpleChangeDTO.getState()); + changeDTO.setNotifyHook(simpleChangeDTO.getNotifyHook()); + return change(changeDTO); + } + + @ApiOperation(value = "绘图变化") + @PostMapping("/change") + public SubmitResultVO change(@RequestBody SubmitChangeDTO changeDTO) { + if (CharSequenceUtil.isBlank(changeDTO.getTaskId())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "taskId不能为空"); + } + if (!Set.of(TaskAction.UPSCALE, TaskAction.VARIATION, TaskAction.REROLL).contains(changeDTO.getAction())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "action参数错误"); + } + String description = "/up " + changeDTO.getTaskId(); + if (TaskAction.REROLL.equals(changeDTO.getAction())) { + description += " R"; + } else { + description += " " + changeDTO.getAction().name().charAt(0) + changeDTO.getIndex(); + } + if (TaskAction.UPSCALE.equals(changeDTO.getAction())) { + TaskCondition condition = new TaskCondition().setDescription(description); + Task existTask = this.taskStoreService.findOne(condition); + if (existTask != null) { + return SubmitResultVO.of(ReturnCode.EXISTED, "任务已存在", existTask.getId()) + .setProperty("status", existTask.getStatus()) + .setProperty("imageUrl", existTask.getImageUrl()); + } + } + Task targetTask = this.taskStoreService.get(changeDTO.getTaskId()); + if (targetTask == null) { + return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效"); + } + if (!TaskStatus.SUCCESS.equals(targetTask.getStatus())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务状态错误"); + } + if (!Set.of(TaskAction.IMAGINE, TaskAction.VARIATION, TaskAction.REROLL, TaskAction.BLEND).contains(targetTask.getAction())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务不允许执行变化"); + } + Task task = newTask(changeDTO); + task.setAction(changeDTO.getAction()); + task.setPrompt(targetTask.getPrompt()); + task.setPromptEn(targetTask.getPromptEn()); + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, targetTask.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT)); + task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID)); + task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID)); + task.setDescription(description); + int messageFlags = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_FLAGS); + String messageId = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_ID); + String messageHash = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_HASH); + if (TaskAction.UPSCALE.equals(changeDTO.getAction())) { + return this.taskService.submitUpscale(task, messageId, messageHash, changeDTO.getIndex(), messageFlags); + } else if (TaskAction.VARIATION.equals(changeDTO.getAction())) { + return this.taskService.submitVariation(task, messageId, messageHash, changeDTO.getIndex(), messageFlags); + } else { + return this.taskService.submitReroll(task, messageId, messageHash, messageFlags); + } + } + + @ApiOperation(value = "提交Describe任务") + @PostMapping("/describe") + public SubmitResultVO describe(@RequestBody SubmitDescribeDTO describeDTO) { + if (CharSequenceUtil.isBlank(describeDTO.getBase64())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64不能为空"); + } + IDataUrlSerializer serializer = new DataUrlSerializer(); + DataUrl dataUrl; + try { + dataUrl = serializer.unserialize(describeDTO.getBase64()); + } catch (MalformedURLException e) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误"); + } + Task task = newTask(describeDTO); + task.setAction(TaskAction.DESCRIBE); + String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType()); + task.setDescription("/describe " + taskFileName); + return this.taskService.submitDescribe(task, dataUrl); + } + + @ApiOperation(value = "提交Blend任务") + @PostMapping("/blend") + public SubmitResultVO blend(@RequestBody SubmitBlendDTO blendDTO) { + List base64Array = blendDTO.getBase64Array(); + if (base64Array == null || base64Array.size() < 2 || base64Array.size() > 5) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64List参数错误"); + } + if (blendDTO.getDimensions() == null) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "dimensions参数错误"); + } + IDataUrlSerializer serializer = new DataUrlSerializer(); + List dataUrlList = new ArrayList<>(); + try { + for (String base64 : base64Array) { + DataUrl dataUrl = serializer.unserialize(base64); + dataUrlList.add(dataUrl); + } + } catch (MalformedURLException e) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误"); + } + Task task = newTask(blendDTO); + task.setAction(TaskAction.BLEND); + task.setDescription("/blend " + task.getId() + " " + dataUrlList.size()); + return this.taskService.submitBlend(task, dataUrlList, blendDTO.getDimensions()); + } + + private Task newTask(BaseSubmitDTO base) { + Task task = new Task(); + task.setId(System.currentTimeMillis() + "" + RandomUtil.randomNumbers(3)); + task.setSubmitTime(System.currentTimeMillis()); + task.setState(base.getState()); + String notifyHook = CharSequenceUtil.isBlank(base.getNotifyHook()) ? this.properties.getNotifyHook() : base.getNotifyHook(); + task.setProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook); + task.setProperty(Constants.TASK_PROPERTY_NONCE, SnowFlake.INSTANCE.nextId()); + return task; + } + + private String translatePrompt(String prompt) { + String promptEn; + int paramStart = prompt.indexOf(" --"); + if (paramStart > 0) { + promptEn = this.translateService.translateToEnglish(prompt.substring(0, paramStart)).trim() + prompt.substring(paramStart); + } else { + promptEn = this.translateService.translateToEnglish(prompt).trim(); + } + if (CharSequenceUtil.isBlank(promptEn)) { + promptEn = prompt; + } + return promptEn; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/controller/TaskController.java b/src/main/java/com/github/novicezk/midjourney/controller/TaskController.java new file mode 100644 index 0000000000000000000000000000000000000000..b41845fca286c0b78a6c807123273b129a3dc99f --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/controller/TaskController.java @@ -0,0 +1,64 @@ +package com.github.novicezk.midjourney.controller; + +import cn.hutool.core.comparator.CompareUtil; +import com.github.novicezk.midjourney.dto.TaskConditionDTO; +import com.github.novicezk.midjourney.loadbalancer.DiscordLoadBalancer; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.support.Task; +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import io.swagger.annotations.ApiParam; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +@Api(tags = "任务查询") +@RestController +@RequestMapping("/task") +@RequiredArgsConstructor +public class TaskController { + private final TaskStoreService taskStoreService; + private final DiscordLoadBalancer discordLoadBalancer; + + @ApiOperation(value = "指定ID获取任务") + @GetMapping("/{id}/fetch") + public Task fetch(@ApiParam(value = "任务ID") @PathVariable String id) { + return this.taskStoreService.get(id); + } + + @ApiOperation(value = "查询任务队列") + @GetMapping("/queue") + public List queue() { + return this.discordLoadBalancer.getQueueTaskIds().stream() + .map(this.taskStoreService::get).filter(Objects::nonNull) + .sorted(Comparator.comparing(Task::getSubmitTime)) + .toList(); + } + + @ApiOperation(value = "查询所有任务") + @GetMapping("/list") + public List list() { + return this.taskStoreService.list().stream() + .sorted((t1, t2) -> CompareUtil.compare(t2.getSubmitTime(), t1.getSubmitTime())) + .toList(); + } + + @ApiOperation(value = "根据ID列表查询任务") + @PostMapping("/list-by-condition") + public List listByIds(@RequestBody TaskConditionDTO conditionDTO) { + if (conditionDTO.getIds() == null) { + return Collections.emptyList(); + } + return conditionDTO.getIds().stream().map(this.taskStoreService::get).filter(Objects::nonNull).toList(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/domain/DiscordAccount.java b/src/main/java/com/github/novicezk/midjourney/domain/DiscordAccount.java new file mode 100644 index 0000000000000000000000000000000000000000..24ac0039d0367111bc55d80ded898ae40fe1e291 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/domain/DiscordAccount.java @@ -0,0 +1,38 @@ +package com.github.novicezk.midjourney.domain; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.github.novicezk.midjourney.Constants; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + +@Data +@EqualsAndHashCode(callSuper = true) +@ApiModel("Discord账号") +public class DiscordAccount extends DomainObject { + + @ApiModelProperty("服务器ID") + private String guildId; + @ApiModelProperty("频道ID") + private String channelId; + @ApiModelProperty("用户Token") + private String userToken; + @ApiModelProperty("用户UserAgent") + private String userAgent = Constants.DEFAULT_DISCORD_USER_AGENT; + + @ApiModelProperty("是否可用") + private boolean enable = true; + + @ApiModelProperty("并发数") + private int coreSize = 3; + @ApiModelProperty("等待队列长度") + private int queueSize = 10; + @ApiModelProperty("任务超时时间(分钟)") + private int timeoutMinutes = 5; + + @JsonIgnore + public String getDisplay() { + return this.channelId; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/domain/DomainObject.java b/src/main/java/com/github/novicezk/midjourney/domain/DomainObject.java new file mode 100644 index 0000000000000000000000000000000000000000..959d7c2ee63d4a48c7597a20f9f4cda7da9325cb --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/domain/DomainObject.java @@ -0,0 +1,72 @@ +package com.github.novicezk.midjourney.domain; + + +import com.fasterxml.jackson.annotation.JsonIgnore; +import io.swagger.annotations.ApiModelProperty; +import lombok.Getter; +import lombok.Setter; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + + +public class DomainObject implements Serializable { + @Getter + @Setter + @ApiModelProperty("ID") + protected String id; + + @Setter + protected Map properties; // 扩展属性,仅支持基本类型 + + @JsonIgnore + private final transient Object lock = new Object(); + + public void sleep() throws InterruptedException { + synchronized (this.lock) { + this.lock.wait(); + } + } + + public void awake() { + synchronized (this.lock) { + this.lock.notifyAll(); + } + } + + public DomainObject setProperty(String name, Object value) { + getProperties().put(name, value); + return this; + } + + public DomainObject removeProperty(String name) { + getProperties().remove(name); + return this; + } + + public Object getProperty(String name) { + return getProperties().get(name); + } + + @SuppressWarnings("unchecked") + public T getPropertyGeneric(String name) { + return (T) getProperty(name); + } + + public T getProperty(String name, Class clz) { + return getProperty(name, clz, null); + } + + public T getProperty(String name, Class clz, T defaultValue) { + Object value = getProperty(name); + return value == null ? defaultValue : clz.cast(value); + } + + public Map getProperties() { + if (this.properties == null) { + this.properties = new HashMap<>(); + } + return this.properties; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/BaseSubmitDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/BaseSubmitDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..50a8a4f63a3a8966e30dd1c47c30030fa0348b47 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/BaseSubmitDTO.java @@ -0,0 +1,16 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public abstract class BaseSubmitDTO { + + @ApiModelProperty("自定义参数") + protected String state; + + @ApiModelProperty("回调地址, 为空时使用全局notifyHook") + protected String notifyHook; +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitBlendDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitBlendDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..dd56aecc92ceec496a0dfb282a22ee331ffcf5bf --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitBlendDTO.java @@ -0,0 +1,21 @@ +package com.github.novicezk.midjourney.dto; + +import com.github.novicezk.midjourney.enums.BlendDimensions; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + +import java.util.List; + +@Data +@ApiModel("Blend提交参数") +@EqualsAndHashCode(callSuper = true) +public class SubmitBlendDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "图片base64数组", required = true, example = "[\"data:image/png;base64,xxx1\", \"data:image/png;base64,xxx2\"]") + private List base64Array; + + @ApiModelProperty(value = "比例: PORTRAIT(2:3); SQUARE(1:1); LANDSCAPE(3:2)", example = "SQUARE") + private BlendDimensions dimensions = BlendDimensions.SQUARE; +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitChangeDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitChangeDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..1b6b493b6a7dd3fab3146937ab30f04c15fbe226 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitChangeDTO.java @@ -0,0 +1,25 @@ +package com.github.novicezk.midjourney.dto; + +import com.github.novicezk.midjourney.enums.TaskAction; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + + +@Data +@ApiModel("变化任务提交参数") +@EqualsAndHashCode(callSuper = true) +public class SubmitChangeDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "任务ID", required = true, example = "\"1320098173412546\"") + private String taskId; + + @ApiModelProperty(value = "UPSCALE(放大); VARIATION(变换); REROLL(重新生成)", required = true, + allowableValues = "UPSCALE, VARIATION, REROLL", example = "UPSCALE") + private TaskAction action; + + @ApiModelProperty(value = "序号(1~4), action为UPSCALE,VARIATION时必传", allowableValues = "range[1, 4]", example = "1") + private Integer index; + +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitDescribeDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitDescribeDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..8c34a35fb8b47a2dd1099b5c3028191ecfd77938 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitDescribeDTO.java @@ -0,0 +1,15 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + +@Data +@ApiModel("Describe提交参数") +@EqualsAndHashCode(callSuper = true) +public class SubmitDescribeDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "图片base64", required = true, example = "data:image/png;base64,xxx") + private String base64; +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitImagineDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitImagineDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..51ff49de114a56fabd9f490286af11b7634573aa --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitImagineDTO.java @@ -0,0 +1,26 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + +import java.util.List; + + +@Data +@ApiModel("Imagine提交参数") +@EqualsAndHashCode(callSuper = true) +public class SubmitImagineDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "提示词", required = true, example = "Cat") + private String prompt; + + @ApiModelProperty(value = "垫图base64数组") + private List base64Array; + + @ApiModelProperty(hidden = true) + @Deprecated(since = "3.0", forRemoval = true) + private String base64; + +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitSimpleChangeDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitSimpleChangeDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..6d5e4537a21413b2455b159ea1e0dd5adb2b1d2b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitSimpleChangeDTO.java @@ -0,0 +1,17 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + + +@Data +@ApiModel("变化任务提交参数-simple") +@EqualsAndHashCode(callSuper = true) +public class SubmitSimpleChangeDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "变化描述: ID $action$index", required = true, example = "1320098173412546 U2") + private String content; + +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/TaskConditionDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/TaskConditionDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..09ae83e2a5b0ff669d9e03f53197b79ef8cccc57 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/TaskConditionDTO.java @@ -0,0 +1,14 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModel; +import lombok.Data; + +import java.util.List; + +@Data +@ApiModel("任务查询参数") +public class TaskConditionDTO { + + private List ids; + +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/BlendDimensions.java b/src/main/java/com/github/novicezk/midjourney/enums/BlendDimensions.java new file mode 100644 index 0000000000000000000000000000000000000000..ea1c4d6cd71a6672a9e4a507bf00d3024e756104 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/BlendDimensions.java @@ -0,0 +1,21 @@ +package com.github.novicezk.midjourney.enums; + + +public enum BlendDimensions { + + PORTRAIT("2:3"), + + SQUARE("1:1"), + + LANDSCAPE("3:2"); + + private final String value; + + BlendDimensions(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/MessageType.java b/src/main/java/com/github/novicezk/midjourney/enums/MessageType.java new file mode 100644 index 0000000000000000000000000000000000000000..330654635d44338b19a9013e74025828cea0c639 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/MessageType.java @@ -0,0 +1,26 @@ +package com.github.novicezk.midjourney.enums; + + +public enum MessageType { + /** + * 创建. + */ + CREATE, + /** + * 修改. + */ + UPDATE, + /** + * 删除. + */ + DELETE; + + public static MessageType of(String type) { + return switch (type) { + case "MESSAGE_CREATE" -> CREATE; + case "MESSAGE_UPDATE" -> UPDATE; + case "MESSAGE_DELETE" -> DELETE; + default -> null; + }; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/TaskAction.java b/src/main/java/com/github/novicezk/midjourney/enums/TaskAction.java new file mode 100644 index 0000000000000000000000000000000000000000..35811002960728c2b37bcc7243128de60d19900b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/TaskAction.java @@ -0,0 +1,30 @@ +package com.github.novicezk.midjourney.enums; + + +public enum TaskAction { + /** + * 生成图片. + */ + IMAGINE, + /** + * 选中放大. + */ + UPSCALE, + /** + * 选中其中的一张图,生成四张相似的. + */ + VARIATION, + /** + * 重新执行. + */ + REROLL, + /** + * 图转prompt. + */ + DESCRIBE, + /** + * 多图混合. + */ + BLEND + +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/TaskStatus.java b/src/main/java/com/github/novicezk/midjourney/enums/TaskStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..4b4fa5408fe9520f727a9a2bb46c1ee07f50f0ef --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/TaskStatus.java @@ -0,0 +1,26 @@ +package com.github.novicezk.midjourney.enums; + + +public enum TaskStatus { + /** + * 未启动. + */ + NOT_START, + /** + * 已提交. + */ + SUBMITTED, + /** + * 执行中. + */ + IN_PROGRESS, + /** + * 失败. + */ + FAILURE, + /** + * 成功. + */ + SUCCESS + +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/TranslateWay.java b/src/main/java/com/github/novicezk/midjourney/enums/TranslateWay.java new file mode 100644 index 0000000000000000000000000000000000000000..495297dffca6b9bbdc7c03656d4725f718dbb324 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/TranslateWay.java @@ -0,0 +1,18 @@ +package com.github.novicezk.midjourney.enums; + + +public enum TranslateWay { + /** + * 百度翻译. + */ + BAIDU, + /** + * GPT翻译. + */ + GPT, + /** + * 不翻译. + */ + NULL + +} diff --git a/src/main/java/com/github/novicezk/midjourney/exception/BannedPromptException.java b/src/main/java/com/github/novicezk/midjourney/exception/BannedPromptException.java new file mode 100644 index 0000000000000000000000000000000000000000..fd5bfc5e4a386ea7c36185715d510c4a38ae220b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/exception/BannedPromptException.java @@ -0,0 +1,8 @@ +package com.github.novicezk.midjourney.exception; + +public class BannedPromptException extends Exception { + + public BannedPromptException(String message) { + super(message); + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/exception/SnowFlakeException.java b/src/main/java/com/github/novicezk/midjourney/exception/SnowFlakeException.java new file mode 100644 index 0000000000000000000000000000000000000000..648bc0029a3cafc0d1edf43b21cb1987e076af5f --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/exception/SnowFlakeException.java @@ -0,0 +1,16 @@ +package com.github.novicezk.midjourney.exception; + +public class SnowFlakeException extends RuntimeException { + + public SnowFlakeException(String message) { + super(message); + } + + public SnowFlakeException(String message, Throwable cause) { + super(message, cause); + } + + public SnowFlakeException(Throwable cause) { + super(cause); + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordInstance.java b/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordInstance.java new file mode 100644 index 0000000000000000000000000000000000000000..61994dbff46eab90c21342ea86ddcbde8ba97704 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordInstance.java @@ -0,0 +1,34 @@ +package com.github.novicezk.midjourney.loadbalancer; + + +import com.github.novicezk.midjourney.domain.DiscordAccount; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.result.Message; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.service.DiscordService; +import com.github.novicezk.midjourney.support.Task; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; + +public interface DiscordInstance extends DiscordService { + + String getInstanceId(); + + DiscordAccount account(); + + boolean isAlive(); + + void startWss() throws Exception; + + List getRunningTasks(); + + void exitTask(Task task); + + Map> getRunningFutures(); + + SubmitResultVO submitTask(Task task, Callable> discordSubmit); + +} diff --git a/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordInstanceImpl.java b/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordInstanceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..549318cb20ec8d976a5b3f57e6c47a31cc28f77d --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordInstanceImpl.java @@ -0,0 +1,206 @@ +package com.github.novicezk.midjourney.loadbalancer; + + +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.domain.DiscordAccount; +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.result.Message; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.service.DiscordService; +import com.github.novicezk.midjourney.service.DiscordServiceImpl; +import com.github.novicezk.midjourney.service.NotifyService; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.wss.WebSocketStarter; +import com.github.novicezk.midjourney.wss.user.UserWebSocketStarter; +import eu.maxschuster.dataurl.DataUrl; +import lombok.extern.slf4j.Slf4j; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.web.client.RestTemplate; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; + +@Slf4j +public class DiscordInstanceImpl implements DiscordInstance { + private final DiscordAccount account; + private final WebSocketStarter socketStarter; + private final DiscordService service; + private final TaskStoreService taskStoreService; + private final NotifyService notifyService; + + private final ThreadPoolTaskExecutor taskExecutor; + private final List runningTasks; + private final Map> taskFutureMap = Collections.synchronizedMap(new HashMap<>()); + + public DiscordInstanceImpl(DiscordAccount account, UserWebSocketStarter socketStarter, RestTemplate restTemplate, + TaskStoreService taskStoreService, NotifyService notifyService, + String discordServer, Map paramsMap) { + this.account = account; + this.socketStarter = socketStarter; + this.taskStoreService = taskStoreService; + this.notifyService = notifyService; + this.service = new DiscordServiceImpl(account, restTemplate, discordServer, paramsMap); + this.runningTasks = new CopyOnWriteArrayList<>(); + this.taskExecutor = new ThreadPoolTaskExecutor(); + this.taskExecutor.setCorePoolSize(account.getCoreSize()); + this.taskExecutor.setMaxPoolSize(account.getCoreSize()); + this.taskExecutor.setQueueCapacity(account.getQueueSize()); + this.taskExecutor.setThreadNamePrefix("TaskQueue-" + account.getDisplay() + "-"); + this.taskExecutor.initialize(); + } + + @Override + public String getInstanceId() { + return this.account.getChannelId(); + } + + @Override + public DiscordAccount account() { + return this.account; + } + + @Override + public boolean isAlive() { + return this.account.isEnable(); + } + + @Override + public void startWss() throws Exception { + this.socketStarter.setTrying(true); + this.socketStarter.start(); + } + + @Override + public List getRunningTasks() { + return this.runningTasks; + } + + @Override + public void exitTask(Task task) { + try { + Future future = this.taskFutureMap.get(task.getId()); + if (future != null) { + future.cancel(true); + } + saveAndNotify(task); + } finally { + this.runningTasks.remove(task); + this.taskFutureMap.remove(task.getId()); + } + } + + @Override + public Map> getRunningFutures() { + return this.taskFutureMap; + } + + @Override + public synchronized SubmitResultVO submitTask(Task task, Callable> discordSubmit) { + this.taskStoreService.save(task); + int currentWaitNumbers; + try { + currentWaitNumbers = this.taskExecutor.getThreadPoolExecutor().getQueue().size(); + Future future = this.taskExecutor.submit(() -> executeTask(task, discordSubmit)); + this.taskFutureMap.put(task.getId(), future); + } catch (RejectedExecutionException e) { + this.taskStoreService.delete(task.getId()); + return SubmitResultVO.fail(ReturnCode.QUEUE_REJECTED, "队列已满,请稍后尝试") + .setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId()); + } catch (Exception e) { + log.error("submit task error", e); + return SubmitResultVO.fail(ReturnCode.FAILURE, "提交失败,系统异常") + .setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId()); + } + if (currentWaitNumbers == 0) { + return SubmitResultVO.of(ReturnCode.SUCCESS, "提交成功", task.getId()) + .setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId()); + } else { + return SubmitResultVO.of(ReturnCode.IN_QUEUE, "排队中,前面还有" + currentWaitNumbers + "个任务", task.getId()) + .setProperty("numberOfQueues", currentWaitNumbers) + .setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId()); + } + } + + private void executeTask(Task task, Callable> discordSubmit) { + this.runningTasks.add(task); + try { + task.start(); + Message result = discordSubmit.call(); + if (result.getCode() != ReturnCode.SUCCESS) { + task.fail(result.getDescription()); + saveAndNotify(task); + return; + } + saveAndNotify(task); + do { + task.sleep(); + saveAndNotify(task); + } while (task.getStatus() == TaskStatus.IN_PROGRESS); + log.debug("task finished, id: {}, status: {}", task.getId(), task.getStatus()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (Exception e) { + log.error("task execute error", e); + task.fail("执行错误,系统异常"); + saveAndNotify(task); + } finally { + this.runningTasks.remove(task); + this.taskFutureMap.remove(task.getId()); + } + } + + private void saveAndNotify(Task task) { + this.taskStoreService.save(task); + this.notifyService.notifyTaskChange(task); + } + + @Override + public Message imagine(String prompt, String nonce) { + return this.service.imagine(prompt, nonce); + } + + @Override + public Message upscale(String messageId, int index, String messageHash, int messageFlags, String nonce) { + return this.service.upscale(messageId, index, messageHash, messageFlags, nonce); + } + + @Override + public Message variation(String messageId, int index, String messageHash, int messageFlags, String nonce) { + return this.service.variation(messageId, index, messageHash, messageFlags, nonce); + } + + @Override + public Message reroll(String messageId, String messageHash, int messageFlags, String nonce) { + return this.service.reroll(messageId, messageHash, messageFlags, nonce); + } + + @Override + public Message describe(String finalFileName, String nonce) { + return this.service.describe(finalFileName, nonce); + } + + @Override + public Message blend(List finalFileNames, BlendDimensions dimensions, String nonce) { + return this.service.blend(finalFileNames, dimensions, nonce); + } + + @Override + public Message upload(String fileName, DataUrl dataUrl) { + return this.service.upload(fileName, dataUrl); + } + + @Override + public Message sendImageMessage(String content, String finalFileName) { + return this.service.sendImageMessage(content, finalFileName); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordLoadBalancer.java b/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordLoadBalancer.java new file mode 100644 index 0000000000000000000000000000000000000000..56bee1975e2c523f92ce9bb492af62e14b97cd25 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/loadbalancer/DiscordLoadBalancer.java @@ -0,0 +1,83 @@ +package com.github.novicezk.midjourney.loadbalancer; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.loadbalancer.rule.IRule; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +@Component +@RequiredArgsConstructor +public class DiscordLoadBalancer { + private final IRule rule; + + private final List instances = Collections.synchronizedList(new ArrayList<>()); + + public List getAllInstances() { + return this.instances; + } + + public List getAliveInstances() { + return this.instances.stream().filter(DiscordInstance::isAlive).toList(); + } + + public DiscordInstance chooseInstance() { + return this.rule.choose(getAliveInstances()); + } + + public DiscordInstance getDiscordInstance(String instanceId) { + if (CharSequenceUtil.isBlank(instanceId)) { + return null; + } + return this.instances.stream() + .filter(instance -> CharSequenceUtil.equals(instanceId, instance.getInstanceId())) + .findFirst().orElse(null); + } + + public Set getQueueTaskIds() { + Set taskIds = Collections.synchronizedSet(new HashSet<>()); + for (DiscordInstance instance : getAliveInstances()) { + taskIds.addAll(instance.getRunningFutures().keySet()); + } + return taskIds; + } + + public Stream findRunningTask(TaskCondition condition) { + return getAliveInstances().stream().flatMap(instance -> instance.getRunningTasks().stream().filter(condition)); + } + + public Task getRunningTask(String id) { + for (DiscordInstance instance : getAliveInstances()) { + Optional optional = instance.getRunningTasks().stream().filter(t -> id.equals(t.getId())).findFirst(); + if (optional.isPresent()) { + return optional.get(); + } + } + return null; + } + + public Task getRunningTaskByNonce(String nonce) { + if (CharSequenceUtil.isBlank(nonce)) { + return null; + } + TaskCondition condition = new TaskCondition().setNonce(nonce); + for (DiscordInstance instance : getAliveInstances()) { + Optional optional = instance.getRunningTasks().stream().filter(condition).findFirst(); + if (optional.isPresent()) { + return optional.get(); + } + } + return null; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/BestWaitIdleRule.java b/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/BestWaitIdleRule.java new file mode 100644 index 0000000000000000000000000000000000000000..f214df790993f89c3c6171627b677cd57d4e1422 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/BestWaitIdleRule.java @@ -0,0 +1,31 @@ +package com.github.novicezk.midjourney.loadbalancer.rule; + +import com.github.novicezk.midjourney.loadbalancer.DiscordInstance; + +import java.util.List; + +/** + * 最少等待空闲. + * 选择等待数最少的实例,如果都不需要等待,则选择空闲数最多的实例 + */ +public class BestWaitIdleRule implements IRule { + + @Override + public DiscordInstance choose(List instances) { + if (instances.isEmpty()) { + return null; + } + return instances.stream().min((i1, i2) -> { + int wait1 = i1.getRunningFutures().size() - i1.account().getCoreSize(); + int wait2 = i2.getRunningFutures().size() - i2.account().getCoreSize(); + if (wait1 == wait2 && wait1 == 0) { + // 都不需要等待时,选择空闲数最多的 + int idle1 = i1.account().getCoreSize() - i1.getRunningTasks().size(); + int idle2 = i2.account().getCoreSize() - i2.getRunningTasks().size(); + return idle2 - idle1; + } + return wait1 - wait2; + }).orElse(null); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/IRule.java b/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/IRule.java new file mode 100644 index 0000000000000000000000000000000000000000..ea1320588e9ec7fc3795c789396b5665125c2ad4 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/IRule.java @@ -0,0 +1,10 @@ +package com.github.novicezk.midjourney.loadbalancer.rule; + +import com.github.novicezk.midjourney.loadbalancer.DiscordInstance; + +import java.util.List; + +public interface IRule { + + DiscordInstance choose(List instances); +} diff --git a/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/RoundRobinRule.java b/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/RoundRobinRule.java new file mode 100644 index 0000000000000000000000000000000000000000..f02871f43be5115c00c8872c07438be573d43c97 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/loadbalancer/rule/RoundRobinRule.java @@ -0,0 +1,32 @@ +package com.github.novicezk.midjourney.loadbalancer.rule; + +import com.github.novicezk.midjourney.loadbalancer.DiscordInstance; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * 轮询. + */ +public class RoundRobinRule implements IRule { + private final AtomicInteger position = new AtomicInteger(0); + + @Override + public DiscordInstance choose(List instances) { + if (instances.isEmpty()) { + return null; + } + int pos = incrementAndGet(); + return instances.get(pos % instances.size()); + } + + private int incrementAndGet() { + int current; + int next; + do { + current = this.position.get(); + next = current == Integer.MAX_VALUE ? 0 : current + 1; + } while (!this.position.compareAndSet(current, next)); + return next; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/result/Message.java b/src/main/java/com/github/novicezk/midjourney/result/Message.java new file mode 100644 index 0000000000000000000000000000000000000000..50868aa882b69eb216e249fc91d7407ee8e62174 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/result/Message.java @@ -0,0 +1,57 @@ +package com.github.novicezk.midjourney.result; + +import com.github.novicezk.midjourney.ReturnCode; +import lombok.Getter; + +@Getter +public class Message { + private final int code; + private final String description; + private final T result; + + public static Message success() { + return new Message<>(ReturnCode.SUCCESS, "成功"); + } + + public static Message success(T result) { + return new Message<>(ReturnCode.SUCCESS, "成功", result); + } + + public static Message success(int code, String description, T result) { + return new Message<>(code, description, result); + } + + public static Message notFound() { + return new Message<>(ReturnCode.NOT_FOUND, "数据未找到"); + } + + public static Message validationError() { + return new Message<>(ReturnCode.VALIDATION_ERROR, "校验错误"); + } + + public static Message failure() { + return new Message<>(ReturnCode.FAILURE, "系统异常"); + } + + public static Message failure(String description) { + return new Message<>(ReturnCode.FAILURE, description); + } + + public static Message of(int code, String description) { + return new Message<>(code, description); + } + + public static Message of(int code, String description, T result) { + return new Message<>(code, description, result); + } + + private Message(int code, String description) { + this(code, description, null); + } + + private Message(int code, String description, T result) { + this.code = code; + this.description = description; + this.result = result; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/result/SubmitResultVO.java b/src/main/java/com/github/novicezk/midjourney/result/SubmitResultVO.java new file mode 100644 index 0000000000000000000000000000000000000000..d6d1ffb43c391ade678a0d297d5808d0dbb2f742 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/result/SubmitResultVO.java @@ -0,0 +1,62 @@ +package com.github.novicezk.midjourney.result; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.util.HashMap; +import java.util.Map; + +@Data +@ApiModel("提交结果") +public class SubmitResultVO { + + @ApiModelProperty(value = "状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)", required = true, example = "1") + private int code; + + @ApiModelProperty(value = "描述", required = true, example = "提交成功") + private String description; + + @ApiModelProperty(value = "任务ID", example = "1320098173412546") + private String result; + + @ApiModelProperty(value = "扩展字段") + private Map properties = new HashMap<>(); + + public SubmitResultVO setProperty(String name, Object value) { + this.properties.put(name, value); + return this; + } + + public SubmitResultVO removeProperty(String name) { + this.properties.remove(name); + return this; + } + + public Object getProperty(String name) { + return this.properties.get(name); + } + + @SuppressWarnings("unchecked") + public T getPropertyGeneric(String name) { + return (T) getProperty(name); + } + + public T getProperty(String name, Class clz) { + return clz.cast(getProperty(name)); + } + + public static SubmitResultVO of(int code, String description, String result) { + return new SubmitResultVO(code, description, result); + } + + public static SubmitResultVO fail(int code, String description) { + return new SubmitResultVO(code, description, null); + } + + private SubmitResultVO(int code, String description, String result) { + this.code = code; + this.description = description; + this.result = result; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/DiscordService.java b/src/main/java/com/github/novicezk/midjourney/service/DiscordService.java new file mode 100644 index 0000000000000000000000000000000000000000..9ed0059b88624974d7c94a8b9b54121a7fe290bb --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/DiscordService.java @@ -0,0 +1,28 @@ +package com.github.novicezk.midjourney.service; + + +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.result.Message; +import eu.maxschuster.dataurl.DataUrl; + +import java.util.List; + +public interface DiscordService { + + Message imagine(String prompt, String nonce); + + Message upscale(String messageId, int index, String messageHash, int messageFlags, String nonce); + + Message variation(String messageId, int index, String messageHash, int messageFlags, String nonce); + + Message reroll(String messageId, String messageHash, int messageFlags, String nonce); + + Message describe(String finalFileName, String nonce); + + Message blend(List finalFileNames, BlendDimensions dimensions, String nonce); + + Message upload(String fileName, DataUrl dataUrl); + + Message sendImageMessage(String content, String finalFileName); + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/DiscordServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/DiscordServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..79e2079c9a40f94fe4b1d108dcda6d8e1ce45230 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/DiscordServiceImpl.java @@ -0,0 +1,213 @@ +package com.github.novicezk.midjourney.service; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.domain.DiscordAccount; +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.result.Message; +import eu.maxschuster.dataurl.DataUrl; +import lombok.extern.slf4j.Slf4j; +import org.json.JSONArray; +import org.json.JSONObject; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.HttpStatusCodeException; +import org.springframework.web.client.RestTemplate; + +import java.util.List; +import java.util.Map; + +@Slf4j +public class DiscordServiceImpl implements DiscordService { + private static final String DEFAULT_SESSION_ID = "f1a313a09ce079ce252459dc70231f30"; + + private final DiscordAccount account; + private final Map paramsMap; + private final RestTemplate restTemplate; + + private final String discordInteractionUrl; + private final String discordAttachmentUrl; + private final String discordMessageUrl; + + public DiscordServiceImpl(DiscordAccount account, RestTemplate restTemplate, String discordServer, Map paramsMap) { + this.account = account; + this.restTemplate = restTemplate; + this.paramsMap = paramsMap; + this.discordInteractionUrl = discordServer + "/api/v9/interactions"; + this.discordAttachmentUrl = discordServer + "/api/v9/channels/" + account.getChannelId() + "/attachments"; + this.discordMessageUrl = discordServer + "/api/v9/channels/" + account.getChannelId() + "/messages"; + } + + @Override + public Message imagine(String prompt, String nonce) { + String paramsStr = replaceInteractionParams(this.paramsMap.get("imagine"), nonce); + JSONObject params = new JSONObject(paramsStr); + params.getJSONObject("data").getJSONArray("options").getJSONObject(0) + .put("value", prompt); + return postJsonAndCheckStatus(params.toString()); + } + + @Override + public Message upscale(String messageId, int index, String messageHash, int messageFlags, String nonce) { + String paramsStr = replaceInteractionParams(this.paramsMap.get("upscale"), nonce) + .replace("$message_id", messageId) + .replace("$index", String.valueOf(index)) + .replace("$message_hash", messageHash); + paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString(); + return postJsonAndCheckStatus(paramsStr); + } + + @Override + public Message variation(String messageId, int index, String messageHash, int messageFlags, String nonce) { + String paramsStr = replaceInteractionParams(this.paramsMap.get("variation"), nonce) + .replace("$message_id", messageId) + .replace("$index", String.valueOf(index)) + .replace("$message_hash", messageHash); + paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString(); + return postJsonAndCheckStatus(paramsStr); + } + + @Override + public Message reroll(String messageId, String messageHash, int messageFlags, String nonce) { + String paramsStr = replaceInteractionParams(this.paramsMap.get("reroll"), nonce) + .replace("$message_id", messageId) + .replace("$message_hash", messageHash); + paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString(); + return postJsonAndCheckStatus(paramsStr); + } + + @Override + public Message describe(String finalFileName, String nonce) { + String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true); + String paramsStr = replaceInteractionParams(this.paramsMap.get("describe"), nonce) + .replace("$file_name", fileName) + .replace("$final_file_name", finalFileName); + return postJsonAndCheckStatus(paramsStr); + } + + @Override + public Message blend(List finalFileNames, BlendDimensions dimensions, String nonce) { + String paramsStr = replaceInteractionParams(this.paramsMap.get("blend"), nonce); + JSONObject params = new JSONObject(paramsStr); + JSONArray options = params.getJSONObject("data").getJSONArray("options"); + JSONArray attachments = params.getJSONObject("data").getJSONArray("attachments"); + for (int i = 0; i < finalFileNames.size(); i++) { + String finalFileName = finalFileNames.get(i); + String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true); + JSONObject attachment = new JSONObject().put("id", String.valueOf(i)) + .put("filename", fileName) + .put("uploaded_filename", finalFileName); + attachments.put(attachment); + JSONObject option = new JSONObject().put("type", 11) + .put("name", "image" + (i + 1)) + .put("value", i); + options.put(option); + } + options.put(new JSONObject().put("type", 3) + .put("name", "dimensions") + .put("value", "--ar " + dimensions.getValue())); + return postJsonAndCheckStatus(params.toString()); + } + + private String replaceInteractionParams(String paramsStr, String nonce) { + return paramsStr.replace("$guild_id", this.account.getGuildId()) + .replace("$channel_id", this.account.getChannelId()) + .replace("$session_id", DEFAULT_SESSION_ID) + .replace("$nonce", nonce); + } + + @Override + public Message upload(String fileName, DataUrl dataUrl) { + try { + JSONObject fileObj = new JSONObject(); + fileObj.put("filename", fileName); + fileObj.put("file_size", dataUrl.getData().length); + fileObj.put("id", "0"); + JSONObject params = new JSONObject() + .put("files", new JSONArray().put(fileObj)); + ResponseEntity responseEntity = postJson(this.discordAttachmentUrl, params.toString()); + if (responseEntity.getStatusCode() != HttpStatus.OK) { + log.error("上传图片到discord失败, status: {}, msg: {}", responseEntity.getStatusCodeValue(), responseEntity.getBody()); + return Message.of(ReturnCode.VALIDATION_ERROR, "上传图片到discord失败"); + } + JSONArray array = new JSONObject(responseEntity.getBody()).getJSONArray("attachments"); + if (array.length() == 0) { + return Message.of(ReturnCode.VALIDATION_ERROR, "上传图片到discord失败"); + } + String uploadUrl = array.getJSONObject(0).getString("upload_url"); + String uploadFilename = array.getJSONObject(0).getString("upload_filename"); + putFile(uploadUrl, dataUrl); + return Message.success(uploadFilename); + } catch (Exception e) { + log.error("上传图片到discord失败", e); + return Message.of(ReturnCode.FAILURE, "上传图片到discord失败"); + } + } + + @Override + public Message sendImageMessage(String content, String finalFileName) { + String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true); + String paramsStr = this.paramsMap.get("message").replace("$content", content) + .replace("$channel_id", this.account.getChannelId()) + .replace("$file_name", fileName) + .replace("$final_file_name", finalFileName); + ResponseEntity responseEntity = postJson(this.discordMessageUrl, paramsStr); + if (responseEntity.getStatusCode() != HttpStatus.OK) { + log.error("发送图片消息到discord失败, status: {}, msg: {}", responseEntity.getStatusCodeValue(), responseEntity.getBody()); + return Message.of(ReturnCode.VALIDATION_ERROR, "发送图片消息到discord失败"); + } + JSONObject result = new JSONObject(responseEntity.getBody()); + JSONArray attachments = result.optJSONArray("attachments"); + if (!attachments.isEmpty()) { + return Message.success(attachments.getJSONObject(0).optString("url")); + } + return Message.failure("发送图片消息到discord失败: 图片不存在"); + } + + private void putFile(String uploadUrl, DataUrl dataUrl) { + HttpHeaders headers = new HttpHeaders(); + headers.add("User-Agent", this.account.getUserAgent()); + headers.setContentType(MediaType.valueOf(dataUrl.getMimeType())); + headers.setContentLength(dataUrl.getData().length); + HttpEntity requestEntity = new HttpEntity<>(dataUrl.getData(), headers); + this.restTemplate.put(uploadUrl, requestEntity); + } + + private ResponseEntity postJson(String paramsStr) { + return postJson(this.discordInteractionUrl, paramsStr); + } + + private ResponseEntity postJson(String url, String paramsStr) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.set("Authorization", this.account.getUserToken()); + headers.set("User-Agent", this.account.getUserAgent()); + HttpEntity httpEntity = new HttpEntity<>(paramsStr, headers); + return this.restTemplate.postForEntity(url, httpEntity, String.class); + } + + private Message postJsonAndCheckStatus(String paramsStr) { + try { + ResponseEntity responseEntity = postJson(paramsStr); + if (responseEntity.getStatusCode() == HttpStatus.NO_CONTENT) { + return Message.success(); + } + return Message.of(responseEntity.getStatusCodeValue(), CharSequenceUtil.sub(responseEntity.getBody(), 0, 100)); + } catch (HttpStatusCodeException e) { + return convertHttpStatusCodeException(e); + } + } + + private Message convertHttpStatusCodeException(HttpStatusCodeException e) { + try { + JSONObject error = new JSONObject(e.getResponseBodyAsString()); + return Message.of(error.optInt("code", e.getRawStatusCode()), error.optString("message")); + } catch (Exception je) { + return Message.of(e.getRawStatusCode(), CharSequenceUtil.sub(e.getMessage(), 0, 100)); + } + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/NotifyService.java b/src/main/java/com/github/novicezk/midjourney/service/NotifyService.java new file mode 100644 index 0000000000000000000000000000000000000000..7f18454b8721b4eaea02b3dc949afebda54b8a1e --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/NotifyService.java @@ -0,0 +1,10 @@ +package com.github.novicezk.midjourney.service; + + +import com.github.novicezk.midjourney.support.Task; + +public interface NotifyService { + + void notifyTaskChange(Task task); + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/NotifyServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/NotifyServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..d850c13694ea52ed2cded2201b0c1861677080d1 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/NotifyServiceImpl.java @@ -0,0 +1,76 @@ +package com.github.novicezk.midjourney.service; + +import cn.hutool.cache.CacheUtil; +import cn.hutool.cache.impl.TimedCache; +import cn.hutool.core.exceptions.CheckedUtil; +import cn.hutool.core.text.CharSequenceUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.support.Task; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +import java.time.Duration; + +@Slf4j +@Service +public class NotifyServiceImpl implements NotifyService { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private final ThreadPoolTaskExecutor executor; + private final TimedCache taskLocks = CacheUtil.newTimedCache(Duration.ofHours(1).toMillis()); + + public NotifyServiceImpl(ProxyProperties properties) { + this.executor = new ThreadPoolTaskExecutor(); + this.executor.setCorePoolSize(properties.getNotifyPoolSize()); + this.executor.setThreadNamePrefix("TaskNotify-"); + this.executor.initialize(); + } + + @Override + public void notifyTaskChange(Task task) { + String notifyHook = task.getPropertyGeneric(Constants.TASK_PROPERTY_NOTIFY_HOOK); + if (CharSequenceUtil.isBlank(notifyHook)) { + return; + } + String taskId = task.getId(); + TaskStatus taskStatus = task.getStatus(); + Object taskLock = this.taskLocks.get(taskId, (CheckedUtil.Func0Rt) Object::new); + try { + String paramsStr = OBJECT_MAPPER.writeValueAsString(task); + this.executor.execute(() -> { + synchronized (taskLock) { + try { + ResponseEntity responseEntity = postJson(notifyHook, paramsStr); + if (responseEntity.getStatusCode() == HttpStatus.OK) { + log.debug("推送任务变更成功, 任务ID: {}, status: {}, notifyHook: {}", taskId, taskStatus, notifyHook); + } else { + log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, code: {}, msg: {}", taskId, notifyHook, responseEntity.getStatusCodeValue(), responseEntity.getBody()); + } + } catch (Exception e) { + log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, 描述: {}", taskId, notifyHook, e.getMessage()); + } + } + }); + } catch (JsonProcessingException e) { + log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, 描述: {}", taskId, notifyHook, e.getMessage()); + } + } + + private ResponseEntity postJson(String notifyHook, String paramsJson) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity httpEntity = new HttpEntity<>(paramsJson, headers); + return new RestTemplate().postForEntity(notifyHook, httpEntity, String.class); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/TaskService.java b/src/main/java/com/github/novicezk/midjourney/service/TaskService.java new file mode 100644 index 0000000000000000000000000000000000000000..3477d7e0e32baa24f1da507d111af3eda28d9a1b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/TaskService.java @@ -0,0 +1,23 @@ +package com.github.novicezk.midjourney.service; + +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.support.Task; +import eu.maxschuster.dataurl.DataUrl; + +import java.util.List; + +public interface TaskService { + + SubmitResultVO submitImagine(Task task, List dataUrls); + + SubmitResultVO submitUpscale(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags); + + SubmitResultVO submitVariation(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags); + + SubmitResultVO submitReroll(Task task, String targetMessageId, String targetMessageHash, int messageFlags); + + SubmitResultVO submitDescribe(Task task, DataUrl dataUrl); + + SubmitResultVO submitBlend(Task task, List dataUrls, BlendDimensions dimensions); +} \ No newline at end of file diff --git a/src/main/java/com/github/novicezk/midjourney/service/TaskServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/TaskServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..c98c7c4d3fb395e64e61ffeb04b2afe1cd39a5a5 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/TaskServiceImpl.java @@ -0,0 +1,128 @@ +package com.github.novicezk.midjourney.service; + +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.loadbalancer.DiscordInstance; +import com.github.novicezk.midjourney.loadbalancer.DiscordLoadBalancer; +import com.github.novicezk.midjourney.result.Message; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.util.MimeTypeUtils; +import eu.maxschuster.dataurl.DataUrl; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; + +@Slf4j +@Service +@RequiredArgsConstructor +public class TaskServiceImpl implements TaskService { + private final TaskStoreService taskStoreService; + private final DiscordLoadBalancer discordLoadBalancer; + + @Override + public SubmitResultVO submitImagine(Task task, List dataUrls) { + DiscordInstance instance = this.discordLoadBalancer.chooseInstance(); + if (instance == null) { + return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例"); + } + task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, instance.getInstanceId()); + return instance.submitTask(task, () -> { + List imageUrls = new ArrayList<>(); + for (DataUrl dataUrl : dataUrls) { + String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType()); + Message uploadResult = instance.upload(taskFileName, dataUrl); + if (uploadResult.getCode() != ReturnCode.SUCCESS) { + return Message.of(uploadResult.getCode(), uploadResult.getDescription()); + } + String finalFileName = uploadResult.getResult(); + Message sendImageResult = instance.sendImageMessage("upload image: " + finalFileName, finalFileName); + if (sendImageResult.getCode() != ReturnCode.SUCCESS) { + return Message.of(sendImageResult.getCode(), sendImageResult.getDescription()); + } + imageUrls.add(sendImageResult.getResult()); + } + if (!imageUrls.isEmpty()) { + task.setPrompt(String.join(" ", imageUrls) + " " + task.getPrompt()); + task.setPromptEn(String.join(" ", imageUrls) + " " + task.getPromptEn()); + task.setDescription("/imagine " + task.getPrompt()); + this.taskStoreService.save(task); + } + return instance.imagine(task.getPromptEn(), task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)); + }); + } + + @Override + public SubmitResultVO submitUpscale(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags) { + String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID); + DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId); + if (discordInstance == null || !discordInstance.isAlive()) { + return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId); + } + return discordInstance.submitTask(task, () -> discordInstance.upscale(targetMessageId, index, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE))); + } + + @Override + public SubmitResultVO submitVariation(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags) { + String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID); + DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId); + if (discordInstance == null || !discordInstance.isAlive()) { + return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId); + } + return discordInstance.submitTask(task, () -> discordInstance.variation(targetMessageId, index, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE))); + } + + @Override + public SubmitResultVO submitReroll(Task task, String targetMessageId, String targetMessageHash, int messageFlags) { + String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID); + DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId); + if (discordInstance == null || !discordInstance.isAlive()) { + return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId); + } + return discordInstance.submitTask(task, () -> discordInstance.reroll(targetMessageId, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE))); + } + + @Override + public SubmitResultVO submitDescribe(Task task, DataUrl dataUrl) { + DiscordInstance discordInstance = this.discordLoadBalancer.chooseInstance(); + if (discordInstance == null) { + return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例"); + } + task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.getInstanceId()); + return discordInstance.submitTask(task, () -> { + String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType()); + Message uploadResult = discordInstance.upload(taskFileName, dataUrl); + if (uploadResult.getCode() != ReturnCode.SUCCESS) { + return Message.of(uploadResult.getCode(), uploadResult.getDescription()); + } + String finalFileName = uploadResult.getResult(); + return discordInstance.describe(finalFileName, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)); + }); + } + + @Override + public SubmitResultVO submitBlend(Task task, List dataUrls, BlendDimensions dimensions) { + DiscordInstance discordInstance = this.discordLoadBalancer.chooseInstance(); + if (discordInstance == null) { + return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例"); + } + task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.getInstanceId()); + return discordInstance.submitTask(task, () -> { + List finalFileNames = new ArrayList<>(); + for (DataUrl dataUrl : dataUrls) { + String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType()); + Message uploadResult = discordInstance.upload(taskFileName, dataUrl); + if (uploadResult.getCode() != ReturnCode.SUCCESS) { + return Message.of(uploadResult.getCode(), uploadResult.getDescription()); + } + finalFileNames.add(uploadResult.getResult()); + } + return discordInstance.blend(finalFileNames, dimensions, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)); + }); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/TaskStoreService.java b/src/main/java/com/github/novicezk/midjourney/service/TaskStoreService.java new file mode 100644 index 0000000000000000000000000000000000000000..cec8e46adce16e46e16a05fd4162bf83e395afe8 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/TaskStoreService.java @@ -0,0 +1,23 @@ +package com.github.novicezk.midjourney.service; + + +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; + +import java.util.List; + +public interface TaskStoreService { + + void save(Task task); + + void delete(String id); + + Task get(String id); + + List list(); + + List list(TaskCondition condition); + + Task findOne(TaskCondition condition); + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/TranslateService.java b/src/main/java/com/github/novicezk/midjourney/service/TranslateService.java new file mode 100644 index 0000000000000000000000000000000000000000..a413d483567727b5452f65d2496880a99d6a6021 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/TranslateService.java @@ -0,0 +1,13 @@ +package com.github.novicezk.midjourney.service; + +import java.util.regex.Pattern; + +public interface TranslateService { + + String translateToEnglish(String prompt); + + default boolean containsChinese(String prompt) { + return Pattern.compile("[\u4e00-\u9fa5]").matcher(prompt).find(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/store/InMemoryTaskStoreServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/store/InMemoryTaskStoreServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..4e25a5e279e2be7bea1224b2ca87fb320c7a61e6 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/store/InMemoryTaskStoreServiceImpl.java @@ -0,0 +1,52 @@ +package com.github.novicezk.midjourney.service.store; + +import cn.hutool.cache.CacheUtil; +import cn.hutool.cache.impl.TimedCache; +import cn.hutool.core.collection.ListUtil; +import cn.hutool.core.stream.StreamUtil; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; + +import java.time.Duration; +import java.util.List; + + +public class InMemoryTaskStoreServiceImpl implements TaskStoreService { + private final TimedCache taskMap; + + public InMemoryTaskStoreServiceImpl(Duration timeout) { + this.taskMap = CacheUtil.newTimedCache(timeout.toMillis()); + } + + @Override + public void save(Task task) { + this.taskMap.put(task.getId(), task); + } + + @Override + public void delete(String key) { + this.taskMap.remove(key); + } + + @Override + public Task get(String key) { + return this.taskMap.get(key); + } + + @Override + public List list() { + return ListUtil.toList(this.taskMap.iterator()); + } + + @Override + public List list(TaskCondition condition) { + return StreamUtil.of(this.taskMap.iterator()).filter(condition).toList(); + } + + @Override + public Task findOne(TaskCondition condition) { + return StreamUtil.of(this.taskMap.iterator()).filter(condition).findFirst().orElse(null); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/store/RedisTaskStoreServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/store/RedisTaskStoreServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..287c7bac8e33716f581504ad470003ff4e18ac1c --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/store/RedisTaskStoreServiceImpl.java @@ -0,0 +1,74 @@ +package com.github.novicezk.midjourney.service.store; + +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import org.springframework.data.redis.core.Cursor; +import org.springframework.data.redis.core.RedisCallback; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.ScanOptions; +import org.springframework.data.redis.core.ValueOperations; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +public class RedisTaskStoreServiceImpl implements TaskStoreService { + private static final String KEY_PREFIX = "mj-task-store::"; + + private final Duration timeout; + private final RedisTemplate redisTemplate; + + public RedisTaskStoreServiceImpl(Duration timeout, RedisTemplate redisTemplate) { + this.timeout = timeout; + this.redisTemplate = redisTemplate; + } + + @Override + public void save(Task task) { + this.redisTemplate.opsForValue().set(getRedisKey(task.getId()), task, this.timeout); + } + + @Override + public void delete(String id) { + this.redisTemplate.delete(getRedisKey(id)); + } + + @Override + public Task get(String id) { + return this.redisTemplate.opsForValue().get(getRedisKey(id)); + } + + @Override + public List list() { + Set keys = this.redisTemplate.execute((RedisCallback>) connection -> { + Cursor cursor = connection.scan(ScanOptions.scanOptions().match(KEY_PREFIX + "*").count(1000).build()); + return cursor.stream().map(String::new).collect(Collectors.toSet()); + }); + if (keys == null || keys.isEmpty()) { + return Collections.emptyList(); + } + ValueOperations operations = this.redisTemplate.opsForValue(); + return keys.stream().map(operations::get) + .filter(Objects::nonNull) + .toList(); + } + + @Override + public List list(TaskCondition condition) { + return list().stream().filter(condition).toList(); + } + + @Override + public Task findOne(TaskCondition condition) { + return list().stream().filter(condition).findFirst().orElse(null); + } + + private String getRedisKey(String id) { + return KEY_PREFIX + id; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/translate/BaiduTranslateServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/translate/BaiduTranslateServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..e672862c3e5bb697b77aa24e90be406963b68560 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/translate/BaiduTranslateServiceImpl.java @@ -0,0 +1,56 @@ +package com.github.novicezk.midjourney.service.translate; + + +import cn.hutool.core.exceptions.ValidateException; +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.util.RandomUtil; +import cn.hutool.crypto.digest.MD5; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.service.TranslateService; +import lombok.extern.slf4j.Slf4j; +import org.json.JSONObject; +import org.springframework.beans.factory.support.BeanDefinitionValidationException; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +@Slf4j +public class BaiduTranslateServiceImpl implements TranslateService { + private static final String TRANSLATE_API = "https://fanyi-api.baidu.com/api/trans/vip/translate"; + + private final String appid; + private final String appSecret; + + public BaiduTranslateServiceImpl(ProxyProperties.BaiduTranslateConfig translateConfig) { + this.appid = translateConfig.getAppid(); + this.appSecret = translateConfig.getAppSecret(); + if (!CharSequenceUtil.isAllNotBlank(this.appid, this.appSecret)) { + throw new BeanDefinitionValidationException("mj-proxy.baidu-translate.appid或mj-proxy.baidu-translate.app-secret未配置"); + } + } + + @Override + public String translateToEnglish(String prompt) { + if (!containsChinese(prompt)) { + return prompt; + } + String salt = RandomUtil.randomNumbers(5); + String sign = MD5.create().digestHex(this.appid + prompt + salt + this.appSecret); + String url = TRANSLATE_API + "?from=zh&to=en&appid=" + this.appid + "&salt=" + salt + "&q=" + prompt + "&sign=" + sign; + try { + ResponseEntity responseEntity = new RestTemplate().getForEntity(url, String.class); + if (responseEntity.getStatusCode() != HttpStatus.OK || CharSequenceUtil.isBlank(responseEntity.getBody())) { + throw new ValidateException(responseEntity.getStatusCodeValue() + " - " + responseEntity.getBody()); + } + JSONObject result = new JSONObject(responseEntity.getBody()); + if (result.has("error_code")) { + throw new ValidateException(result.getString("error_code") + " - " + result.getString("error_msg")); + } + return result.getJSONArray("trans_result").getJSONObject(0).getString("dst"); + } catch (Exception e) { + log.warn("调用百度翻译失败: {}", e.getMessage()); + } + return prompt; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/translate/GPTTranslateServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/translate/GPTTranslateServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..f391f1987ecfcfe1d6d30002c6b09657d6d491ab --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/translate/GPTTranslateServiceImpl.java @@ -0,0 +1,83 @@ +package com.github.novicezk.midjourney.service.translate; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.service.TranslateService; +import com.unfbx.chatgpt.OpenAiClient; +import com.unfbx.chatgpt.entity.chat.ChatChoice; +import com.unfbx.chatgpt.entity.chat.ChatCompletion; +import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse; +import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.function.KeyRandomStrategy; +import com.unfbx.chatgpt.interceptor.OpenAILogger; +import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import okhttp3.logging.HttpLoggingInterceptor; +import org.springframework.beans.factory.support.BeanDefinitionValidationException; + +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; + +@Slf4j +public class GPTTranslateServiceImpl implements TranslateService { + private final OpenAiClient openAiClient; + private final ProxyProperties.OpenaiConfig openaiConfig; + + public GPTTranslateServiceImpl(ProxyProperties properties) { + this.openaiConfig = properties.getOpenai(); + if (CharSequenceUtil.isBlank(this.openaiConfig.getGptApiKey())) { + throw new BeanDefinitionValidationException("mj-proxy.openai.gpt-api-key未配置"); + } + HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger()); + httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS); + OkHttpClient.Builder okHttpBuilder = new OkHttpClient.Builder() + .addInterceptor(httpLoggingInterceptor) + .addInterceptor(new OpenAiResponseInterceptor()) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .readTimeout(30, TimeUnit.SECONDS); + if (CharSequenceUtil.isNotBlank(properties.getProxy().getHost())) { + Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(properties.getProxy().getHost(), properties.getProxy().getPort())); + okHttpBuilder.proxy(proxy); + } + OpenAiClient.Builder apiBuilder = OpenAiClient.builder() + .apiKey(Collections.singletonList(this.openaiConfig.getGptApiKey())) + .keyStrategy(new KeyRandomStrategy()) + .okHttpClient(okHttpBuilder.build()); + if (CharSequenceUtil.isNotBlank(this.openaiConfig.getGptApiUrl())) { + apiBuilder.apiHost(this.openaiConfig.getGptApiUrl()); + } + this.openAiClient = apiBuilder.build(); + } + + @Override + public String translateToEnglish(String prompt) { + if (!containsChinese(prompt)) { + return prompt; + } + Message m1 = Message.builder().role(Message.Role.SYSTEM).content("把中文翻译成英文").build(); + Message m2 = Message.builder().role(Message.Role.USER).content(prompt).build(); + ChatCompletion chatCompletion = ChatCompletion.builder() + .messages(Arrays.asList(m1, m2)) + .model(this.openaiConfig.getModel()) + .temperature(this.openaiConfig.getTemperature()) + .maxTokens(this.openaiConfig.getMaxTokens()) + .build(); + ChatCompletionResponse chatCompletionResponse = this.openAiClient.chatCompletion(chatCompletion); + try { + List choices = chatCompletionResponse.getChoices(); + if (!choices.isEmpty()) { + return choices.get(0).getMessage().getContent(); + } + } catch (Exception e) { + log.warn("调用chat-gpt接口翻译中文失败: {}", e.getMessage()); + } + return prompt; + } +} \ No newline at end of file diff --git a/src/main/java/com/github/novicezk/midjourney/service/translate/NoTranslateServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/translate/NoTranslateServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..3db1e688053e83cdae8be21fecfd8522a6c6c40d --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/translate/NoTranslateServiceImpl.java @@ -0,0 +1,14 @@ +package com.github.novicezk.midjourney.service.translate; + + +import com.github.novicezk.midjourney.service.TranslateService; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class NoTranslateServiceImpl implements TranslateService { + + @Override + public String translateToEnglish(String prompt) { + return prompt; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/ApiAuthorizeInterceptor.java b/src/main/java/com/github/novicezk/midjourney/support/ApiAuthorizeInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..cf5eef87ec6219cb8df198e9e3136ad97304df03 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/ApiAuthorizeInterceptor.java @@ -0,0 +1,32 @@ +package com.github.novicezk.midjourney.support; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ProxyProperties; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; +import org.springframework.web.servlet.HandlerInterceptor; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +@Component +@RequiredArgsConstructor +public class ApiAuthorizeInterceptor implements HandlerInterceptor { + private final ProxyProperties properties; + + @Override + public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { + if (CharSequenceUtil.isBlank(this.properties.getApiSecret())) { + return true; + } + String apiSecret = request.getHeader(Constants.API_SECRET_HEADER_NAME); + boolean authorized = CharSequenceUtil.equals(apiSecret, this.properties.getApiSecret()); + if (!authorized) { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + return authorized; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/DiscordAccountHelper.java b/src/main/java/com/github/novicezk/midjourney/support/DiscordAccountHelper.java new file mode 100644 index 0000000000000000000000000000000000000000..4ce5875f477dbd79a58398f12de7db0c04f7cbc2 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/DiscordAccountHelper.java @@ -0,0 +1,44 @@ +package com.github.novicezk.midjourney.support; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.domain.DiscordAccount; +import com.github.novicezk.midjourney.loadbalancer.DiscordInstance; +import com.github.novicezk.midjourney.loadbalancer.DiscordInstanceImpl; +import com.github.novicezk.midjourney.service.NotifyService; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.wss.handle.MessageHandler; +import com.github.novicezk.midjourney.wss.user.UserMessageListener; +import com.github.novicezk.midjourney.wss.user.UserWebSocketStarter; +import lombok.RequiredArgsConstructor; +import org.springframework.web.client.RestTemplate; + +import java.util.List; +import java.util.Map; + +@RequiredArgsConstructor +public class DiscordAccountHelper { + private final DiscordHelper discordHelper; + private final ProxyProperties properties; + private final RestTemplate restTemplate; + private final TaskStoreService taskStoreService; + private final NotifyService notifyService; + private final List messageHandlers; + private final Map paramsMap; + + public DiscordInstance createDiscordInstance(DiscordAccount account) { + if (!CharSequenceUtil.isAllNotBlank(account.getGuildId(), account.getChannelId(), account.getUserToken())) { + throw new IllegalArgumentException("guildId, channelId, userToken must not be blank"); + } + if (CharSequenceUtil.isBlank(account.getUserAgent())) { + account.setUserAgent(Constants.DEFAULT_DISCORD_USER_AGENT); + } + var messageListener = new UserMessageListener(account, this.messageHandlers); + var webSocketStarter = new UserWebSocketStarter(this.discordHelper.getWss(), account, messageListener, this.properties.getProxy()); + return new DiscordInstanceImpl(account, webSocketStarter, this.restTemplate, + this.taskStoreService, this.notifyService, + this.discordHelper.getServer(), this.paramsMap); + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/DiscordAccountInitializer.java b/src/main/java/com/github/novicezk/midjourney/support/DiscordAccountInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..1bba6bea75d7f5b4e62340e3b416f7342e99e510 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/DiscordAccountInitializer.java @@ -0,0 +1,72 @@ +package com.github.novicezk.midjourney.support; + + +import cn.hutool.core.bean.BeanUtil; +import cn.hutool.core.exceptions.ValidateException; +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.domain.DiscordAccount; +import com.github.novicezk.midjourney.loadbalancer.DiscordInstance; +import com.github.novicezk.midjourney.loadbalancer.DiscordLoadBalancer; +import com.github.novicezk.midjourney.util.AsyncLockUtils; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.logging.log4j.util.Strings; +import org.springframework.boot.ApplicationArguments; +import org.springframework.boot.ApplicationRunner; +import org.springframework.stereotype.Component; + +import java.time.Duration; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +@Slf4j +@Component +@RequiredArgsConstructor +public class DiscordAccountInitializer implements ApplicationRunner { + private final DiscordLoadBalancer discordLoadBalancer; + private final DiscordAccountHelper discordAccountHelper; + private final ProxyProperties properties; + + @Override + public void run(ApplicationArguments args) throws Exception { + ProxyProperties.ProxyConfig proxy = this.properties.getProxy(); + if (Strings.isNotBlank(proxy.getHost())) { + System.setProperty("http.proxyHost", proxy.getHost()); + System.setProperty("http.proxyPort", String.valueOf(proxy.getPort())); + System.setProperty("https.proxyHost", proxy.getHost()); + System.setProperty("https.proxyPort", String.valueOf(proxy.getPort())); + } + + List configAccounts = this.properties.getAccounts(); + if (CharSequenceUtil.isNotBlank(this.properties.getDiscord().getChannelId())) { + configAccounts.add(this.properties.getDiscord()); + } + List instances = this.discordLoadBalancer.getAllInstances(); + for (ProxyProperties.DiscordAccountConfig configAccount : configAccounts) { + DiscordAccount account = new DiscordAccount(); + BeanUtil.copyProperties(configAccount, account); + account.setId(configAccount.getChannelId()); + try { + DiscordInstance instance = this.discordAccountHelper.createDiscordInstance(account); + if (!account.isEnable()) { + continue; + } + instance.startWss(); + AsyncLockUtils.LockObject lock = AsyncLockUtils.waitForLock("wss:" + account.getChannelId(), Duration.ofSeconds(10)); + if (ReturnCode.SUCCESS != lock.getProperty("code", Integer.class, 0)) { + throw new ValidateException(lock.getProperty("description", String.class)); + } + instances.add(instance); + } catch (Exception e) { + log.error("Account({}) init fail, disabled: {}", account.getDisplay(), e.getMessage()); + account.setEnable(false); + } + } + Set enableInstanceIds = instances.stream().filter(DiscordInstance::isAlive).map(DiscordInstance::getInstanceId).collect(Collectors.toSet()); + log.info("当前可用账号数 [{}] - {}", enableInstanceIds.size(), String.join(", ", enableInstanceIds)); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/DiscordHelper.java b/src/main/java/com/github/novicezk/midjourney/support/DiscordHelper.java new file mode 100644 index 0000000000000000000000000000000000000000..0107c1122e0da21acc901dce507c55fa120738a1 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/DiscordHelper.java @@ -0,0 +1,88 @@ +package com.github.novicezk.midjourney.support; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; + +@Component +@RequiredArgsConstructor +public class DiscordHelper { + private final ProxyProperties properties; + /** + * DISCORD_SERVER_URL. + */ + public static final String DISCORD_SERVER_URL = "https://discord.com"; + /** + * DISCORD_CDN_URL. + */ + public static final String DISCORD_CDN_URL = "https://cdn.discordapp.com"; + /** + * DISCORD_WSS_URL. + */ + public static final String DISCORD_WSS_URL = "wss://gateway.discord.gg"; + + public String getServer() { + if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getServer())) { + return DISCORD_SERVER_URL; + } + String serverUrl = this.properties.getNgDiscord().getServer(); + if (serverUrl.endsWith("/")) { + serverUrl = serverUrl.substring(0, serverUrl.length() - 1); + } + return serverUrl; + } + + public String getCdn() { + if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getCdn())) { + return DISCORD_CDN_URL; + } + String cdnUrl = this.properties.getNgDiscord().getCdn(); + if (cdnUrl.endsWith("/")) { + cdnUrl = cdnUrl.substring(0, cdnUrl.length() - 1); + } + return cdnUrl; + } + + public String getWss() { + if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getWss())) { + return DISCORD_WSS_URL; + } + String wssUrl = this.properties.getNgDiscord().getWss(); + if (wssUrl.endsWith("/")) { + wssUrl = wssUrl.substring(0, wssUrl.length() - 1); + } + return wssUrl; + } + + public String findTaskIdWithCdnUrl(String url) { + if (!CharSequenceUtil.startWith(url, DISCORD_CDN_URL)) { + return null; + } + int hashStartIndex = url.lastIndexOf("/"); + String taskId = CharSequenceUtil.subBefore(url.substring(hashStartIndex + 1), ".", true); + if (CharSequenceUtil.length(taskId) == 16) { + return taskId; + } + return null; + } + + public String getMessageHash(String imageUrl) { + if (CharSequenceUtil.isBlank(imageUrl)) { + return null; + } + if (CharSequenceUtil.endWith(imageUrl, "_grid_0.webp")) { + int hashStartIndex = imageUrl.lastIndexOf("/"); + if (hashStartIndex < 0) { + return null; + } + return CharSequenceUtil.sub(imageUrl, hashStartIndex + 1, imageUrl.length() - "_grid_0.webp".length()); + } + int hashStartIndex = imageUrl.lastIndexOf("_"); + if (hashStartIndex < 0) { + return null; + } + return CharSequenceUtil.subBefore(imageUrl.substring(hashStartIndex + 1), ".", true); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/Task.java b/src/main/java/com/github/novicezk/midjourney/support/Task.java new file mode 100644 index 0000000000000000000000000000000000000000..3f10fdc9b921c1df17a75e1d9934bc3cd520ed2f --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/Task.java @@ -0,0 +1,68 @@ +package com.github.novicezk.midjourney.support; + +import com.github.novicezk.midjourney.domain.DomainObject; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.enums.TaskStatus; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + +import java.io.Serial; + +@Data +@EqualsAndHashCode(callSuper = true) +@ApiModel("任务") +public class Task extends DomainObject { + @Serial + private static final long serialVersionUID = -674915748204390789L; + + @ApiModelProperty("任务类型") + private TaskAction action; + @ApiModelProperty("任务状态") + private TaskStatus status = TaskStatus.NOT_START; + + @ApiModelProperty("提示词") + private String prompt; + @ApiModelProperty("提示词-英文") + private String promptEn; + + @ApiModelProperty("任务描述") + private String description; + @ApiModelProperty("自定义参数") + private String state; + + @ApiModelProperty("提交时间") + private Long submitTime; + @ApiModelProperty("开始执行时间") + private Long startTime; + @ApiModelProperty("结束时间") + private Long finishTime; + + @ApiModelProperty("图片url") + private String imageUrl; + + @ApiModelProperty("任务进度") + private String progress; + @ApiModelProperty("失败原因") + private String failReason; + + public void start() { + this.startTime = System.currentTimeMillis(); + this.status = TaskStatus.SUBMITTED; + this.progress = "0%"; + } + + public void success() { + this.finishTime = System.currentTimeMillis(); + this.status = TaskStatus.SUCCESS; + this.progress = "100%"; + } + + public void fail(String reason) { + this.finishTime = System.currentTimeMillis(); + this.status = TaskStatus.FAILURE; + this.failReason = reason; + this.progress = ""; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/TaskCondition.java b/src/main/java/com/github/novicezk/midjourney/support/TaskCondition.java new file mode 100644 index 0000000000000000000000000000000000000000..18977811dbd211a527307af272309841f498d916 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/TaskCondition.java @@ -0,0 +1,74 @@ +package com.github.novicezk.midjourney.support; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.enums.TaskStatus; +import lombok.Data; +import lombok.experimental.Accessors; + +import java.util.Set; +import java.util.function.Predicate; + + +@Data +@Accessors(chain = true) +public class TaskCondition implements Predicate { + private String id; + + private Set statusSet; + private Set actionSet; + + private String prompt; + private String promptEn; + private String description; + + private String finalPromptEn; + private String messageId; + private String messageHash; + private String progressMessageId; + private String nonce; + + @Override + public boolean test(Task task) { + if (task == null) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.id) && !this.id.equals(task.getId())) { + return false; + } + if (this.statusSet != null && !this.statusSet.isEmpty() && !this.statusSet.contains(task.getStatus())) { + return false; + } + if (this.actionSet != null && !this.actionSet.isEmpty() && !this.actionSet.contains(task.getAction())) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.prompt) && !this.prompt.equals(task.getPrompt())) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.promptEn) && !this.promptEn.equals(task.getPromptEn())) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.description) && !CharSequenceUtil.contains(task.getDescription(), this.description)) { + return false; + } + + if (CharSequenceUtil.isNotBlank(this.finalPromptEn) && !this.finalPromptEn.equals(task.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT))) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.messageId) && !this.messageId.equals(task.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID))) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.messageHash) && !this.messageHash.equals(task.getProperty(Constants.TASK_PROPERTY_MESSAGE_HASH))) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.progressMessageId) && !this.progressMessageId.equals(task.getProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID))) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.nonce) && !this.nonce.equals(task.getProperty(Constants.TASK_PROPERTY_NONCE))) { + return false; + } + return true; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/TaskTimeoutSchedule.java b/src/main/java/com/github/novicezk/midjourney/support/TaskTimeoutSchedule.java new file mode 100644 index 0000000000000000000000000000000000000000..8350d955b1e223ec5e2baa894adc4950b4c9c005 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/TaskTimeoutSchedule.java @@ -0,0 +1,38 @@ +package com.github.novicezk.midjourney.support; + +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.loadbalancer.DiscordLoadBalancer; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +@Slf4j +@Component +@RequiredArgsConstructor +public class TaskTimeoutSchedule { + private final DiscordLoadBalancer discordLoadBalancer; + + @Scheduled(fixedRate = 30000L) + public void checkTasks() { + this.discordLoadBalancer.getAliveInstances().forEach(instance -> { + long timeout = TimeUnit.MINUTES.toMillis(instance.account().getTimeoutMinutes()); + List tasks = instance.getRunningTasks().stream() + .filter(t -> System.currentTimeMillis() - t.getStartTime() > timeout) + .toList(); + for (Task task : tasks) { + if (Set.of(TaskStatus.FAILURE, TaskStatus.SUCCESS).contains(task.getStatus())) { + log.warn("task status is failure/success but is in the queue, end it. id: {}", task.getId()); + } else { + log.debug("task timeout, id: {}", task.getId()); + task.fail("任务超时"); + } + instance.exitTask(task); + } + }); + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/AsyncLockUtils.java b/src/main/java/com/github/novicezk/midjourney/util/AsyncLockUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..645379b8dc9e376855b0c8c2dbdf9c8b06b3e122 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/AsyncLockUtils.java @@ -0,0 +1,59 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.cache.CacheUtil; +import cn.hutool.cache.impl.TimedCache; +import cn.hutool.core.thread.ThreadUtil; +import com.github.novicezk.midjourney.domain.DomainObject; +import lombok.experimental.UtilityClass; + +import java.time.Duration; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +@UtilityClass +public class AsyncLockUtils { + private static final TimedCache LOCK_MAP = CacheUtil.newTimedCache(Duration.ofDays(1).toMillis()); + + public static synchronized LockObject getLock(String key) { + return LOCK_MAP.get(key); + } + + public static LockObject waitForLock(String key, Duration duration) throws TimeoutException { + LockObject lockObject; + synchronized (LOCK_MAP) { + if (!LOCK_MAP.containsKey(key)) { + LOCK_MAP.put(key, new LockObject(key)); + } + lockObject = LOCK_MAP.get(key); + } + Future future = ThreadUtil.execAsync(() -> { + try { + lockObject.sleep(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + try { + future.get(duration.toMillis(), TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { + // do nothing + } catch (TimeoutException e) { + future.cancel(true); + throw new TimeoutException("Wait Timeout"); + } finally { + LOCK_MAP.remove(lockObject.getId()); + } + return lockObject; + } + + public static class LockObject extends DomainObject { + + public LockObject(String id) { + this.id = id; + } + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/BannedPromptUtils.java b/src/main/java/com/github/novicezk/midjourney/util/BannedPromptUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..13b155b25f61479c95914010f1f7424639e708bd --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/BannedPromptUtils.java @@ -0,0 +1,43 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.core.io.FileUtil; +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.exception.BannedPromptException; +import lombok.experimental.UtilityClass; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +@UtilityClass +public class BannedPromptUtils { + private static final String BANNED_WORDS_FILE_PATH = "/home/spring/config/banned-words.txt"; + private final List BANNED_WORDS; + + static { + List lines; + File file = new File(BANNED_WORDS_FILE_PATH); + if (file.exists()) { + lines = FileUtil.readLines(file, StandardCharsets.UTF_8); + } else { + var resource = BannedPromptUtils.class.getResource("/banned-words.txt"); + lines = FileUtil.readLines(resource, StandardCharsets.UTF_8); + } + BANNED_WORDS = lines.stream().filter(CharSequenceUtil::isNotBlank).toList(); + } + + public static void checkBanned(String promptEn) throws BannedPromptException { + String finalPromptEn = promptEn.toLowerCase(Locale.ENGLISH); + for (String word : BANNED_WORDS) { + Matcher matcher = Pattern.compile("\\b" + word + "\\b").matcher(finalPromptEn); + if (matcher.find()) { + int index = CharSequenceUtil.indexOfIgnoreCase(promptEn, word); + throw new BannedPromptException(promptEn.substring(index, index + word.length())); + } + } + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/ContentParseData.java b/src/main/java/com/github/novicezk/midjourney/util/ContentParseData.java new file mode 100644 index 0000000000000000000000000000000000000000..4ff1af679ff969b33dca434a462889738cdd509b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/ContentParseData.java @@ -0,0 +1,9 @@ +package com.github.novicezk.midjourney.util; + +import lombok.Data; + +@Data +public class ContentParseData { + protected String prompt; + protected String status; +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/ConvertUtils.java b/src/main/java/com/github/novicezk/midjourney/util/ConvertUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..e56dc8b7df2522eb45ef6afa96182b13f3d5b5f7 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/ConvertUtils.java @@ -0,0 +1,85 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.enums.TaskAction; +import eu.maxschuster.dataurl.DataUrl; +import eu.maxschuster.dataurl.DataUrlSerializer; +import eu.maxschuster.dataurl.IDataUrlSerializer; +import lombok.experimental.UtilityClass; + +import java.net.MalformedURLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +@UtilityClass +public class ConvertUtils { + /** + * content正则匹配prompt和进度. + */ + public static final String CONTENT_REGEX = ".*?\\*\\*(.*?)\\*\\*.+<@\\d+> \\((.*?)\\)"; + + public static ContentParseData parseContent(String content) { + return parseContent(content, CONTENT_REGEX); + } + + public static ContentParseData parseContent(String content, String regex) { + if (CharSequenceUtil.isBlank(content)) { + return null; + } + Matcher matcher = Pattern.compile(regex).matcher(content); + if (!matcher.find()) { + return null; + } + ContentParseData parseData = new ContentParseData(); + parseData.setPrompt(matcher.group(1)); + parseData.setStatus(matcher.group(2)); + return parseData; + } + + public static List convertBase64Array(List base64Array) throws MalformedURLException { + if (base64Array == null || base64Array.isEmpty()) { + return Collections.emptyList(); + } + IDataUrlSerializer serializer = new DataUrlSerializer(); + List dataUrlList = new ArrayList<>(); + for (String base64 : base64Array) { + DataUrl dataUrl = serializer.unserialize(base64); + dataUrlList.add(dataUrl); + } + return dataUrlList; + } + + public static TaskChangeParams convertChangeParams(String content) { + List split = CharSequenceUtil.split(content, " "); + if (split.size() != 2) { + return null; + } + String action = split.get(1).toLowerCase(); + TaskChangeParams changeParams = new TaskChangeParams(); + changeParams.setId(split.get(0)); + if (action.charAt(0) == 'u') { + changeParams.setAction(TaskAction.UPSCALE); + } else if (action.charAt(0) == 'v') { + changeParams.setAction(TaskAction.VARIATION); + } else if (action.equals("r")) { + changeParams.setAction(TaskAction.REROLL); + return changeParams; + } else { + return null; + } + try { + int index = Integer.parseInt(action.substring(1, 2)); + if (index < 1 || index > 4) { + return null; + } + changeParams.setIndex(index); + } catch (Exception e) { + return null; + } + return changeParams; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/MimeTypeUtils.java b/src/main/java/com/github/novicezk/midjourney/util/MimeTypeUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..6d7108372bb23c553d95ecaef28ac999c8bb36f2 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/MimeTypeUtils.java @@ -0,0 +1,45 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.core.io.FileUtil; +import cn.hutool.core.text.CharSequenceUtil; +import lombok.experimental.UtilityClass; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@UtilityClass +public class MimeTypeUtils { + private final Map> MIME_TYPE_MAP; + + static { + MIME_TYPE_MAP = new HashMap<>(); + var resource = MimeTypeUtils.class.getResource("/mime.types"); + var lines = FileUtil.readLines(resource, StandardCharsets.UTF_8); + for (var line : lines) { + if (CharSequenceUtil.isBlank(line)) { + continue; + } + var arr = line.split(":"); + MIME_TYPE_MAP.put(arr[0], CharSequenceUtil.split(arr[1], ' ')); + } + } + + public static String guessFileSuffix(String mimeType) { + if (CharSequenceUtil.isBlank(mimeType)) { + return null; + } + String key = mimeType; + if (!MIME_TYPE_MAP.containsKey(key)) { + key = MIME_TYPE_MAP.keySet().stream().filter(k -> CharSequenceUtil.startWithIgnoreCase(mimeType, k)) + .findFirst().orElse(null); + } + var suffixList = MIME_TYPE_MAP.get(key); + if (suffixList == null || suffixList.isEmpty()) { + return null; + } + return suffixList.iterator().next(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/SnowFlake.java b/src/main/java/com/github/novicezk/midjourney/util/SnowFlake.java new file mode 100644 index 0000000000000000000000000000000000000000..5b59c6ff4aac69d5167a5e84eabd66627b783bf9 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/SnowFlake.java @@ -0,0 +1,152 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.core.exceptions.ValidateException; +import com.github.novicezk.midjourney.exception.SnowFlakeException; +import lombok.extern.slf4j.Slf4j; + +import java.lang.management.ManagementFactory; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.util.Date; +import java.util.concurrent.ThreadLocalRandom; + +@Slf4j +public class SnowFlake { + private long workerId; + private long datacenterId; + private long sequence = 0L; + private final long twepoch; + private final long sequenceMask; + private final long workerIdShift; + private final long datacenterIdShift; + private final long timestampLeftShift; + private long lastTimestamp = -1L; + private final boolean randomSequence; + private long count = 0L; + private final long timeOffset; + private final ThreadLocalRandom tlr = ThreadLocalRandom.current(); + + public static final SnowFlake INSTANCE = new SnowFlake(); + + private SnowFlake() { + this(false, 10, null, 5L, 5L, 12L); + } + + private SnowFlake(boolean randomSequence, long timeOffset, Date epochDate, long workerIdBits, long datacenterIdBits, long sequenceBits) { + if (null != epochDate) { + this.twepoch = epochDate.getTime(); + } else { + // 2012/12/12 23:59:59 GMT + this.twepoch = 1355327999000L; + } + long maxWorkerId = ~(-1L << workerIdBits); + long maxDatacenterId = ~(-1L << datacenterIdBits); + this.sequenceMask = ~(-1L << sequenceBits); + this.workerIdShift = sequenceBits; + this.datacenterIdShift = sequenceBits + workerIdBits; + this.timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits; + this.randomSequence = randomSequence; + this.timeOffset = timeOffset; + try { + this.datacenterId = getDatacenterId(maxDatacenterId); + this.workerId = getMaxWorkerId(datacenterId, maxWorkerId); + } catch (Exception e) { + log.warn("datacenterId or workerId generate error: {}, set default value", e.getMessage()); + this.datacenterId = 4; + this.workerId = 1; + } + } + + public synchronized String nextId() { + long currentTimestamp = timeGen(); + if (currentTimestamp < this.lastTimestamp) { + long offset = this.lastTimestamp - currentTimestamp; + if (offset > this.timeOffset) { + throw new ValidateException("Clock moved backwards, refusing to generate id for [" + offset + "ms]"); + } + try { + this.wait(offset << 1); + } catch (InterruptedException e) { + throw new SnowFlakeException(e); + } + currentTimestamp = timeGen(); + if (currentTimestamp < this.lastTimestamp) { + throw new SnowFlakeException("Clock moved backwards, refusing to generate id for [" + offset + "ms]"); + } + } + if (this.lastTimestamp == currentTimestamp) { + long tempSequence = this.sequence + 1; + if (this.randomSequence) { + this.sequence = tempSequence & this.sequenceMask; + this.count = (this.count + 1) & this.sequenceMask; + if (this.count == 0) { + currentTimestamp = this.tillNextMillis(this.lastTimestamp); + } + } else { + this.sequence = tempSequence & this.sequenceMask; + if (this.sequence == 0) { + currentTimestamp = this.tillNextMillis(lastTimestamp); + } + } + } else { + this.sequence = this.randomSequence ? this.tlr.nextLong(this.sequenceMask + 1) : 0L; + this.count = 0L; + } + this.lastTimestamp = currentTimestamp; + long id = ((currentTimestamp - this.twepoch) << this.timestampLeftShift) | + (this.datacenterId << this.datacenterIdShift) | + (this.workerId << this.workerIdShift) | + this.sequence; + return String.valueOf(id); + } + + private static long getDatacenterId(long maxDatacenterId) { + long id = 0L; + try { + InetAddress ip = InetAddress.getLocalHost(); + NetworkInterface network = NetworkInterface.getByInetAddress(ip); + if (network == null) { + id = 1L; + } else { + byte[] mac = network.getHardwareAddress(); + if (null != mac) { + id = ((0x000000FF & (long) mac[mac.length - 1]) | (0x0000FF00 & (((long) mac[mac.length - 2]) << 8))) >> 6; + id = id % (maxDatacenterId + 1); + } + } + } catch (Exception e) { + throw new SnowFlakeException(e); + } + return id; + } + + private static long getMaxWorkerId(long datacenterId, long maxWorkerId) { + StringBuilder macIpPid = new StringBuilder(); + macIpPid.append(datacenterId); + try { + String name = ManagementFactory.getRuntimeMXBean().getName(); + if (name != null && !name.isEmpty()) { + macIpPid.append(name.split("@")[0]); + } + String hostIp = InetAddress.getLocalHost().getHostAddress(); + String ipStr = hostIp.replace("\\.", ""); + macIpPid.append(ipStr); + } catch (Exception e) { + throw new SnowFlakeException(e); + } + return (macIpPid.toString().hashCode() & 0xffff) % (maxWorkerId + 1); + } + + private long tillNextMillis(long lastTimestamp) { + long timestamp = timeGen(); + while (timestamp <= lastTimestamp) { + timestamp = timeGen(); + } + return timestamp; + } + + private long timeGen() { + return System.currentTimeMillis(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/TaskChangeParams.java b/src/main/java/com/github/novicezk/midjourney/util/TaskChangeParams.java new file mode 100644 index 0000000000000000000000000000000000000000..94f234ff3f5c21a6ec28681faaff4305d7f0adeb --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/TaskChangeParams.java @@ -0,0 +1,11 @@ +package com.github.novicezk.midjourney.util; + +import com.github.novicezk.midjourney.enums.TaskAction; +import lombok.Data; + +@Data +public class TaskChangeParams { + private String id; + private TaskAction action; + private Integer index; +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/UVContentParseData.java b/src/main/java/com/github/novicezk/midjourney/util/UVContentParseData.java new file mode 100644 index 0000000000000000000000000000000000000000..377bb32caec681f2e2562da4fd327b329547d755 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/UVContentParseData.java @@ -0,0 +1,10 @@ +package com.github.novicezk.midjourney.util; + +import lombok.Data; +import lombok.EqualsAndHashCode; + +@Data +@EqualsAndHashCode(callSuper = true) +public class UVContentParseData extends ContentParseData { + protected Integer index; +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/WebSocketStarter.java b/src/main/java/com/github/novicezk/midjourney/wss/WebSocketStarter.java new file mode 100644 index 0000000000000000000000000000000000000000..bc5d00efb5d024b9cdef3fc4c5f031afa23d8e97 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/WebSocketStarter.java @@ -0,0 +1,23 @@ +package com.github.novicezk.midjourney.wss; + +import com.github.novicezk.midjourney.ProxyProperties; +import com.neovisionaries.ws.client.ProxySettings; +import com.neovisionaries.ws.client.WebSocketFactory; +import org.apache.logging.log4j.util.Strings; + +public interface WebSocketStarter { + + void setTrying(boolean trying); + + void start() throws Exception; + + default WebSocketFactory createWebSocketFactory(ProxyProperties.ProxyConfig proxy) { + WebSocketFactory webSocketFactory = new WebSocketFactory().setConnectionTimeout(10000); + if (Strings.isNotBlank(proxy.getHost())) { + ProxySettings proxySettings = webSocketFactory.getProxySettings(); + proxySettings.setHost(proxy.getHost()); + proxySettings.setPort(proxy.getPort()); + } + return webSocketFactory; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/BlendSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/BlendSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..be4fbeb10ec592916289fa0346a4809140dfc00a --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/BlendSuccessHandler.java @@ -0,0 +1,47 @@ +package com.github.novicezk.midjourney.wss.handle; + + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Optional; +import java.util.Set; + +/** + * blend消息处理. + * 完成(create): ** --v 5.1** - <@1012983546824114217> (relaxed) + */ +@Component +public class BlendSuccessHandler extends MessageHandler { + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = ConvertUtils.parseContent(content); + if (parseData == null || !MessageType.CREATE.equals(messageType)) { + return; + } + Optional interaction = message.optObject("interaction"); + if (interaction.isPresent() && "blend".equals(interaction.get().getString("name"))) { + // blend任务开始时,设置prompt + Task task = this.discordLoadBalancer.getRunningTaskByNonce(getMessageNonce(message)); + if (task != null) { + task.setPromptEn(parseData.getPrompt()); + task.setPrompt(parseData.getPrompt()); + } + } + if (hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.BLEND)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/DescribeSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/DescribeSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..a10d28b8f357c2a116683019def68c8f5321395a --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/DescribeSuccessHandler.java @@ -0,0 +1,47 @@ +package com.github.novicezk.midjourney.wss.handle; + +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.support.Task; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Optional; + +/** + * describe消息处理. + */ +@Component +public class DescribeSuccessHandler extends MessageHandler { + + @Override + public void handle(MessageType messageType, DataObject message) { + Optional interaction = message.optObject("interaction"); + if (!MessageType.UPDATE.equals(messageType) || interaction.isEmpty() || !"describe".equals(interaction.get().getString("name"))) { + return; + } + DataArray embeds = message.getArray("embeds"); + if (embeds.isEmpty()) { + return; + } + String description = embeds.getObject(0).getString("description"); + Optional imageOptional = embeds.getObject(0).optObject("image"); + if (imageOptional.isEmpty()) { + return; + } + String imageUrl = imageOptional.get().getString("url"); + String taskId = this.discordHelper.findTaskIdWithCdnUrl(imageUrl); + Task task = this.discordLoadBalancer.getRunningTask(taskId); + if (task == null) { + return; + } + task.setPrompt(description); + task.setPromptEn(description); + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, description); + task.setImageUrl(replaceCdnUrl(imageUrl)); + finishTask(task, message); + task.awake(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/ErrorMessageHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/ErrorMessageHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..1a60bc9715619f40b307504087bb2518b4a72af8 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/ErrorMessageHandler.java @@ -0,0 +1,68 @@ +package com.github.novicezk.midjourney.wss.handle; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import lombok.extern.slf4j.Slf4j; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.Optional; +import java.util.Set; + +@Slf4j +@Component +public class ErrorMessageHandler extends MessageHandler { + @Autowired + protected ProxyProperties properties; + + @Override + public void handle(MessageType messageType, DataObject message) { + Optional embedsOptional = message.optArray("embeds"); + if (!MessageType.CREATE.equals(messageType) || embedsOptional.isEmpty() || embedsOptional.get().isEmpty()) { + return; + } + DataObject embed = embedsOptional.get().getObject(0); + String title = embed.getString("title", null); + String description = embed.getString("description", null); + String footerText = ""; + Optional footer = embed.optObject("footer"); + if (footer.isPresent()) { + footerText = footer.get().getString("text", ""); + } + String channelId = message.getString("channel_id", ""); + int color = embed.getInt("color", 0); + if (color == 16239475) { + log.warn("{} - MJ警告信息: {}\n{}\nfooter: {}", channelId, title, description, footerText); + } else if (color == 16711680) { + log.error("{} - MJ异常信息: {}\n{}\nfooter: {}", channelId, title, description, footerText); + String nonce = getMessageNonce(message); + Task task = this.discordLoadBalancer.getRunningTaskByNonce(nonce); + if (task != null) { + task.fail("[" + title + "] " + description); + task.awake(); + } + } else if (CharSequenceUtil.contains(title, "Invalid link")) { + // 兼容 Invalid link! 错误 + log.error("{} - MJ异常信息: {}\n{}\nfooter: {}", channelId, title, description, footerText); + DataObject messageReference = message.optObject("message_reference").orElse(DataObject.empty()); + String referenceMessageId = messageReference.getString("message_id", ""); + if (CharSequenceUtil.isBlank(referenceMessageId)) { + return; + } + TaskCondition condition = new TaskCondition().setStatusSet(Set.of(TaskStatus.IN_PROGRESS)) + .setProgressMessageId(referenceMessageId); + Task task = this.discordLoadBalancer.findRunningTask(condition).findFirst().orElse(null); + if (task != null) { + task.fail("[" + title + "] " + description); + task.awake(); + } + } + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/ImagineSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/ImagineSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..68baafe9861fea39ddbb4f42a4863fa64b7cb222 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/ImagineSuccessHandler.java @@ -0,0 +1,34 @@ +package com.github.novicezk.midjourney.wss.handle; + + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; + +/** + * imagine消息处理. + * 完成(create): **cat** - <@1012983546824114217> (relaxed) + */ +@Component +public class ImagineSuccessHandler extends MessageHandler { + private static final String CONTENT_REGEX = "\\*\\*(.*?)\\*\\* - <@\\d+> \\((.*?)\\)"; + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = ConvertUtils.parseContent(content, CONTENT_REGEX); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.IMAGINE)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/MessageHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/MessageHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..c279d4f96daade9efc10db77903001b0f3f2560a --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/MessageHandler.java @@ -0,0 +1,85 @@ +package com.github.novicezk.midjourney.wss.handle; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.loadbalancer.DiscordLoadBalancer; +import com.github.novicezk.midjourney.support.DiscordHelper; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; + +import javax.annotation.Resource; +import java.util.Comparator; + +public abstract class MessageHandler { + @Resource + protected DiscordLoadBalancer discordLoadBalancer; + @Resource + protected DiscordHelper discordHelper; + + public abstract void handle(MessageType messageType, DataObject message); + + protected String getMessageContent(DataObject message) { + return message.hasKey("content") ? message.getString("content") : ""; + } + + protected String getMessageNonce(DataObject message) { + return message.hasKey("nonce") ? message.getString("nonce") : ""; + } + + protected void findAndFinishImageTask(TaskCondition condition, String finalPrompt, DataObject message) { + String imageUrl = getImageUrl(message); + String messageHash = this.discordHelper.getMessageHash(imageUrl); + condition.setMessageHash(messageHash); + Task task = this.discordLoadBalancer.findRunningTask(condition) + .findFirst().orElseGet(() -> { + condition.setMessageHash(null); + return this.discordLoadBalancer.findRunningTask(condition) + .min(Comparator.comparing(Task::getStartTime)) + .orElse(null); + }); + if (task == null) { + return; + } + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, finalPrompt); + task.setProperty(Constants.TASK_PROPERTY_MESSAGE_HASH, messageHash); + task.setImageUrl(imageUrl); + finishTask(task, message); + task.awake(); + } + + protected void finishTask(Task task, DataObject message) { + task.setProperty(Constants.TASK_PROPERTY_MESSAGE_ID, message.getString("id")); + task.setProperty(Constants.TASK_PROPERTY_FLAGS, message.getInt("flags", 0)); + task.setProperty(Constants.TASK_PROPERTY_MESSAGE_HASH, this.discordHelper.getMessageHash(task.getImageUrl())); + task.success(); + } + + protected boolean hasImage(DataObject message) { + DataArray attachments = message.optArray("attachments").orElse(DataArray.empty()); + return !attachments.isEmpty(); + } + + protected String getImageUrl(DataObject message) { + DataArray attachments = message.getArray("attachments"); + if (!attachments.isEmpty()) { + String imageUrl = attachments.getObject(0).getString("url"); + return replaceCdnUrl(imageUrl); + } + return null; + } + + protected String replaceCdnUrl(String imageUrl) { + if (CharSequenceUtil.isBlank(imageUrl)) { + return imageUrl; + } + String cdn = this.discordHelper.getCdn(); + if (CharSequenceUtil.startWith(imageUrl, cdn)) { + return imageUrl; + } + return CharSequenceUtil.replaceFirst(imageUrl, DiscordHelper.DISCORD_CDN_URL, cdn); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/RerollSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/RerollSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..96e0c17b9b62727655f7c84a66f47ee2b40207fc --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/RerollSuccessHandler.java @@ -0,0 +1,49 @@ +package com.github.novicezk.midjourney.wss.handle; + + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; + +/** + * reroll 消息处理. + * 完成(create): **cat** - <@1012983546824114217> (relaxed) + * 完成(create): **cat** - Variations by <@1012983546824114217> (relaxed) + * 完成(create): **cat** - Variations (Strong或Subtle) by <@1012983546824114217> (relaxed) + */ +@Component +public class RerollSuccessHandler extends MessageHandler { + private static final String CONTENT_REGEX_1 = "\\*\\*(.*?)\\*\\* - <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_2 = "\\*\\*(.*?)\\*\\* - Variations by <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_3 = "\\*\\*(.*?)\\*\\* - Variations \\(.*?\\) by <@\\d+> \\((.*?)\\)"; + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = getParseData(content); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.REROLL)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + + private ContentParseData getParseData(String content) { + ContentParseData parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_1); + if (parseData == null) { + parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_2); + } + if (parseData == null) { + parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_3); + } + return parseData; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/StartAndProgressHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/StartAndProgressHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..def0e0465673ac311922b3445dfd43206a9b60c3 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/StartAndProgressHandler.java @@ -0,0 +1,72 @@ +package com.github.novicezk.midjourney.wss.handle; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import lombok.extern.slf4j.Slf4j; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Optional; +import java.util.Set; + +@Slf4j +@Component +public class StartAndProgressHandler extends MessageHandler { + + @Override + public void handle(MessageType messageType, DataObject message) { + String nonce = getMessageNonce(message); + String content = getMessageContent(message); + ContentParseData parseData = ConvertUtils.parseContent(content); + if (MessageType.CREATE.equals(messageType) && CharSequenceUtil.isNotBlank(nonce)) { + if (isError(message)) { + return; + } + // 任务开始 + Task task = this.discordLoadBalancer.getRunningTaskByNonce(nonce); + if (task == null) { + return; + } + task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, message.getString("id")); + // 兼容少数content为空的场景 + if (parseData != null) { + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, parseData.getPrompt()); + } + task.setStatus(TaskStatus.IN_PROGRESS); + task.awake(); + } else if (MessageType.UPDATE.equals(messageType) && parseData != null) { + // 任务进度 + TaskCondition condition = new TaskCondition().setStatusSet(Set.of(TaskStatus.IN_PROGRESS)) + .setProgressMessageId(message.getString("id")); + Task task = this.discordLoadBalancer.findRunningTask(condition).findFirst().orElse(null); + if (task == null) { + return; + } + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, parseData.getPrompt()); + task.setStatus(TaskStatus.IN_PROGRESS); + task.setProgress(parseData.getStatus()); + String imageUrl = getImageUrl(message); + task.setImageUrl(imageUrl); + task.setProperty(Constants.TASK_PROPERTY_MESSAGE_HASH, this.discordHelper.getMessageHash(imageUrl)); + task.awake(); + } + } + + private boolean isError(DataObject message) { + Optional embedsOptional = message.optArray("embeds"); + if (embedsOptional.isEmpty() || embedsOptional.get().isEmpty()) { + return false; + } + DataObject embed = embedsOptional.get().getObject(0); + return embed.getInt("color", 0) == 16711680; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/UpscaleSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/UpscaleSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..5c05d9973e468cb5f191e8c78c6a6fdb2c70bad4 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/UpscaleSuccessHandler.java @@ -0,0 +1,57 @@ +package com.github.novicezk.midjourney.wss.handle; + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * upscale消息处理. + * 完成(create): **cat** - Upscaled (Beta或Light) by <@1083152202048217169> (fast) + * 完成(create): **cat** - Upscaled by <@1083152202048217169> (fast) + * 完成(create): **cat** - Image #1 <@1012983546824114217> + */ +@Component +public class UpscaleSuccessHandler extends MessageHandler { + private static final String CONTENT_REGEX_1 = "\\*\\*(.*?)\\*\\* - Upscaled \\(.*?\\) by <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_2 = "\\*\\*(.*?)\\*\\* - Upscaled by <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_3 = "\\*\\*(.*?)\\*\\* - Image #\\d <@\\d+>"; + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = getParseData(content); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.UPSCALE)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + + private ContentParseData getParseData(String content) { + ContentParseData parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_1); + if (parseData == null) { + parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_2); + } + if (parseData != null) { + return parseData; + } + Matcher matcher = Pattern.compile(CONTENT_REGEX_3).matcher(content); + if (!matcher.find()) { + return null; + } + parseData = new ContentParseData(); + parseData.setPrompt(matcher.group(1)); + parseData.setStatus("done"); + return parseData; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/VariationSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/VariationSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..fdf44e4b599ed74cd0b9b7a1e086f0bbbd038c42 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/VariationSuccessHandler.java @@ -0,0 +1,43 @@ +package com.github.novicezk.midjourney.wss.handle; + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; + +/** + * variation消息处理. + * 完成(create): **cat** - Variations (Strong或Subtle) by <@1012983546824114217> (relaxed) + * 完成(create): **cat** - Variations by <@1012983546824114217> (relaxed) + */ +@Component +public class VariationSuccessHandler extends MessageHandler { + private static final String CONTENT_REGEX_1 = "\\*\\*(.*?)\\*\\* - Variations by <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_2 = "\\*\\*(.*?)\\*\\* - Variations \\(.*?\\) by <@\\d+> \\((.*?)\\)"; + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = getParseData(content); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.VARIATION)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + + private ContentParseData getParseData(String content) { + ContentParseData parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_1); + if (parseData == null) { + parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_2); + } + return parseData; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/user/UserMessageListener.java b/src/main/java/com/github/novicezk/midjourney/wss/user/UserMessageListener.java new file mode 100644 index 0000000000000000000000000000000000000000..d978950ca1dcf22396ab0228f2a973a57a082263 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/user/UserMessageListener.java @@ -0,0 +1,48 @@ +package com.github.novicezk.midjourney.wss.user; + + +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.thread.ThreadUtil; +import com.github.novicezk.midjourney.domain.DiscordAccount; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.wss.handle.MessageHandler; +import lombok.extern.slf4j.Slf4j; +import net.dv8tion.jda.api.utils.data.DataObject; + +import java.util.List; + +@Slf4j +public class UserMessageListener { + private final DiscordAccount account; + private final List messageHandlers; + + public UserMessageListener(DiscordAccount account, List messageHandlers) { + this.account = account; + this.messageHandlers = messageHandlers; + } + + public void onMessage(DataObject raw) { + MessageType messageType = MessageType.of(raw.getString("t")); + if (messageType == null || MessageType.DELETE == messageType) { + return; + } + DataObject data = raw.getObject("d"); + if (ignoreAndLogMessage(data, messageType)) { + return; + } + ThreadUtil.sleep(50); + for (MessageHandler messageHandler : this.messageHandlers) { + messageHandler.handle(messageType, data); + } + } + + private boolean ignoreAndLogMessage(DataObject data, MessageType messageType) { + String channelId = data.getString("channel_id"); + if (!CharSequenceUtil.equals(channelId, this.account.getChannelId())) { + return true; + } + String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System"); + log.debug("{} - {} - {}: {}", this.account.getDisplay(), messageType.name(), authorName, data.opt("content").orElse("")); + return false; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/user/UserWebSocketStarter.java b/src/main/java/com/github/novicezk/midjourney/wss/user/UserWebSocketStarter.java new file mode 100644 index 0000000000000000000000000000000000000000..6e29baea526cea319bd435c99cd9966f8c0a03b5 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/user/UserWebSocketStarter.java @@ -0,0 +1,347 @@ +package com.github.novicezk.midjourney.wss.user; + +import cn.hutool.core.exceptions.ValidateException; +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.thread.ThreadUtil; +import cn.hutool.core.util.RandomUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.domain.DiscordAccount; +import com.github.novicezk.midjourney.util.AsyncLockUtils; +import com.github.novicezk.midjourney.wss.WebSocketStarter; +import com.neovisionaries.ws.client.WebSocket; +import com.neovisionaries.ws.client.WebSocketAdapter; +import com.neovisionaries.ws.client.WebSocketFactory; +import com.neovisionaries.ws.client.WebSocketFrame; +import eu.bitwalker.useragentutils.UserAgent; +import lombok.extern.slf4j.Slf4j; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; +import net.dv8tion.jda.api.utils.data.DataType; +import net.dv8tion.jda.internal.requests.WebSocketCode; +import net.dv8tion.jda.internal.utils.compress.Decompressor; +import net.dv8tion.jda.internal.utils.compress.ZlibDecompressor; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +@Slf4j +public class UserWebSocketStarter extends WebSocketAdapter implements WebSocketStarter { + private static final int CONNECT_RETRY_LIMIT = 3; + + private final ProxyProperties.ProxyConfig proxyConfig; + private final DiscordAccount account; + private final UserMessageListener userMessageListener; + private final ScheduledExecutorService heartExecutor; + private final String wssServer; + private final DataObject authData; + + private Decompressor decompressor; + private WebSocket socket = null; + private String resumeGatewayUrl; + private String sessionId; + + private Future heartbeatInterval; + private Future heartbeatTimeout; + private boolean heartbeatAck = false; + private Object sequence = null; + private long interval = 41250; + private boolean trying = false; + + public UserWebSocketStarter(String wssServer, DiscordAccount account, UserMessageListener userMessageListener, ProxyProperties.ProxyConfig proxyConfig) { + this.wssServer = wssServer; + this.account = account; + this.userMessageListener = userMessageListener; + this.proxyConfig = proxyConfig; + this.heartExecutor = Executors.newSingleThreadScheduledExecutor(); + this.authData = createAuthData(); + } + + @Override + public void setTrying(boolean trying) { + this.trying = trying; + } + + @Override + public synchronized void start() throws Exception { + this.decompressor = new ZlibDecompressor(2048); + WebSocketFactory webSocketFactory = createWebSocketFactory(this.proxyConfig); + String gatewayUrl = CharSequenceUtil.isNotBlank(this.resumeGatewayUrl) ? this.resumeGatewayUrl : this.wssServer; + this.socket = webSocketFactory.createSocket(gatewayUrl + "/?encoding=json&v=9&compress=zlib-stream"); + this.socket.addListener(this); + this.socket.addHeader("Accept-Encoding", "gzip, deflate, br") + .addHeader("Accept-Language", "zh-CN,zh;q=0.9") + .addHeader("Cache-Control", "no-cache") + .addHeader("Pragma", "no-cache") + .addHeader("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits") + .addHeader("User-Agent", this.account.getUserAgent()); + this.socket.connect(); + } + + @Override + public void onConnected(WebSocket websocket, Map> headers) { + log.debug("[wss-{}] Connected to websocket.", this.account.getDisplay()); + } + + @Override + public void handleCallbackError(WebSocket websocket, Throwable cause) throws Exception { + log.error("[wss-{}] There was some websocket error.", this.account.getDisplay(), cause); + } + + @Override + public void onDisconnected(WebSocket websocket, WebSocketFrame serverCloseFrame, WebSocketFrame clientCloseFrame, boolean closedByServer) throws Exception { + int code; + String closeReason; + if (closedByServer) { + code = serverCloseFrame.getCloseCode(); + closeReason = serverCloseFrame.getCloseReason(); + } else { + code = clientCloseFrame.getCloseCode(); + closeReason = clientCloseFrame.getCloseReason(); + } + connectFinish(code, closeReason); + if (this.trying) { + return; + } + if (code == 5240) { + // 隐式关闭wss + clearAllStates(); + } else if (code >= 4000) { + log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.account.getDisplay(), code, closeReason); + clearAllStates(); + this.account.setEnable(false); + } else if (code == 2001) { + // reconnect + log.warn("[wss-{}] Waiting reconnect...", this.account.getDisplay()); + clearSocketStates(); + start(); + } else { + log.warn("[wss-{}] Closed by {}({}). Waiting try new connection...", this.account.getDisplay(), code, closeReason); + clearAllStates(); + tryNewConnect(); + } + } + + private void tryNewConnect() { + this.trying = true; + for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) { + try { + clearAllStates(); + start(); + AsyncLockUtils.LockObject lock = AsyncLockUtils.waitForLock("wss:" + this.account.getChannelId(), Duration.ofSeconds(20)); + int code = lock.getProperty("code", Integer.class, 0); + if (code == ReturnCode.SUCCESS) { + log.debug("[wss-{}] New connection success.", this.account.getDisplay()); + return; + } + throw new ValidateException(lock.getProperty("description", String.class)); + } catch (Exception e) { + if (e instanceof TimeoutException) { + close(5240, "try new connect"); + } + log.warn("[wss-{}] Try new connection fail ({}): {}", this.account.getDisplay(), i, e.getMessage()); + ThreadUtil.sleep(5000); + } + } + log.error("[wss-{}] Account disabled", this.account.getDisplay()); + this.account.setEnable(false); + } + + @Override + public void onBinaryMessage(WebSocket websocket, byte[] binary) throws Exception { + if (this.decompressor == null) { + return; + } + byte[] decompressBinary = this.decompressor.decompress(binary); + if (decompressBinary == null) { + return; + } + String json = new String(decompressBinary, StandardCharsets.UTF_8); + DataObject data = DataObject.fromJson(json); + int opCode = data.getInt("op"); + switch (opCode) { + case WebSocketCode.HEARTBEAT -> { + log.debug("[wss-{}] Receive heartbeat.", this.account.getDisplay()); + handleHeartbeat(); + } + case WebSocketCode.HEARTBEAT_ACK -> { + this.heartbeatAck = true; + clearHeartbeatTimeout(); + } + case WebSocketCode.HELLO -> { + handleHello(data); + doResumeOrIdentify(); + } + case WebSocketCode.RESUME -> { + log.debug("[wss-{}] Receive resumed.", this.account.getDisplay()); + connectSuccess(); + } + case WebSocketCode.RECONNECT -> reconnect("receive server reconnect"); + case WebSocketCode.INVALIDATE_SESSION -> close(1009, "receive session invalid"); + case WebSocketCode.DISPATCH -> handleDispatch(data); + default -> log.debug("[wss-{}] Receive unknown code: {}.", this.account.getDisplay(), data); + } + } + + private void handleHello(DataObject data) { + clearHeartbeatInterval(); + this.interval = data.getObject("d").getLong("heartbeat_interval"); + this.heartbeatAck = true; + this.heartbeatInterval = this.heartExecutor.scheduleAtFixedRate(() -> { + if (this.heartbeatAck) { + this.heartbeatAck = false; + send(WebSocketCode.HEARTBEAT, this.sequence); + } else { + reconnect("heartbeat has not ack interval"); + } + }, (long) Math.floor(RandomUtil.randomDouble(0, 1) * this.interval), this.interval, TimeUnit.MILLISECONDS); + } + + private void doResumeOrIdentify() { + if (CharSequenceUtil.isBlank(this.sessionId)) { + log.debug("[wss-{}] Send identify msg.", this.account.getDisplay()); + send(WebSocketCode.IDENTIFY, this.authData); + } else { + log.debug("[wss-{}] Send resume msg.", this.account.getDisplay()); + send(WebSocketCode.RESUME, DataObject.empty().put("token", this.account.getUserToken()) + .put("session_id", this.sessionId).put("seq", this.sequence)); + } + } + + private void handleHeartbeat() { + send(WebSocketCode.HEARTBEAT, this.sequence); + this.heartbeatTimeout = ThreadUtil.execAsync(() -> { + ThreadUtil.sleep(this.interval); + reconnect("heartbeat has not ack"); + }); + } + + private void clearAllStates() { + clearSocketStates(); + clearResumeStates(); + } + + private void clearSocketStates() { + clearHeartbeatTimeout(); + clearHeartbeatInterval(); + this.socket = null; + this.decompressor = null; + } + + private void clearResumeStates() { + this.sessionId = null; + this.sequence = null; + this.resumeGatewayUrl = null; + } + + private void clearHeartbeatInterval() { + if (this.heartbeatInterval != null) { + this.heartbeatInterval.cancel(true); + this.heartbeatInterval = null; + } + } + + private void clearHeartbeatTimeout() { + if (this.heartbeatTimeout != null) { + this.heartbeatTimeout.cancel(true); + this.heartbeatTimeout = null; + } + } + + private void reconnect(String reason) { + close(2001, reason); + } + + private void close(int code, String reason) { + if (this.socket != null) { + this.socket.sendClose(code, reason); + } + } + + private void send(int op, Object d) { + if (this.socket != null) { + this.socket.sendText(DataObject.empty().put("op", op).put("d", d).toString()); + } + } + + private void connectSuccess() { + this.trying = false; + connectFinish(ReturnCode.SUCCESS, ""); + } + + private void connectFinish(int code, String description) { + AsyncLockUtils.LockObject lock = AsyncLockUtils.getLock("wss:" + this.account.getChannelId()); + if (lock != null) { + lock.setProperty("code", code); + lock.setProperty("description", description); + lock.awake(); + } + } + + private void handleDispatch(DataObject raw) { + this.sequence = raw.opt("s").orElse(null); + if (!raw.isType("d", DataType.OBJECT)) { + return; + } + DataObject content = raw.getObject("d"); + String t = raw.getString("t", null); + if ("READY".equals(t)) { + this.sessionId = content.getString("session_id"); + this.resumeGatewayUrl = content.getString("resume_gateway_url"); + log.debug("[wss-{}] Dispatch ready.", this.account.getDisplay()); + connectSuccess(); + return; + } + try { + this.userMessageListener.onMessage(raw); + } catch (Exception e) { + log.error("[wss-{}] Handle message error", this.account.getDisplay(), e); + } + } + + private DataObject createAuthData() { + UserAgent agent = UserAgent.parseUserAgentString(this.account.getUserAgent()); + DataObject connectionProperties = DataObject.empty() + .put("browser", agent.getBrowser().getGroup().getName()) + .put("browser_user_agent", this.account.getUserAgent()) + .put("browser_version", agent.getBrowserVersion().toString()) + .put("client_build_number", 222963) + .put("client_event_source", null) + .put("device", "") + .put("os", agent.getOperatingSystem().getName()) + .put("referer", "https://www.midjourney.com") + .put("referrer_current", "") + .put("referring_domain", "www.midjourney.com") + .put("referring_domain_current", "") + .put("release_channel", "stable") + .put("system_locale", "zh-CN"); + DataObject presence = DataObject.empty() + .put("activities", DataArray.empty()) + .put("afk", false) + .put("since", 0) + .put("status", "online"); + DataObject clientState = DataObject.empty() + .put("api_code_version", 0) + .put("guild_versions", DataObject.empty()) + .put("highest_last_message_id", "0") + .put("private_channels_version", "0") + .put("read_state_version", 0) + .put("user_guild_settings_version", -1) + .put("user_settings_version", -1); + return DataObject.empty() + .put("capabilities", 16381) + .put("client_state", clientState) + .put("compress", false) + .put("presence", presence) + .put("properties", connectionProperties) + .put("token", this.account.getUserToken()); + } + +} diff --git a/src/main/java/spring/config/BeanConfig.java b/src/main/java/spring/config/BeanConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..26eeac4f82493b7a3aaff093ef59cf1d116ae5fe --- /dev/null +++ b/src/main/java/spring/config/BeanConfig.java @@ -0,0 +1,100 @@ +package spring.config; + +import cn.hutool.core.io.IoUtil; +import cn.hutool.core.util.ReflectUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.loadbalancer.rule.IRule; +import com.github.novicezk.midjourney.service.NotifyService; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.service.TranslateService; +import com.github.novicezk.midjourney.service.store.InMemoryTaskStoreServiceImpl; +import com.github.novicezk.midjourney.service.store.RedisTaskStoreServiceImpl; +import com.github.novicezk.midjourney.service.translate.BaiduTranslateServiceImpl; +import com.github.novicezk.midjourney.service.translate.GPTTranslateServiceImpl; +import com.github.novicezk.midjourney.service.translate.NoTranslateServiceImpl; +import com.github.novicezk.midjourney.support.DiscordAccountHelper; +import com.github.novicezk.midjourney.support.DiscordHelper; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.wss.handle.MessageHandler; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; +import org.springframework.data.redis.serializer.StringRedisSerializer; +import org.springframework.web.client.RestTemplate; + +import java.io.IOException; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@Configuration +public class BeanConfig { + @Autowired + private ApplicationContext applicationContext; + @Autowired + private ProxyProperties properties; + + @Bean + TranslateService translateService() { + return switch (this.properties.getTranslateWay()) { + case BAIDU -> new BaiduTranslateServiceImpl(this.properties.getBaiduTranslate()); + case GPT -> new GPTTranslateServiceImpl(this.properties); + default -> new NoTranslateServiceImpl(); + }; + } + + @Bean + TaskStoreService taskStoreService(RedisConnectionFactory redisConnectionFactory) { + ProxyProperties.TaskStore.Type type = this.properties.getTaskStore().getType(); + Duration timeout = this.properties.getTaskStore().getTimeout(); + return switch (type) { + case IN_MEMORY -> new InMemoryTaskStoreServiceImpl(timeout); + case REDIS -> new RedisTaskStoreServiceImpl(timeout, taskRedisTemplate(redisConnectionFactory)); + }; + } + + @Bean + RedisTemplate taskRedisTemplate(RedisConnectionFactory redisConnectionFactory) { + RedisTemplate redisTemplate = new RedisTemplate<>(); + redisTemplate.setConnectionFactory(redisConnectionFactory); + redisTemplate.setKeySerializer(new StringRedisSerializer()); + redisTemplate.setHashKeySerializer(new StringRedisSerializer()); + redisTemplate.setValueSerializer(new Jackson2JsonRedisSerializer<>(Task.class)); + return redisTemplate; + } + + @Bean + public RestTemplate restTemplate() { + return new RestTemplate(); + } + + @Bean + public IRule loadBalancerRule() { + String ruleClassName = IRule.class.getPackageName() + "." + this.properties.getAccountChooseRule(); + return ReflectUtil.newInstance(ruleClassName); + } + + @Bean + List messageHandlers() { + return this.applicationContext.getBeansOfType(MessageHandler.class).values().stream().toList(); + } + + @Bean + DiscordAccountHelper discordAccountHelper(DiscordHelper discordHelper, TaskStoreService taskStoreService, NotifyService notifyService) throws IOException { + var resources = this.applicationContext.getResources("classpath:api-params/*.json"); + Map paramsMap = new HashMap<>(); + for (var resource : resources) { + String filename = resource.getFilename(); + String params = IoUtil.readUtf8(resource.getInputStream()); + paramsMap.put(filename.substring(0, filename.length() - 5), params); + } + return new DiscordAccountHelper(discordHelper, this.properties, restTemplate(), taskStoreService, notifyService, messageHandlers(), paramsMap); + } + + +} diff --git a/src/main/java/spring/config/WebMvcConfig.java b/src/main/java/spring/config/WebMvcConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..d2e775a2d1e95d6dc63b6d7ce496fae76f471a68 --- /dev/null +++ b/src/main/java/spring/config/WebMvcConfig.java @@ -0,0 +1,33 @@ +package spring.config; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.support.ApiAuthorizeInterceptor; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.InterceptorRegistry; +import org.springframework.web.servlet.config.annotation.ViewControllerRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import javax.annotation.Resource; + +@Configuration +public class WebMvcConfig implements WebMvcConfigurer { + @Resource + private ApiAuthorizeInterceptor apiAuthorizeInterceptor; + @Resource + private ProxyProperties properties; + + @Override + public void addViewControllers(ViewControllerRegistry registry) { + registry.addViewController("/").setViewName("redirect:doc.html"); + } + + @Override + public void addInterceptors(InterceptorRegistry registry) { + if (CharSequenceUtil.isNotBlank(this.properties.getApiSecret())) { + registry.addInterceptor(this.apiAuthorizeInterceptor) + .addPathPatterns("/submit/**", "/task/**", "/account/**"); + } + } + +} diff --git a/src/main/resources/api-params/blend.json b/src/main/resources/api-params/blend.json new file mode 100644 index 0000000000000000000000000000000000000000..e082cf06c91edbf21e38c1439dbc5d5735ae7fe7 --- /dev/null +++ b/src/main/resources/api-params/blend.json @@ -0,0 +1,16 @@ +{ + "type":2, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "application_id":"936929561302675456", + "session_id":"$session_id", + "nonce": "$nonce", + "data":{ + "version":"1118961510123847773", + "id":"1062880104792997970", + "name":"blend", + "type":1, + "options":[], + "attachments":[] + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/describe.json b/src/main/resources/api-params/describe.json new file mode 100644 index 0000000000000000000000000000000000000000..4402b0012f0ef73cb22c7df30027dc264a8c4d10 --- /dev/null +++ b/src/main/resources/api-params/describe.json @@ -0,0 +1,28 @@ +{ + "type": 2, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "data": { + "version": "1118961510123847774", + "id": "1092492867185950852", + "name": "describe", + "type": 1, + "options": [ + { + "type": 11, + "name": "image", + "value": 0 + } + ], + "attachments": [ + { + "id": "0", + "filename": "$file_name", + "uploaded_filename": "$final_file_name" + } + ] + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/imagine.json b/src/main/resources/api-params/imagine.json new file mode 100644 index 0000000000000000000000000000000000000000..f027d6f2c8317881fffb763f8b1e9b3997e73029 --- /dev/null +++ b/src/main/resources/api-params/imagine.json @@ -0,0 +1,21 @@ +{ + "type": 2, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "data": { + "version": "1118961510123847772", + "id": "938956540159881230", + "name": "imagine", + "type": 1, + "options": [ + { + "type": 3, + "name": "prompt", + "value": "$prompt" + } + ] + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/message.json b/src/main/resources/api-params/message.json new file mode 100644 index 0000000000000000000000000000000000000000..9f8cccbc952fab9752c04874971b1ce061db156a --- /dev/null +++ b/src/main/resources/api-params/message.json @@ -0,0 +1,13 @@ +{ + "content":"$content", + "channel_id":"$channel_id", + "type":0, + "sticker_ids":[], + "attachments":[ + { + "id":"0", + "filename": "$file_name", + "uploaded_filename": "$final_file_name" + } + ] +} \ No newline at end of file diff --git a/src/main/resources/api-params/reroll.json b/src/main/resources/api-params/reroll.json new file mode 100644 index 0000000000000000000000000000000000000000..84435b6de77afadde1a11e47c1ce066b8b0662ee --- /dev/null +++ b/src/main/resources/api-params/reroll.json @@ -0,0 +1,14 @@ +{ + "type": 3, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "message_id": "$message_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "message_flags": 0, + "data": { + "component_type": 2, + "custom_id": "MJ::JOB::reroll::0::$message_hash::SOLO" + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/upscale.json b/src/main/resources/api-params/upscale.json new file mode 100644 index 0000000000000000000000000000000000000000..f1afdd5aac6f34adccd2f6b097b664493430ea06 --- /dev/null +++ b/src/main/resources/api-params/upscale.json @@ -0,0 +1,14 @@ +{ + "type": 3, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "message_id": "$message_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "message_flags": 0, + "data": { + "component_type": 2, + "custom_id": "MJ::JOB::upsample::$index::$message_hash" + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/variation.json b/src/main/resources/api-params/variation.json new file mode 100644 index 0000000000000000000000000000000000000000..0f750aef94d1b280e95e96aca373ff74753710e2 --- /dev/null +++ b/src/main/resources/api-params/variation.json @@ -0,0 +1,14 @@ +{ + "type": 3, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "message_id": "$message_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "message_flags": 0, + "data": { + "component_type": 2, + "custom_id": "MJ::JOB::variation::$index::$message_hash" + } +} \ No newline at end of file diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml new file mode 100644 index 0000000000000000000000000000000000000000..6337a64690902d04d56e16e36bfaf071ee69e7f3 --- /dev/null +++ b/src/main/resources/application.yml @@ -0,0 +1,5 @@ +mj: + task-store: + type: in_memory + timeout: 30d + translate-way: null diff --git a/src/main/resources/banned-words.txt b/src/main/resources/banned-words.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d38603750927ff761a206f273081418b4c002a4 --- /dev/null +++ b/src/main/resources/banned-words.txt @@ -0,0 +1,211 @@ +blood +twerk +making love +voluptuous +naughty +wincest +orgy +no clothes +au naturel +no shirt +decapitate +bare +nude +barely dressed +nude +bra +risque +scantily clad +cleavage +stripped +infested +full frontal +unclothed +invisible clothes +wearing nothing +lingerie +with no shirt +naked +without clothes on +negligee +zero clothes +gruesome +fascist +nazi +prophet mohammed +slave +coon +honkey +cocaine +heroin +meth +crack +kill +belle delphine +hitler +jinping +lolita +president xi +torture +disturbing +farts +fart +poop +infected +warts +shit +brown pudding +bunghole +vomit +voluptuous +seductive +sperm +sexy +sadist +sensored +censored +silenced +deepfake +inappropriate +waifu +succubus +slaughter +surgery +reproduce +crucified +seductively +explicit +inappropriate +large bust +explicit +wang +inappropriate +teratoma +intimate +see through +tryphophobia +bloodbath +wound +cronenberg +khorne +cannibal +cannibalism +visceral +guts +bloodshot +gory +killing +crucifixion +surgery +vivisection +massacre +hemoglobin +suicide +arse +labia +ass +mammaries +badonkers +bloody +minge +big ass +mommy milker +booba +nipple +oppai +booty +organs +bosom +ovaries +flesh +breasts +penis +busty +phallus +clunge +sexy female +crotch +skimpy +dick +thick +bruises +girth +titty +honkers +vagina +hooters +veiny +knob +ahegao +pinup +ballgag +car crash +playboy +bimbo +pleasure +bodily fluids +pleasures +boudoir +rule34 +brothel +seducing +dominatrix +corpse +seductive +erotic +seductive +fuck +sensual +hardcore +sexy +hentai +shag +horny +crucified +shibari +incest +smut +jav +succubus +jerk off king at pic +thot +kinbaku +legs spread +sensuality +belly button +porn +patriotic +bleed +excrement +petite +seduction +mccurry +provocative +sultry +erected +camisole +tight white +arrest +see-through +feces +anus +revealing clothing +vein +loli +-edge +boobs +-backed +tied up +zedong +bathing +jail +reticulum +rear end +sakimichan +behind bars +shirtless +sakimichan +seductive +sexi +sexualiz +sexual \ No newline at end of file diff --git a/src/main/resources/banner.txt b/src/main/resources/banner.txt new file mode 100644 index 0000000000000000000000000000000000000000..2eef25d92e3405884a984eb0473db5b745a6bb6b --- /dev/null +++ b/src/main/resources/banner.txt @@ -0,0 +1,8 @@ + + , /) , +___ _(/ ___ __ __ _ __ __ _____/ +// (__(_(_(_ /_(_)(_(_/ (_/ (__(/_(_/_ /_)_/ (_(_) /(__(_/_ + .-/ .-/ .-/ / .-/ + (_/ (_/ (_/ (_/ + +:: MidJourney Proxy :: v2.5 diff --git a/src/main/resources/config/application.yml b/src/main/resources/config/application.yml new file mode 100644 index 0000000000000000000000000000000000000000..b9097b2d6af31ee336b358e45bc3ff1d054af9e1 --- /dev/null +++ b/src/main/resources/config/application.yml @@ -0,0 +1,23 @@ +server: + port: 8080 + servlet: + context-path: /mj +logging: + level: + ROOT: info + com.github.novicezk.midjourney: debug +knife4j: + enable: true + openapi: + title: Midjourney Proxy API文档 + description: 代理 MidJourney 的discord频道,实现api形式调用AI绘图 + concat: novicezk + url: https://github.com/novicezk/midjourney-proxy + version: v2.5 + terms-of-service-url: https://github.com/novicezk/midjourney-proxy + group: + api: + group-name: API + api-rule: package + api-rule-resources: + - com.github.novicezk.midjourney.controller \ No newline at end of file diff --git a/src/main/resources/mime.types b/src/main/resources/mime.types new file mode 100644 index 0000000000000000000000000000000000000000..fcff7ef006540d0747ba3616d2edeef8fc1bd961 --- /dev/null +++ b/src/main/resources/mime.types @@ -0,0 +1,82 @@ +text/html:html htm shtml +text/css:css +text/xml:xml + +text/mathml:mml +text/plain:txt +text/vnd.sun.j2me.app-descriptor:jad +text/vnd.wap.wml:wml +text/x-component:htc + +image/gif:gif +image/jpeg:jpg jpeg +image/png:png +image/tiff:tif tiff +image/vnd.wap.wbmp:wbmp +image/x-icon:ico +image/x-jng:jng +image/x-ms-bmp:bmp +image/svg+xml:svg svgz +image/webp:webp + +application/javascript:js +application/x-javascript:js +application/atom+xml:atom +application/rss+xml:rss + +application/font-woff:woff +application/java-archive:jar war ear +application/json:json +application/mac-binhex40:hqx +application/msword:doc +application/pdf:pdf +application/postscript:ps eps ai +application/rtf:rtf +application/vnd.apple.mpegurl:m3u8 +application/vnd.ms-excel:xls +application/vnd.ms-fontobject:eot +application/vnd.ms-powerpoint:ppt +application/vnd.wap.wmlc:wmlc +application/vnd.google-earth.kml+xml:kml +application/vnd.google-earth.kmz:kmz +application/x-7z-compressed:7z +application/x-cocoa:cco +application/x-java-archive-diff:jardiff +application/x-java-jnlp-file:jnlp +application/x-makeself:run +application/x-perl:pl pm +application/x-pilot:prc pdb +application/x-rar-compressed:rar +application/x-redhat-package-manager:rpm +application/x-sea:sea +application/x-shockwave-flash:swf +application/x-stuffit:sit +application/x-tcl:tcl tk +application/x-x509-ca-cert:der pem crt +application/x-xpinstall:xpi +application/xhtml+xml:xhtml +application/xspf+xml:xspf +application/zip:zip + +application/vnd.openxmlformats-officedocument.wordprocessingml.document:docx +application/vnd.openxmlformats-officedocument.spreadsheetml.sheet:xlsx +application/vnd.openxmlformats-officedocument.presentationml.presentation:pptx + +audio/midi:mid midi kar +audio/mpeg:mp3 +audio/ogg:ogg +audio/x-m4a:m4a +audio/x-realaudio:ra + +video/3gpp:3gpp 3gp +video/mp2t:ts +video/mp4:mp4 +video/mpeg:mpeg mpg +video/quicktime:mov +video/webm:webm +video/x-flv:flv +video/x-m4v:m4v +video/x-mng:mng +video/x-ms-asf:asx asf +video/x-ms-wmv:wmv +video/x-msvideo:avi