Stem_Copilot / graph.py
Krishna111111's picture
Still Trying...
0930342
Raw
History Blame Contribute Delete
14.1 kB
"""
LangGraph β€” Intelligent Model Router for STEM Copilot.
Routes queries to the best free OpenRouter model based on intent:
- Math / derivation / proofs β†’ top reasoning models (rotated)
- Physics / chemistry concepts β†’ strong general models (rotated)
- Casual greetings ("hi", "ok") β†’ lightweight models
- Image understanding β†’ vision-capable models (verified)
Model pools are fetched from the OpenRouter API, cached for 10 min,
and rotated per-category to avoid per-model rate limits.
On 429 errors, the failed model is skipped and the next model in the
pool is tried automatically (up to 2 retries).
"""
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage
from langchain_core.runnables import RunnableConfig
from langchain_openrouter import ChatOpenRouter # type:ignore
from typing import TypedDict, Annotated
import sqlite3
import time
import base64
import logging
from langgraph.checkpoint.sqlite import SqliteSaver # type:ignore
from config import OPENROUTER_API_KEY, DB_PATH, LLM_TIMEOUT, CONTEXT_WINDOW
import prompts
logger = logging.getLogger(__name__)
# ── LangSmith document attachment (surfaces the raw uploaded file) ──
# Optional: only active if the installed langsmith SDK supports attachments.
try:
from langsmith import traceable
try:
from langsmith import Attachment # type:ignore
except ImportError:
from langsmith.schemas import Attachment # type:ignore
_LANGSMITH_ATTACH = True
except Exception:
traceable = None # type:ignore
Attachment = None # type:ignore
_LANGSMITH_ATTACH = False
_MIME_BY_EXT = {
"pdf": "application/pdf",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"doc": "application/msword",
"txt": "text/plain",
"md": "text/markdown",
"csv": "text/csv",
"json": "application/json",
}
def _guess_mime(filename: str) -> str:
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
return _MIME_BY_EXT.get(ext, "application/octet-stream")
if _LANGSMITH_ATTACH:
@traceable(name="uploaded_document")
def _trace_document(filename: str, document) -> str: # noqa: ANN001
"""Nested run that carries the raw user-uploaded file as a LangSmith attachment."""
return f"Attached document: {filename}"
def _attach_document(doc_name: str, doc_bytes_b64: str) -> None:
"""Surface the original uploaded document in the active LangSmith trace.
Best-effort: any failure (old SDK, no API key, decode error) is swallowed so
tracing can never break the chat response.
"""
if not (_LANGSMITH_ATTACH and doc_bytes_b64):
return
try:
raw = base64.b64decode(doc_bytes_b64)
name = doc_name or "uploaded_document"
_trace_document(name, Attachment(mime_type=_guess_mime(name), data=raw))
except Exception as exc:
logger.debug("langsmith attach skipped: %s", exc)
# ── Checkpointer ──────────────────────────────────────────────
_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
checkpointer = SqliteSaver(conn=_conn)
# ── Constants ─────────────────────────────────────────────────
# Routing is fully static β€” only these pinned models are ever used. No model is
# fetched/discovered from the OpenRouter API at runtime, so a query can never be
# routed to an unknown, unreliable, or guardrail model.
DEFAULT_MODEL = "openai/gpt-oss-120b:free" # casual / general text
COMPLEX_MODEL = "nvidia/nemotron-3-ultra-550b-a55b:free" # hard physics/chem/maths
DEFAULT_VISION_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free" # image queries
def _is_guardrail(model_id: str) -> bool:
"""True for safety-classifier / guardrail models β€” we never route to these."""
base = model_id.lower().split(":")[0]
return any(kw in base for kw in ("guard", "shield", "safety", "moderation"))
# Safety net: guarantee no pinned model is a guardrail classifier.
assert not _is_guardrail(DEFAULT_MODEL), "DEFAULT_MODEL must not be a guardrail model"
assert not _is_guardrail(COMPLEX_MODEL), "COMPLEX_MODEL must not be a guardrail model"
assert not _is_guardrail(DEFAULT_VISION_MODEL), "DEFAULT_VISION_MODEL must not be a guardrail model"
# ── Query Classification ─────────────────────────────────────
_CASUAL_EXACT = frozenset([
"hi", "hii", "hiii", "hello", "hey", "yo", "sup", "hola",
"thanks", "thank you", "thankyou", "thx", "ty",
"ok", "okay", "k", "kk", "fine", "alright",
"bye", "goodbye", "see you", "later",
"good morning", "good night", "good evening", "gm", "gn",
"welcome", "cool", "nice", "great", "awesome", "wow",
"got it", "understood", "sure", "yes", "no", "yeah", "nah", "yep", "nope",
"hm", "hmm", "oh", "ah", "lol", "haha", "hehe", "xd",
"what", "nothing", "nvm", "nevermind", "nm",
])
_REASONING_KW = [
"derive", "derivation", "prove", "proof", "solve", "equation",
"integral", "integrate", "differentiate", "differentiation",
"formula", "calculate", "calculation", "compute",
"theorem", "lemma", "corollary",
"limit", "matrix", "determinant", "vector", "eigen",
"trigonometry", "trigonometric",
"quadratic", "polynomial", "logarithm",
"calculus", "algebra", "geometry", "coordinate",
"probability", "permutation", "combination",
"step by step", "show steps", "show working", "work out",
"simplify", "factorize", "factorise", "expand",
"numerator", "denominator", "fraction",
"dx", "dy", "dz", " lim ",
]
_SCIENCE_KW = [
"physics", "chemistry", "biology",
"molecule", "atom", "electron", "proton", "neutron",
"ion", "isotope", "nucleus",
"force", "energy", "momentum", "velocity", "acceleration",
"wave", "wavelength", "frequency", "amplitude",
"light", "optics", "lens", "mirror", "refraction", "reflection",
"electric", "magnetic", "electromagnetic", "circuit", "resistance",
"gravity", "gravitational", "newton", "coulomb", "charge",
"reaction", "compound", "element", "bond", "orbital",
"hybridisation", "hybridization", "valence",
"thermodynamics", "entropy", "enthalpy",
"kinematics", "dynamics",
"quantum", "relativity", "nuclear",
"acid", "base", "salt", "buffer",
"oxidation", "reduction", "redox",
"mole", "molarity", "avogadro",
"ncert", "class 11", "class 12", "class xi", "class xii",
]
def _classify(text: str) -> str:
"""Return one of: casual, reasoning, science, general."""
t = text.strip().lower()
if t in _CASUAL_EXACT or len(t) < 6:
return "casual"
if any(kw in t for kw in _REASONING_KW):
return "reasoning"
if any(kw in t for kw in _SCIENCE_KW):
return "science"
return "general"
# ── Model Picker (static) ─────────────────────────────────────
def _pick(category: str = "", has_image: bool = False, skip_models: set | None = None) -> tuple[str, str]:
"""
Static routing β€” pinned models chosen by query type:
β€’ image attached β†’ DEFAULT_VISION_MODEL (actual = "vision")
β€’ hard physics/chem/maths (reasoning,
science) β†’ COMPLEX_MODEL (actual = "text")
β€’ casual / general β†’ DEFAULT_MODEL (actual = "text")
`skip_models` is accepted for call-site compatibility but intentionally
ignored: there is no model pool to rotate or fall back through.
"""
if has_image:
return DEFAULT_VISION_MODEL, "vision"
if category in ("reasoning", "science"):
return COMPLEX_MODEL, "text"
return DEFAULT_MODEL, "text"
# ── Message helpers ────────────────────────────────────────────
def _extract_text(messages: list) -> str:
"""Get raw text from the last user message (handles multimodal)."""
if not messages:
return ""
content = messages[-1].content if hasattr(messages[-1], "content") else ""
if isinstance(content, list):
return " ".join(
p.get("text", "") for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
return str(content)
def _strip_images(messages: list) -> list:
"""Remove all image_url content from messages. Preserves message types."""
out = []
for msg in messages:
if isinstance(msg.content, list):
text_parts = [
p.get("text", "")
for p in msg.content
if isinstance(p, dict) and p.get("type") == "text"
]
text = " ".join(t for t in text_parts if t).strip()
if not text:
text = "(user sent an image)"
out.append(msg.__class__(content=text))
else:
out.append(msg)
return out
# ── LLM factory ───────────────────────────────────────────────
def _make_llm(api_key: str, model_id: str):
key = api_key or OPENROUTER_API_KEY
if not key:
raise ValueError("No API key. Add your OpenRouter key in Settings.")
return ChatOpenRouter(
model=model_id,
openrouter_api_key=key,
temperature=0.5,
max_tokens=2048,
max_retries=0,
streaming=True,
)
# ── LangGraph state & node ────────────────────────────────────
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
def chat_node(state: ChatState, config: RunnableConfig):
cfg = config.get("configurable", {})
persona = cfg.get("persona", "nerd")
context = cfg.get("context", "")
language = cfg.get("language", "auto")
username = cfg.get("username", "")
profile = cfg.get("student_profile", "")
api_key = cfg.get("user_api_key", "")
search_enabled = cfg.get("search_enabled", False)
has_image = cfg.get("has_image", False)
doc_bytes = cfg.get("doc_bytes", "")
doc_name = cfg.get("doc_name", "")
# Surface the raw uploaded document in the LangSmith trace (not just its text).
if doc_bytes:
_attach_document(doc_name, doc_bytes)
# Sliding window: only the last CONTEXT_WINDOW messages are sent to the LLM.
# Full history stays in the checkpoint for the /history view β€” this just caps
# input-token growth on long conversations.
messages = state["messages"]
if len(messages) > CONTEXT_WINDOW:
messages = messages[-CONTEXT_WINDOW:]
base_prompt = prompts.build(persona, context, language, username, profile)
if search_enabled:
scope_override = (
"## SEARCH MODE ACTIVE β€” READ THIS FIRST, BEFORE ANY OTHER INSTRUCTION BELOW:\n"
"The student has turned on Web Search or YouTube mode. In this mode, you are not "
"restricted to Physics, Chemistry, and Mathematics. The student may ask about anything β€” "
"general knowledge, current events, how-to questions, other subjects, casual curiosity β€” "
"and you must answer it directly and helpfully using the web search results or YouTube "
"transcript provided below. Do NOT say 'I'm here to help you with Physics, Chemistry, and "
"Maths' or any variation of that redirect β€” it is forbidden while this mode is active.\n\n"
"PRIORITY RULE: If the student's question is itself a Class XI/XII Physics, Chemistry, or "
"Maths doubt, and the NCERT context below contains relevant material, treat that NCERT "
"context as the primary source and the web/YouTube results as a secondary supplement. "
"For every other kind of question, rely on the web/YouTube results as the primary source.\n\n"
)
base_prompt = scope_override + base_prompt
if has_image:
base_prompt += (
"\n\nIMAGE HANDLING:\n"
"The student attached an image. Analyze it and respond educationally. "
"Do NOT output safety classifications or moderation labels.\n"
)
sys = SystemMessage(content=base_prompt)
# Static routing β†’ one pinned model per query type. A free-tier 429 is usually
# a transient per-minute cap, so retry the same model once after a short backoff.
category = _classify(_extract_text(messages))
model_id, actual = _pick(category=category, has_image=has_image)
send_messages = messages if actual == "vision" else _strip_images(messages)
last_error = None
for attempt in range(2):
logger.debug("router attempt=%d model=%s vision=%s", attempt + 1, model_id, actual == "vision")
try:
llm = _make_llm(api_key, model_id)
return {"messages": [llm.invoke([sys] + send_messages, config=config)]}
except Exception as e:
last_error = e
err = str(e).lower()
if attempt == 0 and ("429" in err or "toomany" in err or "rate" in err):
logger.info("429 on %s, retrying once after backoff", model_id)
time.sleep(2)
continue
raise
raise last_error # type:ignore
# ── Compile graph ─────────────────────────────────────────────
_g = StateGraph(ChatState)
_g.add_node("chat_node", chat_node)
_g.add_edge(START, "chat_node")
_g.add_edge("chat_node", END)
chatbot = _g.compile(checkpointer=checkpointer)