🐛 Bug: Fix the bug where weight polling cannot match the model.
Browse files
main.py
CHANGED
|
@@ -759,7 +759,9 @@ class ModelRequestHandler:
|
|
| 759 |
weights = safe_get(config, 'api_keys', api_index, "weights")
|
| 760 |
|
| 761 |
# 步骤 1: 提取 matching_providers 中的所有 provider 值
|
| 762 |
-
|
|
|
|
|
|
|
| 763 |
|
| 764 |
intersection = None
|
| 765 |
if weights and all_providers:
|
|
@@ -768,21 +770,25 @@ class ModelRequestHandler:
|
|
| 768 |
for model_rule in weight_keys:
|
| 769 |
provider_rules.extend(get_provider_rules(model_rule, config, request_model))
|
| 770 |
provider_list = get_provider_list(provider_rules, config, request_model)
|
| 771 |
-
weight_keys = set([provider['provider'] for provider in provider_list])
|
| 772 |
# print("all_providers", all_providers)
|
| 773 |
-
# print("weights",
|
|
|
|
|
|
|
| 774 |
# 步骤 3: 计算交集
|
| 775 |
intersection = all_providers.intersection(weight_keys)
|
|
|
|
| 776 |
|
| 777 |
if weights and intersection:
|
| 778 |
-
|
|
|
|
| 779 |
|
| 780 |
if scheduling_algorithm == "weighted_round_robin":
|
| 781 |
-
weighted_provider_name_list = weighted_round_robin(
|
| 782 |
elif scheduling_algorithm == "lottery":
|
| 783 |
-
weighted_provider_name_list = lottery_scheduling(
|
| 784 |
else:
|
| 785 |
-
weighted_provider_name_list = list(
|
| 786 |
# print("weighted_provider_name_list", weighted_provider_name_list)
|
| 787 |
|
| 788 |
new_matching_providers = []
|
|
|
|
| 759 |
weights = safe_get(config, 'api_keys', api_index, "weights")
|
| 760 |
|
| 761 |
# 步骤 1: 提取 matching_providers 中的所有 provider 值
|
| 762 |
+
# print("matching_providers", matching_providers)
|
| 763 |
+
# print(type(matching_providers[0]['model'][0].keys()), list(matching_providers[0]['model'][0].keys())[0], matching_providers[0]['model'][0].keys())
|
| 764 |
+
all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
|
| 765 |
|
| 766 |
intersection = None
|
| 767 |
if weights and all_providers:
|
|
|
|
| 770 |
for model_rule in weight_keys:
|
| 771 |
provider_rules.extend(get_provider_rules(model_rule, config, request_model))
|
| 772 |
provider_list = get_provider_list(provider_rules, config, request_model)
|
| 773 |
+
weight_keys = set([provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in provider_list])
|
| 774 |
# print("all_providers", all_providers)
|
| 775 |
+
# print("weights", weights)
|
| 776 |
+
# print("weight_keys", weight_keys)
|
| 777 |
+
|
| 778 |
# 步骤 3: 计算交集
|
| 779 |
intersection = all_providers.intersection(weight_keys)
|
| 780 |
+
# print("intersection", intersection)
|
| 781 |
|
| 782 |
if weights and intersection:
|
| 783 |
+
filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
|
| 784 |
+
# print("filtered_weights", filtered_weights)
|
| 785 |
|
| 786 |
if scheduling_algorithm == "weighted_round_robin":
|
| 787 |
+
weighted_provider_name_list = weighted_round_robin(filtered_weights)
|
| 788 |
elif scheduling_algorithm == "lottery":
|
| 789 |
+
weighted_provider_name_list = lottery_scheduling(filtered_weights)
|
| 790 |
else:
|
| 791 |
+
weighted_provider_name_list = list(filtered_weights.keys())
|
| 792 |
# print("weighted_provider_name_list", weighted_provider_name_list)
|
| 793 |
|
| 794 |
new_matching_providers = []
|