diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..70b2d635f2b542524d3459aa051b213b6a683e05 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,22 @@ +target/ + +### IntelliJ IDEA ### +.idea +*.iws +*.iml +*.ipr + +### VS Code ### +.vscode/ + +### Macos ### +.DS_Store + +### application config # +config/application.yml + +.git +.gitignore +docker +docs +README.md \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..646719d4fbe6ebf26b33e0210895237082eada43 --- /dev/null +++ b/.gitignore @@ -0,0 +1,40 @@ +target/ +!.mvn/wrapper/maven-wrapper.jar +!**/src/main/** +!**/src/test/** +bin/ + +### STS ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### IntelliJ IDEA ### +.idea +*.iws +*.iml +*.ipr + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ + +### VS Code ### +.vscode/ + +### Macos ### +.DS_Store + +### application config # +config/application.yml + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..edb45cf22255b187f0d56ca3f2664b06b4fe9278 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM maven:3.8.5-openjdk-17 + +ARG user=spring +ARG group=spring + +ENV SPRING_HOME=/home/spring + +RUN groupadd -g 1000 ${group} \ + && useradd -d "$SPRING_HOME" -u 1000 -g 1000 -m -s /bin/bash ${user} \ + && mkdir -p $SPRING_HOME/config \ + && mkdir -p $SPRING_HOME/logs \ + && chown -R ${user}:${group} $SPRING_HOME/config $SPRING_HOME/logs + +# Railway 不支持使用 VOLUME, 本地需要构建时,取消下一行的注释 +# VOLUME ["$SPRING_HOME/config", "$SPRING_HOME/logs"] + +USER ${user} +WORKDIR $SPRING_HOME + +COPY . . + +RUN mvn clean package \ + && mv target/midjourney-proxy-*.jar ./app.jar \ + && rm -rf target + +EXPOSE 8080 9876 + +ENV JAVA_OPTS -XX:MaxRAMPercentage=85 -Djava.awt.headless=true -XX:+HeapDumpOnOutOfMemoryError \ + -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -Xlog:gc:file=/home/spring/logs/gc.log \ + -Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.port=9876 -Dcom.sun.management.jmxremote.ssl=false \ + -Dcom.sun.management.jmxremote.authenticate=false -Dlogging.file.path=/home/spring/logs \ + -Dserver.port=8080 -Duser.timezone=Asia/Shanghai + +ENTRYPOINT ["bash","-c","java $JAVA_OPTS -jar app.jar"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..d6112a611ad9cf345b44777080c9f54cb02caa54 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,35 @@ +FROM openjdk:17.0 + +ARG user=spring +ARG group=spring + +ENV SPRING_HOME=/home/spring +ENV APP_HOME=$SPRING_HOME/app + +ENV JAVA_OPTS -XX:MaxRAMPercentage=85 -Djava.awt.headless=true -XX:+HeapDumpOnOutOfMemoryError \ + -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -Xlog:gc:file=/home/spring/logs/gc.log \ + -Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.port=9876 -Dcom.sun.management.jmxremote.ssl=false \ + -Dcom.sun.management.jmxremote.authenticate=false -Dlogging.file.path=/home/spring/logs \ + -Dserver.port=8080 -Duser.timezone=Asia/Shanghai + +RUN groupadd -g 1000 ${group} \ + && useradd -d "$SPRING_HOME" -u 1000 -g 1000 -m -s /bin/bash ${user} \ + && mkdir -p $SPRING_HOME/config \ + && mkdir -p $SPRING_HOME/logs \ + && mkdir -p $APP_HOME \ + && chown -R ${user}:${group} $SPRING_HOME/config $SPRING_HOME/logs $APP_HOME + +VOLUME ["$SPRING_HOME/config", "$SPRING_HOME/logs"] + +USER ${user} + +WORKDIR $SPRING_HOME + +EXPOSE 8080 9876 + +ENTRYPOINT ["bash","-c","java $JAVA_OPTS -cp ./app org.springframework.boot.loader.JarLauncher"] + +COPY --chown=${user}:${group} dependencies $APP_HOME/ +COPY --chown=${user}:${group} spring-boot-loader $APP_HOME/ +COPY --chown=${user}:${group} snapshot-dependencies $APP_HOME/ +COPY --chown=${user}:${group} application $APP_HOME/ diff --git a/docker/build-image.sh b/docker/build-image.sh new file mode 100644 index 0000000000000000000000000000000000000000..2499e9a1d634fd8db11293fe981da064f638edaa --- /dev/null +++ b/docker/build-image.sh @@ -0,0 +1,34 @@ +#!/bin/bash +set -e -u -o pipefail + +if [ $# -lt 1 ]; then + echo 'version is required' + exit 1 +fi + +VERSION=$1 +ARCH=amd64 + +if [ $# -ge 2 ]; then + ARCH=$2 +fi + +JAR_FILE_COUNT=$(find "../target/" -maxdepth 1 -name '*.jar' | wc -l) +if [ $JAR_FILE_COUNT == 0 ]; then + echo "jar file not found, please execute: mvn clean package" + exit 1 +fi + +JAR_FILE_NAME=$(ls ../target/*.jar|grep -v source) +echo ${JAR_FILE_NAME} + +cp ${JAR_FILE_NAME} ./app.jar + +java -Djarmode=layertools -jar app.jar extract + +docker build . -t midjourney-proxy:${VERSION} + +rm -rf application dependencies snapshot-dependencies spring-boot-loader app.jar + +docker tag midjourney-proxy:${VERSION} novicezk/midjourney-proxy-${ARCH}:${VERSION} +docker push novicezk/midjourney-proxy-${ARCH}:${VERSION} \ No newline at end of file diff --git a/docker/build-manifest.sh b/docker/build-manifest.sh new file mode 100644 index 0000000000000000000000000000000000000000..ce04849e83b20ff2570a795cbe7fdec1be300876 --- /dev/null +++ b/docker/build-manifest.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e -u -o pipefail + +if [ $# -lt 1 ]; then + echo 'version is required' + exit 1 +fi + +VERSION=$1 + +echo "create manifest..." +docker manifest create novicezk/midjourney-proxy:${VERSION} novicezk/midjourney-proxy-amd64:${VERSION} novicezk/midjourney-proxy-arm64v8:${VERSION} + +echo "annotate amd64..." +docker manifest annotate novicezk/midjourney-proxy:${VERSION} novicezk/midjourney-proxy-amd64:${VERSION} --os linux --arch amd64 + +echo "annotate arm64v8..." +docker manifest annotate novicezk/midjourney-proxy:${VERSION} novicezk/midjourney-proxy-arm64v8:${VERSION} --os linux --arch arm64 --variant v8 + +echo "push manifest..." +docker manifest push novicezk/midjourney-proxy:${VERSION} \ No newline at end of file diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000000000000000000000000000000000000..a22e3173043f53c365c46a2405e122c6f06081e6 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,124 @@ +# API接口说明 + +`http://ip:port/mj` 已有api文档,此处仅作补充 + +## 1. 数据结构 + +### 任务 +| 字段 | 类型 | 示例 | 描述 | +|:-----:|:----:|:----|:----| +| id | string | 1689231405853400 | 任务ID | +| action | string | IMAGINE | 任务类型: IMAGINE(绘图)、UPSCALE(选中放大)、VARIATION(选中变换)、REROLL(重新执行)、DESCRIBE(图生文)、BLEAND(图片混合) | +| status | string | SUCCESS | 任务状态: NOT_START(未启动)、SUBMITTED(已提交处理)、IN_PROGRESS(执行中)、FAILURE(失败)、SUCCESS(成功) | +| prompt | string | 猫猫 | 提示词 | +| promptEn | string | Cat | 英文提示词 | +| description | string | /imagine 猫猫 | 任务描述 | +| submitTime | number | 1689231405854 | 提交时间 | +| startTime | number | 1689231442755 | 开始执行时间 | +| finishTime | number | 1689231544312 | 结束时间 | +| progress | string | 100% | 任务进度 | +| imageUrl | string | https://cdn.discordapp.com/attachments/xxx/xxx/xxxx.png | 生成图片的url, 成功或执行中时有值,可能为png或webp | +| failReason | string | [Invalid parameter] Invalid value | 失败原因, 失败时有值 | +| properties | object | {"finalPrompt": "Cat"} | 任务的扩展属性,系统内部使用 | + + +## 2. 任务提交返回 +- code=1: 提交成功,result为任务ID + ```json + { + "code": 1, + "description": "成功", + "result": "8498455807619990" + } + ``` +- code=21: 任务已存在,U时可能发生 + ```json + { + "code": 21, + "description": "任务已存在", + "result": "0741798445574458", + "properties": { + "status": "SUCCESS", + "imageUrl": "https://xxxx" + } + } + ``` +- code=22: 提交成功,进入队列等待 + ```json + { + "code": 22, + "description": "排队中,前面还有1个任务", + "result": "0741798445574458", + "properties": { + "numberOfQueues": 1 + } + } + ``` +- code=24: prompt包含敏感词 + ```json + { + "code": 24, + "description": "可能包含敏感词", + "properties": { + "promptEn": "nude body", + "bannedWord": "nude" + } + } + ``` +- other: 提交错误,description为错误描述 + +## 3. `/mj/submit/simple-change` 绘图变化-simple +接口作用同 `/mj/submit/change`(绘图变化),传参方式不同,该接口接收content,格式为`ID 操作`,例如:1320098173412546 U2 + +- 放大 U1~U4 +- 变换 V1~V4 +- 重新执行 R + +## 4. `/mj/submit/describe` 图生文 +```json +{ + // 图片的base64字符串 + "base64": "data:image/png;base64,xxx" +} +``` + +后续任务完成后,properties中finalPrompt即为图片生成的prompt +```json +{ + "id":"14001929738841620", + "action":"DESCRIBE", + "status": "SUCCESS", + "description":"/describe 14001929738841620.png", + "imageUrl":"https://cdn.discordapp.com/attachments/xxx/xxx/14001929738841620.png", + "properties": { + "finalPrompt": "1️⃣ Cat --ar 5:4\n\n2️⃣ Cat2 --ar 5:4\n\n3️⃣ Cat3 --ar 5:4\n\n4️⃣ Cat4 --ar 5:4" + } + // ... +} +``` + +## 5. 任务变更回调 +任务状态变化或进度改变时,会调用业务系统的接口 +- 接口地址为配置的 mj.notify-hook,任务提交时支持传`notifyHook`以改变此任务的回调地址 +- 两者都为空时,不触发回调 + +POST application/json +```json +{ + "id": "14001929738841620", + "action": "IMAGINE", + "status": "SUCCESS", + "prompt": "猫猫", + "promptEn": "Cat", + "description": "/imagine 猫猫", + "submitTime": 1689231405854, + "startTime": 1689231442755, + "finishTime": 1689231544312, + "progress": "100%", + "imageUrl": "https://cdn.discordapp.com/attachments/xxx/xxx/xxxx.png", + "failReason": null, + "properties": { + "finalPrompt": "Cat" + } +} +``` diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 0000000000000000000000000000000000000000..89b386ee91016f60834f3a1dca4fdd8d8278fb19 --- /dev/null +++ b/docs/config.md @@ -0,0 +1,41 @@ +## 配置项 + +| 变量名 | 非空 | 描述 | +| :-----| :----: | :---- | +| mj.discord.guild-id | 是 | discord服务器ID | +| mj.discord.channel-id | 是 | discord频道ID | +| mj.discord.user-token | 是 | discord用户Token | +| mj.discord.session-id | 否 | discord用户SessionId,建议从interactions请求中复制替换掉 | +| mj.discord.user-agent | 否 | 调用discord接口、连接wss时的user-agent,建议从浏览器network复制 | +| mj.api-secret | 否 | 接口密钥,为空不启用鉴权;调用接口时需要加请求头 mj-api-secret | +| mj.notify-hook | 否 | 全局的任务状态变更回调地址 | +| mj.notify-notify-pool-size | 否 | 通知回调线程池大小,默认10 | +| mj.task-store.type | 否 | 任务存储方式,默认in_memory(内存\重启后丢失),可选redis | +| mj.task-store.timeout | 否 | 任务过期时间,过期后删除,默认30天 | +| mj.queue.core-size | 否 | 并发数,默认为3 | +| mj.queue.queue-size | 否 | 等待队列,默认长度10 | +| mj.queue.timeout-minutes | 否 | 任务超时时间,默认为5分钟 | +| mj.proxy.host | 否 | 代理host,全局代理不生效时设置 | +| mj.proxy.port | 否 | 代理port,全局代理不生效时设置 | +| mj.ng-discord.server | 否 | https://discord.com 反代地址 | +| mj.ng-discord.cdn | 否 | https://cdn.discordapp.com 反代地址 | +| mj.ng-discord.wss | 否 | wss://gateway.discord.gg 反代地址 | +| mj.translate-way | 否 | 中文prompt翻译成英文的方式,可选null(默认)、baidu、gpt | +| mj.baidu-translate.appid | 否 | 百度翻译的appid | +| mj.baidu-translate.app-secret | 否 | 百度翻译的app-secret | +| mj.openai.gpt-api-url | 否 | 自定义gpt的接口地址,默认不需要配置 | +| mj.openai.gpt-api-key | 否 | gpt的api-key | +| mj.openai.timeout | 否 | openai调用的超时时间,默认30秒 | +| mj.openai.model | 否 | openai的模型,默认gpt-3.5-turbo | +| mj.openai.max-tokens | 否 | 返回结果的最大分词数,默认2048 | +| mj.openai.temperature | 否 | 相似度(0-2.0),默认0 | +| spring.redis | 否 | 任务存储方式设置为redis,需配置redis相关属性 | + +### spring.redis配置参考 +```yaml +spring: + redis: + host: 10.107.xxx.xxx + port: 6379 + password: xxx +``` \ No newline at end of file diff --git a/docs/discord-params.md b/docs/discord-params.md new file mode 100644 index 0000000000000000000000000000000000000000..5861845767e0328f8d9155645a2ba8d733259c67 --- /dev/null +++ b/docs/discord-params.md @@ -0,0 +1,16 @@ +## 获取discord配置参数 + +### 1. 获取用户Token +进入频道,打开network,刷新页面,找到 `messages` 的请求,这里的 authorization 即用户Token,后续设置到 `mj.discord.user-token` + +![User Token](img_8.png) + +### 2. 获取用户sessionId +进入频道,打开network,发送/imagine作图指令,找到 `interactions` 的请求,这里的 session_id 即用户sessionId,后续设置到 `mj.discord.session-id` + +![User Session](params_session_id.png) + +### 3. 获取服务器ID、频道ID + +频道的url里取出 服务器ID、频道ID,后续设置到配置项 +![Guild Channel ID](img_9.png) diff --git a/docs/docker-start.md b/docs/docker-start.md new file mode 100644 index 0000000000000000000000000000000000000000..6292c57242d784b52ad96f374af73426527d63fd --- /dev/null +++ b/docs/docker-start.md @@ -0,0 +1,21 @@ +## Docker 部署教程 + +1. /xxx/xxx/config目录下创建 application.yml(mj配置项)、banned-words.txt(可选,覆盖默认的敏感词文件);参考src/main/resources下的文件 +2. 启动容器,映射config目录 +```shell +docker run -d --name midjourney-proxy \ + -p 8080:8080 \ + -v /xxx/xxx/config:/home/spring/config \ + novicezk/midjourney-proxy:2.4 +``` +3. 访问 `http://ip:port/mj` 查看API文档 + +附: 不映射config目录方式,直接在启动命令中设置参数 +```shell +docker run -d --name midjourney-proxy \ + -p 8080:8080 \ + -e mj.discord.guild-id=xxx \ + -e mj.discord.channel-id=xxx \ + -e mj.discord.user-token=xxx \ + novicezk/midjourney-proxy:2.4 +``` diff --git a/docs/img_1.png b/docs/img_1.png new file mode 100644 index 0000000000000000000000000000000000000000..ef462a6d28c492342d017e86ec65b25d359f4852 Binary files /dev/null and b/docs/img_1.png differ diff --git a/docs/img_10.png b/docs/img_10.png new file mode 100644 index 0000000000000000000000000000000000000000..4e6185d5059d31a9e7aafbc61230b9b6388e32ec Binary files /dev/null and b/docs/img_10.png differ diff --git a/docs/img_2.png b/docs/img_2.png new file mode 100644 index 0000000000000000000000000000000000000000..f4ee6ccd7352e4d3849cf1553a41070ad1be6d9d Binary files /dev/null and b/docs/img_2.png differ diff --git a/docs/img_3.png b/docs/img_3.png new file mode 100644 index 0000000000000000000000000000000000000000..19772033d8154527dfdc973c55c6670ec6f1ab54 Binary files /dev/null and b/docs/img_3.png differ diff --git a/docs/img_4.png b/docs/img_4.png new file mode 100644 index 0000000000000000000000000000000000000000..85658e256ff09c993187c4aec357d700166dc08d Binary files /dev/null and b/docs/img_4.png differ diff --git a/docs/img_5.png b/docs/img_5.png new file mode 100644 index 0000000000000000000000000000000000000000..6a5bfdaaa5ba38bcef04003d00ddb09dd93a8380 Binary files /dev/null and b/docs/img_5.png differ diff --git a/docs/img_6.png b/docs/img_6.png new file mode 100644 index 0000000000000000000000000000000000000000..fc0b74d95603c48d4c49dd3fb17003bfd8e4473d Binary files /dev/null and b/docs/img_6.png differ diff --git a/docs/img_7.png b/docs/img_7.png new file mode 100644 index 0000000000000000000000000000000000000000..a14683ae7f37a0bcfb549bff9fed1ee4495a0a8c Binary files /dev/null and b/docs/img_7.png differ diff --git a/docs/img_8.png b/docs/img_8.png new file mode 100644 index 0000000000000000000000000000000000000000..33127d2842fc76b9d70ae764e70041a5cbe65053 Binary files /dev/null and b/docs/img_8.png differ diff --git a/docs/img_9.png b/docs/img_9.png new file mode 100644 index 0000000000000000000000000000000000000000..1e7bf5be281a5a58abc1303cd465c895a37d5b9f Binary files /dev/null and b/docs/img_9.png differ diff --git a/docs/manager-qrcode.png b/docs/manager-qrcode.png new file mode 100644 index 0000000000000000000000000000000000000000..2594f9a405a3219498f0cbfe1639cda670bdac4a Binary files /dev/null and b/docs/manager-qrcode.png differ diff --git a/docs/params_session_id.png b/docs/params_session_id.png new file mode 100644 index 0000000000000000000000000000000000000000..7ae18471a4542fde1c3de82cde3caa460395981c Binary files /dev/null and b/docs/params_session_id.png differ diff --git a/docs/params_user.png b/docs/params_user.png new file mode 100644 index 0000000000000000000000000000000000000000..3603b199e8bcf16302362fce66658970853ddcd9 Binary files /dev/null and b/docs/params_user.png differ diff --git a/docs/railway-start.md b/docs/railway-start.md new file mode 100644 index 0000000000000000000000000000000000000000..630f6131b6fb886d89462c28876aa4fc8a0e2ac2 --- /dev/null +++ b/docs/railway-start.md @@ -0,0 +1,36 @@ +## Railway 部署教程 + +Railway是一个提供弹性部署方案的平台,服务器在海外,方便MidJourney的调用。 + +**Railway 提供 5 美元,500 个小时/月的免费额度** + +### 1. Fork本仓库 +### 2. Railway使用github账号登录 +进入 [railway官网](https://railway.app) 选择 `Login` -> `Github`,登录github账号 + +### 3. [New Project](https://railway.app/new) 添加对fork仓库的授权 +![railway_img_1](./railway_img_1.png) +![railway_img_2](./railway_img_2.png) +![railway_img_3](./railway_img_3.png) + +### 4. 选择该fork仓库,新建项目,设置环境变量 +![railway_img_4](./railway_img_4.png) +![railway_img_5](./railway_img_5.png) +![railway_img_6](./railway_img_6.png) +![railway_img_7](./railway_img_7.png) +此处配置项参考 [Wiki / 配置项](https://github.com/novicezk/midjourney-proxy/wiki/%E9%85%8D%E7%BD%AE%E9%A1%B9) ,建议配置api密钥启用鉴权,接口调用时需添加请求头 `mj-api-secret` + +### 5. 启动服务 +进入刚才的Project,它应该已经在自动部署了,后续更新配置之后会自动重新部署 +![railway_img_8](./railway_img_8.png) + +若部署启动失败请查看日志,检查配置项 +![railway_img_9](./railway_img_9.png) +![railway_img_10](./railway_img_10.png) + +### 6. 开始使用 +等待部署成功后,生成随机域名 +![railway_img_11](./railway_img_11.png) +![railway_img_12](./railway_img_12.png) + +访问 `https://midjourney-proxy-***.app/mj` diff --git a/docs/railway_img_1.png b/docs/railway_img_1.png new file mode 100644 index 0000000000000000000000000000000000000000..1fa8f0577f20974a8a36e0082c5012f7589dfdf4 Binary files /dev/null and b/docs/railway_img_1.png differ diff --git a/docs/railway_img_10.png b/docs/railway_img_10.png new file mode 100644 index 0000000000000000000000000000000000000000..4bc7fee8221c4e4c00333837083552d26c4db1b4 Binary files /dev/null and b/docs/railway_img_10.png differ diff --git a/docs/railway_img_11.png b/docs/railway_img_11.png new file mode 100644 index 0000000000000000000000000000000000000000..f6864099a023cca4e66dbb56455c33ab3afcf725 Binary files /dev/null and b/docs/railway_img_11.png differ diff --git a/docs/railway_img_12.png b/docs/railway_img_12.png new file mode 100644 index 0000000000000000000000000000000000000000..35d107c76ad19a690ab90d38bcb264d5eca15541 Binary files /dev/null and b/docs/railway_img_12.png differ diff --git a/docs/railway_img_2.png b/docs/railway_img_2.png new file mode 100644 index 0000000000000000000000000000000000000000..41385d6bce71b861ed6934443c021336179fb789 Binary files /dev/null and b/docs/railway_img_2.png differ diff --git a/docs/railway_img_3.png b/docs/railway_img_3.png new file mode 100644 index 0000000000000000000000000000000000000000..4d1bfcbe9b57a28530e8e3d37fa07ac7d6522cc6 Binary files /dev/null and b/docs/railway_img_3.png differ diff --git a/docs/railway_img_4.png b/docs/railway_img_4.png new file mode 100644 index 0000000000000000000000000000000000000000..be1fec5ae230a4f8ec4597d29d773a342979e7f7 Binary files /dev/null and b/docs/railway_img_4.png differ diff --git a/docs/railway_img_5.png b/docs/railway_img_5.png new file mode 100644 index 0000000000000000000000000000000000000000..8a1e88bc278258817794fcc4fa4dc6c09d484878 Binary files /dev/null and b/docs/railway_img_5.png differ diff --git a/docs/railway_img_6.png b/docs/railway_img_6.png new file mode 100644 index 0000000000000000000000000000000000000000..f98c1969330a1e316c028e1693a97f9eb62a3113 Binary files /dev/null and b/docs/railway_img_6.png differ diff --git a/docs/railway_img_7.png b/docs/railway_img_7.png new file mode 100644 index 0000000000000000000000000000000000000000..3359bec9ff9e5ff782025183963803d7fbabc3ed Binary files /dev/null and b/docs/railway_img_7.png differ diff --git a/docs/railway_img_8.png b/docs/railway_img_8.png new file mode 100644 index 0000000000000000000000000000000000000000..51180fdc79b1ac4bcb2299c697f203538e32c035 Binary files /dev/null and b/docs/railway_img_8.png differ diff --git a/docs/railway_img_9.png b/docs/railway_img_9.png new file mode 100644 index 0000000000000000000000000000000000000000..5e07f10ec3cca077507f94909fa9ad5a27868cc0 Binary files /dev/null and b/docs/railway_img_9.png differ diff --git a/docs/receipt-code.png b/docs/receipt-code.png new file mode 100644 index 0000000000000000000000000000000000000000..0746aea2b3837b6bd5e7293bdeaf904d589e292e Binary files /dev/null and b/docs/receipt-code.png differ diff --git a/docs/zeabur-start.md b/docs/zeabur-start.md new file mode 100644 index 0000000000000000000000000000000000000000..26c7b9bb5d47b0402c04c47a6dac50004ffcdd43 --- /dev/null +++ b/docs/zeabur-start.md @@ -0,0 +1,33 @@ +## Zeabur 部署教程 + +### Zeabur 优势 +1. 新注册的 `Github` 账号可能无法使用 `Railway`,但是能用 `Zeabur` +2. 通过 `Railway` 部署的项目会自动生成一个域名,然而因为某些原因,形如 `*.up.railway.app` 的域名在国内无法访问 +3. `Zeabur` 服务器运行在国外,但是其生成的域名 `*.zeabur.app` 没有被污染,国内可直接访问 + +### 开始部署 + +1. 打开网址 https://zeabur.com/zh-CN +2. 点击现在开始 +3. 点击 `Sign in with GitHub` +4. 登陆你的 `Github` 账号 +5. 点击 `Authorize zeabur` 授权 +6. 点击 `创建项目` 并输入一个项目名称,点击 `创建` +7. 点击 `+` 添加服务,选择 `Git-Deploy service from source code in GitHub repository.` +8. 点击 `Configure GitHub` 根据需要选择 `All repositories` 或者 `Only select repositories` +9. 点击 `install`,之后自动跳转,最好再刷新一下页面 +10. 点击 你 fork 的 `midjourney-proxy` 项目 +11. 点击环境变量,点击编辑原始环境变量,添加你需要的环境变量 +12. 关于环境变量,与 `Railway` 稍有不同,需要把 `.` 和 `-` 全部换成 `_`,例如如下格式 + ```properties + PORT=8080 + mj_discord_guild_id=xxx + mj_discord_channel_id=xxx + mj_discord_user_token=xxx + mj_api_secret=*** + ``` + 此处配置项参考 [Wiki / 配置项](https://github.com/novicezk/midjourney-proxy/wiki/%E9%85%8D%E7%BD%AE%E9%A1%B9) ,建议配置api密钥启用鉴权,接口调用时需添加请求头 `mj-api-secret` +13. 然后取消 `Building`,点击 `Redeploy` (此做法是为了让环境变量生效) +14. 部署 `midjourney-proxy` 大概需要 `2` 分钟,此时你可以做的是:配置域名 +15. 点击下方的域名,点击生成域名,输入前缀,例如 `midjourney-proxy-demo`,点击保存;或者添加自定义域名,之后加上 `CNAME` 解析 +16. 等待部署成功,访问 `https://midjourney-proxy-demo.zeabur.app/mj` \ No newline at end of file diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..e3f9ae787dc29e822cdf58101c17b891cf22b2c7 --- /dev/null +++ b/pom.xml @@ -0,0 +1,125 @@ + + + 4.0.0 + + + org.springframework.boot + spring-boot-starter-parent + 2.6.14 + + + com.github.novicezk + midjourney-proxy + 2.4 + + + 5.8.18 + 20220924 + 5.0.0-beta.9 + 1.0.14-beta1 + 2.0.0 + 4.1.0 + 1.21 + 4.5.14 + 17 + ${java.version} + ${java.version} + + + + + org.springframework.boot + spring-boot-starter-web + + + org.springframework.boot + spring-boot-starter-data-redis + + + + cn.hutool + hutool-core + ${hutool.version} + + + cn.hutool + hutool-cache + ${hutool.version} + + + cn.hutool + hutool-crypto + ${hutool.version} + + + org.json + json + ${org-json.version} + + + net.dv8tion + JDA + ${jda.version} + + + club.minnced + opus-java + + + + + com.unfbx + chatgpt-java + ${chatgpt-java.version} + + + slf4j-simple + org.slf4j + + + + + eu.maxschuster + dataurl + ${dataurl.version} + + + com.github.xiaoymin + knife4j-openapi2-spring-boot-starter + ${knife4j.verison} + + + eu.bitwalker + UserAgentUtils + ${user-agent-utils.verison} + + + org.apache.httpcomponents + httpclient + ${httpclient.verison} + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + org.projectlombok + lombok + true + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + + diff --git a/src/main/java/com/github/novicezk/midjourney/Constants.java b/src/main/java/com/github/novicezk/midjourney/Constants.java new file mode 100644 index 0000000000000000000000000000000000000000..bb8b7fa6c79e9ed43b2330e87b7092b9e467ace6 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/Constants.java @@ -0,0 +1,19 @@ +package com.github.novicezk.midjourney; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public final class Constants { + // 任务扩展属性 start + public static final String TASK_PROPERTY_NOTIFY_HOOK = "notifyHook"; + public static final String TASK_PROPERTY_FINAL_PROMPT = "finalPrompt"; + public static final String TASK_PROPERTY_MESSAGE_ID = "messageId"; + public static final String TASK_PROPERTY_MESSAGE_HASH = "messageHash"; + + public static final String TASK_PROPERTY_PROGRESS_MESSAGE_ID = "progressMessageId"; + public static final String TASK_PROPERTY_FLAGS = "flags"; + public static final String TASK_PROPERTY_NONCE = "nonce"; + // 任务扩展属性 end + + public static final String API_SECRET_HEADER_NAME = "mj-api-secret"; +} diff --git a/src/main/java/com/github/novicezk/midjourney/ProxyApplication.java b/src/main/java/com/github/novicezk/midjourney/ProxyApplication.java new file mode 100644 index 0000000000000000000000000000000000000000..f46750039f3f633d4486bd517aeaba1cb0c9d80f --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/ProxyApplication.java @@ -0,0 +1,19 @@ +package com.github.novicezk.midjourney; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Import; +import org.springframework.scheduling.annotation.EnableScheduling; +import spring.config.BeanConfig; +import spring.config.WebMvcConfig; + +@EnableScheduling +@SpringBootApplication +@Import({BeanConfig.class, WebMvcConfig.class}) +public class ProxyApplication { + + public static void main(String[] args) { + SpringApplication.run(ProxyApplication.class, args); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/ProxyProperties.java b/src/main/java/com/github/novicezk/midjourney/ProxyProperties.java new file mode 100644 index 0000000000000000000000000000000000000000..6969b647bf637d9a921309c9ac2a3d1ae126deb7 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/ProxyProperties.java @@ -0,0 +1,201 @@ +package com.github.novicezk.midjourney; + +import com.github.novicezk.midjourney.enums.TranslateWay; +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +import java.time.Duration; + +@Data +@Component +@ConfigurationProperties(prefix = "mj") +public class ProxyProperties { + /** + * task存储配置. + */ + private final TaskStore taskStore = new TaskStore(); + /** + * discord配置. + */ + private final DiscordConfig discord = new DiscordConfig(); + /** + * 代理配置. + */ + private final ProxyConfig proxy = new ProxyConfig(); + /** + * 反代配置. + */ + private final NgDiscordConfig ngDiscord = new NgDiscordConfig(); + /** + * 任务队列配置. + */ + private final TaskQueueConfig queue = new TaskQueueConfig(); + /** + * 百度翻译配置. + */ + private final BaiduTranslateConfig baiduTranslate = new BaiduTranslateConfig(); + /** + * openai配置. + */ + private final OpenaiConfig openai = new OpenaiConfig(); + /** + * 中文prompt翻译方式. + */ + private TranslateWay translateWay = TranslateWay.NULL; + /** + * 接口密钥,为空不启用鉴权;调用接口时需要加请求头 mj-api-secret. + */ + private String apiSecret; + /** + * 任务状态变更回调地址. + */ + private String notifyHook; + /** + * 通知回调线程池大小. + */ + private int notifyPoolSize = 10; + /** + * 接口是否返回任务扩展属性. + */ + private boolean includeTaskExtended = false; + + @Data + public static class DiscordConfig { + /** + * 你的服务器id. + */ + private String guildId; + /** + * 你的频道id. + */ + private String channelId; + /** + * 你的登录token. + */ + private String userToken; + /** + * 你的频道id. + */ + private String sessionId = "9c4055428e13bcbf2248a6b36084c5f3"; + /** + * 调用discord接口、连接wss时的user-agent. + */ + private String userAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36"; + /** + * 是否使用user_token连接wss,默认启用. + */ + private boolean userWss = true; + /** + * 你的机器人token. + */ + private String botToken; + } + + @Data + public static class BaiduTranslateConfig { + /** + * 百度翻译的APP_ID. + */ + private String appid; + /** + * 百度翻译的密钥. + */ + private String appSecret; + } + + @Data + public static class OpenaiConfig { + /** + * 自定义gpt的api-url. + */ + private String gptApiUrl; + /** + * gpt的api-key. + */ + private String gptApiKey; + /** + * 超时时间. + */ + private Duration timeout = Duration.ofSeconds(30); + /** + * 使用的模型. + */ + private String model = "gpt-3.5-turbo"; + /** + * 返回结果的最大分词数. + */ + private int maxTokens = 2048; + /** + * 相似度,取值 0-2. + */ + private double temperature = 0; + } + + @Data + public static class TaskStore { + /** + * 任务过期时间,默认30天. + */ + private Duration timeout = Duration.ofDays(30); + /** + * 任务存储方式: redis(默认)、in_memory. + */ + private Type type = Type.IN_MEMORY; + + public enum Type { + /** + * redis. + */ + REDIS, + /** + * in_memory. + */ + IN_MEMORY + } + } + + @Data + public static class ProxyConfig { + /** + * 代理host. + */ + private String host; + /** + * 代理端口. + */ + private Integer port; + } + + @Data + public static class NgDiscordConfig { + /** + * https://discord.com 反代. + */ + private String server; + /** + * https://cdn.discordapp.com 反代. + */ + private String cdn; + /** + * wss://gateway.discord.gg 反代. + */ + private String wss; + } + + @Data + public static class TaskQueueConfig { + /** + * 并发数. + */ + private int coreSize = 3; + /** + * 等待队列长度. + */ + private int queueSize = 10; + /** + * 任务超时时间(分钟). + */ + private int timeoutMinutes = 5; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/ReturnCode.java b/src/main/java/com/github/novicezk/midjourney/ReturnCode.java new file mode 100644 index 0000000000000000000000000000000000000000..ec60264215181d9aaea3c4781d0cd7c1ccf1ea23 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/ReturnCode.java @@ -0,0 +1,42 @@ +package com.github.novicezk.midjourney; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public final class ReturnCode { + /** + * 成功. + */ + public static final int SUCCESS = 1; + /** + * 数据未找到. + */ + public static final int NOT_FOUND = 3; + /** + * 校验错误. + */ + public static final int VALIDATION_ERROR = 4; + /** + * 系统异常. + */ + public static final int FAILURE = 9; + + /** + * 已存在. + */ + public static final int EXISTED = 21; + /** + * 排队中. + */ + public static final int IN_QUEUE = 22; + /** + * 队列已满. + */ + public static final int QUEUE_REJECTED = 23; + /** + * prompt包含敏感词. + */ + public static final int BANNED_PROMPT = 24; + + +} \ No newline at end of file diff --git a/src/main/java/com/github/novicezk/midjourney/controller/SubmitController.java b/src/main/java/com/github/novicezk/midjourney/controller/SubmitController.java new file mode 100644 index 0000000000000000000000000000000000000000..8009fe72e0900214de99be0dfb34b0031c735a8c --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/controller/SubmitController.java @@ -0,0 +1,227 @@ +package com.github.novicezk.midjourney.controller; + +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.util.RandomUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.dto.BaseSubmitDTO; +import com.github.novicezk.midjourney.dto.SubmitBlendDTO; +import com.github.novicezk.midjourney.dto.SubmitChangeDTO; +import com.github.novicezk.midjourney.dto.SubmitDescribeDTO; +import com.github.novicezk.midjourney.dto.SubmitImagineDTO; +import com.github.novicezk.midjourney.dto.SubmitSimpleChangeDTO; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.exception.BannedPromptException; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.service.TaskService; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.service.TranslateService; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.BannedPromptUtils; +import com.github.novicezk.midjourney.util.ConvertUtils; +import com.github.novicezk.midjourney.util.MimeTypeUtils; +import com.github.novicezk.midjourney.util.SnowFlake; +import com.github.novicezk.midjourney.util.TaskChangeParams; +import eu.maxschuster.dataurl.DataUrl; +import eu.maxschuster.dataurl.DataUrlSerializer; +import eu.maxschuster.dataurl.IDataUrlSerializer; +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.net.MalformedURLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +@Api(tags = "任务提交") +@RestController +@RequestMapping("/submit") +@RequiredArgsConstructor +public class SubmitController { + private final TranslateService translateService; + private final TaskStoreService taskStoreService; + private final ProxyProperties properties; + private final TaskService taskService; + + @ApiOperation(value = "提交Imagine任务") + @PostMapping("/imagine") + public SubmitResultVO imagine(@RequestBody SubmitImagineDTO imagineDTO) { + String prompt = imagineDTO.getPrompt(); + if (CharSequenceUtil.isBlank(prompt)) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "prompt不能为空"); + } + prompt = prompt.trim(); + Task task = newTask(imagineDTO); + task.setAction(TaskAction.IMAGINE); + task.setPrompt(prompt); + String promptEn = translatePrompt(prompt); + try { + BannedPromptUtils.checkBanned(promptEn); + } catch (BannedPromptException e) { + return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词") + .setProperty("promptEn", promptEn).setProperty("bannedWord", e.getMessage()); + } + List base64Array = Optional.ofNullable(imagineDTO.getBase64Array()).orElse(new ArrayList<>()); + if (CharSequenceUtil.isNotBlank(imagineDTO.getBase64())) { + base64Array.add(imagineDTO.getBase64()); + } + List dataUrls; + try { + dataUrls = ConvertUtils.convertBase64Array(base64Array); + } catch (MalformedURLException e) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误"); + } + task.setPromptEn(promptEn); + task.setDescription("/imagine " + prompt); + return this.taskService.submitImagine(task, dataUrls); + } + + @ApiOperation(value = "绘图变化-simple") + @PostMapping("/simple-change") + public SubmitResultVO simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) { + TaskChangeParams changeParams = ConvertUtils.convertChangeParams(simpleChangeDTO.getContent()); + if (changeParams == null) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "content参数错误"); + } + SubmitChangeDTO changeDTO = new SubmitChangeDTO(); + changeDTO.setAction(changeParams.getAction()); + changeDTO.setTaskId(changeParams.getId()); + changeDTO.setIndex(changeParams.getIndex()); + changeDTO.setState(simpleChangeDTO.getState()); + changeDTO.setNotifyHook(simpleChangeDTO.getNotifyHook()); + return change(changeDTO); + } + + @ApiOperation(value = "绘图变化") + @PostMapping("/change") + public SubmitResultVO change(@RequestBody SubmitChangeDTO changeDTO) { + if (CharSequenceUtil.isBlank(changeDTO.getTaskId())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "taskId不能为空"); + } + if (!Set.of(TaskAction.UPSCALE, TaskAction.VARIATION, TaskAction.REROLL).contains(changeDTO.getAction())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "action参数错误"); + } + String description = "/up " + changeDTO.getTaskId(); + if (TaskAction.REROLL.equals(changeDTO.getAction())) { + description += " R"; + } else { + description += " " + changeDTO.getAction().name().charAt(0) + changeDTO.getIndex(); + } + if (TaskAction.UPSCALE.equals(changeDTO.getAction())) { + TaskCondition condition = new TaskCondition().setDescription(description); + Task existTask = this.taskStoreService.findOne(condition); + if (existTask != null) { + return SubmitResultVO.of(ReturnCode.EXISTED, "任务已存在", existTask.getId()) + .setProperty("status", existTask.getStatus()) + .setProperty("imageUrl", existTask.getImageUrl()); + } + } + Task targetTask = this.taskStoreService.get(changeDTO.getTaskId()); + if (targetTask == null) { + return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效"); + } + if (!TaskStatus.SUCCESS.equals(targetTask.getStatus())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务状态错误"); + } + if (!Set.of(TaskAction.IMAGINE, TaskAction.VARIATION, TaskAction.REROLL, TaskAction.BLEND).contains(targetTask.getAction())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务不允许执行变化"); + } + Task task = newTask(changeDTO); + task.setAction(changeDTO.getAction()); + task.setPrompt(targetTask.getPrompt()); + task.setPromptEn(targetTask.getPromptEn()); + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, targetTask.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT)); + task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID)); + task.setDescription(description); + int messageFlags = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_FLAGS); + String messageId = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_ID); + String messageHash = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_HASH); + if (TaskAction.UPSCALE.equals(changeDTO.getAction())) { + return this.taskService.submitUpscale(task, messageId, messageHash, changeDTO.getIndex(), messageFlags); + } else if (TaskAction.VARIATION.equals(changeDTO.getAction())) { + return this.taskService.submitVariation(task, messageId, messageHash, changeDTO.getIndex(), messageFlags); + } else { + return this.taskService.submitReroll(task, messageId, messageHash, messageFlags); + } + } + + @ApiOperation(value = "提交Describe任务") + @PostMapping("/describe") + public SubmitResultVO describe(@RequestBody SubmitDescribeDTO describeDTO) { + if (CharSequenceUtil.isBlank(describeDTO.getBase64())) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64不能为空"); + } + IDataUrlSerializer serializer = new DataUrlSerializer(); + DataUrl dataUrl; + try { + dataUrl = serializer.unserialize(describeDTO.getBase64()); + } catch (MalformedURLException e) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误"); + } + Task task = newTask(describeDTO); + task.setAction(TaskAction.DESCRIBE); + String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType()); + task.setDescription("/describe " + taskFileName); + return this.taskService.submitDescribe(task, dataUrl); + } + + @ApiOperation(value = "提交Blend任务") + @PostMapping("/blend") + public SubmitResultVO blend(@RequestBody SubmitBlendDTO blendDTO) { + List base64Array = blendDTO.getBase64Array(); + if (base64Array == null || base64Array.size() < 2 || base64Array.size() > 5) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64List参数错误"); + } + if (blendDTO.getDimensions() == null) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "dimensions参数错误"); + } + IDataUrlSerializer serializer = new DataUrlSerializer(); + List dataUrlList = new ArrayList<>(); + try { + for (String base64 : base64Array) { + DataUrl dataUrl = serializer.unserialize(base64); + dataUrlList.add(dataUrl); + } + } catch (MalformedURLException e) { + return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误"); + } + Task task = newTask(blendDTO); + task.setAction(TaskAction.BLEND); + task.setDescription("/blend " + task.getId() + " " + dataUrlList.size()); + return this.taskService.submitBlend(task, dataUrlList, blendDTO.getDimensions()); + } + + private Task newTask(BaseSubmitDTO base) { + Task task = new Task(); + task.setId(System.currentTimeMillis() + "" + RandomUtil.randomNumbers(3)); + task.setSubmitTime(System.currentTimeMillis()); + task.setState(base.getState()); + String notifyHook = CharSequenceUtil.isBlank(base.getNotifyHook()) ? this.properties.getNotifyHook() : base.getNotifyHook(); + task.setProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook); + task.setProperty(Constants.TASK_PROPERTY_NONCE, SnowFlake.INSTANCE.nextId()); + return task; + } + + private String translatePrompt(String prompt) { + String promptEn; + int paramStart = prompt.indexOf(" --"); + if (paramStart > 0) { + promptEn = this.translateService.translateToEnglish(prompt.substring(0, paramStart)).trim() + prompt.substring(paramStart); + } else { + promptEn = this.translateService.translateToEnglish(prompt).trim(); + } + if (CharSequenceUtil.isBlank(promptEn)) { + promptEn = prompt; + } + return promptEn; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/controller/TaskController.java b/src/main/java/com/github/novicezk/midjourney/controller/TaskController.java new file mode 100644 index 0000000000000000000000000000000000000000..c4f5b37e7fcb48ddbbff75a181bcbaa304601c92 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/controller/TaskController.java @@ -0,0 +1,65 @@ +package com.github.novicezk.midjourney.controller; + +import cn.hutool.core.comparator.CompareUtil; +import com.github.novicezk.midjourney.dto.TaskConditionDTO; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskQueueHelper; +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import io.swagger.annotations.ApiParam; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +@Api(tags = "任务查询") +@RestController +@RequestMapping("/task") +@RequiredArgsConstructor +public class TaskController { + private final TaskStoreService taskStoreService; + private final TaskQueueHelper taskQueueHelper; + + @ApiOperation(value = "查询所有任务") + @GetMapping("/list") + public List list() { + return this.taskStoreService.list().stream() + .sorted((t1, t2) -> CompareUtil.compare(t2.getSubmitTime(), t1.getSubmitTime())) + .toList(); + } + + @ApiOperation(value = "指定ID获取任务") + @GetMapping("/{id}/fetch") + public Task fetch(@ApiParam(value = "任务ID") @PathVariable String id) { + return this.taskStoreService.get(id); + } + + @ApiOperation(value = "查询任务队列") + @GetMapping("/queue") + public List queue() { + Set queueTaskIds = this.taskQueueHelper.getQueueTaskIds(); + return queueTaskIds.stream().map(this.taskStoreService::get).filter(Objects::nonNull) + .sorted(Comparator.comparing(Task::getSubmitTime)) + .toList(); + } + + @ApiOperation(value = "根据条件查询任务") + @PostMapping("/list-by-condition") + public List listByCondition(@RequestBody TaskConditionDTO conditionDTO) { + if (conditionDTO.getIds() == null) { + return Collections.emptyList(); + } + return conditionDTO.getIds().stream().map(this.taskStoreService::get).filter(Objects::nonNull).toList(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/BaseSubmitDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/BaseSubmitDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..50a8a4f63a3a8966e30dd1c47c30030fa0348b47 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/BaseSubmitDTO.java @@ -0,0 +1,16 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public abstract class BaseSubmitDTO { + + @ApiModelProperty("自定义参数") + protected String state; + + @ApiModelProperty("回调地址, 为空时使用全局notifyHook") + protected String notifyHook; +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitBlendDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitBlendDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..dd56aecc92ceec496a0dfb282a22ee331ffcf5bf --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitBlendDTO.java @@ -0,0 +1,21 @@ +package com.github.novicezk.midjourney.dto; + +import com.github.novicezk.midjourney.enums.BlendDimensions; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + +import java.util.List; + +@Data +@ApiModel("Blend提交参数") +@EqualsAndHashCode(callSuper = true) +public class SubmitBlendDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "图片base64数组", required = true, example = "[\"data:image/png;base64,xxx1\", \"data:image/png;base64,xxx2\"]") + private List base64Array; + + @ApiModelProperty(value = "比例: PORTRAIT(2:3); SQUARE(1:1); LANDSCAPE(3:2)", example = "SQUARE") + private BlendDimensions dimensions = BlendDimensions.SQUARE; +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitChangeDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitChangeDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..1b6b493b6a7dd3fab3146937ab30f04c15fbe226 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitChangeDTO.java @@ -0,0 +1,25 @@ +package com.github.novicezk.midjourney.dto; + +import com.github.novicezk.midjourney.enums.TaskAction; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + + +@Data +@ApiModel("变化任务提交参数") +@EqualsAndHashCode(callSuper = true) +public class SubmitChangeDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "任务ID", required = true, example = "\"1320098173412546\"") + private String taskId; + + @ApiModelProperty(value = "UPSCALE(放大); VARIATION(变换); REROLL(重新生成)", required = true, + allowableValues = "UPSCALE, VARIATION, REROLL", example = "UPSCALE") + private TaskAction action; + + @ApiModelProperty(value = "序号(1~4), action为UPSCALE,VARIATION时必传", allowableValues = "range[1, 4]", example = "1") + private Integer index; + +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitDescribeDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitDescribeDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..8c34a35fb8b47a2dd1099b5c3028191ecfd77938 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitDescribeDTO.java @@ -0,0 +1,15 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + +@Data +@ApiModel("Describe提交参数") +@EqualsAndHashCode(callSuper = true) +public class SubmitDescribeDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "图片base64", required = true, example = "data:image/png;base64,xxx") + private String base64; +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitImagineDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitImagineDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..51ff49de114a56fabd9f490286af11b7634573aa --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitImagineDTO.java @@ -0,0 +1,26 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + +import java.util.List; + + +@Data +@ApiModel("Imagine提交参数") +@EqualsAndHashCode(callSuper = true) +public class SubmitImagineDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "提示词", required = true, example = "Cat") + private String prompt; + + @ApiModelProperty(value = "垫图base64数组") + private List base64Array; + + @ApiModelProperty(hidden = true) + @Deprecated(since = "3.0", forRemoval = true) + private String base64; + +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/SubmitSimpleChangeDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/SubmitSimpleChangeDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..6d5e4537a21413b2455b159ea1e0dd5adb2b1d2b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/SubmitSimpleChangeDTO.java @@ -0,0 +1,17 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.EqualsAndHashCode; + + +@Data +@ApiModel("变化任务提交参数-simple") +@EqualsAndHashCode(callSuper = true) +public class SubmitSimpleChangeDTO extends BaseSubmitDTO { + + @ApiModelProperty(value = "变化描述: ID $action$index", required = true, example = "1320098173412546 U2") + private String content; + +} diff --git a/src/main/java/com/github/novicezk/midjourney/dto/TaskConditionDTO.java b/src/main/java/com/github/novicezk/midjourney/dto/TaskConditionDTO.java new file mode 100644 index 0000000000000000000000000000000000000000..09ae83e2a5b0ff669d9e03f53197b79ef8cccc57 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/dto/TaskConditionDTO.java @@ -0,0 +1,14 @@ +package com.github.novicezk.midjourney.dto; + +import io.swagger.annotations.ApiModel; +import lombok.Data; + +import java.util.List; + +@Data +@ApiModel("任务查询参数") +public class TaskConditionDTO { + + private List ids; + +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/BlendDimensions.java b/src/main/java/com/github/novicezk/midjourney/enums/BlendDimensions.java new file mode 100644 index 0000000000000000000000000000000000000000..ea1c4d6cd71a6672a9e4a507bf00d3024e756104 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/BlendDimensions.java @@ -0,0 +1,21 @@ +package com.github.novicezk.midjourney.enums; + + +public enum BlendDimensions { + + PORTRAIT("2:3"), + + SQUARE("1:1"), + + LANDSCAPE("3:2"); + + private final String value; + + BlendDimensions(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/MessageType.java b/src/main/java/com/github/novicezk/midjourney/enums/MessageType.java new file mode 100644 index 0000000000000000000000000000000000000000..330654635d44338b19a9013e74025828cea0c639 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/MessageType.java @@ -0,0 +1,26 @@ +package com.github.novicezk.midjourney.enums; + + +public enum MessageType { + /** + * 创建. + */ + CREATE, + /** + * 修改. + */ + UPDATE, + /** + * 删除. + */ + DELETE; + + public static MessageType of(String type) { + return switch (type) { + case "MESSAGE_CREATE" -> CREATE; + case "MESSAGE_UPDATE" -> UPDATE; + case "MESSAGE_DELETE" -> DELETE; + default -> null; + }; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/TaskAction.java b/src/main/java/com/github/novicezk/midjourney/enums/TaskAction.java new file mode 100644 index 0000000000000000000000000000000000000000..35811002960728c2b37bcc7243128de60d19900b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/TaskAction.java @@ -0,0 +1,30 @@ +package com.github.novicezk.midjourney.enums; + + +public enum TaskAction { + /** + * 生成图片. + */ + IMAGINE, + /** + * 选中放大. + */ + UPSCALE, + /** + * 选中其中的一张图,生成四张相似的. + */ + VARIATION, + /** + * 重新执行. + */ + REROLL, + /** + * 图转prompt. + */ + DESCRIBE, + /** + * 多图混合. + */ + BLEND + +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/TaskStatus.java b/src/main/java/com/github/novicezk/midjourney/enums/TaskStatus.java new file mode 100644 index 0000000000000000000000000000000000000000..4b4fa5408fe9520f727a9a2bb46c1ee07f50f0ef --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/TaskStatus.java @@ -0,0 +1,26 @@ +package com.github.novicezk.midjourney.enums; + + +public enum TaskStatus { + /** + * 未启动. + */ + NOT_START, + /** + * 已提交. + */ + SUBMITTED, + /** + * 执行中. + */ + IN_PROGRESS, + /** + * 失败. + */ + FAILURE, + /** + * 成功. + */ + SUCCESS + +} diff --git a/src/main/java/com/github/novicezk/midjourney/enums/TranslateWay.java b/src/main/java/com/github/novicezk/midjourney/enums/TranslateWay.java new file mode 100644 index 0000000000000000000000000000000000000000..495297dffca6b9bbdc7c03656d4725f718dbb324 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/enums/TranslateWay.java @@ -0,0 +1,18 @@ +package com.github.novicezk.midjourney.enums; + + +public enum TranslateWay { + /** + * 百度翻译. + */ + BAIDU, + /** + * GPT翻译. + */ + GPT, + /** + * 不翻译. + */ + NULL + +} diff --git a/src/main/java/com/github/novicezk/midjourney/exception/BannedPromptException.java b/src/main/java/com/github/novicezk/midjourney/exception/BannedPromptException.java new file mode 100644 index 0000000000000000000000000000000000000000..fd5bfc5e4a386ea7c36185715d510c4a38ae220b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/exception/BannedPromptException.java @@ -0,0 +1,8 @@ +package com.github.novicezk.midjourney.exception; + +public class BannedPromptException extends Exception { + + public BannedPromptException(String message) { + super(message); + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/exception/SnowFlakeException.java b/src/main/java/com/github/novicezk/midjourney/exception/SnowFlakeException.java new file mode 100644 index 0000000000000000000000000000000000000000..648bc0029a3cafc0d1edf43b21cb1987e076af5f --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/exception/SnowFlakeException.java @@ -0,0 +1,16 @@ +package com.github.novicezk.midjourney.exception; + +public class SnowFlakeException extends RuntimeException { + + public SnowFlakeException(String message) { + super(message); + } + + public SnowFlakeException(String message, Throwable cause) { + super(message, cause); + } + + public SnowFlakeException(Throwable cause) { + super(cause); + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/result/Message.java b/src/main/java/com/github/novicezk/midjourney/result/Message.java new file mode 100644 index 0000000000000000000000000000000000000000..50868aa882b69eb216e249fc91d7407ee8e62174 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/result/Message.java @@ -0,0 +1,57 @@ +package com.github.novicezk.midjourney.result; + +import com.github.novicezk.midjourney.ReturnCode; +import lombok.Getter; + +@Getter +public class Message { + private final int code; + private final String description; + private final T result; + + public static Message success() { + return new Message<>(ReturnCode.SUCCESS, "成功"); + } + + public static Message success(T result) { + return new Message<>(ReturnCode.SUCCESS, "成功", result); + } + + public static Message success(int code, String description, T result) { + return new Message<>(code, description, result); + } + + public static Message notFound() { + return new Message<>(ReturnCode.NOT_FOUND, "数据未找到"); + } + + public static Message validationError() { + return new Message<>(ReturnCode.VALIDATION_ERROR, "校验错误"); + } + + public static Message failure() { + return new Message<>(ReturnCode.FAILURE, "系统异常"); + } + + public static Message failure(String description) { + return new Message<>(ReturnCode.FAILURE, description); + } + + public static Message of(int code, String description) { + return new Message<>(code, description); + } + + public static Message of(int code, String description, T result) { + return new Message<>(code, description, result); + } + + private Message(int code, String description) { + this(code, description, null); + } + + private Message(int code, String description, T result) { + this.code = code; + this.description = description; + this.result = result; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/result/SubmitResultVO.java b/src/main/java/com/github/novicezk/midjourney/result/SubmitResultVO.java new file mode 100644 index 0000000000000000000000000000000000000000..d6d1ffb43c391ade678a0d297d5808d0dbb2f742 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/result/SubmitResultVO.java @@ -0,0 +1,62 @@ +package com.github.novicezk.midjourney.result; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.util.HashMap; +import java.util.Map; + +@Data +@ApiModel("提交结果") +public class SubmitResultVO { + + @ApiModelProperty(value = "状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)", required = true, example = "1") + private int code; + + @ApiModelProperty(value = "描述", required = true, example = "提交成功") + private String description; + + @ApiModelProperty(value = "任务ID", example = "1320098173412546") + private String result; + + @ApiModelProperty(value = "扩展字段") + private Map properties = new HashMap<>(); + + public SubmitResultVO setProperty(String name, Object value) { + this.properties.put(name, value); + return this; + } + + public SubmitResultVO removeProperty(String name) { + this.properties.remove(name); + return this; + } + + public Object getProperty(String name) { + return this.properties.get(name); + } + + @SuppressWarnings("unchecked") + public T getPropertyGeneric(String name) { + return (T) getProperty(name); + } + + public T getProperty(String name, Class clz) { + return clz.cast(getProperty(name)); + } + + public static SubmitResultVO of(int code, String description, String result) { + return new SubmitResultVO(code, description, result); + } + + public static SubmitResultVO fail(int code, String description) { + return new SubmitResultVO(code, description, null); + } + + private SubmitResultVO(int code, String description, String result) { + this.code = code; + this.description = description; + this.result = result; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/DiscordService.java b/src/main/java/com/github/novicezk/midjourney/service/DiscordService.java new file mode 100644 index 0000000000000000000000000000000000000000..9ed0059b88624974d7c94a8b9b54121a7fe290bb --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/DiscordService.java @@ -0,0 +1,28 @@ +package com.github.novicezk.midjourney.service; + + +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.result.Message; +import eu.maxschuster.dataurl.DataUrl; + +import java.util.List; + +public interface DiscordService { + + Message imagine(String prompt, String nonce); + + Message upscale(String messageId, int index, String messageHash, int messageFlags, String nonce); + + Message variation(String messageId, int index, String messageHash, int messageFlags, String nonce); + + Message reroll(String messageId, String messageHash, int messageFlags, String nonce); + + Message describe(String finalFileName, String nonce); + + Message blend(List finalFileNames, BlendDimensions dimensions, String nonce); + + Message upload(String fileName, DataUrl dataUrl); + + Message sendImageMessage(String content, String finalFileName); + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/DiscordServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/DiscordServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..b8eb42b3a01d2e3a0a327c8d9716647fd6ef0317 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/DiscordServiceImpl.java @@ -0,0 +1,241 @@ +package com.github.novicezk.midjourney.service; + + +import cn.hutool.core.io.resource.ResourceUtil; +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.result.Message; +import com.github.novicezk.midjourney.support.DiscordHelper; +import eu.maxschuster.dataurl.DataUrl; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.json.JSONArray; +import org.json.JSONObject; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Service; +import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.RestTemplate; + +import javax.annotation.PostConstruct; +import java.util.List; + +@Slf4j +@Service +@RequiredArgsConstructor +public class DiscordServiceImpl implements DiscordService { + private final ProxyProperties properties; + private final DiscordHelper discordHelper; + + private String discordApiUrl; + private String userAgent; + + private String discordUploadUrl; + private String discordSendMessageUrl; + + private String imagineParamsJson; + private String upscaleParamsJson; + private String variationParamsJson; + private String rerollParamsJson; + private String describeParamsJson; + private String blendParamsJson; + private String messageParamsJson; + + private String discordUserToken; + private String discordGuildId; + private String discordChannelId; + private String discordSessionId; + + @PostConstruct + void init() { + ProxyProperties.DiscordConfig discord = this.properties.getDiscord(); + this.discordUserToken = discord.getUserToken(); + this.discordGuildId = discord.getGuildId(); + this.discordChannelId = discord.getChannelId(); + this.discordSessionId = discord.getSessionId(); + this.userAgent = discord.getUserAgent(); + + String serverUrl = this.discordHelper.getServer(); + this.discordApiUrl = serverUrl + "/api/v9/interactions"; + this.discordUploadUrl = serverUrl + "/api/v9/channels/" + this.discordChannelId + "/attachments"; + this.discordSendMessageUrl = serverUrl + "/api/v9/channels/" + this.discordChannelId + "/messages"; + + this.imagineParamsJson = ResourceUtil.readUtf8Str("api-params/imagine.json"); + this.upscaleParamsJson = ResourceUtil.readUtf8Str("api-params/upscale.json"); + this.variationParamsJson = ResourceUtil.readUtf8Str("api-params/variation.json"); + this.rerollParamsJson = ResourceUtil.readUtf8Str("api-params/reroll.json"); + this.describeParamsJson = ResourceUtil.readUtf8Str("api-params/describe.json"); + this.blendParamsJson = ResourceUtil.readUtf8Str("api-params/blend.json"); + this.messageParamsJson = ResourceUtil.readUtf8Str("api-params/message.json"); + } + + @Override + public Message imagine(String prompt, String nonce) { + String paramsStr = replaceInteractionParams(this.imagineParamsJson, nonce); + JSONObject params = new JSONObject(paramsStr); + params.getJSONObject("data").getJSONArray("options").getJSONObject(0) + .put("value", prompt); + return postJsonAndCheckStatus(params.toString()); + } + + @Override + public Message upscale(String messageId, int index, String messageHash, int messageFlags, String nonce) { + String paramsStr = replaceInteractionParams(this.upscaleParamsJson, nonce) + .replace("$message_id", messageId) + .replace("$index", String.valueOf(index)) + .replace("$message_hash", messageHash); + paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString(); + return postJsonAndCheckStatus(paramsStr); + } + + @Override + public Message variation(String messageId, int index, String messageHash, int messageFlags, String nonce) { + String paramsStr = replaceInteractionParams(this.variationParamsJson, nonce) + .replace("$message_id", messageId) + .replace("$index", String.valueOf(index)) + .replace("$message_hash", messageHash); + paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString(); + return postJsonAndCheckStatus(paramsStr); + } + + @Override + public Message reroll(String messageId, String messageHash, int messageFlags, String nonce) { + String paramsStr = replaceInteractionParams(this.rerollParamsJson, nonce) + .replace("$message_id", messageId) + .replace("$message_hash", messageHash); + paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString(); + return postJsonAndCheckStatus(paramsStr); + } + + @Override + public Message describe(String finalFileName, String nonce) { + String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true); + String paramsStr = replaceInteractionParams(this.describeParamsJson, nonce) + .replace("$file_name", fileName) + .replace("$final_file_name", finalFileName); + return postJsonAndCheckStatus(paramsStr); + } + + @Override + public Message blend(List finalFileNames, BlendDimensions dimensions, String nonce) { + String paramsStr = replaceInteractionParams(this.blendParamsJson, nonce); + JSONObject params = new JSONObject(paramsStr); + JSONArray options = params.getJSONObject("data").getJSONArray("options"); + JSONArray attachments = params.getJSONObject("data").getJSONArray("attachments"); + for (int i = 0; i < finalFileNames.size(); i++) { + String finalFileName = finalFileNames.get(i); + String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true); + JSONObject attachment = new JSONObject().put("id", String.valueOf(i)) + .put("filename", fileName) + .put("uploaded_filename", finalFileName); + attachments.put(attachment); + JSONObject option = new JSONObject().put("type", 11) + .put("name", "image" + (i + 1)) + .put("value", i); + options.put(option); + } + options.put(new JSONObject().put("type", 3) + .put("name", "dimensions") + .put("value", "--ar " + dimensions.getValue())); + return postJsonAndCheckStatus(params.toString()); + } + + private String replaceInteractionParams(String paramsStr, String nonce) { + return paramsStr.replace("$guild_id", this.discordGuildId) + .replace("$channel_id", this.discordChannelId) + .replace("$session_id", this.discordSessionId) + .replace("$nonce", nonce); + } + + @Override + public Message upload(String fileName, DataUrl dataUrl) { + try { + JSONObject fileObj = new JSONObject(); + fileObj.put("filename", fileName); + fileObj.put("file_size", dataUrl.getData().length); + fileObj.put("id", "0"); + JSONObject params = new JSONObject() + .put("files", new JSONArray().put(fileObj)); + ResponseEntity responseEntity = postJson(this.discordUploadUrl, params.toString()); + if (responseEntity.getStatusCode() != HttpStatus.OK) { + log.error("上传图片到discord失败, status: {}, msg: {}", responseEntity.getStatusCodeValue(), responseEntity.getBody()); + return Message.of(ReturnCode.VALIDATION_ERROR, "上传图片到discord失败"); + } + JSONArray array = new JSONObject(responseEntity.getBody()).getJSONArray("attachments"); + if (array.length() == 0) { + return Message.of(ReturnCode.VALIDATION_ERROR, "上传图片到discord失败"); + } + String uploadUrl = array.getJSONObject(0).getString("upload_url"); + String uploadFilename = array.getJSONObject(0).getString("upload_filename"); + putFile(uploadUrl, dataUrl); + return Message.success(uploadFilename); + } catch (Exception e) { + log.error("上传图片到discord失败", e); + return Message.of(ReturnCode.FAILURE, "上传图片到discord失败"); + } + } + + @Override + public Message sendImageMessage(String content, String finalFileName) { + String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true); + String paramsStr = this.messageParamsJson.replace("$content", content) + .replace("$channel_id", this.discordChannelId) + .replace("$file_name", fileName) + .replace("$final_file_name", finalFileName); + ResponseEntity responseEntity = postJson(this.discordSendMessageUrl, paramsStr); + if (responseEntity.getStatusCode() != HttpStatus.OK) { + log.error("发送图片消息到discord失败, status: {}, msg: {}", responseEntity.getStatusCodeValue(), responseEntity.getBody()); + return Message.of(ReturnCode.VALIDATION_ERROR, "发送图片消息到discord失败"); + } + JSONObject result = new JSONObject(responseEntity.getBody()); + JSONArray attachments = result.optJSONArray("attachments"); + if (!attachments.isEmpty()) { + return Message.success(attachments.getJSONObject(0).optString("url")); + } + return Message.failure("发送图片消息到discord失败: 图片不存在"); + } + + private void putFile(String uploadUrl, DataUrl dataUrl) { + HttpHeaders headers = new HttpHeaders(); + headers.add("User-Agent", this.userAgent); + headers.setContentType(MediaType.valueOf(dataUrl.getMimeType())); + headers.setContentLength(dataUrl.getData().length); + HttpEntity requestEntity = new HttpEntity<>(dataUrl.getData(), headers); + new RestTemplate().put(uploadUrl, requestEntity); + } + + private ResponseEntity postJson(String paramsStr) { + return postJson(discordApiUrl, paramsStr); + } + + private ResponseEntity postJson(String url, String paramsStr) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.set("Authorization", this.discordUserToken); + headers.add("User-Agent", this.userAgent); + HttpEntity httpEntity = new HttpEntity<>(paramsStr, headers); + return new RestTemplate().postForEntity(url, httpEntity, String.class); + } + + private Message postJsonAndCheckStatus(String paramsStr) { + try { + ResponseEntity responseEntity = postJson(paramsStr); + if (responseEntity.getStatusCode() == HttpStatus.NO_CONTENT) { + return Message.success(); + } + return Message.of(responseEntity.getStatusCodeValue(), CharSequenceUtil.sub(responseEntity.getBody(), 0, 100)); + } catch (HttpClientErrorException e) { + try { + JSONObject error = new JSONObject(e.getResponseBodyAsString()); + return Message.of(error.optInt("code", e.getRawStatusCode()), error.optString("message")); + } catch (Exception je) { + return Message.of(e.getRawStatusCode(), CharSequenceUtil.sub(e.getMessage(), 0, 100)); + } + } + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/NotifyService.java b/src/main/java/com/github/novicezk/midjourney/service/NotifyService.java new file mode 100644 index 0000000000000000000000000000000000000000..7f18454b8721b4eaea02b3dc949afebda54b8a1e --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/NotifyService.java @@ -0,0 +1,10 @@ +package com.github.novicezk.midjourney.service; + + +import com.github.novicezk.midjourney.support.Task; + +public interface NotifyService { + + void notifyTaskChange(Task task); + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/NotifyServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/NotifyServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..d850c13694ea52ed2cded2201b0c1861677080d1 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/NotifyServiceImpl.java @@ -0,0 +1,76 @@ +package com.github.novicezk.midjourney.service; + +import cn.hutool.cache.CacheUtil; +import cn.hutool.cache.impl.TimedCache; +import cn.hutool.core.exceptions.CheckedUtil; +import cn.hutool.core.text.CharSequenceUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.support.Task; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +import java.time.Duration; + +@Slf4j +@Service +public class NotifyServiceImpl implements NotifyService { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private final ThreadPoolTaskExecutor executor; + private final TimedCache taskLocks = CacheUtil.newTimedCache(Duration.ofHours(1).toMillis()); + + public NotifyServiceImpl(ProxyProperties properties) { + this.executor = new ThreadPoolTaskExecutor(); + this.executor.setCorePoolSize(properties.getNotifyPoolSize()); + this.executor.setThreadNamePrefix("TaskNotify-"); + this.executor.initialize(); + } + + @Override + public void notifyTaskChange(Task task) { + String notifyHook = task.getPropertyGeneric(Constants.TASK_PROPERTY_NOTIFY_HOOK); + if (CharSequenceUtil.isBlank(notifyHook)) { + return; + } + String taskId = task.getId(); + TaskStatus taskStatus = task.getStatus(); + Object taskLock = this.taskLocks.get(taskId, (CheckedUtil.Func0Rt) Object::new); + try { + String paramsStr = OBJECT_MAPPER.writeValueAsString(task); + this.executor.execute(() -> { + synchronized (taskLock) { + try { + ResponseEntity responseEntity = postJson(notifyHook, paramsStr); + if (responseEntity.getStatusCode() == HttpStatus.OK) { + log.debug("推送任务变更成功, 任务ID: {}, status: {}, notifyHook: {}", taskId, taskStatus, notifyHook); + } else { + log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, code: {}, msg: {}", taskId, notifyHook, responseEntity.getStatusCodeValue(), responseEntity.getBody()); + } + } catch (Exception e) { + log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, 描述: {}", taskId, notifyHook, e.getMessage()); + } + } + }); + } catch (JsonProcessingException e) { + log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, 描述: {}", taskId, notifyHook, e.getMessage()); + } + } + + private ResponseEntity postJson(String notifyHook, String paramsJson) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity httpEntity = new HttpEntity<>(paramsJson, headers); + return new RestTemplate().postForEntity(notifyHook, httpEntity, String.class); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/TaskService.java b/src/main/java/com/github/novicezk/midjourney/service/TaskService.java new file mode 100644 index 0000000000000000000000000000000000000000..3477d7e0e32baa24f1da507d111af3eda28d9a1b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/TaskService.java @@ -0,0 +1,23 @@ +package com.github.novicezk.midjourney.service; + +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.support.Task; +import eu.maxschuster.dataurl.DataUrl; + +import java.util.List; + +public interface TaskService { + + SubmitResultVO submitImagine(Task task, List dataUrls); + + SubmitResultVO submitUpscale(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags); + + SubmitResultVO submitVariation(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags); + + SubmitResultVO submitReroll(Task task, String targetMessageId, String targetMessageHash, int messageFlags); + + SubmitResultVO submitDescribe(Task task, DataUrl dataUrl); + + SubmitResultVO submitBlend(Task task, List dataUrls, BlendDimensions dimensions); +} \ No newline at end of file diff --git a/src/main/java/com/github/novicezk/midjourney/service/TaskServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/TaskServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..e98493b9f1404bf5adc6f48a550df332b7dce3f6 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/TaskServiceImpl.java @@ -0,0 +1,98 @@ +package com.github.novicezk.midjourney.service; + +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.enums.BlendDimensions; +import com.github.novicezk.midjourney.result.Message; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskQueueHelper; +import com.github.novicezk.midjourney.util.MimeTypeUtils; +import eu.maxschuster.dataurl.DataUrl; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; + +@Slf4j +@Service +@RequiredArgsConstructor +public class TaskServiceImpl implements TaskService { + private final TaskStoreService taskStoreService; + private final DiscordService discordService; + private final TaskQueueHelper taskQueueHelper; + + @Override + public SubmitResultVO submitImagine(Task task, List dataUrls) { + return this.taskQueueHelper.submitTask(task, () -> { + List imageUrls = new ArrayList<>(); + for (DataUrl dataUrl : dataUrls) { + String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType()); + Message uploadResult = this.discordService.upload(taskFileName, dataUrl); + if (uploadResult.getCode() != ReturnCode.SUCCESS) { + return Message.of(uploadResult.getCode(), uploadResult.getDescription()); + } + String finalFileName = uploadResult.getResult(); + Message sendImageResult = this.discordService.sendImageMessage("upload image: " + finalFileName, finalFileName); + if (sendImageResult.getCode() != ReturnCode.SUCCESS) { + return Message.of(sendImageResult.getCode(), sendImageResult.getDescription()); + } + imageUrls.add(sendImageResult.getResult()); + } + if (!imageUrls.isEmpty()) { + task.setPrompt(String.join(" ", imageUrls) + " " + task.getPrompt()); + task.setPromptEn(String.join(" ", imageUrls) + " " + task.getPromptEn()); + task.setDescription("/imagine " + task.getPrompt()); + this.taskStoreService.save(task); + } + return this.discordService.imagine(task.getPromptEn(), task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)); + }); + } + + @Override + public SubmitResultVO submitUpscale(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags) { + return this.taskQueueHelper.submitTask(task, () -> this.discordService.upscale(targetMessageId, index, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE))); + } + + @Override + public SubmitResultVO submitVariation(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags) { + return this.taskQueueHelper.submitTask(task, () -> this.discordService.variation(targetMessageId, index, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE))); + } + + @Override + public SubmitResultVO submitReroll(Task task, String targetMessageId, String targetMessageHash, int messageFlags) { + return this.taskQueueHelper.submitTask(task, () -> this.discordService.reroll(targetMessageId, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE))); + } + + @Override + public SubmitResultVO submitDescribe(Task task, DataUrl dataUrl) { + return this.taskQueueHelper.submitTask(task, () -> { + String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType()); + Message uploadResult = this.discordService.upload(taskFileName, dataUrl); + if (uploadResult.getCode() != ReturnCode.SUCCESS) { + return Message.of(uploadResult.getCode(), uploadResult.getDescription()); + } + String finalFileName = uploadResult.getResult(); + return this.discordService.describe(finalFileName, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)); + }); + } + + @Override + public SubmitResultVO submitBlend(Task task, List dataUrls, BlendDimensions dimensions) { + return this.taskQueueHelper.submitTask(task, () -> { + List finalFileNames = new ArrayList<>(); + for (DataUrl dataUrl : dataUrls) { + String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType()); + Message uploadResult = this.discordService.upload(taskFileName, dataUrl); + if (uploadResult.getCode() != ReturnCode.SUCCESS) { + return Message.of(uploadResult.getCode(), uploadResult.getDescription()); + } + finalFileNames.add(uploadResult.getResult()); + } + return this.discordService.blend(finalFileNames, dimensions, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)); + }); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/TaskStoreService.java b/src/main/java/com/github/novicezk/midjourney/service/TaskStoreService.java new file mode 100644 index 0000000000000000000000000000000000000000..cec8e46adce16e46e16a05fd4162bf83e395afe8 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/TaskStoreService.java @@ -0,0 +1,23 @@ +package com.github.novicezk.midjourney.service; + + +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; + +import java.util.List; + +public interface TaskStoreService { + + void save(Task task); + + void delete(String id); + + Task get(String id); + + List list(); + + List list(TaskCondition condition); + + Task findOne(TaskCondition condition); + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/TranslateService.java b/src/main/java/com/github/novicezk/midjourney/service/TranslateService.java new file mode 100644 index 0000000000000000000000000000000000000000..a413d483567727b5452f65d2496880a99d6a6021 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/TranslateService.java @@ -0,0 +1,13 @@ +package com.github.novicezk.midjourney.service; + +import java.util.regex.Pattern; + +public interface TranslateService { + + String translateToEnglish(String prompt); + + default boolean containsChinese(String prompt) { + return Pattern.compile("[\u4e00-\u9fa5]").matcher(prompt).find(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/store/InMemoryTaskStoreServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/store/InMemoryTaskStoreServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..4e25a5e279e2be7bea1224b2ca87fb320c7a61e6 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/store/InMemoryTaskStoreServiceImpl.java @@ -0,0 +1,52 @@ +package com.github.novicezk.midjourney.service.store; + +import cn.hutool.cache.CacheUtil; +import cn.hutool.cache.impl.TimedCache; +import cn.hutool.core.collection.ListUtil; +import cn.hutool.core.stream.StreamUtil; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; + +import java.time.Duration; +import java.util.List; + + +public class InMemoryTaskStoreServiceImpl implements TaskStoreService { + private final TimedCache taskMap; + + public InMemoryTaskStoreServiceImpl(Duration timeout) { + this.taskMap = CacheUtil.newTimedCache(timeout.toMillis()); + } + + @Override + public void save(Task task) { + this.taskMap.put(task.getId(), task); + } + + @Override + public void delete(String key) { + this.taskMap.remove(key); + } + + @Override + public Task get(String key) { + return this.taskMap.get(key); + } + + @Override + public List list() { + return ListUtil.toList(this.taskMap.iterator()); + } + + @Override + public List list(TaskCondition condition) { + return StreamUtil.of(this.taskMap.iterator()).filter(condition).toList(); + } + + @Override + public Task findOne(TaskCondition condition) { + return StreamUtil.of(this.taskMap.iterator()).filter(condition).findFirst().orElse(null); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/store/RedisTaskStoreServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/store/RedisTaskStoreServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..287c7bac8e33716f581504ad470003ff4e18ac1c --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/store/RedisTaskStoreServiceImpl.java @@ -0,0 +1,74 @@ +package com.github.novicezk.midjourney.service.store; + +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import org.springframework.data.redis.core.Cursor; +import org.springframework.data.redis.core.RedisCallback; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.ScanOptions; +import org.springframework.data.redis.core.ValueOperations; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +public class RedisTaskStoreServiceImpl implements TaskStoreService { + private static final String KEY_PREFIX = "mj-task-store::"; + + private final Duration timeout; + private final RedisTemplate redisTemplate; + + public RedisTaskStoreServiceImpl(Duration timeout, RedisTemplate redisTemplate) { + this.timeout = timeout; + this.redisTemplate = redisTemplate; + } + + @Override + public void save(Task task) { + this.redisTemplate.opsForValue().set(getRedisKey(task.getId()), task, this.timeout); + } + + @Override + public void delete(String id) { + this.redisTemplate.delete(getRedisKey(id)); + } + + @Override + public Task get(String id) { + return this.redisTemplate.opsForValue().get(getRedisKey(id)); + } + + @Override + public List list() { + Set keys = this.redisTemplate.execute((RedisCallback>) connection -> { + Cursor cursor = connection.scan(ScanOptions.scanOptions().match(KEY_PREFIX + "*").count(1000).build()); + return cursor.stream().map(String::new).collect(Collectors.toSet()); + }); + if (keys == null || keys.isEmpty()) { + return Collections.emptyList(); + } + ValueOperations operations = this.redisTemplate.opsForValue(); + return keys.stream().map(operations::get) + .filter(Objects::nonNull) + .toList(); + } + + @Override + public List list(TaskCondition condition) { + return list().stream().filter(condition).toList(); + } + + @Override + public Task findOne(TaskCondition condition) { + return list().stream().filter(condition).findFirst().orElse(null); + } + + private String getRedisKey(String id) { + return KEY_PREFIX + id; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/translate/BaiduTranslateServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/translate/BaiduTranslateServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..e672862c3e5bb697b77aa24e90be406963b68560 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/translate/BaiduTranslateServiceImpl.java @@ -0,0 +1,56 @@ +package com.github.novicezk.midjourney.service.translate; + + +import cn.hutool.core.exceptions.ValidateException; +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.util.RandomUtil; +import cn.hutool.crypto.digest.MD5; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.service.TranslateService; +import lombok.extern.slf4j.Slf4j; +import org.json.JSONObject; +import org.springframework.beans.factory.support.BeanDefinitionValidationException; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +@Slf4j +public class BaiduTranslateServiceImpl implements TranslateService { + private static final String TRANSLATE_API = "https://fanyi-api.baidu.com/api/trans/vip/translate"; + + private final String appid; + private final String appSecret; + + public BaiduTranslateServiceImpl(ProxyProperties.BaiduTranslateConfig translateConfig) { + this.appid = translateConfig.getAppid(); + this.appSecret = translateConfig.getAppSecret(); + if (!CharSequenceUtil.isAllNotBlank(this.appid, this.appSecret)) { + throw new BeanDefinitionValidationException("mj-proxy.baidu-translate.appid或mj-proxy.baidu-translate.app-secret未配置"); + } + } + + @Override + public String translateToEnglish(String prompt) { + if (!containsChinese(prompt)) { + return prompt; + } + String salt = RandomUtil.randomNumbers(5); + String sign = MD5.create().digestHex(this.appid + prompt + salt + this.appSecret); + String url = TRANSLATE_API + "?from=zh&to=en&appid=" + this.appid + "&salt=" + salt + "&q=" + prompt + "&sign=" + sign; + try { + ResponseEntity responseEntity = new RestTemplate().getForEntity(url, String.class); + if (responseEntity.getStatusCode() != HttpStatus.OK || CharSequenceUtil.isBlank(responseEntity.getBody())) { + throw new ValidateException(responseEntity.getStatusCodeValue() + " - " + responseEntity.getBody()); + } + JSONObject result = new JSONObject(responseEntity.getBody()); + if (result.has("error_code")) { + throw new ValidateException(result.getString("error_code") + " - " + result.getString("error_msg")); + } + return result.getJSONArray("trans_result").getJSONObject(0).getString("dst"); + } catch (Exception e) { + log.warn("调用百度翻译失败: {}", e.getMessage()); + } + return prompt; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/service/translate/GPTTranslateServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/translate/GPTTranslateServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..f391f1987ecfcfe1d6d30002c6b09657d6d491ab --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/translate/GPTTranslateServiceImpl.java @@ -0,0 +1,83 @@ +package com.github.novicezk.midjourney.service.translate; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.service.TranslateService; +import com.unfbx.chatgpt.OpenAiClient; +import com.unfbx.chatgpt.entity.chat.ChatChoice; +import com.unfbx.chatgpt.entity.chat.ChatCompletion; +import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse; +import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.function.KeyRandomStrategy; +import com.unfbx.chatgpt.interceptor.OpenAILogger; +import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import okhttp3.logging.HttpLoggingInterceptor; +import org.springframework.beans.factory.support.BeanDefinitionValidationException; + +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; + +@Slf4j +public class GPTTranslateServiceImpl implements TranslateService { + private final OpenAiClient openAiClient; + private final ProxyProperties.OpenaiConfig openaiConfig; + + public GPTTranslateServiceImpl(ProxyProperties properties) { + this.openaiConfig = properties.getOpenai(); + if (CharSequenceUtil.isBlank(this.openaiConfig.getGptApiKey())) { + throw new BeanDefinitionValidationException("mj-proxy.openai.gpt-api-key未配置"); + } + HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger()); + httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS); + OkHttpClient.Builder okHttpBuilder = new OkHttpClient.Builder() + .addInterceptor(httpLoggingInterceptor) + .addInterceptor(new OpenAiResponseInterceptor()) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .readTimeout(30, TimeUnit.SECONDS); + if (CharSequenceUtil.isNotBlank(properties.getProxy().getHost())) { + Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(properties.getProxy().getHost(), properties.getProxy().getPort())); + okHttpBuilder.proxy(proxy); + } + OpenAiClient.Builder apiBuilder = OpenAiClient.builder() + .apiKey(Collections.singletonList(this.openaiConfig.getGptApiKey())) + .keyStrategy(new KeyRandomStrategy()) + .okHttpClient(okHttpBuilder.build()); + if (CharSequenceUtil.isNotBlank(this.openaiConfig.getGptApiUrl())) { + apiBuilder.apiHost(this.openaiConfig.getGptApiUrl()); + } + this.openAiClient = apiBuilder.build(); + } + + @Override + public String translateToEnglish(String prompt) { + if (!containsChinese(prompt)) { + return prompt; + } + Message m1 = Message.builder().role(Message.Role.SYSTEM).content("把中文翻译成英文").build(); + Message m2 = Message.builder().role(Message.Role.USER).content(prompt).build(); + ChatCompletion chatCompletion = ChatCompletion.builder() + .messages(Arrays.asList(m1, m2)) + .model(this.openaiConfig.getModel()) + .temperature(this.openaiConfig.getTemperature()) + .maxTokens(this.openaiConfig.getMaxTokens()) + .build(); + ChatCompletionResponse chatCompletionResponse = this.openAiClient.chatCompletion(chatCompletion); + try { + List choices = chatCompletionResponse.getChoices(); + if (!choices.isEmpty()) { + return choices.get(0).getMessage().getContent(); + } + } catch (Exception e) { + log.warn("调用chat-gpt接口翻译中文失败: {}", e.getMessage()); + } + return prompt; + } +} \ No newline at end of file diff --git a/src/main/java/com/github/novicezk/midjourney/service/translate/NoTranslateServiceImpl.java b/src/main/java/com/github/novicezk/midjourney/service/translate/NoTranslateServiceImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..3db1e688053e83cdae8be21fecfd8522a6c6c40d --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/service/translate/NoTranslateServiceImpl.java @@ -0,0 +1,14 @@ +package com.github.novicezk.midjourney.service.translate; + + +import com.github.novicezk.midjourney.service.TranslateService; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class NoTranslateServiceImpl implements TranslateService { + + @Override + public String translateToEnglish(String prompt) { + return prompt; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/ApiAuthorizeInterceptor.java b/src/main/java/com/github/novicezk/midjourney/support/ApiAuthorizeInterceptor.java new file mode 100644 index 0000000000000000000000000000000000000000..cf5eef87ec6219cb8df198e9e3136ad97304df03 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/ApiAuthorizeInterceptor.java @@ -0,0 +1,32 @@ +package com.github.novicezk.midjourney.support; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.ProxyProperties; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; +import org.springframework.web.servlet.HandlerInterceptor; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +@Component +@RequiredArgsConstructor +public class ApiAuthorizeInterceptor implements HandlerInterceptor { + private final ProxyProperties properties; + + @Override + public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { + if (CharSequenceUtil.isBlank(this.properties.getApiSecret())) { + return true; + } + String apiSecret = request.getHeader(Constants.API_SECRET_HEADER_NAME); + boolean authorized = CharSequenceUtil.equals(apiSecret, this.properties.getApiSecret()); + if (!authorized) { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED); + } + return authorized; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/DiscordHelper.java b/src/main/java/com/github/novicezk/midjourney/support/DiscordHelper.java new file mode 100644 index 0000000000000000000000000000000000000000..e5bf79e3f3527f201028962b194591a2e12bb8cf --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/DiscordHelper.java @@ -0,0 +1,115 @@ +package com.github.novicezk.midjourney.support; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import lombok.RequiredArgsConstructor; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.stereotype.Component; +import org.springframework.web.client.RestTemplate; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +@Component +@RequiredArgsConstructor +public class DiscordHelper { + private final ProxyProperties properties; + /** + * SIMPLE_URL_PREFIX. + */ + public static final String SIMPLE_URL_PREFIX = "https://s.mj.run/"; + /** + * DISCORD_SERVER_URL. + */ + public static final String DISCORD_SERVER_URL = "https://discord.com"; + /** + * DISCORD_CDN_URL. + */ + public static final String DISCORD_CDN_URL = "https://cdn.discordapp.com"; + /** + * DISCORD_WSS_URL. + */ + public static final String DISCORD_WSS_URL = "wss://gateway.discord.gg"; + + public String getServer() { + if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getServer())) { + return DISCORD_SERVER_URL; + } + String serverUrl = this.properties.getNgDiscord().getServer(); + if (serverUrl.endsWith("/")) { + serverUrl = serverUrl.substring(0, serverUrl.length() - 1); + } + return serverUrl; + } + + public String getCdn() { + if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getCdn())) { + return DISCORD_CDN_URL; + } + String cdnUrl = this.properties.getNgDiscord().getCdn(); + if (cdnUrl.endsWith("/")) { + cdnUrl = cdnUrl.substring(0, cdnUrl.length() - 1); + } + return cdnUrl; + } + + public String getWss() { + if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getWss())) { + return DISCORD_WSS_URL; + } + String wssUrl = this.properties.getNgDiscord().getWss(); + if (wssUrl.endsWith("/")) { + wssUrl = wssUrl.substring(0, wssUrl.length() - 1); + } + return wssUrl; + } + + + public String getRealPrompt(String prompt) { + String regex = ""; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(prompt); + while (matcher.find()) { + String url = matcher.group(); + String realUrl = getRealUrl(url.substring(1, url.length() - 1)); + prompt = prompt.replace(url, realUrl); + } + return prompt; + } + + public String getRealUrl(String url) { + if (!CharSequenceUtil.startWith(url, SIMPLE_URL_PREFIX)) { + return url; + } + ResponseEntity res = getDisableRedirectRestTemplate().getForEntity(url, Void.class); + if (res.getStatusCode() == HttpStatus.FOUND) { + return res.getHeaders().getFirst("Location"); + } + return url; + } + + public String findTaskIdWithCdnUrl(String url) { + if (!CharSequenceUtil.startWith(url, DISCORD_CDN_URL)) { + return null; + } + int hashStartIndex = url.lastIndexOf("/"); + String taskId = CharSequenceUtil.subBefore(url.substring(hashStartIndex + 1), ".", true); + if (CharSequenceUtil.length(taskId) == 16) { + return taskId; + } + return null; + } + + private RestTemplate getDisableRedirectRestTemplate() { + CloseableHttpClient httpClient = HttpClientBuilder.create() + .disableRedirectHandling() + .build(); + HttpComponentsClientHttpRequestFactory factory = new HttpComponentsClientHttpRequestFactory(httpClient); + return new RestTemplate(factory); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/Task.java b/src/main/java/com/github/novicezk/midjourney/support/Task.java new file mode 100644 index 0000000000000000000000000000000000000000..45691928d0ba3662c228d29d34cb97594f7f1212 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/Task.java @@ -0,0 +1,114 @@ +package com.github.novicezk.midjourney.support; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.enums.TaskStatus; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.io.Serial; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +@Data +@ApiModel("任务") +public class Task implements Serializable { + @Serial + private static final long serialVersionUID = -674915748204390789L; + + private TaskAction action; + @ApiModelProperty("任务ID") + private String id; + @ApiModelProperty("提示词") + private String prompt; + @ApiModelProperty("提示词-英文") + private String promptEn; + + @ApiModelProperty("任务描述") + private String description; + @ApiModelProperty("自定义参数") + private String state; + @ApiModelProperty("提交时间") + private Long submitTime; + @ApiModelProperty("开始执行时间") + private Long startTime; + @ApiModelProperty("结束时间") + private Long finishTime; + @ApiModelProperty("图片url") + private String imageUrl; + @ApiModelProperty("任务状态") + private TaskStatus status = TaskStatus.NOT_START; + @ApiModelProperty("任务进度") + private String progress; + @ApiModelProperty("失败原因") + private String failReason; + + // 任务扩展属性,仅支持基本类型 + private Map properties; + + @JsonIgnore + private final transient Object lock = new Object(); + + public void sleep() throws InterruptedException { + synchronized (this.lock) { + this.lock.wait(); + } + } + + public void awake() { + synchronized (this.lock) { + this.lock.notifyAll(); + } + } + + public void start() { + this.startTime = System.currentTimeMillis(); + this.status = TaskStatus.SUBMITTED; + this.progress = "0%"; + } + + public void success() { + this.finishTime = System.currentTimeMillis(); + this.status = TaskStatus.SUCCESS; + this.progress = "100%"; + } + + public void fail(String reason) { + this.finishTime = System.currentTimeMillis(); + this.status = TaskStatus.FAILURE; + this.failReason = reason; + this.progress = ""; + } + + public Task setProperty(String name, Object value) { + getProperties().put(name, value); + return this; + } + + public Task removeProperty(String name) { + getProperties().remove(name); + return this; + } + + public Object getProperty(String name) { + return getProperties().get(name); + } + + @SuppressWarnings("unchecked") + public T getPropertyGeneric(String name) { + return (T) getProperty(name); + } + + public T getProperty(String name, Class clz) { + return clz.cast(getProperty(name)); + } + + public Map getProperties() { + if (this.properties == null) { + this.properties = new HashMap<>(); + } + return this.properties; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/TaskCondition.java b/src/main/java/com/github/novicezk/midjourney/support/TaskCondition.java new file mode 100644 index 0000000000000000000000000000000000000000..047ffe02e3ed3f59f2e40fb81640bc35d2de16fd --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/TaskCondition.java @@ -0,0 +1,70 @@ +package com.github.novicezk.midjourney.support; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.enums.TaskStatus; +import lombok.Data; +import lombok.experimental.Accessors; + +import java.util.Set; +import java.util.function.Predicate; + + +@Data +@Accessors(chain = true) +public class TaskCondition implements Predicate { + private String id; + + private Set statusSet; + private Set actionSet; + + private String prompt; + private String promptEn; + private String description; + + private String finalPromptEn; + private String messageId; + private String progressMessageId; + private String nonce; + + @Override + public boolean test(Task task) { + if (task == null) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.id) && !this.id.equals(task.getId())) { + return false; + } + if (this.statusSet != null && !this.statusSet.isEmpty() && !this.statusSet.contains(task.getStatus())) { + return false; + } + if (this.actionSet != null && !this.actionSet.isEmpty() && !this.actionSet.contains(task.getAction())) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.prompt) && !this.prompt.equals(task.getPrompt())) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.promptEn) && !this.promptEn.equals(task.getPromptEn())) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.description) && !this.description.equals(task.getDescription())) { + return false; + } + + if (CharSequenceUtil.isNotBlank(this.finalPromptEn) && !this.finalPromptEn.equals(task.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT))) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.messageId) && !this.messageId.equals(task.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID))) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.progressMessageId) && !this.progressMessageId.equals(task.getProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID))) { + return false; + } + if (CharSequenceUtil.isNotBlank(this.nonce) && !this.nonce.equals(task.getProperty(Constants.TASK_PROPERTY_NONCE))) { + return false; + } + return true; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/TaskQueueHelper.java b/src/main/java/com/github/novicezk/midjourney/support/TaskQueueHelper.java new file mode 100644 index 0000000000000000000000000000000000000000..5e8212b24ebf42e8c27dbfdbe5c3c9416aeaaa9d --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/TaskQueueHelper.java @@ -0,0 +1,132 @@ +package com.github.novicezk.midjourney.support; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.ReturnCode; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.result.Message; +import com.github.novicezk.midjourney.result.SubmitResultVO; +import com.github.novicezk.midjourney.service.NotifyService; +import com.github.novicezk.midjourney.service.TaskStoreService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.function.Predicate; +import java.util.stream.Stream; + +@Slf4j +@Component +public class TaskQueueHelper { + @Resource + private TaskStoreService taskStoreService; + @Resource + private NotifyService notifyService; + + private final ThreadPoolTaskExecutor taskExecutor; + private final List runningTasks; + private final Map> taskFutureMap = Collections.synchronizedMap(new HashMap<>()); + + public TaskQueueHelper(ProxyProperties properties) { + ProxyProperties.TaskQueueConfig queueConfig = properties.getQueue(); + this.runningTasks = new CopyOnWriteArrayList<>(); + this.taskExecutor = new ThreadPoolTaskExecutor(); + this.taskExecutor.setCorePoolSize(queueConfig.getCoreSize()); + this.taskExecutor.setMaxPoolSize(queueConfig.getCoreSize()); + this.taskExecutor.setQueueCapacity(queueConfig.getQueueSize()); + this.taskExecutor.setThreadNamePrefix("TaskQueue-"); + this.taskExecutor.initialize(); + } + + public Set getQueueTaskIds() { + return this.taskFutureMap.keySet(); + } + + public Task getRunningTask(String id) { + if (CharSequenceUtil.isBlank(id)) { + return null; + } + return this.runningTasks.stream().filter(t -> id.equals(t.getId())).findFirst().orElse(null); + } + + public Task getRunningTaskByNonce(String nonce) { + if (CharSequenceUtil.isBlank(nonce)) { + return null; + } + TaskCondition condition = new TaskCondition().setNonce(nonce); + return findRunningTask(condition).findFirst().orElse(null); + } + + public Stream findRunningTask(Predicate condition) { + return this.runningTasks.stream().filter(condition); + } + + public Future getRunningFuture(String taskId) { + return this.taskFutureMap.get(taskId); + } + + public SubmitResultVO submitTask(Task task, Callable> discordSubmit) { + this.taskStoreService.save(task); + int size; + try { + size = this.taskExecutor.getThreadPoolExecutor().getQueue().size(); + Future future = this.taskExecutor.submit(() -> executeTask(task, discordSubmit)); + this.taskFutureMap.put(task.getId(), future); + } catch (RejectedExecutionException e) { + this.taskStoreService.delete(task.getId()); + return SubmitResultVO.fail(ReturnCode.QUEUE_REJECTED, "队列已满,请稍后尝试"); + } catch (Exception e) { + log.error("submit task error", e); + return SubmitResultVO.fail(ReturnCode.FAILURE, "提交失败,系统异常"); + } + if (size == 0) { + return SubmitResultVO.of(ReturnCode.SUCCESS, "提交成功", task.getId()); + } else { + return SubmitResultVO.of(ReturnCode.IN_QUEUE, "排队中,前面还有" + size + "个任务", task.getId()) + .setProperty("numberOfQueues", size); + } + } + + private void executeTask(Task task, Callable> discordSubmit) { + this.runningTasks.add(task); + try { + task.start(); + Message result = discordSubmit.call(); + if (result.getCode() != ReturnCode.SUCCESS) { + task.fail(result.getDescription()); + saveAndNotify(task); + return; + } + saveAndNotify(task); + do { + task.sleep(); + saveAndNotify(task); + } while (task.getStatus() == TaskStatus.IN_PROGRESS); + log.debug("task finished, id: {}, status: {}", task.getId(), task.getStatus()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (Exception e) { + log.error("task execute error", e); + task.fail("执行错误,系统异常"); + saveAndNotify(task); + } finally { + this.runningTasks.remove(task); + this.taskFutureMap.remove(task.getId()); + } + } + + public void saveAndNotify(Task task) { + this.taskStoreService.save(task); + this.notifyService.notifyTaskChange(task); + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/support/TaskTimeoutSchedule.java b/src/main/java/com/github/novicezk/midjourney/support/TaskTimeoutSchedule.java new file mode 100644 index 0000000000000000000000000000000000000000..b1da68af31e5f6c95be47936fc98f16a1f15df27 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/support/TaskTimeoutSchedule.java @@ -0,0 +1,43 @@ +package com.github.novicezk.midjourney.support; + +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.enums.TaskStatus; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.Set; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +@Slf4j +@Component +@RequiredArgsConstructor +public class TaskTimeoutSchedule { + private final TaskQueueHelper taskQueueHelper; + private final ProxyProperties properties; + + @Scheduled(fixedRate = 30000L) + public void checkTasks() { + long currentTime = System.currentTimeMillis(); + long timeout = TimeUnit.MINUTES.toMillis(this.properties.getQueue().getTimeoutMinutes()); + List tasks = this.taskQueueHelper.findRunningTask(new TaskCondition()) + .filter(t -> currentTime - t.getStartTime() > timeout) + .toList(); + for (Task task : tasks) { + if (Set.of(TaskStatus.FAILURE, TaskStatus.SUCCESS).contains(task.getStatus())) { + log.warn("task status is failure/success but is in the queue, end it. id: {}", task.getId()); + } else { + log.debug("task timeout, id: {}", task.getId()); + task.fail("任务超时"); + } + Future future = this.taskQueueHelper.getRunningFuture(task.getId()); + if (future != null) { + future.cancel(true); + } + this.taskQueueHelper.saveAndNotify(task); + } + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/BannedPromptUtils.java b/src/main/java/com/github/novicezk/midjourney/util/BannedPromptUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..13b155b25f61479c95914010f1f7424639e708bd --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/BannedPromptUtils.java @@ -0,0 +1,43 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.core.io.FileUtil; +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.exception.BannedPromptException; +import lombok.experimental.UtilityClass; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +@UtilityClass +public class BannedPromptUtils { + private static final String BANNED_WORDS_FILE_PATH = "/home/spring/config/banned-words.txt"; + private final List BANNED_WORDS; + + static { + List lines; + File file = new File(BANNED_WORDS_FILE_PATH); + if (file.exists()) { + lines = FileUtil.readLines(file, StandardCharsets.UTF_8); + } else { + var resource = BannedPromptUtils.class.getResource("/banned-words.txt"); + lines = FileUtil.readLines(resource, StandardCharsets.UTF_8); + } + BANNED_WORDS = lines.stream().filter(CharSequenceUtil::isNotBlank).toList(); + } + + public static void checkBanned(String promptEn) throws BannedPromptException { + String finalPromptEn = promptEn.toLowerCase(Locale.ENGLISH); + for (String word : BANNED_WORDS) { + Matcher matcher = Pattern.compile("\\b" + word + "\\b").matcher(finalPromptEn); + if (matcher.find()) { + int index = CharSequenceUtil.indexOfIgnoreCase(promptEn, word); + throw new BannedPromptException(promptEn.substring(index, index + word.length())); + } + } + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/ContentParseData.java b/src/main/java/com/github/novicezk/midjourney/util/ContentParseData.java new file mode 100644 index 0000000000000000000000000000000000000000..4ff1af679ff969b33dca434a462889738cdd509b --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/ContentParseData.java @@ -0,0 +1,9 @@ +package com.github.novicezk.midjourney.util; + +import lombok.Data; + +@Data +public class ContentParseData { + protected String prompt; + protected String status; +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/ConvertUtils.java b/src/main/java/com/github/novicezk/midjourney/util/ConvertUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..e56dc8b7df2522eb45ef6afa96182b13f3d5b5f7 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/ConvertUtils.java @@ -0,0 +1,85 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.enums.TaskAction; +import eu.maxschuster.dataurl.DataUrl; +import eu.maxschuster.dataurl.DataUrlSerializer; +import eu.maxschuster.dataurl.IDataUrlSerializer; +import lombok.experimental.UtilityClass; + +import java.net.MalformedURLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +@UtilityClass +public class ConvertUtils { + /** + * content正则匹配prompt和进度. + */ + public static final String CONTENT_REGEX = ".*?\\*\\*(.*?)\\*\\*.+<@\\d+> \\((.*?)\\)"; + + public static ContentParseData parseContent(String content) { + return parseContent(content, CONTENT_REGEX); + } + + public static ContentParseData parseContent(String content, String regex) { + if (CharSequenceUtil.isBlank(content)) { + return null; + } + Matcher matcher = Pattern.compile(regex).matcher(content); + if (!matcher.find()) { + return null; + } + ContentParseData parseData = new ContentParseData(); + parseData.setPrompt(matcher.group(1)); + parseData.setStatus(matcher.group(2)); + return parseData; + } + + public static List convertBase64Array(List base64Array) throws MalformedURLException { + if (base64Array == null || base64Array.isEmpty()) { + return Collections.emptyList(); + } + IDataUrlSerializer serializer = new DataUrlSerializer(); + List dataUrlList = new ArrayList<>(); + for (String base64 : base64Array) { + DataUrl dataUrl = serializer.unserialize(base64); + dataUrlList.add(dataUrl); + } + return dataUrlList; + } + + public static TaskChangeParams convertChangeParams(String content) { + List split = CharSequenceUtil.split(content, " "); + if (split.size() != 2) { + return null; + } + String action = split.get(1).toLowerCase(); + TaskChangeParams changeParams = new TaskChangeParams(); + changeParams.setId(split.get(0)); + if (action.charAt(0) == 'u') { + changeParams.setAction(TaskAction.UPSCALE); + } else if (action.charAt(0) == 'v') { + changeParams.setAction(TaskAction.VARIATION); + } else if (action.equals("r")) { + changeParams.setAction(TaskAction.REROLL); + return changeParams; + } else { + return null; + } + try { + int index = Integer.parseInt(action.substring(1, 2)); + if (index < 1 || index > 4) { + return null; + } + changeParams.setIndex(index); + } catch (Exception e) { + return null; + } + return changeParams; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/MimeTypeUtils.java b/src/main/java/com/github/novicezk/midjourney/util/MimeTypeUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..6d7108372bb23c553d95ecaef28ac999c8bb36f2 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/MimeTypeUtils.java @@ -0,0 +1,45 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.core.io.FileUtil; +import cn.hutool.core.text.CharSequenceUtil; +import lombok.experimental.UtilityClass; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@UtilityClass +public class MimeTypeUtils { + private final Map> MIME_TYPE_MAP; + + static { + MIME_TYPE_MAP = new HashMap<>(); + var resource = MimeTypeUtils.class.getResource("/mime.types"); + var lines = FileUtil.readLines(resource, StandardCharsets.UTF_8); + for (var line : lines) { + if (CharSequenceUtil.isBlank(line)) { + continue; + } + var arr = line.split(":"); + MIME_TYPE_MAP.put(arr[0], CharSequenceUtil.split(arr[1], ' ')); + } + } + + public static String guessFileSuffix(String mimeType) { + if (CharSequenceUtil.isBlank(mimeType)) { + return null; + } + String key = mimeType; + if (!MIME_TYPE_MAP.containsKey(key)) { + key = MIME_TYPE_MAP.keySet().stream().filter(k -> CharSequenceUtil.startWithIgnoreCase(mimeType, k)) + .findFirst().orElse(null); + } + var suffixList = MIME_TYPE_MAP.get(key); + if (suffixList == null || suffixList.isEmpty()) { + return null; + } + return suffixList.iterator().next(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/SnowFlake.java b/src/main/java/com/github/novicezk/midjourney/util/SnowFlake.java new file mode 100644 index 0000000000000000000000000000000000000000..5b59c6ff4aac69d5167a5e84eabd66627b783bf9 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/SnowFlake.java @@ -0,0 +1,152 @@ +package com.github.novicezk.midjourney.util; + +import cn.hutool.core.exceptions.ValidateException; +import com.github.novicezk.midjourney.exception.SnowFlakeException; +import lombok.extern.slf4j.Slf4j; + +import java.lang.management.ManagementFactory; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.util.Date; +import java.util.concurrent.ThreadLocalRandom; + +@Slf4j +public class SnowFlake { + private long workerId; + private long datacenterId; + private long sequence = 0L; + private final long twepoch; + private final long sequenceMask; + private final long workerIdShift; + private final long datacenterIdShift; + private final long timestampLeftShift; + private long lastTimestamp = -1L; + private final boolean randomSequence; + private long count = 0L; + private final long timeOffset; + private final ThreadLocalRandom tlr = ThreadLocalRandom.current(); + + public static final SnowFlake INSTANCE = new SnowFlake(); + + private SnowFlake() { + this(false, 10, null, 5L, 5L, 12L); + } + + private SnowFlake(boolean randomSequence, long timeOffset, Date epochDate, long workerIdBits, long datacenterIdBits, long sequenceBits) { + if (null != epochDate) { + this.twepoch = epochDate.getTime(); + } else { + // 2012/12/12 23:59:59 GMT + this.twepoch = 1355327999000L; + } + long maxWorkerId = ~(-1L << workerIdBits); + long maxDatacenterId = ~(-1L << datacenterIdBits); + this.sequenceMask = ~(-1L << sequenceBits); + this.workerIdShift = sequenceBits; + this.datacenterIdShift = sequenceBits + workerIdBits; + this.timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits; + this.randomSequence = randomSequence; + this.timeOffset = timeOffset; + try { + this.datacenterId = getDatacenterId(maxDatacenterId); + this.workerId = getMaxWorkerId(datacenterId, maxWorkerId); + } catch (Exception e) { + log.warn("datacenterId or workerId generate error: {}, set default value", e.getMessage()); + this.datacenterId = 4; + this.workerId = 1; + } + } + + public synchronized String nextId() { + long currentTimestamp = timeGen(); + if (currentTimestamp < this.lastTimestamp) { + long offset = this.lastTimestamp - currentTimestamp; + if (offset > this.timeOffset) { + throw new ValidateException("Clock moved backwards, refusing to generate id for [" + offset + "ms]"); + } + try { + this.wait(offset << 1); + } catch (InterruptedException e) { + throw new SnowFlakeException(e); + } + currentTimestamp = timeGen(); + if (currentTimestamp < this.lastTimestamp) { + throw new SnowFlakeException("Clock moved backwards, refusing to generate id for [" + offset + "ms]"); + } + } + if (this.lastTimestamp == currentTimestamp) { + long tempSequence = this.sequence + 1; + if (this.randomSequence) { + this.sequence = tempSequence & this.sequenceMask; + this.count = (this.count + 1) & this.sequenceMask; + if (this.count == 0) { + currentTimestamp = this.tillNextMillis(this.lastTimestamp); + } + } else { + this.sequence = tempSequence & this.sequenceMask; + if (this.sequence == 0) { + currentTimestamp = this.tillNextMillis(lastTimestamp); + } + } + } else { + this.sequence = this.randomSequence ? this.tlr.nextLong(this.sequenceMask + 1) : 0L; + this.count = 0L; + } + this.lastTimestamp = currentTimestamp; + long id = ((currentTimestamp - this.twepoch) << this.timestampLeftShift) | + (this.datacenterId << this.datacenterIdShift) | + (this.workerId << this.workerIdShift) | + this.sequence; + return String.valueOf(id); + } + + private static long getDatacenterId(long maxDatacenterId) { + long id = 0L; + try { + InetAddress ip = InetAddress.getLocalHost(); + NetworkInterface network = NetworkInterface.getByInetAddress(ip); + if (network == null) { + id = 1L; + } else { + byte[] mac = network.getHardwareAddress(); + if (null != mac) { + id = ((0x000000FF & (long) mac[mac.length - 1]) | (0x0000FF00 & (((long) mac[mac.length - 2]) << 8))) >> 6; + id = id % (maxDatacenterId + 1); + } + } + } catch (Exception e) { + throw new SnowFlakeException(e); + } + return id; + } + + private static long getMaxWorkerId(long datacenterId, long maxWorkerId) { + StringBuilder macIpPid = new StringBuilder(); + macIpPid.append(datacenterId); + try { + String name = ManagementFactory.getRuntimeMXBean().getName(); + if (name != null && !name.isEmpty()) { + macIpPid.append(name.split("@")[0]); + } + String hostIp = InetAddress.getLocalHost().getHostAddress(); + String ipStr = hostIp.replace("\\.", ""); + macIpPid.append(ipStr); + } catch (Exception e) { + throw new SnowFlakeException(e); + } + return (macIpPid.toString().hashCode() & 0xffff) % (maxWorkerId + 1); + } + + private long tillNextMillis(long lastTimestamp) { + long timestamp = timeGen(); + while (timestamp <= lastTimestamp) { + timestamp = timeGen(); + } + return timestamp; + } + + private long timeGen() { + return System.currentTimeMillis(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/TaskChangeParams.java b/src/main/java/com/github/novicezk/midjourney/util/TaskChangeParams.java new file mode 100644 index 0000000000000000000000000000000000000000..94f234ff3f5c21a6ec28681faaff4305d7f0adeb --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/TaskChangeParams.java @@ -0,0 +1,11 @@ +package com.github.novicezk.midjourney.util; + +import com.github.novicezk.midjourney.enums.TaskAction; +import lombok.Data; + +@Data +public class TaskChangeParams { + private String id; + private TaskAction action; + private Integer index; +} diff --git a/src/main/java/com/github/novicezk/midjourney/util/UVContentParseData.java b/src/main/java/com/github/novicezk/midjourney/util/UVContentParseData.java new file mode 100644 index 0000000000000000000000000000000000000000..377bb32caec681f2e2562da4fd327b329547d755 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/util/UVContentParseData.java @@ -0,0 +1,10 @@ +package com.github.novicezk.midjourney.util; + +import lombok.Data; +import lombok.EqualsAndHashCode; + +@Data +@EqualsAndHashCode(callSuper = true) +public class UVContentParseData extends ContentParseData { + protected Integer index; +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/WebSocketStarter.java b/src/main/java/com/github/novicezk/midjourney/wss/WebSocketStarter.java new file mode 100644 index 0000000000000000000000000000000000000000..c4f4da5798a6cf21384627d14f94ec378ffc5abe --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/WebSocketStarter.java @@ -0,0 +1,32 @@ +package com.github.novicezk.midjourney.wss; + +import com.github.novicezk.midjourney.ProxyProperties; +import com.neovisionaries.ws.client.ProxySettings; +import com.neovisionaries.ws.client.WebSocketFactory; +import org.apache.logging.log4j.util.Strings; + +public interface WebSocketStarter { + + void start() throws Exception; + + default void initProxy(ProxyProperties properties) { + ProxyProperties.ProxyConfig proxy = properties.getProxy(); + if (Strings.isNotBlank(proxy.getHost())) { + System.setProperty("http.proxyHost", proxy.getHost()); + System.setProperty("http.proxyPort", String.valueOf(proxy.getPort())); + System.setProperty("https.proxyHost", proxy.getHost()); + System.setProperty("https.proxyPort", String.valueOf(proxy.getPort())); + } + } + + default WebSocketFactory createWebSocketFactory(ProxyProperties properties) { + ProxyProperties.ProxyConfig proxy = properties.getProxy(); + WebSocketFactory webSocketFactory = new WebSocketFactory().setConnectionTimeout(10000); + if (Strings.isNotBlank(proxy.getHost())) { + ProxySettings proxySettings = webSocketFactory.getProxySettings(); + proxySettings.setHost(proxy.getHost()); + proxySettings.setPort(proxy.getPort()); + } + return webSocketFactory; + } +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/BlendSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/BlendSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..42dea1dda0c4ce6b908f8dc9c9000f7fb34fd187 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/BlendSuccessHandler.java @@ -0,0 +1,33 @@ +package com.github.novicezk.midjourney.wss.handle; + + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; + +/** + * blend消息处理. + * 完成(create): ** --v 5.1** - <@1012983546824114217> (relaxed) + */ +@Component +public class BlendSuccessHandler extends MessageHandler { + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = ConvertUtils.parseContent(content); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.BLEND)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/DescribeSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/DescribeSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..629c5524313ef1cccebf24025626eb73ef50eaac --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/DescribeSuccessHandler.java @@ -0,0 +1,51 @@ +package com.github.novicezk.midjourney.wss.handle; + +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.service.TranslateService; +import com.github.novicezk.midjourney.support.Task; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.Optional; + +/** + * describe消息处理. + */ +@Component +public class DescribeSuccessHandler extends MessageHandler { + @Autowired + private TranslateService translateService; + + @Override + public void handle(MessageType messageType, DataObject message) { + Optional interaction = message.optObject("interaction"); + if (!MessageType.UPDATE.equals(messageType) || interaction.isEmpty() || !"describe".equals(interaction.get().getString("name"))) { + return; + } + DataArray embeds = message.getArray("embeds"); + if (embeds.isEmpty()) { + return; + } + String description = embeds.getObject(0).getString("description"); + Optional imageOptional = embeds.getObject(0).optObject("image"); + if (imageOptional.isEmpty()) { + return; + } + String imageUrl = imageOptional.get().getString("url"); + String taskId = this.discordHelper.findTaskIdWithCdnUrl(imageUrl); + Task task = this.taskQueueHelper.getRunningTask(taskId); + if (task == null) { + return; + } + task.setPrompt(description); + task.setPromptEn(description); + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, description); + task.setImageUrl(replaceCdnUrl(imageUrl)); + finishTask(task, message); + task.awake(); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/ErrorMessageHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/ErrorMessageHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..8e1a8baef226cb499f2fa6763820272dd359a6b8 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/ErrorMessageHandler.java @@ -0,0 +1,68 @@ +package com.github.novicezk.midjourney.wss.handle; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import lombok.extern.slf4j.Slf4j; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; +import java.util.Optional; +import java.util.Set; + +@Slf4j +@Component +public class ErrorMessageHandler extends MessageHandler { + @Resource + protected ProxyProperties properties; + + @Override + public void handle(MessageType messageType, DataObject message) { + Optional embedsOptional = message.optArray("embeds"); + if (!MessageType.CREATE.equals(messageType) || embedsOptional.isEmpty() || embedsOptional.get().isEmpty()) { + return; + } + DataObject embed = embedsOptional.get().getObject(0); + String title = embed.getString("title", null); + String description = embed.getString("description", null); + String footerText = ""; + Optional footer = embed.optObject("footer"); + if (footer.isPresent()) { + footerText = footer.get().getString("text", ""); + } + String channelId = message.getString("channel_id", ""); + int color = embed.getInt("color", 0); + if (color == 16239475) { + log.warn("{} - MJ警告信息: {}\n{}\nfooter: {}", channelId, title, description, footerText); + } else if (color == 16711680) { + log.error("{} - MJ异常信息: {}\n{}\nfooter: {}", channelId, title, description, footerText); + String nonce = getMessageNonce(message); + Task task = this.taskQueueHelper.getRunningTaskByNonce(nonce); + if (task != null) { + task.fail("[" + title + "] " + description); + task.awake(); + } + } else if (CharSequenceUtil.contains(title, "Invalid link")) { + // 兼容 Invalid link! 错误 + log.error("{} - MJ异常信息: {}\n{}\nfooter: {}", channelId, title, description, footerText); + DataObject messageReference = message.optObject("message_reference").orElse(DataObject.empty()); + String referenceMessageId = messageReference.getString("message_id", ""); + if (CharSequenceUtil.isBlank(referenceMessageId)) { + return; + } + TaskCondition condition = new TaskCondition().setStatusSet(Set.of(TaskStatus.IN_PROGRESS)) + .setProgressMessageId(referenceMessageId); + Task task = this.taskQueueHelper.findRunningTask(condition).findFirst().orElse(null); + if (task != null) { + task.fail("[" + title + "] " + description); + task.awake(); + } + } + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/ImagineSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/ImagineSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..68baafe9861fea39ddbb4f42a4863fa64b7cb222 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/ImagineSuccessHandler.java @@ -0,0 +1,34 @@ +package com.github.novicezk.midjourney.wss.handle; + + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; + +/** + * imagine消息处理. + * 完成(create): **cat** - <@1012983546824114217> (relaxed) + */ +@Component +public class ImagineSuccessHandler extends MessageHandler { + private static final String CONTENT_REGEX = "\\*\\*(.*?)\\*\\* - <@\\d+> \\((.*?)\\)"; + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = ConvertUtils.parseContent(content, CONTENT_REGEX); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.IMAGINE)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/MessageHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/MessageHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..a8f9d5f3b0ef97e15eaf5fe04d13f2f67ff94a33 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/MessageHandler.java @@ -0,0 +1,82 @@ +package com.github.novicezk.midjourney.wss.handle; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.support.DiscordHelper; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.support.TaskQueueHelper; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; + +import javax.annotation.Resource; +import java.util.Comparator; + +public abstract class MessageHandler { + @Resource + protected TaskQueueHelper taskQueueHelper; + @Resource + protected DiscordHelper discordHelper; + + public abstract void handle(MessageType messageType, DataObject message); + + protected String getMessageContent(DataObject message) { + return message.hasKey("content") ? message.getString("content") : ""; + } + + protected String getMessageNonce(DataObject message) { + return message.hasKey("nonce") ? message.getString("nonce") : ""; + } + + protected void findAndFinishImageTask(TaskCondition condition, String finalPrompt, DataObject message) { + Task task = this.taskQueueHelper.findRunningTask(condition) + .max(Comparator.comparing(Task::getProgress)) + .orElse(null); + if (task == null) { + return; + } + task.setImageUrl(getImageUrl(message)); + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, finalPrompt); + finishTask(task, message); + task.awake(); + } + + protected void finishTask(Task task, DataObject message) { + task.setProperty(Constants.TASK_PROPERTY_MESSAGE_ID, message.getString("id")); + task.setProperty(Constants.TASK_PROPERTY_MESSAGE_HASH, getMessageHash(task.getImageUrl())); + task.setProperty(Constants.TASK_PROPERTY_FLAGS, message.getInt("flags", 0)); + task.success(); + } + + protected boolean hasImage(DataObject message) { + DataArray attachments = message.optArray("attachments").orElse(DataArray.empty()); + return !attachments.isEmpty(); + } + + protected String getImageUrl(DataObject message) { + DataArray attachments = message.getArray("attachments"); + if (!attachments.isEmpty()) { + String imageUrl = attachments.getObject(0).getString("url"); + return replaceCdnUrl(imageUrl); + } + return null; + } + + protected String replaceCdnUrl(String imageUrl) { + if (CharSequenceUtil.isBlank(imageUrl)) { + return imageUrl; + } + String cdn = this.discordHelper.getCdn(); + if (CharSequenceUtil.startWith(imageUrl, cdn)) { + return imageUrl; + } + return CharSequenceUtil.replaceFirst(imageUrl, DiscordHelper.DISCORD_CDN_URL, cdn); + } + + protected String getMessageHash(String imageUrl) { + int hashStartIndex = imageUrl.lastIndexOf("_"); + return CharSequenceUtil.subBefore(imageUrl.substring(hashStartIndex + 1), ".", true); + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/RerollSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/RerollSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..96e0c17b9b62727655f7c84a66f47ee2b40207fc --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/RerollSuccessHandler.java @@ -0,0 +1,49 @@ +package com.github.novicezk.midjourney.wss.handle; + + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; + +/** + * reroll 消息处理. + * 完成(create): **cat** - <@1012983546824114217> (relaxed) + * 完成(create): **cat** - Variations by <@1012983546824114217> (relaxed) + * 完成(create): **cat** - Variations (Strong或Subtle) by <@1012983546824114217> (relaxed) + */ +@Component +public class RerollSuccessHandler extends MessageHandler { + private static final String CONTENT_REGEX_1 = "\\*\\*(.*?)\\*\\* - <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_2 = "\\*\\*(.*?)\\*\\* - Variations by <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_3 = "\\*\\*(.*?)\\*\\* - Variations \\(.*?\\) by <@\\d+> \\((.*?)\\)"; + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = getParseData(content); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.REROLL)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + + private ContentParseData getParseData(String content) { + ContentParseData parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_1); + if (parseData == null) { + parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_2); + } + if (parseData == null) { + parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_3); + } + return parseData; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/StartAndProgressHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/StartAndProgressHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..daff7184eedbc2c9f1c0191f396b2dd333f2d1b2 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/StartAndProgressHandler.java @@ -0,0 +1,70 @@ +package com.github.novicezk.midjourney.wss.handle; + + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.Constants; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskStatus; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import lombok.extern.slf4j.Slf4j; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Optional; +import java.util.Set; + +@Slf4j +@Component +public class StartAndProgressHandler extends MessageHandler { + + @Override + public void handle(MessageType messageType, DataObject message) { + String nonce = getMessageNonce(message); + String content = getMessageContent(message); + ContentParseData parseData = ConvertUtils.parseContent(content); + if (MessageType.CREATE.equals(messageType) && CharSequenceUtil.isNotBlank(nonce)) { + if (isError(message)) { + return; + } + // 任务开始 + Task task = this.taskQueueHelper.getRunningTaskByNonce(nonce); + if (task == null) { + return; + } + task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, message.getString("id")); + // 兼容少数content为空的场景 + if (parseData != null) { + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, parseData.getPrompt()); + } + task.setStatus(TaskStatus.IN_PROGRESS); + task.awake(); + } else if (MessageType.UPDATE.equals(messageType) && parseData != null) { + // 任务进度 + TaskCondition condition = new TaskCondition().setStatusSet(Set.of(TaskStatus.IN_PROGRESS)) + .setProgressMessageId(message.getString("id")); + Task task = this.taskQueueHelper.findRunningTask(condition).findFirst().orElse(null); + if (task == null) { + return; + } + task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, parseData.getPrompt()); + task.setStatus(TaskStatus.IN_PROGRESS); + task.setProgress(parseData.getStatus()); + task.setImageUrl(getImageUrl(message)); + task.awake(); + } + } + + private boolean isError(DataObject message) { + Optional embedsOptional = message.optArray("embeds"); + if (embedsOptional.isEmpty() || embedsOptional.get().isEmpty()) { + return false; + } + DataObject embed = embedsOptional.get().getObject(0); + return embed.getInt("color", 0) == 16711680; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/UpscaleSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/UpscaleSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..5c05d9973e468cb5f191e8c78c6a6fdb2c70bad4 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/UpscaleSuccessHandler.java @@ -0,0 +1,57 @@ +package com.github.novicezk.midjourney.wss.handle; + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * upscale消息处理. + * 完成(create): **cat** - Upscaled (Beta或Light) by <@1083152202048217169> (fast) + * 完成(create): **cat** - Upscaled by <@1083152202048217169> (fast) + * 完成(create): **cat** - Image #1 <@1012983546824114217> + */ +@Component +public class UpscaleSuccessHandler extends MessageHandler { + private static final String CONTENT_REGEX_1 = "\\*\\*(.*?)\\*\\* - Upscaled \\(.*?\\) by <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_2 = "\\*\\*(.*?)\\*\\* - Upscaled by <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_3 = "\\*\\*(.*?)\\*\\* - Image #\\d <@\\d+>"; + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = getParseData(content); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.UPSCALE)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + + private ContentParseData getParseData(String content) { + ContentParseData parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_1); + if (parseData == null) { + parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_2); + } + if (parseData != null) { + return parseData; + } + Matcher matcher = Pattern.compile(CONTENT_REGEX_3).matcher(content); + if (!matcher.find()) { + return null; + } + parseData = new ContentParseData(); + parseData.setPrompt(matcher.group(1)); + parseData.setStatus("done"); + return parseData; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/handle/VariationSuccessHandler.java b/src/main/java/com/github/novicezk/midjourney/wss/handle/VariationSuccessHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..fdf44e4b599ed74cd0b9b7a1e086f0bbbd038c42 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/handle/VariationSuccessHandler.java @@ -0,0 +1,43 @@ +package com.github.novicezk.midjourney.wss.handle; + +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.enums.TaskAction; +import com.github.novicezk.midjourney.support.TaskCondition; +import com.github.novicezk.midjourney.util.ContentParseData; +import com.github.novicezk.midjourney.util.ConvertUtils; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.stereotype.Component; + +import java.util.Set; + +/** + * variation消息处理. + * 完成(create): **cat** - Variations (Strong或Subtle) by <@1012983546824114217> (relaxed) + * 完成(create): **cat** - Variations by <@1012983546824114217> (relaxed) + */ +@Component +public class VariationSuccessHandler extends MessageHandler { + private static final String CONTENT_REGEX_1 = "\\*\\*(.*?)\\*\\* - Variations by <@\\d+> \\((.*?)\\)"; + private static final String CONTENT_REGEX_2 = "\\*\\*(.*?)\\*\\* - Variations \\(.*?\\) by <@\\d+> \\((.*?)\\)"; + + @Override + public void handle(MessageType messageType, DataObject message) { + String content = getMessageContent(message); + ContentParseData parseData = getParseData(content); + if (MessageType.CREATE.equals(messageType) && parseData != null && hasImage(message)) { + TaskCondition condition = new TaskCondition() + .setActionSet(Set.of(TaskAction.VARIATION)) + .setFinalPromptEn(parseData.getPrompt()); + findAndFinishImageTask(condition, parseData.getPrompt(), message); + } + } + + private ContentParseData getParseData(String content) { + ContentParseData parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_1); + if (parseData == null) { + parseData = ConvertUtils.parseContent(content, CONTENT_REGEX_2); + } + return parseData; + } + +} diff --git a/src/main/java/com/github/novicezk/midjourney/wss/user/UserMessageListener.java b/src/main/java/com/github/novicezk/midjourney/wss/user/UserMessageListener.java new file mode 100644 index 0000000000000000000000000000000000000000..ebc8ef9e25d86bf70983e72649c44953411db66c --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/user/UserMessageListener.java @@ -0,0 +1,52 @@ +package com.github.novicezk.midjourney.wss.user; + + +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.enums.MessageType; +import com.github.novicezk.midjourney.wss.handle.MessageHandler; +import lombok.extern.slf4j.Slf4j; +import net.dv8tion.jda.api.utils.data.DataObject; +import org.springframework.boot.context.event.ApplicationStartedEvent; +import org.springframework.context.ApplicationListener; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; +import java.util.ArrayList; +import java.util.List; + +@Slf4j +@Component +public class UserMessageListener implements ApplicationListener { + @Resource + private ProxyProperties properties; + private final List messageHandlers = new ArrayList<>(); + + @Override + public void onApplicationEvent(ApplicationStartedEvent event) { + this.messageHandlers.addAll(event.getApplicationContext().getBeansOfType(MessageHandler.class).values()); + } + + public void onMessage(DataObject raw) { + MessageType messageType = MessageType.of(raw.getString("t")); + if (messageType == null || MessageType.DELETE == messageType) { + return; + } + DataObject data = raw.getObject("d"); + if (ignoreAndLogMessage(data, messageType)) { + return; + } + for (MessageHandler messageHandler : this.messageHandlers) { + messageHandler.handle(messageType, data); + } + } + + private boolean ignoreAndLogMessage(DataObject data, MessageType messageType) { + String channelId = data.getString("channel_id"); + if (!this.properties.getDiscord().getChannelId().equals(channelId)) { + return true; + } + String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System"); + log.debug("{} - {}: {}", messageType.name(), authorName, data.opt("content").orElse("")); + return false; + } +} \ No newline at end of file diff --git a/src/main/java/com/github/novicezk/midjourney/wss/user/UserWebSocketStarter.java b/src/main/java/com/github/novicezk/midjourney/wss/user/UserWebSocketStarter.java new file mode 100644 index 0000000000000000000000000000000000000000..4ee81928cc2984ca214b5f6b09f53bec88fee382 --- /dev/null +++ b/src/main/java/com/github/novicezk/midjourney/wss/user/UserWebSocketStarter.java @@ -0,0 +1,244 @@ +package com.github.novicezk.midjourney.wss.user; + +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.thread.ThreadUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.support.DiscordHelper; +import com.github.novicezk.midjourney.wss.WebSocketStarter; +import com.neovisionaries.ws.client.WebSocket; +import com.neovisionaries.ws.client.WebSocketAdapter; +import com.neovisionaries.ws.client.WebSocketFactory; +import com.neovisionaries.ws.client.WebSocketFrame; +import eu.bitwalker.useragentutils.UserAgent; +import lombok.extern.slf4j.Slf4j; +import net.dv8tion.jda.api.utils.data.DataArray; +import net.dv8tion.jda.api.utils.data.DataObject; +import net.dv8tion.jda.api.utils.data.DataType; +import net.dv8tion.jda.internal.requests.WebSocketCode; +import net.dv8tion.jda.internal.utils.compress.Decompressor; +import net.dv8tion.jda.internal.utils.compress.ZlibDecompressor; + +import javax.annotation.Resource; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +@Slf4j +public class UserWebSocketStarter extends WebSocketAdapter implements WebSocketStarter { + private static final int CONNECT_RETRY_LIMIT = 3; + + private final String userToken; + private final String userAgent; + private final DataObject auth; + + private ScheduledExecutorService heartExecutor; + private WebSocket socket = null; + private String sessionId; + private Future heartbeatTask; + private Decompressor decompressor; + + private boolean connected = false; + private final AtomicInteger sequence = new AtomicInteger(0); + + @Resource + private UserMessageListener userMessageListener; + @Resource + private DiscordHelper discordHelper; + + private final ProxyProperties properties; + + public UserWebSocketStarter(ProxyProperties properties) { + initProxy(properties); + this.properties = properties; + this.userToken = properties.getDiscord().getUserToken(); + this.userAgent = properties.getDiscord().getUserAgent(); + this.auth = createAuthData(); + } + + @Override + public synchronized void start() throws Exception { + this.decompressor = new ZlibDecompressor(2048); + this.heartExecutor = Executors.newSingleThreadScheduledExecutor(); + WebSocketFactory webSocketFactory = createWebSocketFactory(this.properties); + this.socket = webSocketFactory.createSocket(this.discordHelper.getWss() + "/?encoding=json&v=9&compress=zlib-stream"); + this.socket.addListener(this); + this.socket.addHeader("Accept-Encoding", "gzip, deflate, br").addHeader("Accept-Language", "en-US,en;q=0.9") + .addHeader("Cache-Control", "no-cache").addHeader("Pragma", "no-cache") + .addHeader("Sec-WebSocket-Extensions", "permessage-deflate; client_max_window_bits") + .addHeader("User-Agent", this.userAgent); + this.socket.connect(); + } + + @Override + public void onConnected(WebSocket websocket, Map> headers) { + log.debug("[gateway] Connected to websocket."); + this.connected = true; + } + + @Override + public void onBinaryMessage(WebSocket websocket, byte[] binary) throws Exception { + byte[] decompressBinary = this.decompressor.decompress(binary); + if (decompressBinary == null) { + return; + } + String json = new String(decompressBinary, StandardCharsets.UTF_8); + DataObject data = DataObject.fromJson(json); + int opCode = data.getInt("op"); + if (opCode != WebSocketCode.HEARTBEAT_ACK) { + this.sequence.incrementAndGet(); + } + if (opCode == WebSocketCode.HELLO) { + if (this.heartbeatTask == null && this.heartExecutor != null) { + long interval = data.getObject("d").getLong("heartbeat_interval"); + this.heartbeatTask = + this.heartExecutor.scheduleAtFixedRate(this::heartbeat, interval, interval, TimeUnit.MILLISECONDS); + } + sayHello(); + } else if (opCode == WebSocketCode.HEARTBEAT_ACK) { + log.trace("[gateway] Heartbeat ack."); + } else if (opCode == WebSocketCode.HEARTBEAT) { + send(DataObject.empty().put("op", WebSocketCode.HEARTBEAT).put("d", this.sequence)); + } else if (opCode == WebSocketCode.INVALIDATE_SESSION) { + log.debug("[gateway] Invalid session."); + close("session invalid"); + } else if (opCode == WebSocketCode.RECONNECT) { + log.debug("[gateway] Received opcode 7 (reconnect)."); + close("reconnect"); + } else if (opCode == WebSocketCode.DISPATCH) { + onDispatch(data); + } + } + + @Override + public void onDisconnected(WebSocket websocket, WebSocketFrame serverCloseFrame, WebSocketFrame clientCloseFrame, + boolean closedByServer) { + reset(); + int code = 1000; + String closeReason = ""; + if (clientCloseFrame != null) { + code = clientCloseFrame.getCloseCode(); + closeReason = clientCloseFrame.getCloseReason(); + } else if (serverCloseFrame != null) { + code = serverCloseFrame.getCloseCode(); + closeReason = serverCloseFrame.getCloseReason(); + } + if (code >= 4010 || code == 4004) { + log.warn("[gateway] Websocket closed and can't reconnect! code: {}, reason: {}", code, closeReason); + System.exit(code); + return; + } + log.warn("[gateway] Websocket closed and will be reconnect... code: {}, reason: {}", code, closeReason); + ThreadUtil.execute(() -> { + try { + retryStart(0); + } catch (Exception e) { + log.error("[gateway] Websocket reconnect error", e); + System.exit(1); + } + }); + } + + private void retryStart(int currentRetryTime) throws Exception { + try { + start(); + } catch (Exception e) { + if (currentRetryTime < CONNECT_RETRY_LIMIT) { + currentRetryTime++; + log.warn("[gateway] Websocket start fail, retry {} time... error: {}", currentRetryTime, + e.getMessage()); + Thread.sleep(5000L); + retryStart(currentRetryTime); + } else { + throw e; + } + } + } + + @Override + public void handleCallbackError(WebSocket websocket, Throwable cause) throws Exception { + log.error("[gateway] There was some websocket error", cause); + } + + private void sayHello() { + DataObject data; + if (CharSequenceUtil.isBlank(this.sessionId)) { + data = DataObject.empty().put("op", WebSocketCode.IDENTIFY).put("d", this.auth); + log.trace("[gateway] Say hello: identify"); + } else { + data = DataObject.empty().put("op", WebSocketCode.RESUME).put("d", + DataObject.empty().put("token", this.userToken).put("session_id", this.sessionId).put("seq", + Math.max(this.sequence.get() - 1, 0))); + log.trace("[gateway] Say hello: resume"); + } + send(data); + } + + private void close(String reason) { + this.connected = false; + this.socket.disconnect(1000, reason); + } + + private void reset() { + this.connected = false; + this.sessionId = null; + this.sequence.set(0); + this.decompressor = null; + this.socket = null; + if (this.heartbeatTask != null) { + this.heartbeatTask.cancel(true); + this.heartbeatTask = null; + } + } + + private void heartbeat() { + if (!this.connected) { + return; + } + send(DataObject.empty().put("op", WebSocketCode.HEARTBEAT).put("d", this.sequence)); + } + + private void onDispatch(DataObject raw) { + if (!raw.isType("d", DataType.OBJECT)) { + return; + } + DataObject content = raw.getObject("d"); + String t = raw.getString("t", null); + if ("READY".equals(t)) { + this.sessionId = content.getString("session_id"); + return; + } + try { + this.userMessageListener.onMessage(raw); + } catch (Exception e) { + log.error("user-wss handle message error", e); + } + } + + protected void send(DataObject message) { + log.trace("[gateway] > {}", message); + this.socket.sendText(message.toString()); + } + + private DataObject createAuthData() { + UserAgent agent = UserAgent.parseUserAgentString(this.userAgent); + DataObject connectionProperties = DataObject.empty().put("os", agent.getOperatingSystem().getName()) + .put("browser", agent.getBrowser().getGroup().getName()).put("device", "").put("system_locale", "zh-CN") + .put("browser_version", agent.getBrowserVersion().toString()).put("browser_user_agent", this.userAgent) + .put("referer", "").put("referring_domain", "").put("referrer_current", "") + .put("referring_domain_current", "").put("release_channel", "stable").put("client_build_number", 117300) + .put("client_event_source", null); + DataObject presence = DataObject.empty().put("status", "online").put("since", 0) + .put("activities", DataArray.empty()).put("afk", false); + DataObject clientState = DataObject.empty().put("guild_hashes", DataArray.empty()).put("highest_last_message_id", "0") + .put("read_state_version", 0).put("user_guild_settings_version", -1).put("user_settings_version", -1); + return DataObject.empty().put("token", this.userToken).put("capabilities", 4093) + .put("properties", connectionProperties).put("presence", presence).put("compress", false) + .put("client_state", clientState); + } + +} diff --git a/src/main/java/spring/config/BeanConfig.java b/src/main/java/spring/config/BeanConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..f283b9ab8bd32b0dc07dc2537224d5565575f2ab --- /dev/null +++ b/src/main/java/spring/config/BeanConfig.java @@ -0,0 +1,66 @@ +package spring.config; + +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.service.TaskStoreService; +import com.github.novicezk.midjourney.service.TranslateService; +import com.github.novicezk.midjourney.service.store.InMemoryTaskStoreServiceImpl; +import com.github.novicezk.midjourney.service.store.RedisTaskStoreServiceImpl; +import com.github.novicezk.midjourney.service.translate.BaiduTranslateServiceImpl; +import com.github.novicezk.midjourney.service.translate.GPTTranslateServiceImpl; +import com.github.novicezk.midjourney.service.translate.NoTranslateServiceImpl; +import com.github.novicezk.midjourney.support.Task; +import com.github.novicezk.midjourney.wss.WebSocketStarter; +import com.github.novicezk.midjourney.wss.user.UserWebSocketStarter; +import org.springframework.boot.ApplicationRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; +import org.springframework.data.redis.serializer.StringRedisSerializer; + +import java.time.Duration; + +@Configuration +public class BeanConfig { + + @Bean + TranslateService translateService(ProxyProperties properties) { + return switch (properties.getTranslateWay()) { + case BAIDU -> new BaiduTranslateServiceImpl(properties.getBaiduTranslate()); + case GPT -> new GPTTranslateServiceImpl(properties); + default -> new NoTranslateServiceImpl(); + }; + } + + @Bean + TaskStoreService taskStoreService(ProxyProperties proxyProperties, RedisConnectionFactory redisConnectionFactory) { + ProxyProperties.TaskStore.Type type = proxyProperties.getTaskStore().getType(); + Duration timeout = proxyProperties.getTaskStore().getTimeout(); + return switch (type) { + case IN_MEMORY -> new InMemoryTaskStoreServiceImpl(timeout); + case REDIS -> new RedisTaskStoreServiceImpl(timeout, taskRedisTemplate(redisConnectionFactory)); + }; + } + + @Bean + RedisTemplate taskRedisTemplate(RedisConnectionFactory redisConnectionFactory) { + RedisTemplate redisTemplate = new RedisTemplate<>(); + redisTemplate.setConnectionFactory(redisConnectionFactory); + redisTemplate.setKeySerializer(new StringRedisSerializer()); + redisTemplate.setHashKeySerializer(new StringRedisSerializer()); + redisTemplate.setValueSerializer(new Jackson2JsonRedisSerializer<>(Task.class)); + return redisTemplate; + } + + @Bean + WebSocketStarter webSocketStarter(ProxyProperties properties) { + return new UserWebSocketStarter(properties); + } + + @Bean + ApplicationRunner enableMetaChangeReceiverInitializer(WebSocketStarter webSocketStarter) { + return args -> webSocketStarter.start(); + } + +} diff --git a/src/main/java/spring/config/WebMvcConfig.java b/src/main/java/spring/config/WebMvcConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..b710261ae05c65a550df1b00feddcc3add95864e --- /dev/null +++ b/src/main/java/spring/config/WebMvcConfig.java @@ -0,0 +1,33 @@ +package spring.config; + +import cn.hutool.core.text.CharSequenceUtil; +import com.github.novicezk.midjourney.ProxyProperties; +import com.github.novicezk.midjourney.support.ApiAuthorizeInterceptor; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.InterceptorRegistry; +import org.springframework.web.servlet.config.annotation.ViewControllerRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +import javax.annotation.Resource; + +@Configuration +public class WebMvcConfig implements WebMvcConfigurer { + @Resource + private ApiAuthorizeInterceptor apiAuthorizeInterceptor; + @Resource + private ProxyProperties properties; + + @Override + public void addViewControllers(ViewControllerRegistry registry) { + registry.addViewController("/").setViewName("redirect:doc.html"); + } + + @Override + public void addInterceptors(InterceptorRegistry registry) { + if (CharSequenceUtil.isNotBlank(this.properties.getApiSecret())) { + registry.addInterceptor(this.apiAuthorizeInterceptor) + .addPathPatterns("/submit/**", "/task/**"); + } + } + +} diff --git a/src/main/resources/api-params/blend.json b/src/main/resources/api-params/blend.json new file mode 100644 index 0000000000000000000000000000000000000000..e082cf06c91edbf21e38c1439dbc5d5735ae7fe7 --- /dev/null +++ b/src/main/resources/api-params/blend.json @@ -0,0 +1,16 @@ +{ + "type":2, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "application_id":"936929561302675456", + "session_id":"$session_id", + "nonce": "$nonce", + "data":{ + "version":"1118961510123847773", + "id":"1062880104792997970", + "name":"blend", + "type":1, + "options":[], + "attachments":[] + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/describe.json b/src/main/resources/api-params/describe.json new file mode 100644 index 0000000000000000000000000000000000000000..4402b0012f0ef73cb22c7df30027dc264a8c4d10 --- /dev/null +++ b/src/main/resources/api-params/describe.json @@ -0,0 +1,28 @@ +{ + "type": 2, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "data": { + "version": "1118961510123847774", + "id": "1092492867185950852", + "name": "describe", + "type": 1, + "options": [ + { + "type": 11, + "name": "image", + "value": 0 + } + ], + "attachments": [ + { + "id": "0", + "filename": "$file_name", + "uploaded_filename": "$final_file_name" + } + ] + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/imagine.json b/src/main/resources/api-params/imagine.json new file mode 100644 index 0000000000000000000000000000000000000000..f027d6f2c8317881fffb763f8b1e9b3997e73029 --- /dev/null +++ b/src/main/resources/api-params/imagine.json @@ -0,0 +1,21 @@ +{ + "type": 2, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "data": { + "version": "1118961510123847772", + "id": "938956540159881230", + "name": "imagine", + "type": 1, + "options": [ + { + "type": 3, + "name": "prompt", + "value": "$prompt" + } + ] + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/message.json b/src/main/resources/api-params/message.json new file mode 100644 index 0000000000000000000000000000000000000000..9f8cccbc952fab9752c04874971b1ce061db156a --- /dev/null +++ b/src/main/resources/api-params/message.json @@ -0,0 +1,13 @@ +{ + "content":"$content", + "channel_id":"$channel_id", + "type":0, + "sticker_ids":[], + "attachments":[ + { + "id":"0", + "filename": "$file_name", + "uploaded_filename": "$final_file_name" + } + ] +} \ No newline at end of file diff --git a/src/main/resources/api-params/reroll.json b/src/main/resources/api-params/reroll.json new file mode 100644 index 0000000000000000000000000000000000000000..84435b6de77afadde1a11e47c1ce066b8b0662ee --- /dev/null +++ b/src/main/resources/api-params/reroll.json @@ -0,0 +1,14 @@ +{ + "type": 3, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "message_id": "$message_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "message_flags": 0, + "data": { + "component_type": 2, + "custom_id": "MJ::JOB::reroll::0::$message_hash::SOLO" + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/upscale.json b/src/main/resources/api-params/upscale.json new file mode 100644 index 0000000000000000000000000000000000000000..f1afdd5aac6f34adccd2f6b097b664493430ea06 --- /dev/null +++ b/src/main/resources/api-params/upscale.json @@ -0,0 +1,14 @@ +{ + "type": 3, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "message_id": "$message_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "message_flags": 0, + "data": { + "component_type": 2, + "custom_id": "MJ::JOB::upsample::$index::$message_hash" + } +} \ No newline at end of file diff --git a/src/main/resources/api-params/variation.json b/src/main/resources/api-params/variation.json new file mode 100644 index 0000000000000000000000000000000000000000..0f750aef94d1b280e95e96aca373ff74753710e2 --- /dev/null +++ b/src/main/resources/api-params/variation.json @@ -0,0 +1,14 @@ +{ + "type": 3, + "guild_id": "$guild_id", + "channel_id": "$channel_id", + "message_id": "$message_id", + "application_id": "936929561302675456", + "session_id": "$session_id", + "nonce": "$nonce", + "message_flags": 0, + "data": { + "component_type": 2, + "custom_id": "MJ::JOB::variation::$index::$message_hash" + } +} \ No newline at end of file diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml new file mode 100644 index 0000000000000000000000000000000000000000..59f06a3b552c04a55ca345494cfc00bd130cc84a --- /dev/null +++ b/src/main/resources/application.yml @@ -0,0 +1,15 @@ +mj: + discord: + guild-id: xxx + channel-id: xxx + user-token: xxx + session-id: 9c4055428e13bcbf2248a6b36084c5f3 + user-agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36 + task-store: + type: in_memory + timeout: 30d + translate-way: null + queue: + timeout-minutes: 5 + core-size: 3 + queue-size: 10 \ No newline at end of file diff --git a/src/main/resources/banned-words.txt b/src/main/resources/banned-words.txt new file mode 100644 index 0000000000000000000000000000000000000000..5d38603750927ff761a206f273081418b4c002a4 --- /dev/null +++ b/src/main/resources/banned-words.txt @@ -0,0 +1,211 @@ +blood +twerk +making love +voluptuous +naughty +wincest +orgy +no clothes +au naturel +no shirt +decapitate +bare +nude +barely dressed +nude +bra +risque +scantily clad +cleavage +stripped +infested +full frontal +unclothed +invisible clothes +wearing nothing +lingerie +with no shirt +naked +without clothes on +negligee +zero clothes +gruesome +fascist +nazi +prophet mohammed +slave +coon +honkey +cocaine +heroin +meth +crack +kill +belle delphine +hitler +jinping +lolita +president xi +torture +disturbing +farts +fart +poop +infected +warts +shit +brown pudding +bunghole +vomit +voluptuous +seductive +sperm +sexy +sadist +sensored +censored +silenced +deepfake +inappropriate +waifu +succubus +slaughter +surgery +reproduce +crucified +seductively +explicit +inappropriate +large bust +explicit +wang +inappropriate +teratoma +intimate +see through +tryphophobia +bloodbath +wound +cronenberg +khorne +cannibal +cannibalism +visceral +guts +bloodshot +gory +killing +crucifixion +surgery +vivisection +massacre +hemoglobin +suicide +arse +labia +ass +mammaries +badonkers +bloody +minge +big ass +mommy milker +booba +nipple +oppai +booty +organs +bosom +ovaries +flesh +breasts +penis +busty +phallus +clunge +sexy female +crotch +skimpy +dick +thick +bruises +girth +titty +honkers +vagina +hooters +veiny +knob +ahegao +pinup +ballgag +car crash +playboy +bimbo +pleasure +bodily fluids +pleasures +boudoir +rule34 +brothel +seducing +dominatrix +corpse +seductive +erotic +seductive +fuck +sensual +hardcore +sexy +hentai +shag +horny +crucified +shibari +incest +smut +jav +succubus +jerk off king at pic +thot +kinbaku +legs spread +sensuality +belly button +porn +patriotic +bleed +excrement +petite +seduction +mccurry +provocative +sultry +erected +camisole +tight white +arrest +see-through +feces +anus +revealing clothing +vein +loli +-edge +boobs +-backed +tied up +zedong +bathing +jail +reticulum +rear end +sakimichan +behind bars +shirtless +sakimichan +seductive +sexi +sexualiz +sexual \ No newline at end of file diff --git a/src/main/resources/banner.txt b/src/main/resources/banner.txt new file mode 100644 index 0000000000000000000000000000000000000000..053d7e07ec646e13ca20feab0bebcda9d4c08131 --- /dev/null +++ b/src/main/resources/banner.txt @@ -0,0 +1,8 @@ + + , /) , +___ _(/ ___ __ __ _ __ __ _____/ +// (__(_(_(_ /_(_)(_(_/ (_/ (__(/_(_/_ /_)_/ (_(_) /(__(_/_ + .-/ .-/ .-/ / .-/ + (_/ (_/ (_/ (_/ + +:: MidJourney Proxy :: v2.4 diff --git a/src/main/resources/config/application.yml b/src/main/resources/config/application.yml new file mode 100644 index 0000000000000000000000000000000000000000..7c6b6ac5e462c3652936a6b71f44a2f23df8e55e --- /dev/null +++ b/src/main/resources/config/application.yml @@ -0,0 +1,23 @@ +server: + port: 8080 + servlet: + context-path: /mj +logging: + level: + ROOT: info + com.github.novicezk.midjourney: debug +knife4j: + enable: true + openapi: + title: Midjourney Proxy API文档 + description: 代理 MidJourney 的discord频道,实现api形式调用AI绘图 + concat: novicezk + url: https://github.com/novicezk/midjourney-proxy + version: v2.4 + terms-of-service-url: https://github.com/novicezk/midjourney-proxy + group: + api: + group-name: API + api-rule: package + api-rule-resources: + - com.github.novicezk.midjourney.controller \ No newline at end of file diff --git a/src/main/resources/mime.types b/src/main/resources/mime.types new file mode 100644 index 0000000000000000000000000000000000000000..fcff7ef006540d0747ba3616d2edeef8fc1bd961 --- /dev/null +++ b/src/main/resources/mime.types @@ -0,0 +1,82 @@ +text/html:html htm shtml +text/css:css +text/xml:xml + +text/mathml:mml +text/plain:txt +text/vnd.sun.j2me.app-descriptor:jad +text/vnd.wap.wml:wml +text/x-component:htc + +image/gif:gif +image/jpeg:jpg jpeg +image/png:png +image/tiff:tif tiff +image/vnd.wap.wbmp:wbmp +image/x-icon:ico +image/x-jng:jng +image/x-ms-bmp:bmp +image/svg+xml:svg svgz +image/webp:webp + +application/javascript:js +application/x-javascript:js +application/atom+xml:atom +application/rss+xml:rss + +application/font-woff:woff +application/java-archive:jar war ear +application/json:json +application/mac-binhex40:hqx +application/msword:doc +application/pdf:pdf +application/postscript:ps eps ai +application/rtf:rtf +application/vnd.apple.mpegurl:m3u8 +application/vnd.ms-excel:xls +application/vnd.ms-fontobject:eot +application/vnd.ms-powerpoint:ppt +application/vnd.wap.wmlc:wmlc +application/vnd.google-earth.kml+xml:kml +application/vnd.google-earth.kmz:kmz +application/x-7z-compressed:7z +application/x-cocoa:cco +application/x-java-archive-diff:jardiff +application/x-java-jnlp-file:jnlp +application/x-makeself:run +application/x-perl:pl pm +application/x-pilot:prc pdb +application/x-rar-compressed:rar +application/x-redhat-package-manager:rpm +application/x-sea:sea +application/x-shockwave-flash:swf +application/x-stuffit:sit +application/x-tcl:tcl tk +application/x-x509-ca-cert:der pem crt +application/x-xpinstall:xpi +application/xhtml+xml:xhtml +application/xspf+xml:xspf +application/zip:zip + +application/vnd.openxmlformats-officedocument.wordprocessingml.document:docx +application/vnd.openxmlformats-officedocument.spreadsheetml.sheet:xlsx +application/vnd.openxmlformats-officedocument.presentationml.presentation:pptx + +audio/midi:mid midi kar +audio/mpeg:mp3 +audio/ogg:ogg +audio/x-m4a:m4a +audio/x-realaudio:ra + +video/3gpp:3gpp 3gp +video/mp2t:ts +video/mp4:mp4 +video/mpeg:mpeg mpg +video/quicktime:mov +video/webm:webm +video/x-flv:flv +video/x-m4v:m4v +video/x-mng:mng +video/x-ms-asf:asx asf +video/x-ms-wmv:wmv +video/x-msvideo:avi