ayushKishor's picture
Harden provider response handling
b5e99b3
# -*- coding: utf-8 -*-
"""
pluto/modes.py — Real mode switching engine.
Groq primary:
- MODE_QUICK: llama-3.1-8b-instant (fast, lightweight)
- MODE_REASONING: llama-3.3-70b-versatile (deep, accurate)
- MODE_VISION: llama-3.1-8b-instant (text/doc understanding)
Mistral fallback (if Groq fails or no key):
- All modes: mistral-small-latest
Real switching = True because MODE_QUICK uses 8b and MODE_REASONING uses 70b.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from dotenv import load_dotenv
load_dotenv()
def _clean_api_key(api_key: str | None) -> str:
cleaned = str(api_key or "").strip().strip('"').strip("'")
if cleaned.lower().startswith("bearer "):
cleaned = cleaned[7:].strip()
return cleaned
def _looks_like_nvidia_key(api_key: str) -> bool:
return _clean_api_key(api_key).startswith("nvapi-")
@dataclass(frozen=True)
class ModeConfig:
"""Concrete model configuration for a single processing mode."""
mode_name: str
model_id: str
temperature: float
max_tokens: int
compute_profile: str
provider: str # "nvidia" | "groq" | "mistral"
def to_log_dict(self) -> dict:
return {
"mode_name": self.mode_name,
"model_id": self.model_id,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"compute_profile": self.compute_profile,
"provider": self.provider,
}
def _build_registry() -> dict[str, ModeConfig]:
"""
NVIDIA NIM model stack — Pluto v2.
Roles:
MODE_QUICK → Nemotron Nano 8B (high-volume: extraction workers, critic, judge)
MODE_REASONING → Nemotron Super 49B (synthesis, strategist audit, debate responder)
MODE_VISION → Nemotron Nano VL (doc parsing: tables, figures, scanned PDFs)
MODE_ULTRA → Nemotron Ultra 253B (escalation only: confidence < 0.6)
Embedding + reranking are handled separately in embedder.py and dispatcher.py
(they use /v1/embeddings and scoring endpoints, not chat completions).
Fallback: if NVIDIA_API_KEY absent, fall back to Groq or Mistral.
"""
# Check for any NVIDIA key
nvidia_keys = [
"NVIDIA_API_KEY", "NVIDIA_API_KEY_NANO", "NVIDIA_API_KEY_SUPER",
"NVIDIA_API_KEY_VL", "NVIDIA_API_KEY_EMBED", "NVIDIA_API_KEY_RERANK",
"NVIDIA_API_KEY_ULTRA"
]
nvidia_ready = any(_clean_api_key(os.getenv(k)) for k in nvidia_keys)
groq_key = _clean_api_key(os.getenv("GROQ_API_KEY", ""))
mistral_key = _clean_api_key(os.getenv("MISTRAL_API_KEY", ""))
if nvidia_ready:
return {
"MODE_QUICK": ModeConfig(
mode_name="MODE_QUICK",
model_id="meta/llama-3.2-3b-instruct",
temperature=0.1,
max_tokens=1024,
compute_profile="low-latency",
provider="nvidia",
),
"MODE_REASONING": ModeConfig(
mode_name="MODE_REASONING",
model_id="nvidia/nemotron-3-nano-omni-30b-a3b-reasoning",
temperature=0.3,
max_tokens=4096,
compute_profile="high-reasoning",
provider="nvidia",
),
"MODE_VISION": ModeConfig(
mode_name="MODE_VISION",
model_id="nvidia/llama-3.1-nemotron-nano-vl-8b-v1",
temperature=0.1,
max_tokens=4096,
compute_profile="vision-capable",
provider="nvidia",
),
"MODE_ULTRA": ModeConfig(
mode_name="MODE_ULTRA",
model_id="nvidia/llama-3.1-nemotron-ultra-253b-v1",
temperature=0.2,
max_tokens=4096,
compute_profile="deep-reasoning",
provider="nvidia",
),
# Keep MODE_GEMINI name for backward compat — maps to Super
"MODE_GEMINI": ModeConfig(
mode_name="MODE_GEMINI",
model_id="nvidia/nemotron-3-nano-omni-30b-a3b-reasoning",
temperature=0.0,
max_tokens=4096,
compute_profile="high-throughput",
provider="nvidia",
),
}
elif groq_key:
# Groq fallback — same size tiers
return {
"MODE_QUICK": ModeConfig(
mode_name="MODE_QUICK",
model_id="llama-3.1-8b-instant",
temperature=0.1,
max_tokens=1024,
compute_profile="low-latency",
provider="groq",
),
"MODE_REASONING": ModeConfig(
mode_name="MODE_REASONING",
model_id="llama-3.3-70b-versatile",
temperature=0.3,
max_tokens=4096,
compute_profile="high-reasoning",
provider="groq",
),
"MODE_VISION": ModeConfig(
mode_name="MODE_VISION",
model_id="llama-3.1-8b-instant",
temperature=0.1,
max_tokens=4096,
compute_profile="vision-capable",
provider="groq",
),
"MODE_ULTRA": ModeConfig(
mode_name="MODE_ULTRA",
model_id="llama-3.3-70b-versatile",
temperature=0.2,
max_tokens=4096,
compute_profile="deep-reasoning",
provider="groq",
),
"MODE_GEMINI": ModeConfig(
mode_name="MODE_GEMINI",
model_id="llama-3.3-70b-versatile",
temperature=0.0,
max_tokens=4096,
compute_profile="high-throughput",
provider="groq",
),
}
if mistral_key and not _looks_like_nvidia_key(mistral_key):
return _build_mistral_registry()
return _build_unconfigured_registry()
def _build_mistral_registry() -> dict[str, ModeConfig]:
"""Use Mistral for every mode when it is the only configured chat provider."""
return {
"MODE_QUICK": ModeConfig(
mode_name="MODE_QUICK",
model_id="mistral-small-latest",
temperature=0.1,
max_tokens=1024,
compute_profile="fallback",
provider="mistral",
),
"MODE_REASONING": ModeConfig(
mode_name="MODE_REASONING",
model_id="mistral-small-latest",
temperature=0.3,
max_tokens=4096,
compute_profile="fallback",
provider="mistral",
),
"MODE_VISION": ModeConfig(
mode_name="MODE_VISION",
model_id="mistral-small-latest",
temperature=0.1,
max_tokens=4096,
compute_profile="fallback",
provider="mistral",
),
"MODE_ULTRA": ModeConfig(
mode_name="MODE_ULTRA",
model_id="mistral-small-latest",
temperature=0.2,
max_tokens=4096,
compute_profile="fallback",
provider="mistral",
),
"MODE_GEMINI": ModeConfig(
mode_name="MODE_GEMINI",
model_id="mistral-small-latest",
temperature=0.0,
max_tokens=4096,
compute_profile="fallback",
provider="mistral",
),
}
def _build_unconfigured_registry() -> dict[str, ModeConfig]:
"""Return placeholder modes so imports work without provider credentials."""
return {
"MODE_QUICK": ModeConfig(
mode_name="MODE_QUICK",
model_id="unconfigured/MODE_QUICK",
temperature=0.1,
max_tokens=1024,
compute_profile="unconfigured",
provider="unconfigured",
),
"MODE_REASONING": ModeConfig(
mode_name="MODE_REASONING",
model_id="unconfigured/MODE_REASONING",
temperature=0.3,
max_tokens=4096,
compute_profile="unconfigured",
provider="unconfigured",
),
"MODE_VISION": ModeConfig(
mode_name="MODE_VISION",
model_id="unconfigured/MODE_VISION",
temperature=0.1,
max_tokens=4096,
compute_profile="unconfigured",
provider="unconfigured",
),
"MODE_ULTRA": ModeConfig(
mode_name="MODE_ULTRA",
model_id="unconfigured/MODE_ULTRA",
temperature=0.2,
max_tokens=4096,
compute_profile="unconfigured",
provider="unconfigured",
),
"MODE_GEMINI": ModeConfig(
mode_name="MODE_GEMINI",
model_id="unconfigured/MODE_GEMINI",
temperature=0.0,
max_tokens=4096,
compute_profile="unconfigured",
provider="unconfigured",
),
}
MODE_REGISTRY: dict[str, ModeConfig] = _build_registry()
def _missing_provider_error() -> EnvironmentError:
return EnvironmentError("None of NVIDIA_API_KEY, GROQ_API_KEY, or MISTRAL_API_KEY is set.")
def _is_unconfigured() -> bool:
return any(mode.provider == "unconfigured" for mode in MODE_REGISTRY.values())
def _refresh_mode_registry() -> None:
"""Refresh mode config in place so imported MODE_REGISTRY references stay valid."""
MODE_REGISTRY.clear()
MODE_REGISTRY.update(_build_registry())
def is_real_switching() -> bool:
"""True if MODE_QUICK and MODE_REASONING use DIFFERENT model_ids."""
if _is_unconfigured():
_refresh_mode_registry()
if _is_unconfigured():
return False
quick = MODE_REGISTRY["MODE_QUICK"].model_id
reasoning = MODE_REGISTRY["MODE_REASONING"].model_id
return quick != reasoning
def get_mode(mode_name: str) -> ModeConfig:
"""Look up a mode config by name."""
if mode_name not in MODE_REGISTRY:
raise ValueError(f"Unknown mode: {mode_name}. Valid: {list(MODE_REGISTRY)}")
mode = MODE_REGISTRY[mode_name]
if mode.provider == "unconfigured":
_refresh_mode_registry()
mode = MODE_REGISTRY.get(mode_name)
if mode is None:
raise ValueError(f"Unknown mode: {mode_name}. Valid: {list(MODE_REGISTRY)}")
if mode.provider == "unconfigured":
raise _missing_provider_error()
return mode