| |
| """ |
| pluto/modes.py — Real mode switching engine. |
| |
| Groq primary: |
| - MODE_QUICK: llama-3.1-8b-instant (fast, lightweight) |
| - MODE_REASONING: llama-3.3-70b-versatile (deep, accurate) |
| - MODE_VISION: llama-3.1-8b-instant (text/doc understanding) |
| |
| Mistral fallback (if Groq fails or no key): |
| - All modes: mistral-small-latest |
| |
| Real switching = True because MODE_QUICK uses 8b and MODE_REASONING uses 70b. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from dataclasses import dataclass |
|
|
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
|
|
| def _clean_api_key(api_key: str | None) -> str: |
| cleaned = str(api_key or "").strip().strip('"').strip("'") |
| if cleaned.lower().startswith("bearer "): |
| cleaned = cleaned[7:].strip() |
| return cleaned |
|
|
|
|
| def _looks_like_nvidia_key(api_key: str) -> bool: |
| return _clean_api_key(api_key).startswith("nvapi-") |
|
|
|
|
| @dataclass(frozen=True) |
| class ModeConfig: |
| """Concrete model configuration for a single processing mode.""" |
|
|
| mode_name: str |
| model_id: str |
| temperature: float |
| max_tokens: int |
| compute_profile: str |
| provider: str |
|
|
| def to_log_dict(self) -> dict: |
| return { |
| "mode_name": self.mode_name, |
| "model_id": self.model_id, |
| "temperature": self.temperature, |
| "max_tokens": self.max_tokens, |
| "compute_profile": self.compute_profile, |
| "provider": self.provider, |
| } |
|
|
|
|
| def _build_registry() -> dict[str, ModeConfig]: |
| """ |
| NVIDIA NIM model stack — Pluto v2. |
| |
| Roles: |
| MODE_QUICK → Nemotron Nano 8B (high-volume: extraction workers, critic, judge) |
| MODE_REASONING → Nemotron Super 49B (synthesis, strategist audit, debate responder) |
| MODE_VISION → Nemotron Nano VL (doc parsing: tables, figures, scanned PDFs) |
| MODE_ULTRA → Nemotron Ultra 253B (escalation only: confidence < 0.6) |
| |
| Embedding + reranking are handled separately in embedder.py and dispatcher.py |
| (they use /v1/embeddings and scoring endpoints, not chat completions). |
| |
| Fallback: if NVIDIA_API_KEY absent, fall back to Groq or Mistral. |
| """ |
| |
| nvidia_keys = [ |
| "NVIDIA_API_KEY", "NVIDIA_API_KEY_NANO", "NVIDIA_API_KEY_SUPER", |
| "NVIDIA_API_KEY_VL", "NVIDIA_API_KEY_EMBED", "NVIDIA_API_KEY_RERANK", |
| "NVIDIA_API_KEY_ULTRA" |
| ] |
| nvidia_ready = any(_clean_api_key(os.getenv(k)) for k in nvidia_keys) |
| groq_key = _clean_api_key(os.getenv("GROQ_API_KEY", "")) |
| mistral_key = _clean_api_key(os.getenv("MISTRAL_API_KEY", "")) |
|
|
| if nvidia_ready: |
| return { |
| "MODE_QUICK": ModeConfig( |
| mode_name="MODE_QUICK", |
| model_id="meta/llama-3.2-3b-instruct", |
| temperature=0.1, |
| max_tokens=1024, |
| compute_profile="low-latency", |
| provider="nvidia", |
| ), |
| "MODE_REASONING": ModeConfig( |
| mode_name="MODE_REASONING", |
| model_id="nvidia/nemotron-3-nano-omni-30b-a3b-reasoning", |
| temperature=0.3, |
| max_tokens=4096, |
| compute_profile="high-reasoning", |
| provider="nvidia", |
| ), |
| "MODE_VISION": ModeConfig( |
| mode_name="MODE_VISION", |
| model_id="nvidia/llama-3.1-nemotron-nano-vl-8b-v1", |
| temperature=0.1, |
| max_tokens=4096, |
| compute_profile="vision-capable", |
| provider="nvidia", |
| ), |
| "MODE_ULTRA": ModeConfig( |
| mode_name="MODE_ULTRA", |
| model_id="nvidia/llama-3.1-nemotron-ultra-253b-v1", |
| temperature=0.2, |
| max_tokens=4096, |
| compute_profile="deep-reasoning", |
| provider="nvidia", |
| ), |
| |
| "MODE_GEMINI": ModeConfig( |
| mode_name="MODE_GEMINI", |
| model_id="nvidia/nemotron-3-nano-omni-30b-a3b-reasoning", |
| temperature=0.0, |
| max_tokens=4096, |
| compute_profile="high-throughput", |
| provider="nvidia", |
| ), |
| } |
| elif groq_key: |
| |
| return { |
| "MODE_QUICK": ModeConfig( |
| mode_name="MODE_QUICK", |
| model_id="llama-3.1-8b-instant", |
| temperature=0.1, |
| max_tokens=1024, |
| compute_profile="low-latency", |
| provider="groq", |
| ), |
| "MODE_REASONING": ModeConfig( |
| mode_name="MODE_REASONING", |
| model_id="llama-3.3-70b-versatile", |
| temperature=0.3, |
| max_tokens=4096, |
| compute_profile="high-reasoning", |
| provider="groq", |
| ), |
| "MODE_VISION": ModeConfig( |
| mode_name="MODE_VISION", |
| model_id="llama-3.1-8b-instant", |
| temperature=0.1, |
| max_tokens=4096, |
| compute_profile="vision-capable", |
| provider="groq", |
| ), |
| "MODE_ULTRA": ModeConfig( |
| mode_name="MODE_ULTRA", |
| model_id="llama-3.3-70b-versatile", |
| temperature=0.2, |
| max_tokens=4096, |
| compute_profile="deep-reasoning", |
| provider="groq", |
| ), |
| "MODE_GEMINI": ModeConfig( |
| mode_name="MODE_GEMINI", |
| model_id="llama-3.3-70b-versatile", |
| temperature=0.0, |
| max_tokens=4096, |
| compute_profile="high-throughput", |
| provider="groq", |
| ), |
| } |
| if mistral_key and not _looks_like_nvidia_key(mistral_key): |
| return _build_mistral_registry() |
| return _build_unconfigured_registry() |
|
|
|
|
| def _build_mistral_registry() -> dict[str, ModeConfig]: |
| """Use Mistral for every mode when it is the only configured chat provider.""" |
| return { |
| "MODE_QUICK": ModeConfig( |
| mode_name="MODE_QUICK", |
| model_id="mistral-small-latest", |
| temperature=0.1, |
| max_tokens=1024, |
| compute_profile="fallback", |
| provider="mistral", |
| ), |
| "MODE_REASONING": ModeConfig( |
| mode_name="MODE_REASONING", |
| model_id="mistral-small-latest", |
| temperature=0.3, |
| max_tokens=4096, |
| compute_profile="fallback", |
| provider="mistral", |
| ), |
| "MODE_VISION": ModeConfig( |
| mode_name="MODE_VISION", |
| model_id="mistral-small-latest", |
| temperature=0.1, |
| max_tokens=4096, |
| compute_profile="fallback", |
| provider="mistral", |
| ), |
| "MODE_ULTRA": ModeConfig( |
| mode_name="MODE_ULTRA", |
| model_id="mistral-small-latest", |
| temperature=0.2, |
| max_tokens=4096, |
| compute_profile="fallback", |
| provider="mistral", |
| ), |
| "MODE_GEMINI": ModeConfig( |
| mode_name="MODE_GEMINI", |
| model_id="mistral-small-latest", |
| temperature=0.0, |
| max_tokens=4096, |
| compute_profile="fallback", |
| provider="mistral", |
| ), |
| } |
|
|
|
|
| def _build_unconfigured_registry() -> dict[str, ModeConfig]: |
| """Return placeholder modes so imports work without provider credentials.""" |
| return { |
| "MODE_QUICK": ModeConfig( |
| mode_name="MODE_QUICK", |
| model_id="unconfigured/MODE_QUICK", |
| temperature=0.1, |
| max_tokens=1024, |
| compute_profile="unconfigured", |
| provider="unconfigured", |
| ), |
| "MODE_REASONING": ModeConfig( |
| mode_name="MODE_REASONING", |
| model_id="unconfigured/MODE_REASONING", |
| temperature=0.3, |
| max_tokens=4096, |
| compute_profile="unconfigured", |
| provider="unconfigured", |
| ), |
| "MODE_VISION": ModeConfig( |
| mode_name="MODE_VISION", |
| model_id="unconfigured/MODE_VISION", |
| temperature=0.1, |
| max_tokens=4096, |
| compute_profile="unconfigured", |
| provider="unconfigured", |
| ), |
| "MODE_ULTRA": ModeConfig( |
| mode_name="MODE_ULTRA", |
| model_id="unconfigured/MODE_ULTRA", |
| temperature=0.2, |
| max_tokens=4096, |
| compute_profile="unconfigured", |
| provider="unconfigured", |
| ), |
| "MODE_GEMINI": ModeConfig( |
| mode_name="MODE_GEMINI", |
| model_id="unconfigured/MODE_GEMINI", |
| temperature=0.0, |
| max_tokens=4096, |
| compute_profile="unconfigured", |
| provider="unconfigured", |
| ), |
| } |
|
|
|
|
| MODE_REGISTRY: dict[str, ModeConfig] = _build_registry() |
|
|
|
|
| def _missing_provider_error() -> EnvironmentError: |
| return EnvironmentError("None of NVIDIA_API_KEY, GROQ_API_KEY, or MISTRAL_API_KEY is set.") |
|
|
|
|
| def _is_unconfigured() -> bool: |
| return any(mode.provider == "unconfigured" for mode in MODE_REGISTRY.values()) |
|
|
|
|
| def _refresh_mode_registry() -> None: |
| """Refresh mode config in place so imported MODE_REGISTRY references stay valid.""" |
| MODE_REGISTRY.clear() |
| MODE_REGISTRY.update(_build_registry()) |
|
|
|
|
| def is_real_switching() -> bool: |
| """True if MODE_QUICK and MODE_REASONING use DIFFERENT model_ids.""" |
| if _is_unconfigured(): |
| _refresh_mode_registry() |
| if _is_unconfigured(): |
| return False |
| quick = MODE_REGISTRY["MODE_QUICK"].model_id |
| reasoning = MODE_REGISTRY["MODE_REASONING"].model_id |
| return quick != reasoning |
|
|
|
|
| def get_mode(mode_name: str) -> ModeConfig: |
| """Look up a mode config by name.""" |
| if mode_name not in MODE_REGISTRY: |
| raise ValueError(f"Unknown mode: {mode_name}. Valid: {list(MODE_REGISTRY)}") |
| mode = MODE_REGISTRY[mode_name] |
| if mode.provider == "unconfigured": |
| _refresh_mode_registry() |
| mode = MODE_REGISTRY.get(mode_name) |
| if mode is None: |
| raise ValueError(f"Unknown mode: {mode_name}. Valid: {list(MODE_REGISTRY)}") |
| if mode.provider == "unconfigured": |
| raise _missing_provider_error() |
| return mode |
|
|