StemGraph_AI / graph.py
Krishna111111's picture
round-robin removed; static models; fix; ux changes..
bc7a752
Raw
History Blame Contribute Delete
17.8 kB
"""
LangGraph β€” Intelligent Model Router for STEM Copilot.
Routes queries to the best free OpenRouter model based on intent:
- Math / derivation / proofs β†’ top reasoning models (rotated)
- Physics / chemistry concepts β†’ strong general models (rotated)
- Casual greetings ("hi", "ok") β†’ lightweight models
- Image understanding β†’ vision-capable models (verified)
Model pools are fetched from the OpenRouter API, cached for 10 min,
and rotated per-category to avoid per-model rate limits.
On 429 errors, the failed model is skipped and the next model in the
pool is tried automatically (up to 2 retries).
"""
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage
from langchain_core.runnables import RunnableConfig
from langchain_openrouter import ChatOpenRouter # type:ignore
from typing import TypedDict, Annotated
import sqlite3
import time
import json
import base64
import urllib.request
import threading
from langgraph.checkpoint.sqlite import SqliteSaver # type:ignore
from config import OPENROUTER_API_KEY, DB_PATH, LLM_TIMEOUT
import prompts
# ── LangSmith document attachment (surfaces the raw uploaded file) ──
# Optional: only active if the installed langsmith SDK supports attachments.
try:
from langsmith import traceable
try:
from langsmith import Attachment # type:ignore
except ImportError:
from langsmith.schemas import Attachment # type:ignore
_LANGSMITH_ATTACH = True
except Exception:
traceable = None # type:ignore
Attachment = None # type:ignore
_LANGSMITH_ATTACH = False
_MIME_BY_EXT = {
"pdf": "application/pdf",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"doc": "application/msword",
"txt": "text/plain",
"md": "text/markdown",
"csv": "text/csv",
"json": "application/json",
}
def _guess_mime(filename: str) -> str:
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
return _MIME_BY_EXT.get(ext, "application/octet-stream")
if _LANGSMITH_ATTACH:
@traceable(name="uploaded_document")
def _trace_document(filename: str, document) -> str: # noqa: ANN001
"""Nested run that carries the raw user-uploaded file as a LangSmith attachment."""
return f"Attached document: {filename}"
def _attach_document(doc_name: str, doc_bytes_b64: str) -> None:
"""Surface the original uploaded document in the active LangSmith trace.
Best-effort: any failure (old SDK, no API key, decode error) is swallowed so
tracing can never break the chat response.
"""
if not (_LANGSMITH_ATTACH and doc_bytes_b64):
return
try:
raw = base64.b64decode(doc_bytes_b64)
name = doc_name or "uploaded_document"
_trace_document(name, Attachment(mime_type=_guess_mime(name), data=raw))
except Exception as exc:
print(f"[LANGSMITH ATTACH] skipped: {exc}", flush=True)
# ── Checkpointer ──────────────────────────────────────────────
_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
checkpointer = SqliteSaver(conn=_conn)
# ── Constants ─────────────────────────────────────────────────
DEFAULT_MODEL = "openai/gpt-oss-120b:free" # pinned default for text
DEFAULT_VISION_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free" # pinned default for vision
MODELS_TTL = 10 * 60 # seconds
def _is_guardrail(model_id: str) -> bool:
"""Auto-filter safety classifiers / guardrail models."""
base = model_id.lower().split(":")[0]
return any(kw in base for kw in ("guard", "shield", "safety", "moderation"))
# ── Query Classification ─────────────────────────────────────
_CASUAL_EXACT = frozenset([
"hi", "hii", "hiii", "hello", "hey", "yo", "sup", "hola",
"thanks", "thank you", "thankyou", "thx", "ty",
"ok", "okay", "k", "kk", "fine", "alright",
"bye", "goodbye", "see you", "later",
"good morning", "good night", "good evening", "gm", "gn",
"welcome", "cool", "nice", "great", "awesome", "wow",
"got it", "understood", "sure", "yes", "no", "yeah", "nah", "yep", "nope",
"hm", "hmm", "oh", "ah", "lol", "haha", "hehe", "xd",
"what", "nothing", "nvm", "nevermind", "nm",
])
_REASONING_KW = [
"derive", "derivation", "prove", "proof", "solve", "equation",
"integral", "integrate", "differentiate", "differentiation",
"formula", "calculate", "calculation", "compute",
"theorem", "lemma", "corollary",
"limit", "matrix", "determinant", "vector", "eigen",
"trigonometry", "trigonometric",
"quadratic", "polynomial", "logarithm",
"calculus", "algebra", "geometry", "coordinate",
"probability", "permutation", "combination",
"step by step", "show steps", "show working", "work out",
"simplify", "factorize", "factorise", "expand",
"numerator", "denominator", "fraction",
"dx", "dy", "dz", " lim ",
]
_SCIENCE_KW = [
"physics", "chemistry", "biology",
"molecule", "atom", "electron", "proton", "neutron",
"ion", "isotope", "nucleus",
"force", "energy", "momentum", "velocity", "acceleration",
"wave", "wavelength", "frequency", "amplitude",
"light", "optics", "lens", "mirror", "refraction", "reflection",
"electric", "magnetic", "electromagnetic", "circuit", "resistance",
"gravity", "gravitational", "newton", "coulomb", "charge",
"reaction", "compound", "element", "bond", "orbital",
"hybridisation", "hybridization", "valence",
"thermodynamics", "entropy", "enthalpy",
"kinematics", "dynamics",
"quantum", "relativity", "nuclear",
"acid", "base", "salt", "buffer",
"oxidation", "reduction", "redox",
"mole", "molarity", "avogadro",
"ncert", "class 11", "class 12", "class xi", "class xii",
]
def _classify(text: str) -> str:
"""Return one of: casual, reasoning, science, general."""
t = text.strip().lower()
if t in _CASUAL_EXACT or len(t) < 6:
return "casual"
if any(kw in t for kw in _REASONING_KW):
return "reasoning"
if any(kw in t for kw in _SCIENCE_KW):
return "science"
return "general"
# ── Model Scoring ─────────────────────────────────────────────
def _score(model_id: str) -> int:
"""Higher score β‰ˆ better for hard reasoning tasks."""
m = model_id.lower()
s = 50
for tag, pts in [
("253b", 100), ("250b", 100), ("200b", 95),
("120b", 90), ("110b", 88), ("100b", 85),
("70b", 70), ("72b", 70), ("65b", 68),
("49b", 55), ("46b", 55), ("40b", 55),
("27b", 40), ("32b", 42), ("34b", 42),
("13b", 20), ("14b", 20), ("8b", 15),
("3b", -20), ("1b", -40), ("0.5b", -50),
]:
if tag in m:
s += pts
break
if "nemotron" in m: s += 15
if "gpt" in m: s += 10
if "llama" in m: s += 5
if "qwen" in m: s += 5
if "deepseek" in m: s += 8
return s
# ── Model Pool Fetching & Caching ─────────────────────────────
_cache: dict | None = None
_cache_at = 0.0
_lock = threading.Lock()
_counters: dict[str, int] = {}
def _ensure_pinned(pool: list, pinned_id: str, vision: bool = False) -> list:
"""Guarantee the pinned default model sits at the front of a pool."""
rest = [e for e in pool if e["id"] != pinned_id]
return [{"id": pinned_id, "score": 999, "vision": vision}] + rest
def _fetch_pools() -> dict:
"""Fetch free models, detect vision support, build sorted pools."""
global _cache, _cache_at
with _lock:
if _cache and (time.time() - _cache_at) < MODELS_TTL:
return _cache
try:
req = urllib.request.Request(
"https://openrouter.ai/api/v1/models",
headers={
"HTTP-Referer": "https://stemcopilot.app",
"User-Agent": "STEMCopilot/1.0",
},
)
with urllib.request.urlopen(req, timeout=10) as resp:
raw = json.loads(resp.read().decode())
text_all, vision_all = [], []
for m in raw.get("data", []):
mid = m.get("id", "")
if not mid.endswith(":free"):
continue
if _is_guardrail(mid):
continue
arch = m.get("architecture", {})
modality = arch.get("modality", "")
input_mods = arch.get("input_modalities", [])
has_vision = ("image" in modality) or ("image" in input_mods)
entry = {"id": mid, "score": _score(mid), "vision": has_vision}
text_all.append(entry)
if has_vision:
vision_all.append(entry)
text_all.sort(key=lambda x: x["score"], reverse=True)
vision_all.sort(key=lambda x: x["score"], reverse=True)
pools = {
"reasoning": _ensure_pinned(text_all[:6], DEFAULT_MODEL),
"science": _ensure_pinned(text_all[:8], DEFAULT_MODEL),
"general": _ensure_pinned(text_all[:10], DEFAULT_MODEL),
"casual": text_all[-5:][::-1] if len(text_all) >= 5 else text_all[:3],
"vision": _ensure_pinned(vision_all, DEFAULT_VISION_MODEL, vision=True),
}
with _lock:
_cache = pools
_cache_at = time.time()
print(f"[ROUTER] {len(text_all)} free models, {len(vision_all)} vision-capable")
for cat in ("reasoning", "casual", "vision"):
ids = [e["id"] for e in pools[cat][:3]]
if ids:
print(f"[ROUTER] {cat}: {ids}")
return pools
except Exception as exc:
print(f"[ROUTER] Fetch failed: {exc}")
fallback_text = {"id": DEFAULT_MODEL, "score": 999, "vision": False}
fallback_vision = {"id": DEFAULT_VISION_MODEL, "score": 999, "vision": True}
return _cache or {
"reasoning": [fallback_text],
"science": [fallback_text],
"general": [fallback_text],
"casual": [fallback_text],
"vision": [fallback_vision],
}
# ── Model Picker (round-robin) ────────────────────────────────
def _pick(category: str, has_image: bool = False, skip_models: set | None = None) -> tuple[str, str]:
"""
Return (model_id, actual_category).
If has_image but no vision model exists, falls back to text category.
skip_models: set of model IDs to skip (e.g. after a 429).
"""
pools = _fetch_pools()
skip = skip_models or set()
if has_image:
v = [m for m in pools.get("vision", []) if m["id"] not in skip]
if v:
# Pinned default vision model sits at v[0]; only move past it on a 429.
return v[0]["id"], "vision"
pool = [m for m in (pools.get(category) or pools.get("general") or []) if m["id"] not in skip]
if not pool:
return DEFAULT_MODEL, category
# General / casual queries are pinned to the default text model (deterministic) β€”
# pool[0] is DEFAULT_MODEL, and only a 429 (via skip_models) advances past it.
# Subject queries (reasoning / science) rotate across the curated pool to spread
# load and dodge per-model rate limits, as intended by the routing spec.
if category in ("reasoning", "science"):
idx = _counters.get(category, 0) % len(pool)
_counters[category] = idx + 1
return pool[idx]["id"], category
return pool[0]["id"], category
# ── Message helpers ────────────────────────────────────────────
def _extract_text(messages: list) -> str:
"""Get raw text from the last user message (handles multimodal)."""
if not messages:
return ""
content = messages[-1].content if hasattr(messages[-1], "content") else ""
if isinstance(content, list):
return " ".join(
p.get("text", "") for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
return str(content)
def _strip_images(messages: list) -> list:
"""Remove all image_url content from messages. Preserves message types."""
out = []
for msg in messages:
if isinstance(msg.content, list):
text_parts = [
p.get("text", "")
for p in msg.content
if isinstance(p, dict) and p.get("type") == "text"
]
text = " ".join(t for t in text_parts if t).strip()
if not text:
text = "(user sent an image)"
out.append(msg.__class__(content=text))
else:
out.append(msg)
return out
# ── LLM factory ───────────────────────────────────────────────
def _make_llm(api_key: str, model_id: str):
key = api_key or OPENROUTER_API_KEY
if not key:
raise ValueError("No API key. Add your OpenRouter key in Settings.")
return ChatOpenRouter(
model=model_id,
openrouter_api_key=key,
temperature=0.5,
max_tokens=4096,
max_retries=0,
streaming=True,
)
# ── LangGraph state & node ────────────────────────────────────
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
def chat_node(state: ChatState, config: RunnableConfig):
cfg = config.get("configurable", {})
persona = cfg.get("persona", "nerd")
context = cfg.get("context", "")
language = cfg.get("language", "auto")
username = cfg.get("username", "")
profile = cfg.get("student_profile", "")
api_key = cfg.get("user_api_key", "")
override = cfg.get("model", "")
search_enabled = cfg.get("search_enabled", False)
has_image = cfg.get("has_image", False)
doc_bytes = cfg.get("doc_bytes", "")
doc_name = cfg.get("doc_name", "")
# Surface the raw uploaded document in the LangSmith trace (not just its text).
if doc_bytes:
_attach_document(doc_name, doc_bytes)
user_text = _extract_text(state["messages"])
category = _classify(user_text)
messages = state["messages"]
base_prompt = prompts.build(persona, context, language, username, profile)
if search_enabled:
scope_override = (
"## SEARCH MODE ACTIVE β€” READ THIS FIRST, BEFORE ANY OTHER INSTRUCTION BELOW:\n"
"The student has turned on Web Search or YouTube mode. In this mode, you are not "
"restricted to Physics, Chemistry, and Mathematics. The student may ask about anything β€” "
"general knowledge, current events, how-to questions, other subjects, casual curiosity β€” "
"and you must answer it directly and helpfully using the web search results or YouTube "
"transcript provided below. Do NOT say 'I'm here to help you with Physics, Chemistry, and "
"Maths' or any variation of that redirect β€” it is forbidden while this mode is active.\n\n"
"PRIORITY RULE: If the student's question is itself a Class XI/XII Physics, Chemistry, or "
"Maths doubt, and the NCERT context below contains relevant material, treat that NCERT "
"context as the primary source and the web/YouTube results as a secondary supplement. "
"For every other kind of question, rely on the web/YouTube results as the primary source.\n\n"
)
base_prompt = scope_override + base_prompt
if has_image:
base_prompt += (
"\n\nIMAGE HANDLING:\n"
"The student attached an image. Analyze it and respond educationally. "
"Do NOT output safety classifications or moderation labels.\n"
)
sys = SystemMessage(content=base_prompt)
# Try up to 3 models on 429 errors
skip_models: set[str] = set()
last_error = None
for attempt in range(3):
if override and attempt == 0:
model_id, actual = override, category
else:
model_id, actual = _pick(category, has_image=has_image, skip_models=skip_models)
send_messages = messages if actual == "vision" else _strip_images(messages)
print(f"[ROUTER] attempt={attempt+1} category={category} model={model_id} vision={actual == 'vision'}")
try:
llm = _make_llm(api_key, model_id)
resp = llm.invoke([sys] + send_messages, config=config)
return {"messages": [resp]}
except Exception as e:
err_str = str(e)
last_error = e
if "429" in err_str or "TooManyRequests" in err_str or "rate" in err_str.lower():
print(f"[ROUTER] 429 on {model_id}, rotating...")
skip_models.add(model_id)
continue
raise
raise last_error #type:ignore
# ── Compile graph ─────────────────────────────────────────────
_g = StateGraph(ChatState)
_g.add_node("chat_node", chat_node)
_g.add_edge(START, "chat_node")
_g.add_edge("chat_node", END)
chatbot = _g.compile(checkpointer=checkpointer)