Spaces:
Sleeping
Sleeping
| """ | |
| LangGraph β Intelligent Model Router for STEM Copilot. | |
| Routes queries to the best free OpenRouter model based on intent: | |
| - Math / derivation / proofs β top reasoning models (rotated) | |
| - Physics / chemistry concepts β strong general models (rotated) | |
| - Casual greetings ("hi", "ok") β lightweight models | |
| - Image understanding β vision-capable models (verified) | |
| Model pools are fetched from the OpenRouter API, cached for 10 min, | |
| and rotated per-category to avoid per-model rate limits. | |
| On 429 errors, the failed model is skipped and the next model in the | |
| pool is tried automatically (up to 2 retries). | |
| """ | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage | |
| from langchain_core.runnables import RunnableConfig | |
| from langchain_openrouter import ChatOpenRouter # type:ignore | |
| from typing import TypedDict, Annotated | |
| import sqlite3 | |
| import time | |
| import base64 | |
| import logging | |
| from langgraph.checkpoint.sqlite import SqliteSaver # type:ignore | |
| from config import OPENROUTER_API_KEY, DB_PATH, LLM_TIMEOUT, CONTEXT_WINDOW | |
| import prompts | |
| logger = logging.getLogger(__name__) | |
| # ββ LangSmith document attachment (surfaces the raw uploaded file) ββ | |
| # Optional: only active if the installed langsmith SDK supports attachments. | |
| try: | |
| from langsmith import traceable | |
| try: | |
| from langsmith import Attachment # type:ignore | |
| except ImportError: | |
| from langsmith.schemas import Attachment # type:ignore | |
| _LANGSMITH_ATTACH = True | |
| except Exception: | |
| traceable = None # type:ignore | |
| Attachment = None # type:ignore | |
| _LANGSMITH_ATTACH = False | |
| _MIME_BY_EXT = { | |
| "pdf": "application/pdf", | |
| "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
| "doc": "application/msword", | |
| "txt": "text/plain", | |
| "md": "text/markdown", | |
| "csv": "text/csv", | |
| "json": "application/json", | |
| } | |
| def _guess_mime(filename: str) -> str: | |
| ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" | |
| return _MIME_BY_EXT.get(ext, "application/octet-stream") | |
| if _LANGSMITH_ATTACH: | |
| def _trace_document(filename: str, document) -> str: # noqa: ANN001 | |
| """Nested run that carries the raw user-uploaded file as a LangSmith attachment.""" | |
| return f"Attached document: {filename}" | |
| def _attach_document(doc_name: str, doc_bytes_b64: str) -> None: | |
| """Surface the original uploaded document in the active LangSmith trace. | |
| Best-effort: any failure (old SDK, no API key, decode error) is swallowed so | |
| tracing can never break the chat response. | |
| """ | |
| if not (_LANGSMITH_ATTACH and doc_bytes_b64): | |
| return | |
| try: | |
| raw = base64.b64decode(doc_bytes_b64) | |
| name = doc_name or "uploaded_document" | |
| _trace_document(name, Attachment(mime_type=_guess_mime(name), data=raw)) | |
| except Exception as exc: | |
| logger.debug("langsmith attach skipped: %s", exc) | |
| # ββ Checkpointer ββββββββββββββββββββββββββββββββββββββββββββββ | |
| _conn = sqlite3.connect(DB_PATH, check_same_thread=False) | |
| checkpointer = SqliteSaver(conn=_conn) | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Routing is fully static β only these pinned models are ever used. No model is | |
| # fetched/discovered from the OpenRouter API at runtime, so a query can never be | |
| # routed to an unknown, unreliable, or guardrail model. | |
| DEFAULT_MODEL = "openai/gpt-oss-120b:free" # casual / general text | |
| COMPLEX_MODEL = "nvidia/nemotron-3-ultra-550b-a55b:free" # hard physics/chem/maths | |
| DEFAULT_VISION_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free" # image queries | |
| def _is_guardrail(model_id: str) -> bool: | |
| """True for safety-classifier / guardrail models β we never route to these.""" | |
| base = model_id.lower().split(":")[0] | |
| return any(kw in base for kw in ("guard", "shield", "safety", "moderation")) | |
| # Safety net: guarantee no pinned model is a guardrail classifier. | |
| assert not _is_guardrail(DEFAULT_MODEL), "DEFAULT_MODEL must not be a guardrail model" | |
| assert not _is_guardrail(COMPLEX_MODEL), "COMPLEX_MODEL must not be a guardrail model" | |
| assert not _is_guardrail(DEFAULT_VISION_MODEL), "DEFAULT_VISION_MODEL must not be a guardrail model" | |
| # ββ Query Classification βββββββββββββββββββββββββββββββββββββ | |
| _CASUAL_EXACT = frozenset([ | |
| "hi", "hii", "hiii", "hello", "hey", "yo", "sup", "hola", | |
| "thanks", "thank you", "thankyou", "thx", "ty", | |
| "ok", "okay", "k", "kk", "fine", "alright", | |
| "bye", "goodbye", "see you", "later", | |
| "good morning", "good night", "good evening", "gm", "gn", | |
| "welcome", "cool", "nice", "great", "awesome", "wow", | |
| "got it", "understood", "sure", "yes", "no", "yeah", "nah", "yep", "nope", | |
| "hm", "hmm", "oh", "ah", "lol", "haha", "hehe", "xd", | |
| "what", "nothing", "nvm", "nevermind", "nm", | |
| ]) | |
| _REASONING_KW = [ | |
| "derive", "derivation", "prove", "proof", "solve", "equation", | |
| "integral", "integrate", "differentiate", "differentiation", | |
| "formula", "calculate", "calculation", "compute", | |
| "theorem", "lemma", "corollary", | |
| "limit", "matrix", "determinant", "vector", "eigen", | |
| "trigonometry", "trigonometric", | |
| "quadratic", "polynomial", "logarithm", | |
| "calculus", "algebra", "geometry", "coordinate", | |
| "probability", "permutation", "combination", | |
| "step by step", "show steps", "show working", "work out", | |
| "simplify", "factorize", "factorise", "expand", | |
| "numerator", "denominator", "fraction", | |
| "dx", "dy", "dz", " lim ", | |
| ] | |
| _SCIENCE_KW = [ | |
| "physics", "chemistry", "biology", | |
| "molecule", "atom", "electron", "proton", "neutron", | |
| "ion", "isotope", "nucleus", | |
| "force", "energy", "momentum", "velocity", "acceleration", | |
| "wave", "wavelength", "frequency", "amplitude", | |
| "light", "optics", "lens", "mirror", "refraction", "reflection", | |
| "electric", "magnetic", "electromagnetic", "circuit", "resistance", | |
| "gravity", "gravitational", "newton", "coulomb", "charge", | |
| "reaction", "compound", "element", "bond", "orbital", | |
| "hybridisation", "hybridization", "valence", | |
| "thermodynamics", "entropy", "enthalpy", | |
| "kinematics", "dynamics", | |
| "quantum", "relativity", "nuclear", | |
| "acid", "base", "salt", "buffer", | |
| "oxidation", "reduction", "redox", | |
| "mole", "molarity", "avogadro", | |
| "ncert", "class 11", "class 12", "class xi", "class xii", | |
| ] | |
| def _classify(text: str) -> str: | |
| """Return one of: casual, reasoning, science, general.""" | |
| t = text.strip().lower() | |
| if t in _CASUAL_EXACT or len(t) < 6: | |
| return "casual" | |
| if any(kw in t for kw in _REASONING_KW): | |
| return "reasoning" | |
| if any(kw in t for kw in _SCIENCE_KW): | |
| return "science" | |
| return "general" | |
| # ββ Model Picker (static) βββββββββββββββββββββββββββββββββββββ | |
| def _pick(category: str = "", has_image: bool = False, skip_models: set | None = None) -> tuple[str, str]: | |
| """ | |
| Static routing β pinned models chosen by query type: | |
| β’ image attached β DEFAULT_VISION_MODEL (actual = "vision") | |
| β’ hard physics/chem/maths (reasoning, | |
| science) β COMPLEX_MODEL (actual = "text") | |
| β’ casual / general β DEFAULT_MODEL (actual = "text") | |
| `skip_models` is accepted for call-site compatibility but intentionally | |
| ignored: there is no model pool to rotate or fall back through. | |
| """ | |
| if has_image: | |
| return DEFAULT_VISION_MODEL, "vision" | |
| if category in ("reasoning", "science"): | |
| return COMPLEX_MODEL, "text" | |
| return DEFAULT_MODEL, "text" | |
| # ββ Message helpers ββββββββββββββββββββββββββββββββββββββββββββ | |
| def _extract_text(messages: list) -> str: | |
| """Get raw text from the last user message (handles multimodal).""" | |
| if not messages: | |
| return "" | |
| content = messages[-1].content if hasattr(messages[-1], "content") else "" | |
| if isinstance(content, list): | |
| return " ".join( | |
| p.get("text", "") for p in content | |
| if isinstance(p, dict) and p.get("type") == "text" | |
| ) | |
| return str(content) | |
| def _strip_images(messages: list) -> list: | |
| """Remove all image_url content from messages. Preserves message types.""" | |
| out = [] | |
| for msg in messages: | |
| if isinstance(msg.content, list): | |
| text_parts = [ | |
| p.get("text", "") | |
| for p in msg.content | |
| if isinstance(p, dict) and p.get("type") == "text" | |
| ] | |
| text = " ".join(t for t in text_parts if t).strip() | |
| if not text: | |
| text = "(user sent an image)" | |
| out.append(msg.__class__(content=text)) | |
| else: | |
| out.append(msg) | |
| return out | |
| # ββ LLM factory βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_llm(api_key: str, model_id: str): | |
| key = api_key or OPENROUTER_API_KEY | |
| if not key: | |
| raise ValueError("No API key. Add your OpenRouter key in Settings.") | |
| return ChatOpenRouter( | |
| model=model_id, | |
| openrouter_api_key=key, | |
| temperature=0.5, | |
| max_tokens=2048, | |
| max_retries=0, | |
| streaming=True, | |
| ) | |
| # ββ LangGraph state & node ββββββββββββββββββββββββββββββββββββ | |
| class ChatState(TypedDict): | |
| messages: Annotated[list[BaseMessage], add_messages] | |
| def chat_node(state: ChatState, config: RunnableConfig): | |
| cfg = config.get("configurable", {}) | |
| persona = cfg.get("persona", "nerd") | |
| context = cfg.get("context", "") | |
| language = cfg.get("language", "auto") | |
| username = cfg.get("username", "") | |
| profile = cfg.get("student_profile", "") | |
| api_key = cfg.get("user_api_key", "") | |
| search_enabled = cfg.get("search_enabled", False) | |
| has_image = cfg.get("has_image", False) | |
| doc_bytes = cfg.get("doc_bytes", "") | |
| doc_name = cfg.get("doc_name", "") | |
| # Surface the raw uploaded document in the LangSmith trace (not just its text). | |
| if doc_bytes: | |
| _attach_document(doc_name, doc_bytes) | |
| # Sliding window: only the last CONTEXT_WINDOW messages are sent to the LLM. | |
| # Full history stays in the checkpoint for the /history view β this just caps | |
| # input-token growth on long conversations. | |
| messages = state["messages"] | |
| if len(messages) > CONTEXT_WINDOW: | |
| messages = messages[-CONTEXT_WINDOW:] | |
| base_prompt = prompts.build(persona, context, language, username, profile) | |
| if search_enabled: | |
| scope_override = ( | |
| "## SEARCH MODE ACTIVE β READ THIS FIRST, BEFORE ANY OTHER INSTRUCTION BELOW:\n" | |
| "The student has turned on Web Search or YouTube mode. In this mode, you are not " | |
| "restricted to Physics, Chemistry, and Mathematics. The student may ask about anything β " | |
| "general knowledge, current events, how-to questions, other subjects, casual curiosity β " | |
| "and you must answer it directly and helpfully using the web search results or YouTube " | |
| "transcript provided below. Do NOT say 'I'm here to help you with Physics, Chemistry, and " | |
| "Maths' or any variation of that redirect β it is forbidden while this mode is active.\n\n" | |
| "PRIORITY RULE: If the student's question is itself a Class XI/XII Physics, Chemistry, or " | |
| "Maths doubt, and the NCERT context below contains relevant material, treat that NCERT " | |
| "context as the primary source and the web/YouTube results as a secondary supplement. " | |
| "For every other kind of question, rely on the web/YouTube results as the primary source.\n\n" | |
| ) | |
| base_prompt = scope_override + base_prompt | |
| if has_image: | |
| base_prompt += ( | |
| "\n\nIMAGE HANDLING:\n" | |
| "The student attached an image. Analyze it and respond educationally. " | |
| "Do NOT output safety classifications or moderation labels.\n" | |
| ) | |
| sys = SystemMessage(content=base_prompt) | |
| # Static routing β one pinned model per query type. A free-tier 429 is usually | |
| # a transient per-minute cap, so retry the same model once after a short backoff. | |
| category = _classify(_extract_text(messages)) | |
| model_id, actual = _pick(category=category, has_image=has_image) | |
| send_messages = messages if actual == "vision" else _strip_images(messages) | |
| last_error = None | |
| for attempt in range(2): | |
| logger.debug("router attempt=%d model=%s vision=%s", attempt + 1, model_id, actual == "vision") | |
| try: | |
| llm = _make_llm(api_key, model_id) | |
| return {"messages": [llm.invoke([sys] + send_messages, config=config)]} | |
| except Exception as e: | |
| last_error = e | |
| err = str(e).lower() | |
| if attempt == 0 and ("429" in err or "toomany" in err or "rate" in err): | |
| logger.info("429 on %s, retrying once after backoff", model_id) | |
| time.sleep(2) | |
| continue | |
| raise | |
| raise last_error # type:ignore | |
| # ββ Compile graph βββββββββββββββββββββββββββββββββββββββββββββ | |
| _g = StateGraph(ChatState) | |
| _g.add_node("chat_node", chat_node) | |
| _g.add_edge(START, "chat_node") | |
| _g.add_edge("chat_node", END) | |
| chatbot = _g.compile(checkpointer=checkpointer) | |