# backend/chains/router.py import os import io import shutil import zipfile import requests from typing import Dict from dotenv import load_dotenv # ------------------------- # Boot-time configuration # ------------------------- load_dotenv() # Hugging Face cache dirs -> tmp (HF Spaces cannot write to /root) os.environ.setdefault("HF_HOME", "/tmp/huggingface") os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/sentence_transformers") for _p in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["SENTENCE_TRANSFORMERS_HOME"]]: os.makedirs(_p, exist_ok=True) # ------------------------- # GitHub repo configuration # ------------------------- REPO_URL = "https://github.com/rayymaxx/E-Learning-App" BRANCH = os.environ.get("GIT_BRANCH", "main") REPO_VECTOR_SUBFOLDER = os.environ.get("REPO_VECTOR_SUBFOLDER", "direct-ed-ai-assistant/app/vector_store") VECTOR_STORE_PATH = os.environ.get("VECTOR_STORE_PATH", "/tmp/vector_store") # extracted Chroma folder CHROMA_PERSIST_DIR = os.environ.get("CHROMA_PERSIST_DIR", "/tmp/chroma_db") # working Chroma instance dir os.makedirs("/tmp", exist_ok=True) # ------------------------- # Helpers # ------------------------- def _repo_zip_prefix(repo_url: str, branch: str) -> str: """GitHub zips extract to -/...""" repo_name = repo_url.rstrip("/").split("/")[-1] return f"{repo_name}-{branch}" def fetch_vector_store_if_needed(repo_url: str, branch: str, subfolder: str, dest_dir: str): """Download a subfolder from GitHub repo ZIP into dest_dir (if missing).""" if os.path.exists(dest_dir) and os.listdir(dest_dir): print(f"✅ Vector store already present at {dest_dir}") return dest_dir zip_url = f"{repo_url}/archive/refs/heads/{branch}.zip" print(f"⬇️ Downloading vector store ZIP from {zip_url} ...") r = requests.get(zip_url, timeout=60) r.raise_for_status() extract_root = "/tmp/repo_zip_extract" shutil.rmtree(extract_root, ignore_errors=True) os.makedirs(extract_root, exist_ok=True) with zipfile.ZipFile(io.BytesIO(r.content)) as zf: zf.extractall(extract_root) root_prefix = _repo_zip_prefix(repo_url, branch) source_folder = os.path.join(extract_root, root_prefix, subfolder) if not os.path.exists(source_folder): raise RuntimeError(f"❌ Could not find {subfolder} in {repo_url}@{branch}") for name in os.listdir(source_folder): src, dst = os.path.join(source_folder, name), os.path.join(dest_dir, name) if os.path.isdir(src): shutil.copytree(src, dst, dirs_exist_ok=True) else: shutil.copy2(src, dst) print(f"✅ Vector store ready at {dest_dir}") return dest_dir # ------------------------- # Vector store initialization # ------------------------- from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma def init_vector_store(): embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") try: path = fetch_vector_store_if_needed(REPO_URL, BRANCH, REPO_VECTOR_SUBFOLDER, VECTOR_STORE_PATH) print(f"📂 Using vector store from {path}") return Chroma(persist_directory=path, embedding_function=embeddings) except Exception as e: print(f"⚠️ Could not load prebuilt Chroma DB: {e}") print("➡️ Falling back to empty Chroma at", CHROMA_PERSIST_DIR) os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True) return Chroma(persist_directory=CHROMA_PERSIST_DIR, embedding_function=embeddings) vector_store = init_vector_store() # ------------------------- # LLM wiring # ------------------------- from backend.llms.custom import CustomChatModel from langchain_openai import ChatOpenAI openai_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.1) finetuned_llm = CustomChatModel( api_url="https://Nutnell-E-Learning-platform.hf.space/generate" ).with_fallbacks([openai_llm]) # ------------------------- # Chains # ------------------------- from langchain_core.runnables import ( RunnableBranch, RunnableLambda, RunnableParallel, RunnablePassthrough ) from langchain_core.output_parsers import StrOutputParser from langchain_core.messages import get_buffer_string from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_community.chat_message_histories import ChatMessageHistory from backend.schemas.api_models import ChatInput, ChatOutput from backend.prompts.templates import ( rag_prompt, quiz_generator_prompt, flashcard_generator_prompt, explanation_prompt, definition_prompt, condense_question_prompt ) def format_docs(docs): return "\n---\n".join(doc.page_content for doc in docs) def get_sources_from_docs(docs): return [ {"source": doc.metadata.get("source_url", ""), "name": doc.metadata.get("source_name", "")} for doc in docs ] # Memory memories: Dict[str, ChatMessageHistory] = {} def get_memory_for_session(session_id: str): if session_id not in memories: memories[session_id] = ChatMessageHistory() return memories[session_id] # Retriever def EducationalRetriever(): return vector_store.as_retriever(search_kwargs={"k": 2}) # ------------------------- # Conversation chain # ------------------------- def AdaptiveConversationChain(): retriever = EducationalRetriever() condense = ( RunnableLambda(lambda x: {"question": x["input"], "chat_history": get_buffer_string(x["chat_history"])} ) | condense_question_prompt | openai_llm | StrOutputParser() ) return ( RunnablePassthrough.assign(standalone_question=condense) .assign(context=(RunnableLambda(lambda x: x["standalone_question"]) | retriever)) | RunnableParallel( answer=(RunnableLambda(lambda x: { "context": format_docs(x["context"]), "question": x["input"], "subject": x.get("subject", "the topic"), "difficulty_level": x.get("difficulty_level", "beginner"), }) | rag_prompt | finetuned_llm | StrOutputParser()), sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), ) ) # ------------------------- # Generalized Content Chains # ------------------------- def QuizChain(): return ( RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever())) | RunnableParallel( answer=(RunnableLambda(lambda x: { "context": format_docs(x["context"]), "subject": x.get("subject", "topic"), "difficulty_level": x.get("difficulty_level", "intermediate"), }) | quiz_generator_prompt | finetuned_llm | StrOutputParser()), sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), ) ) def FlashcardChain(): return ( RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever())) | RunnableParallel( answer=(RunnableLambda(lambda x: { "context": format_docs(x["context"]), "difficulty_level": x.get("difficulty_level", "beginner"), }) | flashcard_generator_prompt | finetuned_llm | StrOutputParser()), sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), ) ) def ExplanationChain(): return ( RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever())) | RunnableParallel( answer=(RunnableLambda(lambda x: { "context": format_docs(x["context"]), "topic": x.get("subject", "topic"), "difficulty_level": x.get("difficulty_level", "beginner"), }) | explanation_prompt | finetuned_llm | StrOutputParser()), sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), ) ) def DefinitionChain(): return ( RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever())) | RunnableParallel( answer=(RunnableLambda(lambda x: { "context": format_docs(x["context"]), "term": x.get("subject", "term"), "difficulty_level": x.get("difficulty_level", "beginner"), }) | definition_prompt | finetuned_llm | StrOutputParser()), sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), ) ) # ------------------------- # Dispatcher for content generation # ------------------------- def ContentGenerator(): return RunnableBranch( (lambda x: x.get("request_type") == "quiz_generation", QuizChain()), (lambda x: x.get("request_type") == "flashcard_creation", FlashcardChain()), (lambda x: x.get("request_type") == "explanation", ExplanationChain()), (lambda x: x.get("request_type") == "definition", DefinitionChain()), RunnableLambda(lambda _: {"answer": "Unknown request.", "sources": []}), ) # ------------------------- # Analyzer stub # ------------------------- def LearningAnalyzer(): return RunnableLambda(lambda x: (print("LOG: LearningAnalyzer", x.get("input")), x)[1]) # ------------------------- # Main assistant chain # ------------------------- def run_educational_assistant(): return RunnableBranch( (lambda x: x.get("request_type") == "tutoring", AdaptiveConversationChain()), ContentGenerator(), ) # ------------------------- # Exposed pipelines # ------------------------- educational_assistant_chain = run_educational_assistant() | LearningAnalyzer() chat_chain_with_history = RunnableWithMessageHistory( educational_assistant_chain, get_memory_for_session, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer" ).with_types(input_type=ChatInput, output_type=ChatOutput) content_generation_chain = (ContentGenerator() | LearningAnalyzer()).with_types( input_type=ChatInput, output_type=ChatOutput )