trinity-arena / engine.py
Julien Simon
feat: run Trinity Nano locally via ZeroGPU, Mini/Large via OpenRouter
d98b8b5
from config import MODELS, MAX_PROMPT_CHARS, QUALITY_THRESHOLD, DAILY_RATE_LIMIT
from openrouter import call_models_parallel, calculate_cost
from judge import judge_responses
from rate_limiter import RateLimiter
rate_limiter = RateLimiter(max_requests=DAILY_RATE_LIMIT)
def pick_best_value(scores: dict, costs: dict, threshold: int) -> str | None:
candidates = []
for key in costs:
score = scores.get(key)
if score is not None and score >= threshold:
candidates.append((key, costs[key], score))
if not candidates:
return None
# Sort by cost ascending, then by score descending for ties
candidates.sort(key=lambda x: (x[1], -x[2]))
return candidates[0][0]
def validate_prompt(prompt: str, ip: str) -> str | None:
"""Returns an error message if invalid, None if OK."""
if not prompt.strip():
return "Please enter a prompt."
if len(prompt) > MAX_PROMPT_CHARS:
return f"Prompt too long. Maximum is {MAX_PROMPT_CHARS:,} characters."
if not rate_limiter.check(ip):
remaining_secs = rate_limiter.reset_time(ip) - __import__("time").time()
hours = int(remaining_secs // 3600)
return f"Rate limit reached. Resets in ~{hours} hours. Max {DAILY_RATE_LIMIT} comparisons per day."
return None
async def call_models(prompt: str) -> dict:
"""Call all models in parallel and return raw results."""
return await call_models_parallel(prompt)
def compute_costs(model_results: dict) -> dict:
"""Calculate costs from model results."""
costs = {}
for key, result in model_results.items():
model_conf = MODELS[key]
costs[key] = calculate_cost(
prompt_tokens=result["prompt_tokens"],
completion_tokens=result["completion_tokens"],
input_cost_per_m=model_conf["input_cost_per_m"],
output_cost_per_m=model_conf["output_cost_per_m"],
)
return costs
def check_all_failed(model_results: dict) -> str | None:
"""If all models failed, return an error message. Otherwise None."""
valid = {k: v["content"] for k, v in model_results.items() if v["content"] is not None}
if valid:
return None
first_error = next((r["error"] for r in model_results.values() if r["error"]), "")
if "401" in first_error or "Unauthorized" in first_error:
return "API key is missing or invalid. Please check the OPENROUTER_API_KEY secret."
return f"All models failed to respond. Please try again. ({first_error[:100]})"
async def score_responses(prompt: str, model_results: dict) -> dict | None:
"""Judge the valid responses and return scores."""
valid_responses = {k: v["content"] for k, v in model_results.items() if v["content"] is not None}
if not valid_responses:
return None
return await judge_responses(prompt, valid_responses)
async def run_comparison(prompt: str, ip: str) -> dict:
"""Full pipeline (non-streaming). Kept for backward compat."""
error = validate_prompt(prompt, ip)
if error:
return {"error": error}
model_results = await call_models(prompt)
costs = compute_costs(model_results)
all_failed_msg = check_all_failed(model_results)
if all_failed_msg:
return {
"error": all_failed_msg,
"responses": model_results,
"costs": costs,
"scores": None,
"best_value": None,
}
scores = await score_responses(prompt, model_results)
best_value = None
if scores:
best_value = pick_best_value(scores, costs, QUALITY_THRESHOLD)
return {
"error": None,
"responses": model_results,
"costs": costs,
"scores": scores,
"best_value": best_value,
}