from __future__ import annotations import asyncio import json import logging from collections import defaultdict from typing import Any, AsyncIterator from app.clients.hana_client import hana_client from app.clients.openai_compat import openai_chat_completion from app.config import settings from app.services.session_store import ComparisonResult, session_store LOG = logging.getLogger(__name__) PROVIDER_STAGGER_SECONDS = 0.1 _provider_semaphores: dict[str, asyncio.Semaphore] = defaultdict(lambda: asyncio.Semaphore(1)) async def _query_neon( query: str, model_id: str, persona_name: str, system_prompt: str | None, has_persona: bool = True, history: list[tuple[str, str]] | None = None, ) -> dict[str, Any]: try: if settings._neon_security_direct_vllm_enabled(model_id): builtin = hana_client.get_persona_system_prompt(model_id, persona_name) parts: list[str] = [] if system_prompt and has_persona: parts.append(system_prompt) if builtin: parts.append("[Neon persona base from HANA]\n" + builtin) combined = "\n\n".join(parts) if parts else "" messages: list[dict[str, str]] = [] if combined: messages.append({"role": "system", "content": combined}) messages.append({"role": "user", "content": query}) base = f"{settings.neon_security_vllm_base_url.rstrip('/')}/v1" result = await openai_chat_completion( base_url=base, api_key=settings.hana_password_klatchat, model=model_id, messages=messages, ) return { "provider": "neon", "model_id": model_id, "model_name": model_id.split("/")[-1].split("@")[0] if "/" in model_id else model_id, "persona_name": persona_name, "response": result["response"], "elapsed_seconds": result["elapsed_seconds"], "is_neon": True, "params": "24B (quantized)", "has_persona": has_persona, } result = await hana_client.get_inference( query=query, model_id=model_id, persona_name=persona_name, system_prompt=system_prompt, history=history, ) return { "provider": "neon", "model_id": model_id, "model_name": model_id.split("/")[-1].split("@")[0] if "/" in model_id else model_id, "persona_name": persona_name, "response": result["response"], "elapsed_seconds": result["elapsed_seconds"], "is_neon": True, "params": "24B (quantized)", "has_persona": has_persona, } except Exception as exc: LOG.exception("Neon inference failed for %s: %s", model_id, exc) return { "provider": "neon", "model_id": model_id, "model_name": model_id, "persona_name": persona_name, "response": f"[Error]: {exc}", "elapsed_seconds": 0, "is_neon": True, "error": True, "params": "24B (quantized)", "has_persona": has_persona, } async def _query_comparison( query: str, provider_id: str, provider_name: str, base_url: str, api_key: str, model_id: str, model_name: str, system_prompt: str | None, params: str = "", has_persona: bool = True, ) -> dict[str, Any]: sem = _provider_semaphores[provider_id] async with sem: messages: list[dict[str, str]] = [] if system_prompt and has_persona: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": query}) result = await openai_chat_completion( base_url=base_url, api_key=api_key, model=model_id, messages=messages, ) await asyncio.sleep(PROVIDER_STAGGER_SECONDS) return { "provider": provider_id, "provider_name": provider_name, "model_id": model_id, "model_name": model_name, "response": result["response"], "elapsed_seconds": result["elapsed_seconds"], "is_neon": False, "error": result.get("error", False), "params": params, "has_persona": has_persona, } def _resolve_comparison_models(selected_ids: list[str]) -> list[dict[str, Any]]: """Map selected comparison model IDs to their provider config.""" resolved = [] for provider in settings.comparison_providers: for model in provider["models"]: if model["id"] in selected_ids: resolved.append({ "provider_id": provider["id"], "provider_name": provider["name"], "base_url": model.get("base_url", provider["base_url"]), "api_key": model.get("api_key", provider["api_key"]), "model_id": model["id"], "model_name": model["name"], "params": model.get("params", ""), }) return resolved async def stream_comparison( query: str, neon_selections: list[dict[str, str]], comparison_model_ids: list[str], session_id: str | None = None, persona_target: str = "neon-only", ) -> AsyncIterator[str]: """Yield SSE events as each model response completes.""" comparison_models = _resolve_comparison_models(comparison_model_ids) for gi, neon_sel in enumerate(neon_selections): system_prompt = neon_sel.get("system_prompt") neon_gets_persona = persona_target in ("all", "neon-only") comp_gets_persona = persona_target == "all" group_meta = { "neon_model_id": neon_sel["model_id"], "neon_persona": neon_sel["persona_name"], "system_prompt": system_prompt or "", "query": query, "group_index": gi, "persona_target": persona_target, } yield f"event: group_start\ndata: {json.dumps(group_meta)}\n\n" session_responses: list[dict[str, Any]] = [] pending: set[asyncio.Task] = set() neon_task = asyncio.create_task(_query_neon( query=query, model_id=neon_sel["model_id"], persona_name=neon_sel["persona_name"], system_prompt=system_prompt if neon_gets_persona else None, has_persona=neon_gets_persona, )) neon_task.set_name("neon") pending.add(neon_task) for cm in comparison_models: t = asyncio.create_task(_query_comparison( query=query, provider_id=cm["provider_id"], provider_name=cm["provider_name"], base_url=cm["base_url"], api_key=cm["api_key"], model_id=cm["model_id"], model_name=cm["model_name"], system_prompt=system_prompt, params=cm.get("params", ""), has_persona=comp_gets_persona, )) t.set_name(cm["model_id"]) pending.add(t) while pending: done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: try: result = task.result() except Exception as exc: result = { "provider": "unknown", "model_id": task.get_name(), "model_name": task.get_name(), "response": f"[Error]: {exc}", "elapsed_seconds": 0, "is_neon": False, "error": True, } result["group_index"] = gi session_responses.append(result) yield f"event: response\ndata: {json.dumps(result)}\n\n" if session_id: resp_map = {} timing_map = {} for r in session_responses: key = f"{r.get('provider_name', r['provider'])} / {r['model_name']}" if r.get("is_neon"): key = f"Neon / {r['model_name']} ({r.get('persona_name', '')})" resp_map[key] = r["response"] timing_map[key] = r["elapsed_seconds"] session_store.add_result(session_id, ComparisonResult( query=query, neon_model_id=neon_sel["model_id"], neon_persona=neon_sel["persona_name"], responses=resp_map, timings=timing_map, )) yield "event: done\ndata: {}\n\n" async def _run_single_group( query: str, neon_sel: dict[str, str], comparison_models: list[dict[str, Any]], persona_target: str = "neon-only", ) -> dict[str, Any]: """Run one Neon model + all comparison models for a single query.""" system_prompt = neon_sel.get("system_prompt") neon_gets_persona = persona_target in ("all", "neon-only") comp_gets_persona = persona_target == "all" tasks = [] tasks.append(_query_neon( query=query, model_id=neon_sel["model_id"], persona_name=neon_sel["persona_name"], system_prompt=system_prompt if neon_gets_persona else None, has_persona=neon_gets_persona, )) for cm in comparison_models: tasks.append(_query_comparison( query=query, provider_id=cm["provider_id"], provider_name=cm["provider_name"], base_url=cm["base_url"], api_key=cm["api_key"], model_id=cm["model_id"], model_name=cm["model_name"], system_prompt=system_prompt, params=cm.get("params", ""), has_persona=comp_gets_persona, )) results = await asyncio.gather(*tasks, return_exceptions=True) responses: list[dict[str, Any]] = [] for r in results: if isinstance(r, Exception): responses.append({ "provider": "unknown", "model_id": "unknown", "model_name": "unknown", "response": f"[Error]: {r}", "elapsed_seconds": 0, "is_neon": False, "error": True, }) else: responses.append(r) return { "neon_model_id": neon_sel["model_id"], "neon_persona": neon_sel["persona_name"], "query": query, "responses": responses, } async def run_comparison( query: str, neon_selections: list[dict[str, str]], comparison_model_ids: list[str], session_id: str | None = None, persona_target: str = "neon-only", ) -> list[dict[str, Any]]: """Non-streaming version. Neon groups run in parallel; semaphores handle rate limiting.""" comparison_models = _resolve_comparison_models(comparison_model_ids) group_tasks = [ _run_single_group(query, neon_sel, comparison_models, persona_target=persona_target) for neon_sel in neon_selections ] groups = await asyncio.gather(*group_tasks) groups = list(groups) if session_id: for group in groups: resp_map = {} timing_map = {} for r in group["responses"]: key = f"{r.get('provider_name', r['provider'])} / {r['model_name']}" if r.get("is_neon"): key = f"Neon / {r['model_name']} ({r['persona_name']})" resp_map[key] = r["response"] timing_map[key] = r["elapsed_seconds"] session_store.add_result(session_id, ComparisonResult( query=query, neon_model_id=group["neon_model_id"], neon_persona=group["neon_persona"], responses=resp_map, timings=timing_map, )) return groups async def run_csv_comparison( questions: list[str], neon_selections: list[dict[str, str]], comparison_model_ids: list[str], persona_target: str = "neon-only", ) -> list[list[dict[str, Any]]]: """Run comparison for a batch of questions. Returns list of group-lists.""" all_results = [] for q in questions: groups = await run_comparison( query=q, neon_selections=neon_selections, comparison_model_ids=comparison_model_ids, persona_target=persona_target, ) all_results.append(groups) return all_results