Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu commited on
Commit ·
a6ebff0
1
Parent(s): a27db7d
feat: 加入DALLE3支持
Browse files- modules/models/DALLE3.py +38 -0
- modules/models/base_model.py +3 -0
- modules/models/models.py +4 -0
- modules/presets.py +1 -0
modules/models/DALLE3.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import openai
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from .base_model import BaseLLMModel
|
| 6 |
+
from .. import shared
|
| 7 |
+
from ..config import retrieve_proxy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class OpenAI_DALLE3_Client(BaseLLMModel):
|
| 11 |
+
def __init__(self, model_name, api_key, user_name="") -> None:
|
| 12 |
+
super().__init__(model_name=model_name, user=user_name)
|
| 13 |
+
self.api_key = api_key
|
| 14 |
+
|
| 15 |
+
def _get_dalle3_prompt(self):
|
| 16 |
+
prompt = self.history[-1]["content"]
|
| 17 |
+
if prompt.endswith("--raw"):
|
| 18 |
+
prompt = "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:" + prompt
|
| 19 |
+
return prompt
|
| 20 |
+
|
| 21 |
+
@shared.state.switching_api_key
|
| 22 |
+
def get_answer_at_once(self):
|
| 23 |
+
prompt = self._get_dalle3_prompt()
|
| 24 |
+
with retrieve_proxy():
|
| 25 |
+
client = OpenAI(api_key=openai.api_key)
|
| 26 |
+
try:
|
| 27 |
+
response = client.images.generate(
|
| 28 |
+
model="dall-e-3",
|
| 29 |
+
prompt=prompt,
|
| 30 |
+
size="1024x1024",
|
| 31 |
+
quality="standard",
|
| 32 |
+
n=1,
|
| 33 |
+
)
|
| 34 |
+
except openai.BadRequestError as e:
|
| 35 |
+
msg = str(e)
|
| 36 |
+
match = re.search(r"'message': '([^']*)'", msg)
|
| 37 |
+
return match.group(1), 0
|
| 38 |
+
return f'<img src="{response.data[0].url}"> {response.data[0].revised_prompt}', 0
|
modules/models/base_model.py
CHANGED
|
@@ -153,6 +153,7 @@ class ModelType(Enum):
|
|
| 153 |
Qwen = 15
|
| 154 |
OpenAIVision = 16
|
| 155 |
ERNIE = 17
|
|
|
|
| 156 |
|
| 157 |
@classmethod
|
| 158 |
def get_type(cls, model_name: str):
|
|
@@ -195,6 +196,8 @@ class ModelType(Enum):
|
|
| 195 |
model_type = ModelType.Qwen
|
| 196 |
elif "ernie" in model_name_lower:
|
| 197 |
model_type = ModelType.ERNIE
|
|
|
|
|
|
|
| 198 |
else:
|
| 199 |
model_type = ModelType.LLaMA
|
| 200 |
return model_type
|
|
|
|
| 153 |
Qwen = 15
|
| 154 |
OpenAIVision = 16
|
| 155 |
ERNIE = 17
|
| 156 |
+
DALLE3 = 18
|
| 157 |
|
| 158 |
@classmethod
|
| 159 |
def get_type(cls, model_name: str):
|
|
|
|
| 196 |
model_type = ModelType.Qwen
|
| 197 |
elif "ernie" in model_name_lower:
|
| 198 |
model_type = ModelType.ERNIE
|
| 199 |
+
elif "dall" in model_name_lower:
|
| 200 |
+
model_type = ModelType.DALLE3
|
| 201 |
else:
|
| 202 |
model_type = ModelType.LLaMA
|
| 203 |
return model_type
|
modules/models/models.py
CHANGED
|
@@ -129,6 +129,10 @@ def get_model(
|
|
| 129 |
elif model_type == ModelType.ERNIE:
|
| 130 |
from .ERNIE import ERNIE_Client
|
| 131 |
model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
elif model_type == ModelType.Unknown:
|
| 133 |
raise ValueError(f"未知模型: {model_name}")
|
| 134 |
logging.info(msg)
|
|
|
|
| 129 |
elif model_type == ModelType.ERNIE:
|
| 130 |
from .ERNIE import ERNIE_Client
|
| 131 |
model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY"))
|
| 132 |
+
elif model_type == ModelType.DALLE3:
|
| 133 |
+
from .DALLE3 import OpenAI_DALLE3_Client
|
| 134 |
+
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
| 135 |
+
model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
|
| 136 |
elif model_type == ModelType.Unknown:
|
| 137 |
raise ValueError(f"未知模型: {model_name}")
|
| 138 |
logging.info(msg)
|
modules/presets.py
CHANGED
|
@@ -62,6 +62,7 @@ ONLINE_MODELS = [
|
|
| 62 |
"GPT4 Vision",
|
| 63 |
"川虎助理",
|
| 64 |
"川虎助理 Pro",
|
|
|
|
| 65 |
"GooglePaLM",
|
| 66 |
"xmchat",
|
| 67 |
"Azure OpenAI",
|
|
|
|
| 62 |
"GPT4 Vision",
|
| 63 |
"川虎助理",
|
| 64 |
"川虎助理 Pro",
|
| 65 |
+
"DALL-E 3",
|
| 66 |
"GooglePaLM",
|
| 67 |
"xmchat",
|
| 68 |
"Azure OpenAI",
|