Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| from collections import defaultdict | |
| from typing import Any, AsyncIterator | |
| from app.clients.hana_client import hana_client | |
| from app.clients.openai_compat import openai_chat_completion | |
| from app.config import settings | |
| from app.services.session_store import ComparisonResult, session_store | |
| LOG = logging.getLogger(__name__) | |
| PROVIDER_STAGGER_SECONDS = 0.1 | |
| _provider_semaphores: dict[str, asyncio.Semaphore] = defaultdict(lambda: asyncio.Semaphore(1)) | |
| async def _query_neon( | |
| query: str, | |
| model_id: str, | |
| persona_name: str, | |
| system_prompt: str | None, | |
| has_persona: bool = True, | |
| history: list[tuple[str, str]] | None = None, | |
| ) -> dict[str, Any]: | |
| try: | |
| if settings._neon_security_direct_vllm_enabled(model_id): | |
| builtin = hana_client.get_persona_system_prompt(model_id, persona_name) | |
| parts: list[str] = [] | |
| if system_prompt and has_persona: | |
| parts.append(system_prompt) | |
| if builtin: | |
| parts.append("[Neon persona base from HANA]\n" + builtin) | |
| combined = "\n\n".join(parts) if parts else "" | |
| messages: list[dict[str, str]] = [] | |
| if combined: | |
| messages.append({"role": "system", "content": combined}) | |
| messages.append({"role": "user", "content": query}) | |
| base = f"{settings.neon_security_vllm_base_url.rstrip('/')}/v1" | |
| result = await openai_chat_completion( | |
| base_url=base, | |
| api_key=settings.hana_password_klatchat, | |
| model=model_id, | |
| messages=messages, | |
| ) | |
| return { | |
| "provider": "neon", | |
| "model_id": model_id, | |
| "model_name": model_id.split("/")[-1].split("@")[0] if "/" in model_id else model_id, | |
| "persona_name": persona_name, | |
| "response": result["response"], | |
| "elapsed_seconds": result["elapsed_seconds"], | |
| "is_neon": True, | |
| "params": "24B (quantized)", | |
| "has_persona": has_persona, | |
| } | |
| result = await hana_client.get_inference( | |
| query=query, | |
| model_id=model_id, | |
| persona_name=persona_name, | |
| system_prompt=system_prompt, | |
| history=history, | |
| ) | |
| return { | |
| "provider": "neon", | |
| "model_id": model_id, | |
| "model_name": model_id.split("/")[-1].split("@")[0] if "/" in model_id else model_id, | |
| "persona_name": persona_name, | |
| "response": result["response"], | |
| "elapsed_seconds": result["elapsed_seconds"], | |
| "is_neon": True, | |
| "params": "24B (quantized)", | |
| "has_persona": has_persona, | |
| } | |
| except Exception as exc: | |
| LOG.exception("Neon inference failed for %s: %s", model_id, exc) | |
| return { | |
| "provider": "neon", | |
| "model_id": model_id, | |
| "model_name": model_id, | |
| "persona_name": persona_name, | |
| "response": f"[Error]: {exc}", | |
| "elapsed_seconds": 0, | |
| "is_neon": True, | |
| "error": True, | |
| "params": "24B (quantized)", | |
| "has_persona": has_persona, | |
| } | |
| async def _query_comparison( | |
| query: str, | |
| provider_id: str, | |
| provider_name: str, | |
| base_url: str, | |
| api_key: str, | |
| model_id: str, | |
| model_name: str, | |
| system_prompt: str | None, | |
| params: str = "", | |
| has_persona: bool = True, | |
| ) -> dict[str, Any]: | |
| sem = _provider_semaphores[provider_id] | |
| async with sem: | |
| messages: list[dict[str, str]] = [] | |
| if system_prompt and has_persona: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": query}) | |
| result = await openai_chat_completion( | |
| base_url=base_url, | |
| api_key=api_key, | |
| model=model_id, | |
| messages=messages, | |
| ) | |
| await asyncio.sleep(PROVIDER_STAGGER_SECONDS) | |
| return { | |
| "provider": provider_id, | |
| "provider_name": provider_name, | |
| "model_id": model_id, | |
| "model_name": model_name, | |
| "response": result["response"], | |
| "elapsed_seconds": result["elapsed_seconds"], | |
| "is_neon": False, | |
| "error": result.get("error", False), | |
| "params": params, | |
| "has_persona": has_persona, | |
| } | |
| def _resolve_comparison_models(selected_ids: list[str]) -> list[dict[str, Any]]: | |
| """Map selected comparison model IDs to their provider config.""" | |
| resolved = [] | |
| for provider in settings.comparison_providers: | |
| for model in provider["models"]: | |
| if model["id"] in selected_ids: | |
| resolved.append({ | |
| "provider_id": provider["id"], | |
| "provider_name": provider["name"], | |
| "base_url": model.get("base_url", provider["base_url"]), | |
| "api_key": model.get("api_key", provider["api_key"]), | |
| "model_id": model["id"], | |
| "model_name": model["name"], | |
| "params": model.get("params", ""), | |
| }) | |
| return resolved | |
| async def stream_comparison( | |
| query: str, | |
| neon_selections: list[dict[str, str]], | |
| comparison_model_ids: list[str], | |
| session_id: str | None = None, | |
| persona_target: str = "neon-only", | |
| ) -> AsyncIterator[str]: | |
| """Yield SSE events as each model response completes.""" | |
| comparison_models = _resolve_comparison_models(comparison_model_ids) | |
| for gi, neon_sel in enumerate(neon_selections): | |
| system_prompt = neon_sel.get("system_prompt") | |
| neon_gets_persona = persona_target in ("all", "neon-only") | |
| comp_gets_persona = persona_target == "all" | |
| group_meta = { | |
| "neon_model_id": neon_sel["model_id"], | |
| "neon_persona": neon_sel["persona_name"], | |
| "system_prompt": system_prompt or "", | |
| "query": query, | |
| "group_index": gi, | |
| "persona_target": persona_target, | |
| } | |
| yield f"event: group_start\ndata: {json.dumps(group_meta)}\n\n" | |
| session_responses: list[dict[str, Any]] = [] | |
| pending: set[asyncio.Task] = set() | |
| neon_task = asyncio.create_task(_query_neon( | |
| query=query, | |
| model_id=neon_sel["model_id"], | |
| persona_name=neon_sel["persona_name"], | |
| system_prompt=system_prompt if neon_gets_persona else None, | |
| has_persona=neon_gets_persona, | |
| )) | |
| neon_task.set_name("neon") | |
| pending.add(neon_task) | |
| for cm in comparison_models: | |
| t = asyncio.create_task(_query_comparison( | |
| query=query, | |
| provider_id=cm["provider_id"], | |
| provider_name=cm["provider_name"], | |
| base_url=cm["base_url"], | |
| api_key=cm["api_key"], | |
| model_id=cm["model_id"], | |
| model_name=cm["model_name"], | |
| system_prompt=system_prompt, | |
| params=cm.get("params", ""), | |
| has_persona=comp_gets_persona, | |
| )) | |
| t.set_name(cm["model_id"]) | |
| pending.add(t) | |
| while pending: | |
| done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) | |
| for task in done: | |
| try: | |
| result = task.result() | |
| except Exception as exc: | |
| result = { | |
| "provider": "unknown", | |
| "model_id": task.get_name(), | |
| "model_name": task.get_name(), | |
| "response": f"[Error]: {exc}", | |
| "elapsed_seconds": 0, | |
| "is_neon": False, | |
| "error": True, | |
| } | |
| result["group_index"] = gi | |
| session_responses.append(result) | |
| yield f"event: response\ndata: {json.dumps(result)}\n\n" | |
| if session_id: | |
| resp_map = {} | |
| timing_map = {} | |
| for r in session_responses: | |
| key = f"{r.get('provider_name', r['provider'])} / {r['model_name']}" | |
| if r.get("is_neon"): | |
| key = f"Neon / {r['model_name']} ({r.get('persona_name', '')})" | |
| resp_map[key] = r["response"] | |
| timing_map[key] = r["elapsed_seconds"] | |
| session_store.add_result(session_id, ComparisonResult( | |
| query=query, | |
| neon_model_id=neon_sel["model_id"], | |
| neon_persona=neon_sel["persona_name"], | |
| responses=resp_map, | |
| timings=timing_map, | |
| )) | |
| yield "event: done\ndata: {}\n\n" | |
| async def _run_single_group( | |
| query: str, | |
| neon_sel: dict[str, str], | |
| comparison_models: list[dict[str, Any]], | |
| persona_target: str = "neon-only", | |
| ) -> dict[str, Any]: | |
| """Run one Neon model + all comparison models for a single query.""" | |
| system_prompt = neon_sel.get("system_prompt") | |
| neon_gets_persona = persona_target in ("all", "neon-only") | |
| comp_gets_persona = persona_target == "all" | |
| tasks = [] | |
| tasks.append(_query_neon( | |
| query=query, | |
| model_id=neon_sel["model_id"], | |
| persona_name=neon_sel["persona_name"], | |
| system_prompt=system_prompt if neon_gets_persona else None, | |
| has_persona=neon_gets_persona, | |
| )) | |
| for cm in comparison_models: | |
| tasks.append(_query_comparison( | |
| query=query, | |
| provider_id=cm["provider_id"], | |
| provider_name=cm["provider_name"], | |
| base_url=cm["base_url"], | |
| api_key=cm["api_key"], | |
| model_id=cm["model_id"], | |
| model_name=cm["model_name"], | |
| system_prompt=system_prompt, | |
| params=cm.get("params", ""), | |
| has_persona=comp_gets_persona, | |
| )) | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| responses: list[dict[str, Any]] = [] | |
| for r in results: | |
| if isinstance(r, Exception): | |
| responses.append({ | |
| "provider": "unknown", | |
| "model_id": "unknown", | |
| "model_name": "unknown", | |
| "response": f"[Error]: {r}", | |
| "elapsed_seconds": 0, | |
| "is_neon": False, | |
| "error": True, | |
| }) | |
| else: | |
| responses.append(r) | |
| return { | |
| "neon_model_id": neon_sel["model_id"], | |
| "neon_persona": neon_sel["persona_name"], | |
| "query": query, | |
| "responses": responses, | |
| } | |
| async def run_comparison( | |
| query: str, | |
| neon_selections: list[dict[str, str]], | |
| comparison_model_ids: list[str], | |
| session_id: str | None = None, | |
| persona_target: str = "neon-only", | |
| ) -> list[dict[str, Any]]: | |
| """Non-streaming version. Neon groups run in parallel; semaphores handle rate limiting.""" | |
| comparison_models = _resolve_comparison_models(comparison_model_ids) | |
| group_tasks = [ | |
| _run_single_group(query, neon_sel, comparison_models, persona_target=persona_target) | |
| for neon_sel in neon_selections | |
| ] | |
| groups = await asyncio.gather(*group_tasks) | |
| groups = list(groups) | |
| if session_id: | |
| for group in groups: | |
| resp_map = {} | |
| timing_map = {} | |
| for r in group["responses"]: | |
| key = f"{r.get('provider_name', r['provider'])} / {r['model_name']}" | |
| if r.get("is_neon"): | |
| key = f"Neon / {r['model_name']} ({r['persona_name']})" | |
| resp_map[key] = r["response"] | |
| timing_map[key] = r["elapsed_seconds"] | |
| session_store.add_result(session_id, ComparisonResult( | |
| query=query, | |
| neon_model_id=group["neon_model_id"], | |
| neon_persona=group["neon_persona"], | |
| responses=resp_map, | |
| timings=timing_map, | |
| )) | |
| return groups | |
| async def run_csv_comparison( | |
| questions: list[str], | |
| neon_selections: list[dict[str, str]], | |
| comparison_model_ids: list[str], | |
| persona_target: str = "neon-only", | |
| ) -> list[list[dict[str, Any]]]: | |
| """Run comparison for a batch of questions. Returns list of group-lists.""" | |
| all_results = [] | |
| for q in questions: | |
| groups = await run_comparison( | |
| query=q, | |
| neon_selections=neon_selections, | |
| comparison_model_ids=comparison_model_ids, | |
| persona_target=persona_target, | |
| ) | |
| all_results.append(groups) | |
| return all_results | |