File size: 3,772 Bytes
3be2d14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56ef1e6
 
3be2d14
56ef1e6
 
 
3be2d14
 
 
56ef1e6
 
3be2d14
 
56ef1e6
 
 
 
 
 
 
3be2d14
 
 
 
 
 
 
 
 
56ef1e6
3be2d14
 
56ef1e6
 
 
 
 
 
 
d98b8b5
f08e25e
56ef1e6
 
 
 
 
3be2d14
56ef1e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3be2d14
56ef1e6
3be2d14
 
 
 
 
 
56ef1e6
3be2d14
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from config import MODELS, MAX_PROMPT_CHARS, QUALITY_THRESHOLD, DAILY_RATE_LIMIT
from openrouter import call_models_parallel, calculate_cost
from judge import judge_responses
from rate_limiter import RateLimiter

rate_limiter = RateLimiter(max_requests=DAILY_RATE_LIMIT)


def pick_best_value(scores: dict, costs: dict, threshold: int) -> str | None:
    candidates = []
    for key in costs:
        score = scores.get(key)
        if score is not None and score >= threshold:
            candidates.append((key, costs[key], score))

    if not candidates:
        return None

    # Sort by cost ascending, then by score descending for ties
    candidates.sort(key=lambda x: (x[1], -x[2]))
    return candidates[0][0]


def validate_prompt(prompt: str, ip: str) -> str | None:
    """Returns an error message if invalid, None if OK."""
    if not prompt.strip():
        return "Please enter a prompt."
    if len(prompt) > MAX_PROMPT_CHARS:
        return f"Prompt too long. Maximum is {MAX_PROMPT_CHARS:,} characters."
    if not rate_limiter.check(ip):
        remaining_secs = rate_limiter.reset_time(ip) - __import__("time").time()
        hours = int(remaining_secs // 3600)
        return f"Rate limit reached. Resets in ~{hours} hours. Max {DAILY_RATE_LIMIT} comparisons per day."
    return None


async def call_models(prompt: str) -> dict:
    """Call all models in parallel and return raw results."""
    return await call_models_parallel(prompt)


def compute_costs(model_results: dict) -> dict:
    """Calculate costs from model results."""
    costs = {}
    for key, result in model_results.items():
        model_conf = MODELS[key]
        costs[key] = calculate_cost(
            prompt_tokens=result["prompt_tokens"],
            completion_tokens=result["completion_tokens"],
            input_cost_per_m=model_conf["input_cost_per_m"],
            output_cost_per_m=model_conf["output_cost_per_m"],
        )
    return costs


def check_all_failed(model_results: dict) -> str | None:
    """If all models failed, return an error message. Otherwise None."""
    valid = {k: v["content"] for k, v in model_results.items() if v["content"] is not None}
    if valid:
        return None
    first_error = next((r["error"] for r in model_results.values() if r["error"]), "")
    if "401" in first_error or "Unauthorized" in first_error:
        return "API key is missing or invalid. Please check the OPENROUTER_API_KEY secret."
    return f"All models failed to respond. Please try again. ({first_error[:100]})"


async def score_responses(prompt: str, model_results: dict) -> dict | None:
    """Judge the valid responses and return scores."""
    valid_responses = {k: v["content"] for k, v in model_results.items() if v["content"] is not None}
    if not valid_responses:
        return None
    return await judge_responses(prompt, valid_responses)


async def run_comparison(prompt: str, ip: str) -> dict:
    """Full pipeline (non-streaming). Kept for backward compat."""
    error = validate_prompt(prompt, ip)
    if error:
        return {"error": error}

    model_results = await call_models(prompt)
    costs = compute_costs(model_results)

    all_failed_msg = check_all_failed(model_results)
    if all_failed_msg:
        return {
            "error": all_failed_msg,
            "responses": model_results,
            "costs": costs,
            "scores": None,
            "best_value": None,
        }

    scores = await score_responses(prompt, model_results)

    best_value = None
    if scores:
        best_value = pick_best_value(scores, costs, QUALITY_THRESHOLD)

    return {
        "error": None,
        "responses": model_results,
        "costs": costs,
        "scores": scores,
        "best_value": best_value,
    }