| """
|
| model_context.py
|
|
|
| Query and cache model context window sizes from OpenAI-compatible APIs.
|
| Provides token estimation for context usage tracking.
|
| """
|
|
|
| import ipaddress
|
| import logging
|
| import sys
|
| from typing import Dict, List, Optional, Tuple
|
|
|
| from urllib.parse import urlparse
|
|
|
| import httpx
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| _LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.internal"}
|
| _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
|
| "172.20.", "172.21.", "172.22.", "172.23.", "172.24.",
|
| "172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
|
| "172.30.", "172.31.", "192.168.")
|
|
|
|
|
|
|
|
|
|
|
| _TAILSCALE_CGNAT = ipaddress.ip_network("100.64.0.0/10")
|
|
|
|
|
| def _in_tailscale_range(host: str) -> bool:
|
| try:
|
| return ipaddress.ip_address(host) in _TAILSCALE_CGNAT
|
| except ValueError:
|
| return False
|
|
|
|
|
| def _normalize_base_for_compare(url: str) -> str:
|
| url = (url or "").strip().rstrip("/")
|
| for suffix in ("/chat/completions", "/models", "/completions", "/v1/messages"):
|
| if url.endswith(suffix):
|
| url = url[: -len(suffix)].rstrip("/")
|
| return url
|
|
|
|
|
| def _configured_endpoint_kind(url: str) -> Optional[str]:
|
| """Return configured endpoint kind for a chat/base URL when available."""
|
| target = _normalize_base_for_compare(url)
|
| if not target:
|
| return None
|
| if "core.database" not in sys.modules:
|
| return None
|
| try:
|
| from core.database import SessionLocal, ModelEndpoint
|
| db = SessionLocal()
|
| try:
|
| rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
| for ep in rows:
|
| base = _normalize_base_for_compare(getattr(ep, "base_url", "") or "")
|
| if not base:
|
| continue
|
| if target != base and not target.startswith(base + "/"):
|
| continue
|
| kind = (getattr(ep, "endpoint_kind", None) or "auto").strip().lower()
|
| if kind in ("local", "api", "proxy"):
|
| return kind
|
| if getattr(ep, "api_key", None):
|
| parsed = urlparse(base)
|
| host = (parsed.hostname or "").lower()
|
| path = (parsed.path or "").rstrip("/")
|
| if parsed.port != 11434 and "ollama" not in host and (path.endswith("/v1") or "/openai" in path):
|
| return "proxy"
|
| return "auto"
|
| finally:
|
| db.close()
|
| except Exception:
|
| return None
|
|
|
|
|
| def is_local_endpoint(url: str) -> bool:
|
| """Check if URL points to a local/private/tailscale address."""
|
| kind = _configured_endpoint_kind(url)
|
| if kind in ("api", "proxy"):
|
| return False
|
| if kind == "local":
|
| return True
|
| try:
|
| host = urlparse(url).hostname or ""
|
| return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) or _in_tailscale_range(host)
|
| except Exception:
|
| return False
|
|
|
|
|
|
|
|
|
| DEFAULT_CONTEXT = 128000
|
| REQUEST_TIMEOUT = 5
|
|
|
|
|
|
|
|
|
| KNOWN_CONTEXT_WINDOWS = {
|
|
|
| 'claude-sonnet-4-5': 200000,
|
| 'claude-sonnet-4-6': 200000,
|
| 'claude-sonnet-4': 200000,
|
| 'claude-opus-4': 200000,
|
| 'claude-haiku-4': 200000,
|
| 'claude-haiku-3-5': 200000,
|
| 'claude-3-5-sonnet': 200000,
|
| 'claude-3-5-haiku': 200000,
|
| 'claude-3-opus': 200000,
|
| 'claude-3-sonnet': 200000,
|
| 'claude-3-haiku': 200000,
|
|
|
|
|
| 'gpt-5': 400000,
|
| 'gpt-4.1': 1047576,
|
| 'gpt-4.1-mini': 1047576,
|
| 'gpt-4.1-nano': 1047576,
|
| 'gpt-4o': 128000,
|
| 'gpt-4o-mini': 128000,
|
| 'gpt-4-turbo': 128000,
|
| 'gpt-4': 8192,
|
| 'gpt-3.5-turbo': 16385,
|
| 'o1': 200000,
|
| 'o1-mini': 128000,
|
| 'o1-pro': 200000,
|
| 'o3': 200000,
|
| 'o3-mini': 200000,
|
| 'o4-mini': 200000,
|
|
|
|
|
| 'deepseek-chat': 64000,
|
| 'deepseek-coder': 64000,
|
| 'deepseek-reasoner': 64000,
|
| 'deepseek-r1': 64000,
|
| 'deepseek-v3': 64000,
|
| 'deepseek-v2': 64000,
|
|
|
|
|
| 'gemini-2.5-pro': 1048576,
|
| 'gemini-2.5-flash': 1048576,
|
| 'gemini-2.0-flash': 1048576,
|
| 'gemini-1.5-pro': 1048576,
|
| 'gemini-1.5-flash': 1048576,
|
| 'gemma-4': 262144,
|
| 'gemma-3': 128000,
|
| 'gemma-2': 8192,
|
|
|
|
|
| 'mistral-large': 128000,
|
| 'mistral-medium': 32000,
|
| 'mistral-small': 32000,
|
| 'mistral-nemo': 128000,
|
| 'mistral-7b': 32000,
|
| 'mixtral': 32000,
|
| 'codestral': 32000,
|
| 'pixtral': 128000,
|
|
|
|
|
| 'grok-4': 131072,
|
| 'grok-3': 131072,
|
| 'grok-2': 131072,
|
|
|
|
|
| 'llama-4': 1048576,
|
| 'llama-3.3': 131072,
|
| 'llama-3.2': 131072,
|
| 'llama-3.1': 131072,
|
| 'llama-3': 131072,
|
|
|
|
|
| 'qwen3': 131072,
|
| 'qwen2.5': 131072,
|
| 'qwen2': 32768,
|
| 'qwq': 32768,
|
|
|
|
|
| 'command-r-plus': 128000,
|
| 'command-r': 128000,
|
| 'command-a': 256000,
|
|
|
|
|
| 'sonar-pro': 200000,
|
| 'sonar': 128000,
|
|
|
|
|
| 'minimax': 1000000,
|
|
|
|
|
| 'moonshot': 128000,
|
| 'kimi': 128000,
|
|
|
|
|
| 'phi-4': 16000,
|
| 'phi-3': 128000,
|
|
|
|
|
| 'nemotron': 131072,
|
|
|
|
|
| 'yi-large': 32768,
|
| 'yi-1.5': 16384,
|
|
|
|
|
| 'yi-lightning': 16384,
|
|
|
|
|
| 'hermes': 131072,
|
| 'nous-hermes': 131072,
|
|
|
|
|
| 'dolphin': 32768,
|
| 'mythomax': 4096,
|
| 'wizard': 32768,
|
| 'openchat': 8192,
|
| 'solar': 32768,
|
| }
|
|
|
|
|
|
|
|
|
| _context_cache: Dict[Tuple[str, str], int] = {}
|
|
|
|
|
| def get_context_length(endpoint_url: str, model: str) -> int:
|
| """Get the context window size for a model.
|
|
|
| Queries /v1/models on the endpoint and looks for context_length
|
| or context_window fields. Caches result per (endpoint, model).
|
| Falls back to DEFAULT_CONTEXT if unavailable.
|
| """
|
| configured_kind = _configured_endpoint_kind(endpoint_url)
|
| is_local = is_local_endpoint(endpoint_url)
|
|
|
|
|
|
|
|
|
| cache_key = (endpoint_url, model)
|
| if not is_local and cache_key in _context_cache:
|
| return _context_cache[cache_key]
|
|
|
| ctx = _query_context_length(endpoint_url, model)
|
|
|
|
|
|
|
| if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
|
| _context_cache[cache_key] = ctx
|
| logger.info(f"Context length for {model}: {ctx}")
|
| return ctx
|
|
|
|
|
| def _lookup_known(model: str) -> Optional[int]:
|
| """Check known context windows by substring match.
|
|
|
| Picks the LONGEST matching key so a short key never shadows a more specific
|
| one. Without this, 'o1' (200k) precedes 'o1-mini' (128k) in the table and a
|
| first-match return would report o1-mini's window as 200k.
|
| """
|
| name = model.lower()
|
| basename = name.split("/")[-1] if "/" in name else name
|
| basename = basename.split(":")[0]
|
| best_key: Optional[str] = None
|
| best_ctx: Optional[int] = None
|
| for key, ctx in KNOWN_CONTEXT_WINDOWS.items():
|
| if key in basename or key in name:
|
| if best_key is None or len(key) > len(best_key):
|
| best_key, best_ctx = key, ctx
|
| return best_ctx
|
|
|
|
|
| def _query_context_length(endpoint_url: str, model: str) -> int:
|
| """Query the model API for context length."""
|
| known = _lookup_known(model)
|
| api_ctx = None
|
| configured_kind = _configured_endpoint_kind(endpoint_url)
|
|
|
|
|
|
|
|
|
| if configured_kind in ("api", "proxy"):
|
| if known:
|
| logger.info(f"Using known context window for {model}: {known}")
|
| return known
|
| return DEFAULT_CONTEXT
|
|
|
|
|
| if is_local_endpoint(endpoint_url):
|
| try:
|
| base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0]
|
| r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT)
|
| if r.is_success:
|
| slots = r.json()
|
| if isinstance(slots, list) and slots:
|
| n_ctx = slots[0].get("n_ctx")
|
| if n_ctx and isinstance(n_ctx, int) and n_ctx > 0:
|
| logger.info(f"llama.cpp /slots reports n_ctx={n_ctx} for {model}")
|
| return n_ctx
|
| except Exception:
|
| pass
|
|
|
|
|
|
|
|
|
|
|
| from src.copilot import is_copilot_base
|
| if is_copilot_base(endpoint_url):
|
| if known:
|
| logger.info(f"Using known context window for {model}: {known}")
|
| return known or DEFAULT_CONTEXT
|
|
|
| from src.endpoint_resolver import build_models_url
|
|
|
| models_url = build_models_url(endpoint_url)
|
| try:
|
| r = httpx.get(models_url, timeout=REQUEST_TIMEOUT)
|
| if r.is_success:
|
| data = r.json()
|
| models_list = data.get("data") or []
|
|
|
| for m in models_list:
|
| mid = m.get("id", "")
|
| if mid == model or mid.split("/")[-1] == model.split("/")[-1]:
|
| for field in (
|
| "context_length",
|
| "context_window",
|
| "max_model_len",
|
| "max_context_length",
|
| "max_seq_len",
|
| ):
|
| val = m.get(field)
|
| if val and isinstance(val, (int, float)) and val > 0:
|
| api_ctx = int(val)
|
| break
|
|
|
| if not api_ctx:
|
| meta = m.get("meta") or m.get("model_extra") or {}
|
| if isinstance(meta, dict):
|
|
|
| for field in ("n_ctx", "context_length", "context_window", "max_model_len"):
|
| val = meta.get(field)
|
| if val and isinstance(val, (int, float)) and val > 0:
|
| api_ctx = int(val)
|
| break
|
| break
|
| except Exception as e:
|
| logger.debug(f"Failed to query context length for {model}: {e}")
|
|
|
|
|
|
|
| if api_ctx and known:
|
| _is_local = is_local_endpoint(endpoint_url)
|
| if _is_local and api_ctx < known:
|
| logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value")
|
| return api_ctx
|
| result = max(api_ctx, known)
|
| if api_ctx < known:
|
| logger.info(f"API reported {api_ctx} for {model}, using known {known} instead")
|
| return result
|
| if api_ctx:
|
| return api_ctx
|
| if known:
|
| logger.info(f"Using known context window for {model}: {known}")
|
| return known
|
|
|
| return DEFAULT_CONTEXT
|
|
|
|
|
| def estimate_tokens(messages: List[Dict]) -> int:
|
| """Rough token estimate for a list of messages.
|
|
|
| Uses chars * 0.3 which is closer to real BPE tokenizer output
|
| than the commonly-cited chars/4 (which underestimates by ~20-30%).
|
| Also adds ~4 tokens per message for role/formatting overhead, and counts
|
| assistant tool_calls (name + arguments) — a tool-only turn carries
|
| content=None with the real payload in tool_calls, so ignoring them made the
|
| estimate (and the compaction/trim gates that rely on it) blind to large
|
| tool arguments.
|
| """
|
| total = 0
|
| for msg in messages:
|
| total += 4
|
| content = msg.get("content", "")
|
| if isinstance(content, str):
|
| total += int(len(content) * 0.3)
|
| elif isinstance(content, list):
|
| for item in content:
|
| if isinstance(item, dict) and item.get("type") == "text":
|
| total += int(len(item.get("text", "")) * 0.3)
|
|
|
|
|
|
|
|
|
| tool_calls = msg.get("tool_calls")
|
| if isinstance(tool_calls, list):
|
| for tc in tool_calls:
|
| if not isinstance(tc, dict):
|
| continue
|
| fn = tc.get("function") if isinstance(tc.get("function"), dict) else tc
|
| name = fn.get("name", "") or ""
|
| args = fn.get("arguments", "") or ""
|
| if not isinstance(args, str):
|
| args = str(args)
|
| total += 4
|
| total += int((len(str(name)) + len(args)) * 0.3)
|
| return total
|
|
|