""" LangGraph — Intelligent Model Router for STEM Copilot. Routes queries to the best free OpenRouter model based on intent: - Math / derivation / proofs → top reasoning models (rotated) - Physics / chemistry concepts → strong general models (rotated) - Casual greetings ("hi", "ok") → lightweight models - Image understanding → vision-capable models (verified) Model pools are fetched from the OpenRouter API, cached for 10 min, and rotated per-category to avoid per-model rate limits. On 429 errors, the failed model is skipped and the next model in the pool is tried automatically (up to 2 retries). """ from langgraph.graph import StateGraph, START, END from langgraph.graph.message import add_messages from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage from langchain_core.runnables import RunnableConfig from langchain_openrouter import ChatOpenRouter # type:ignore from typing import TypedDict, Annotated import sqlite3 import time import json import base64 import urllib.request import threading from langgraph.checkpoint.sqlite import SqliteSaver # type:ignore from config import OPENROUTER_API_KEY, DB_PATH, LLM_TIMEOUT import prompts # ── LangSmith document attachment (surfaces the raw uploaded file) ── # Optional: only active if the installed langsmith SDK supports attachments. try: from langsmith import traceable try: from langsmith import Attachment # type:ignore except ImportError: from langsmith.schemas import Attachment # type:ignore _LANGSMITH_ATTACH = True except Exception: traceable = None # type:ignore Attachment = None # type:ignore _LANGSMITH_ATTACH = False _MIME_BY_EXT = { "pdf": "application/pdf", "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "doc": "application/msword", "txt": "text/plain", "md": "text/markdown", "csv": "text/csv", "json": "application/json", } def _guess_mime(filename: str) -> str: ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" return _MIME_BY_EXT.get(ext, "application/octet-stream") if _LANGSMITH_ATTACH: @traceable(name="uploaded_document") def _trace_document(filename: str, document) -> str: # noqa: ANN001 """Nested run that carries the raw user-uploaded file as a LangSmith attachment.""" return f"Attached document: {filename}" def _attach_document(doc_name: str, doc_bytes_b64: str) -> None: """Surface the original uploaded document in the active LangSmith trace. Best-effort: any failure (old SDK, no API key, decode error) is swallowed so tracing can never break the chat response. """ if not (_LANGSMITH_ATTACH and doc_bytes_b64): return try: raw = base64.b64decode(doc_bytes_b64) name = doc_name or "uploaded_document" _trace_document(name, Attachment(mime_type=_guess_mime(name), data=raw)) except Exception as exc: print(f"[LANGSMITH ATTACH] skipped: {exc}", flush=True) # ── Checkpointer ────────────────────────────────────────────── _conn = sqlite3.connect(DB_PATH, check_same_thread=False) checkpointer = SqliteSaver(conn=_conn) # ── Constants ───────────────────────────────────────────────── DEFAULT_MODEL = "openai/gpt-oss-120b:free" # pinned default for text DEFAULT_VISION_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free" # pinned default for vision MODELS_TTL = 10 * 60 # seconds def _is_guardrail(model_id: str) -> bool: """Auto-filter safety classifiers / guardrail models.""" base = model_id.lower().split(":")[0] return any(kw in base for kw in ("guard", "shield", "safety", "moderation")) # ── Query Classification ───────────────────────────────────── _CASUAL_EXACT = frozenset([ "hi", "hii", "hiii", "hello", "hey", "yo", "sup", "hola", "thanks", "thank you", "thankyou", "thx", "ty", "ok", "okay", "k", "kk", "fine", "alright", "bye", "goodbye", "see you", "later", "good morning", "good night", "good evening", "gm", "gn", "welcome", "cool", "nice", "great", "awesome", "wow", "got it", "understood", "sure", "yes", "no", "yeah", "nah", "yep", "nope", "hm", "hmm", "oh", "ah", "lol", "haha", "hehe", "xd", "what", "nothing", "nvm", "nevermind", "nm", ]) _REASONING_KW = [ "derive", "derivation", "prove", "proof", "solve", "equation", "integral", "integrate", "differentiate", "differentiation", "formula", "calculate", "calculation", "compute", "theorem", "lemma", "corollary", "limit", "matrix", "determinant", "vector", "eigen", "trigonometry", "trigonometric", "quadratic", "polynomial", "logarithm", "calculus", "algebra", "geometry", "coordinate", "probability", "permutation", "combination", "step by step", "show steps", "show working", "work out", "simplify", "factorize", "factorise", "expand", "numerator", "denominator", "fraction", "dx", "dy", "dz", " lim ", ] _SCIENCE_KW = [ "physics", "chemistry", "biology", "molecule", "atom", "electron", "proton", "neutron", "ion", "isotope", "nucleus", "force", "energy", "momentum", "velocity", "acceleration", "wave", "wavelength", "frequency", "amplitude", "light", "optics", "lens", "mirror", "refraction", "reflection", "electric", "magnetic", "electromagnetic", "circuit", "resistance", "gravity", "gravitational", "newton", "coulomb", "charge", "reaction", "compound", "element", "bond", "orbital", "hybridisation", "hybridization", "valence", "thermodynamics", "entropy", "enthalpy", "kinematics", "dynamics", "quantum", "relativity", "nuclear", "acid", "base", "salt", "buffer", "oxidation", "reduction", "redox", "mole", "molarity", "avogadro", "ncert", "class 11", "class 12", "class xi", "class xii", ] def _classify(text: str) -> str: """Return one of: casual, reasoning, science, general.""" t = text.strip().lower() if t in _CASUAL_EXACT or len(t) < 6: return "casual" if any(kw in t for kw in _REASONING_KW): return "reasoning" if any(kw in t for kw in _SCIENCE_KW): return "science" return "general" # ── Model Scoring ───────────────────────────────────────────── def _score(model_id: str) -> int: """Higher score ≈ better for hard reasoning tasks.""" m = model_id.lower() s = 50 for tag, pts in [ ("253b", 100), ("250b", 100), ("200b", 95), ("120b", 90), ("110b", 88), ("100b", 85), ("70b", 70), ("72b", 70), ("65b", 68), ("49b", 55), ("46b", 55), ("40b", 55), ("27b", 40), ("32b", 42), ("34b", 42), ("13b", 20), ("14b", 20), ("8b", 15), ("3b", -20), ("1b", -40), ("0.5b", -50), ]: if tag in m: s += pts break if "nemotron" in m: s += 15 if "gpt" in m: s += 10 if "llama" in m: s += 5 if "qwen" in m: s += 5 if "deepseek" in m: s += 8 return s # ── Model Pool Fetching & Caching ───────────────────────────── _cache: dict | None = None _cache_at = 0.0 _lock = threading.Lock() _counters: dict[str, int] = {} def _ensure_pinned(pool: list, pinned_id: str, vision: bool = False) -> list: """Guarantee the pinned default model sits at the front of a pool.""" rest = [e for e in pool if e["id"] != pinned_id] return [{"id": pinned_id, "score": 999, "vision": vision}] + rest def _fetch_pools() -> dict: """Fetch free models, detect vision support, build sorted pools.""" global _cache, _cache_at with _lock: if _cache and (time.time() - _cache_at) < MODELS_TTL: return _cache try: req = urllib.request.Request( "https://openrouter.ai/api/v1/models", headers={ "HTTP-Referer": "https://stemcopilot.app", "User-Agent": "STEMCopilot/1.0", }, ) with urllib.request.urlopen(req, timeout=10) as resp: raw = json.loads(resp.read().decode()) text_all, vision_all = [], [] for m in raw.get("data", []): mid = m.get("id", "") if not mid.endswith(":free"): continue if _is_guardrail(mid): continue arch = m.get("architecture", {}) modality = arch.get("modality", "") input_mods = arch.get("input_modalities", []) has_vision = ("image" in modality) or ("image" in input_mods) entry = {"id": mid, "score": _score(mid), "vision": has_vision} text_all.append(entry) if has_vision: vision_all.append(entry) text_all.sort(key=lambda x: x["score"], reverse=True) vision_all.sort(key=lambda x: x["score"], reverse=True) pools = { "reasoning": _ensure_pinned(text_all[:6], DEFAULT_MODEL), "science": _ensure_pinned(text_all[:8], DEFAULT_MODEL), "general": _ensure_pinned(text_all[:10], DEFAULT_MODEL), "casual": text_all[-5:][::-1] if len(text_all) >= 5 else text_all[:3], "vision": _ensure_pinned(vision_all, DEFAULT_VISION_MODEL, vision=True), } with _lock: _cache = pools _cache_at = time.time() print(f"[ROUTER] {len(text_all)} free models, {len(vision_all)} vision-capable") for cat in ("reasoning", "casual", "vision"): ids = [e["id"] for e in pools[cat][:3]] if ids: print(f"[ROUTER] {cat}: {ids}") return pools except Exception as exc: print(f"[ROUTER] Fetch failed: {exc}") fallback_text = {"id": DEFAULT_MODEL, "score": 999, "vision": False} fallback_vision = {"id": DEFAULT_VISION_MODEL, "score": 999, "vision": True} return _cache or { "reasoning": [fallback_text], "science": [fallback_text], "general": [fallback_text], "casual": [fallback_text], "vision": [fallback_vision], } # ── Model Picker (round-robin) ──────────────────────────────── def _pick(category: str, has_image: bool = False, skip_models: set | None = None) -> tuple[str, str]: """ Return (model_id, actual_category). If has_image but no vision model exists, falls back to text category. skip_models: set of model IDs to skip (e.g. after a 429). """ pools = _fetch_pools() skip = skip_models or set() if has_image: v = [m for m in pools.get("vision", []) if m["id"] not in skip] if v: # Pinned default vision model sits at v[0]; only move past it on a 429. return v[0]["id"], "vision" pool = [m for m in (pools.get(category) or pools.get("general") or []) if m["id"] not in skip] if not pool: return DEFAULT_MODEL, category # General / casual queries are pinned to the default text model (deterministic) — # pool[0] is DEFAULT_MODEL, and only a 429 (via skip_models) advances past it. # Subject queries (reasoning / science) rotate across the curated pool to spread # load and dodge per-model rate limits, as intended by the routing spec. if category in ("reasoning", "science"): idx = _counters.get(category, 0) % len(pool) _counters[category] = idx + 1 return pool[idx]["id"], category return pool[0]["id"], category # ── Message helpers ──────────────────────────────────────────── def _extract_text(messages: list) -> str: """Get raw text from the last user message (handles multimodal).""" if not messages: return "" content = messages[-1].content if hasattr(messages[-1], "content") else "" if isinstance(content, list): return " ".join( p.get("text", "") for p in content if isinstance(p, dict) and p.get("type") == "text" ) return str(content) def _strip_images(messages: list) -> list: """Remove all image_url content from messages. Preserves message types.""" out = [] for msg in messages: if isinstance(msg.content, list): text_parts = [ p.get("text", "") for p in msg.content if isinstance(p, dict) and p.get("type") == "text" ] text = " ".join(t for t in text_parts if t).strip() if not text: text = "(user sent an image)" out.append(msg.__class__(content=text)) else: out.append(msg) return out # ── LLM factory ─────────────────────────────────────────────── def _make_llm(api_key: str, model_id: str): key = api_key or OPENROUTER_API_KEY if not key: raise ValueError("No API key. Add your OpenRouter key in Settings.") return ChatOpenRouter( model=model_id, openrouter_api_key=key, temperature=0.5, max_tokens=4096, max_retries=0, streaming=True, ) # ── LangGraph state & node ──────────────────────────────────── class ChatState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] def chat_node(state: ChatState, config: RunnableConfig): cfg = config.get("configurable", {}) persona = cfg.get("persona", "nerd") context = cfg.get("context", "") language = cfg.get("language", "auto") username = cfg.get("username", "") profile = cfg.get("student_profile", "") api_key = cfg.get("user_api_key", "") override = cfg.get("model", "") search_enabled = cfg.get("search_enabled", False) has_image = cfg.get("has_image", False) doc_bytes = cfg.get("doc_bytes", "") doc_name = cfg.get("doc_name", "") # Surface the raw uploaded document in the LangSmith trace (not just its text). if doc_bytes: _attach_document(doc_name, doc_bytes) user_text = _extract_text(state["messages"]) category = _classify(user_text) messages = state["messages"] base_prompt = prompts.build(persona, context, language, username, profile) if search_enabled: scope_override = ( "## SEARCH MODE ACTIVE — READ THIS FIRST, BEFORE ANY OTHER INSTRUCTION BELOW:\n" "The student has turned on Web Search or YouTube mode. In this mode, you are not " "restricted to Physics, Chemistry, and Mathematics. The student may ask about anything — " "general knowledge, current events, how-to questions, other subjects, casual curiosity — " "and you must answer it directly and helpfully using the web search results or YouTube " "transcript provided below. Do NOT say 'I'm here to help you with Physics, Chemistry, and " "Maths' or any variation of that redirect — it is forbidden while this mode is active.\n\n" "PRIORITY RULE: If the student's question is itself a Class XI/XII Physics, Chemistry, or " "Maths doubt, and the NCERT context below contains relevant material, treat that NCERT " "context as the primary source and the web/YouTube results as a secondary supplement. " "For every other kind of question, rely on the web/YouTube results as the primary source.\n\n" ) base_prompt = scope_override + base_prompt if has_image: base_prompt += ( "\n\nIMAGE HANDLING:\n" "The student attached an image. Analyze it and respond educationally. " "Do NOT output safety classifications or moderation labels.\n" ) sys = SystemMessage(content=base_prompt) # Try up to 3 models on 429 errors skip_models: set[str] = set() last_error = None for attempt in range(3): if override and attempt == 0: model_id, actual = override, category else: model_id, actual = _pick(category, has_image=has_image, skip_models=skip_models) send_messages = messages if actual == "vision" else _strip_images(messages) print(f"[ROUTER] attempt={attempt+1} category={category} model={model_id} vision={actual == 'vision'}") try: llm = _make_llm(api_key, model_id) resp = llm.invoke([sys] + send_messages, config=config) return {"messages": [resp]} except Exception as e: err_str = str(e) last_error = e if "429" in err_str or "TooManyRequests" in err_str or "rate" in err_str.lower(): print(f"[ROUTER] 429 on {model_id}, rotating...") skip_models.add(model_id) continue raise raise last_error #type:ignore # ── Compile graph ───────────────────────────────────────────── _g = StateGraph(ChatState) _g.add_node("chat_node", chat_node) _g.add_edge(START, "chat_node") _g.add_edge("chat_node", END) chatbot = _g.compile(checkpointer=checkpointer)