Spaces:
Running on Zero
Running on Zero
| from config import MODELS, MAX_PROMPT_CHARS, QUALITY_THRESHOLD, DAILY_RATE_LIMIT | |
| from openrouter import call_models_parallel, calculate_cost | |
| from judge import judge_responses | |
| from rate_limiter import RateLimiter | |
| rate_limiter = RateLimiter(max_requests=DAILY_RATE_LIMIT) | |
| def pick_best_value(scores: dict, costs: dict, threshold: int) -> str | None: | |
| candidates = [] | |
| for key in costs: | |
| score = scores.get(key) | |
| if score is not None and score >= threshold: | |
| candidates.append((key, costs[key], score)) | |
| if not candidates: | |
| return None | |
| # Sort by cost ascending, then by score descending for ties | |
| candidates.sort(key=lambda x: (x[1], -x[2])) | |
| return candidates[0][0] | |
| def validate_prompt(prompt: str, ip: str) -> str | None: | |
| """Returns an error message if invalid, None if OK.""" | |
| if not prompt.strip(): | |
| return "Please enter a prompt." | |
| if len(prompt) > MAX_PROMPT_CHARS: | |
| return f"Prompt too long. Maximum is {MAX_PROMPT_CHARS:,} characters." | |
| if not rate_limiter.check(ip): | |
| remaining_secs = rate_limiter.reset_time(ip) - __import__("time").time() | |
| hours = int(remaining_secs // 3600) | |
| return f"Rate limit reached. Resets in ~{hours} hours. Max {DAILY_RATE_LIMIT} comparisons per day." | |
| return None | |
| async def call_models(prompt: str) -> dict: | |
| """Call all models in parallel and return raw results.""" | |
| return await call_models_parallel(prompt) | |
| def compute_costs(model_results: dict) -> dict: | |
| """Calculate costs from model results.""" | |
| costs = {} | |
| for key, result in model_results.items(): | |
| model_conf = MODELS[key] | |
| costs[key] = calculate_cost( | |
| prompt_tokens=result["prompt_tokens"], | |
| completion_tokens=result["completion_tokens"], | |
| input_cost_per_m=model_conf["input_cost_per_m"], | |
| output_cost_per_m=model_conf["output_cost_per_m"], | |
| ) | |
| return costs | |
| def check_all_failed(model_results: dict) -> str | None: | |
| """If all models failed, return an error message. Otherwise None.""" | |
| valid = {k: v["content"] for k, v in model_results.items() if v["content"] is not None} | |
| if valid: | |
| return None | |
| first_error = next((r["error"] for r in model_results.values() if r["error"]), "") | |
| if "401" in first_error or "Unauthorized" in first_error: | |
| return "API key is missing or invalid. Please check the OPENROUTER_API_KEY secret." | |
| return f"All models failed to respond. Please try again. ({first_error[:100]})" | |
| async def score_responses(prompt: str, model_results: dict) -> dict | None: | |
| """Judge the valid responses and return scores.""" | |
| valid_responses = {k: v["content"] for k, v in model_results.items() if v["content"] is not None} | |
| if not valid_responses: | |
| return None | |
| return await judge_responses(prompt, valid_responses) | |
| async def run_comparison(prompt: str, ip: str) -> dict: | |
| """Full pipeline (non-streaming). Kept for backward compat.""" | |
| error = validate_prompt(prompt, ip) | |
| if error: | |
| return {"error": error} | |
| model_results = await call_models(prompt) | |
| costs = compute_costs(model_results) | |
| all_failed_msg = check_all_failed(model_results) | |
| if all_failed_msg: | |
| return { | |
| "error": all_failed_msg, | |
| "responses": model_results, | |
| "costs": costs, | |
| "scores": None, | |
| "best_value": None, | |
| } | |
| scores = await score_responses(prompt, model_results) | |
| best_value = None | |
| if scores: | |
| best_value = pick_best_value(scores, costs, QUALITY_THRESHOLD) | |
| return { | |
| "error": None, | |
| "responses": model_results, | |
| "costs": costs, | |
| "scores": scores, | |
| "best_value": best_value, | |
| } | |