Spaces:
Running on Zero
Running on Zero
File size: 3,772 Bytes
3be2d14 56ef1e6 3be2d14 56ef1e6 3be2d14 56ef1e6 3be2d14 56ef1e6 3be2d14 56ef1e6 3be2d14 56ef1e6 d98b8b5 f08e25e 56ef1e6 3be2d14 56ef1e6 3be2d14 56ef1e6 3be2d14 56ef1e6 3be2d14 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | from config import MODELS, MAX_PROMPT_CHARS, QUALITY_THRESHOLD, DAILY_RATE_LIMIT
from openrouter import call_models_parallel, calculate_cost
from judge import judge_responses
from rate_limiter import RateLimiter
rate_limiter = RateLimiter(max_requests=DAILY_RATE_LIMIT)
def pick_best_value(scores: dict, costs: dict, threshold: int) -> str | None:
candidates = []
for key in costs:
score = scores.get(key)
if score is not None and score >= threshold:
candidates.append((key, costs[key], score))
if not candidates:
return None
# Sort by cost ascending, then by score descending for ties
candidates.sort(key=lambda x: (x[1], -x[2]))
return candidates[0][0]
def validate_prompt(prompt: str, ip: str) -> str | None:
"""Returns an error message if invalid, None if OK."""
if not prompt.strip():
return "Please enter a prompt."
if len(prompt) > MAX_PROMPT_CHARS:
return f"Prompt too long. Maximum is {MAX_PROMPT_CHARS:,} characters."
if not rate_limiter.check(ip):
remaining_secs = rate_limiter.reset_time(ip) - __import__("time").time()
hours = int(remaining_secs // 3600)
return f"Rate limit reached. Resets in ~{hours} hours. Max {DAILY_RATE_LIMIT} comparisons per day."
return None
async def call_models(prompt: str) -> dict:
"""Call all models in parallel and return raw results."""
return await call_models_parallel(prompt)
def compute_costs(model_results: dict) -> dict:
"""Calculate costs from model results."""
costs = {}
for key, result in model_results.items():
model_conf = MODELS[key]
costs[key] = calculate_cost(
prompt_tokens=result["prompt_tokens"],
completion_tokens=result["completion_tokens"],
input_cost_per_m=model_conf["input_cost_per_m"],
output_cost_per_m=model_conf["output_cost_per_m"],
)
return costs
def check_all_failed(model_results: dict) -> str | None:
"""If all models failed, return an error message. Otherwise None."""
valid = {k: v["content"] for k, v in model_results.items() if v["content"] is not None}
if valid:
return None
first_error = next((r["error"] for r in model_results.values() if r["error"]), "")
if "401" in first_error or "Unauthorized" in first_error:
return "API key is missing or invalid. Please check the OPENROUTER_API_KEY secret."
return f"All models failed to respond. Please try again. ({first_error[:100]})"
async def score_responses(prompt: str, model_results: dict) -> dict | None:
"""Judge the valid responses and return scores."""
valid_responses = {k: v["content"] for k, v in model_results.items() if v["content"] is not None}
if not valid_responses:
return None
return await judge_responses(prompt, valid_responses)
async def run_comparison(prompt: str, ip: str) -> dict:
"""Full pipeline (non-streaming). Kept for backward compat."""
error = validate_prompt(prompt, ip)
if error:
return {"error": error}
model_results = await call_models(prompt)
costs = compute_costs(model_results)
all_failed_msg = check_all_failed(model_results)
if all_failed_msg:
return {
"error": all_failed_msg,
"responses": model_results,
"costs": costs,
"scores": None,
"best_value": None,
}
scores = await score_responses(prompt, model_results)
best_value = None
if scores:
best_value = pick_best_value(scores, costs, QUALITY_THRESHOLD)
return {
"error": None,
"responses": model_results,
"costs": costs,
"scores": scores,
"best_value": best_value,
}
|