🤖 Models: Add support for o1-mini o1-preview model
Browse files- main.py +4 -0
- request.py +58 -0
- utils.py +9 -0
main.py
CHANGED
|
@@ -201,6 +201,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
|
|
| 201 |
if "gemini" in provider['model'][request.model] and engine == "vertex":
|
| 202 |
engine = "vertex-gemini"
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
if endpoint == "/v1/images/generations":
|
| 205 |
engine = "dalle"
|
| 206 |
request.stream = False
|
|
|
|
| 201 |
if "gemini" in provider['model'][request.model] and engine == "vertex":
|
| 202 |
engine = "vertex-gemini"
|
| 203 |
|
| 204 |
+
if "o1-preview" in provider['model'][request.model] or "o1-mini" in provider['model'][request.model]:
|
| 205 |
+
engine = "o1"
|
| 206 |
+
request.stream = False
|
| 207 |
+
|
| 208 |
if endpoint == "/v1/images/generations":
|
| 209 |
engine = "dalle"
|
| 210 |
request.stream = False
|
request.py
CHANGED
|
@@ -737,6 +737,62 @@ async def get_cloudflare_payload(request, engine, provider):
|
|
| 737 |
|
| 738 |
return url, headers, payload
|
| 739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
async def gpt2claude_tools_json(json_dict):
|
| 741 |
import copy
|
| 742 |
json_dict = copy.deepcopy(json_dict)
|
|
@@ -929,6 +985,8 @@ async def get_payload(request: RequestModel, engine, provider):
|
|
| 929 |
return await get_openrouter_payload(request, engine, provider)
|
| 930 |
elif engine == "cloudflare":
|
| 931 |
return await get_cloudflare_payload(request, engine, provider)
|
|
|
|
|
|
|
| 932 |
elif engine == "dalle":
|
| 933 |
return await get_dalle_payload(request, engine, provider)
|
| 934 |
else:
|
|
|
|
| 737 |
|
| 738 |
return url, headers, payload
|
| 739 |
|
| 740 |
+
async def get_o1_payload(request, engine, provider):
|
| 741 |
+
headers = {
|
| 742 |
+
'Content-Type': 'application/json'
|
| 743 |
+
}
|
| 744 |
+
if provider.get("api"):
|
| 745 |
+
headers['Authorization'] = f"Bearer {provider['api'].next()}"
|
| 746 |
+
|
| 747 |
+
url = provider['base_url']
|
| 748 |
+
|
| 749 |
+
messages = []
|
| 750 |
+
for msg in request.messages:
|
| 751 |
+
if isinstance(msg.content, list):
|
| 752 |
+
content = []
|
| 753 |
+
for item in msg.content:
|
| 754 |
+
if item.type == "text":
|
| 755 |
+
text_message = await get_text_message(msg.role, item.text, engine)
|
| 756 |
+
content.append(text_message)
|
| 757 |
+
else:
|
| 758 |
+
content = msg.content
|
| 759 |
+
|
| 760 |
+
if isinstance(content, list):
|
| 761 |
+
for item in content:
|
| 762 |
+
if item["type"] == "text":
|
| 763 |
+
messages.append({"role": msg.role, "content": item["text"]})
|
| 764 |
+
else:
|
| 765 |
+
messages.append({"role": msg.role, "content": content})
|
| 766 |
+
|
| 767 |
+
model = provider['model'][request.model]
|
| 768 |
+
payload = {
|
| 769 |
+
"model": model,
|
| 770 |
+
"messages": messages,
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
miss_fields = [
|
| 774 |
+
'model',
|
| 775 |
+
'messages',
|
| 776 |
+
'tools',
|
| 777 |
+
'tool_choice',
|
| 778 |
+
'temperature',
|
| 779 |
+
'top_p',
|
| 780 |
+
'max_tokens',
|
| 781 |
+
'presence_penalty',
|
| 782 |
+
'frequency_penalty',
|
| 783 |
+
'n',
|
| 784 |
+
'user',
|
| 785 |
+
'include_usage',
|
| 786 |
+
'logprobs',
|
| 787 |
+
'top_logprobs'
|
| 788 |
+
]
|
| 789 |
+
|
| 790 |
+
for field, value in request.model_dump(exclude_unset=True).items():
|
| 791 |
+
if field not in miss_fields and value is not None:
|
| 792 |
+
payload[field] = value
|
| 793 |
+
|
| 794 |
+
return url, headers, payload
|
| 795 |
+
|
| 796 |
async def gpt2claude_tools_json(json_dict):
|
| 797 |
import copy
|
| 798 |
json_dict = copy.deepcopy(json_dict)
|
|
|
|
| 985 |
return await get_openrouter_payload(request, engine, provider)
|
| 986 |
elif engine == "cloudflare":
|
| 987 |
return await get_cloudflare_payload(request, engine, provider)
|
| 988 |
+
elif engine == "o1":
|
| 989 |
+
return await get_o1_payload(request, engine, provider)
|
| 990 |
elif engine == "dalle":
|
| 991 |
return await get_dalle_payload(request, engine, provider)
|
| 992 |
else:
|
utils.py
CHANGED
|
@@ -53,6 +53,15 @@ def update_config(config_data):
|
|
| 53 |
async def load_config(app=None):
|
| 54 |
import yaml
|
| 55 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
with open('./api.yaml', 'r') as f:
|
| 57 |
# 判断是否为空文件
|
| 58 |
conf = yaml.safe_load(f)
|
|
|
|
| 53 |
async def load_config(app=None):
|
| 54 |
import yaml
|
| 55 |
try:
|
| 56 |
+
# with open('./api.yaml', 'r') as f:
|
| 57 |
+
# tokens = yaml.scan(f)
|
| 58 |
+
# for token in tokens:
|
| 59 |
+
# if isinstance(token, yaml.ScalarToken):
|
| 60 |
+
# value = token.value
|
| 61 |
+
# # 如果plain为False,表示字符串被引号包裹
|
| 62 |
+
# is_quoted = not token.plain
|
| 63 |
+
# print(f"值: {value}, 是否被引号包裹: {is_quoted}")
|
| 64 |
+
|
| 65 |
with open('./api.yaml', 'r') as f:
|
| 66 |
# 判断是否为空文件
|
| 67 |
conf = yaml.safe_load(f)
|