Spaces:
Sleeping
Sleeping
| # backend/chains/router.py | |
| import os | |
| import io | |
| import shutil | |
| import zipfile | |
| import requests | |
| from typing import Dict | |
| from dotenv import load_dotenv | |
| # ------------------------- | |
| # Boot-time configuration | |
| # ------------------------- | |
| load_dotenv() | |
| # Hugging Face cache dirs -> tmp (HF Spaces cannot write to /root) | |
| os.environ.setdefault("HF_HOME", "/tmp/huggingface") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") | |
| os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/sentence_transformers") | |
| for _p in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["SENTENCE_TRANSFORMERS_HOME"]]: | |
| os.makedirs(_p, exist_ok=True) | |
| # ------------------------- | |
| # GitHub repo configuration | |
| # ------------------------- | |
| REPO_URL = "https://github.com/rayymaxx/E-Learning-App" | |
| BRANCH = os.environ.get("GIT_BRANCH", "main") | |
| REPO_VECTOR_SUBFOLDER = os.environ.get("REPO_VECTOR_SUBFOLDER", "direct-ed-ai-assistant/app/vector_store") | |
| VECTOR_STORE_PATH = os.environ.get("VECTOR_STORE_PATH", "/tmp/vector_store") # extracted Chroma folder | |
| CHROMA_PERSIST_DIR = os.environ.get("CHROMA_PERSIST_DIR", "/tmp/chroma_db") # working Chroma instance dir | |
| os.makedirs("/tmp", exist_ok=True) | |
| # ------------------------- | |
| # Helpers | |
| # ------------------------- | |
| def _repo_zip_prefix(repo_url: str, branch: str) -> str: | |
| """GitHub zips extract to <repo-name>-<branch>/...""" | |
| repo_name = repo_url.rstrip("/").split("/")[-1] | |
| return f"{repo_name}-{branch}" | |
| def fetch_vector_store_if_needed(repo_url: str, branch: str, subfolder: str, dest_dir: str): | |
| """Download a subfolder from GitHub repo ZIP into dest_dir (if missing).""" | |
| if os.path.exists(dest_dir) and os.listdir(dest_dir): | |
| print(f"✅ Vector store already present at {dest_dir}") | |
| return dest_dir | |
| zip_url = f"{repo_url}/archive/refs/heads/{branch}.zip" | |
| print(f"⬇️ Downloading vector store ZIP from {zip_url} ...") | |
| r = requests.get(zip_url, timeout=60) | |
| r.raise_for_status() | |
| extract_root = "/tmp/repo_zip_extract" | |
| shutil.rmtree(extract_root, ignore_errors=True) | |
| os.makedirs(extract_root, exist_ok=True) | |
| with zipfile.ZipFile(io.BytesIO(r.content)) as zf: | |
| zf.extractall(extract_root) | |
| root_prefix = _repo_zip_prefix(repo_url, branch) | |
| source_folder = os.path.join(extract_root, root_prefix, subfolder) | |
| if not os.path.exists(source_folder): | |
| raise RuntimeError(f"❌ Could not find {subfolder} in {repo_url}@{branch}") | |
| for name in os.listdir(source_folder): | |
| src, dst = os.path.join(source_folder, name), os.path.join(dest_dir, name) | |
| if os.path.isdir(src): | |
| shutil.copytree(src, dst, dirs_exist_ok=True) | |
| else: | |
| shutil.copy2(src, dst) | |
| print(f"✅ Vector store ready at {dest_dir}") | |
| return dest_dir | |
| # ------------------------- | |
| # Vector store initialization | |
| # ------------------------- | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_chroma import Chroma | |
| def init_vector_store(): | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| try: | |
| path = fetch_vector_store_if_needed(REPO_URL, BRANCH, REPO_VECTOR_SUBFOLDER, VECTOR_STORE_PATH) | |
| print(f"📂 Using vector store from {path}") | |
| return Chroma(persist_directory=path, embedding_function=embeddings) | |
| except Exception as e: | |
| print(f"⚠️ Could not load prebuilt Chroma DB: {e}") | |
| print("➡️ Falling back to empty Chroma at", CHROMA_PERSIST_DIR) | |
| os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True) | |
| return Chroma(persist_directory=CHROMA_PERSIST_DIR, embedding_function=embeddings) | |
| vector_store = init_vector_store() | |
| # ------------------------- | |
| # LLM wiring | |
| # ------------------------- | |
| from backend.llms.custom import CustomChatModel | |
| from langchain_openai import ChatOpenAI | |
| openai_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.1) | |
| finetuned_llm = CustomChatModel( | |
| api_url="https://Nutnell-E-Learning-platform.hf.space/generate" | |
| ).with_fallbacks([openai_llm]) | |
| # ------------------------- | |
| # Chains | |
| # ------------------------- | |
| from langchain_core.runnables import ( | |
| RunnableBranch, RunnableLambda, RunnableParallel, RunnablePassthrough | |
| ) | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.messages import get_buffer_string | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from backend.schemas.api_models import ChatInput, ChatOutput | |
| from backend.prompts.templates import ( | |
| rag_prompt, | |
| quiz_generator_prompt, | |
| flashcard_generator_prompt, | |
| explanation_prompt, | |
| definition_prompt, | |
| condense_question_prompt | |
| ) | |
| def format_docs(docs): return "\n---\n".join(doc.page_content for doc in docs) | |
| def get_sources_from_docs(docs): return [ | |
| {"source": doc.metadata.get("source_url", ""), "name": doc.metadata.get("source_name", "")} | |
| for doc in docs | |
| ] | |
| # Memory | |
| memories: Dict[str, ChatMessageHistory] = {} | |
| def get_memory_for_session(session_id: str): | |
| if session_id not in memories: | |
| memories[session_id] = ChatMessageHistory() | |
| return memories[session_id] | |
| # Retriever | |
| def EducationalRetriever(): return vector_store.as_retriever(search_kwargs={"k": 2}) | |
| # ------------------------- | |
| # Conversation chain | |
| # ------------------------- | |
| def AdaptiveConversationChain(): | |
| retriever = EducationalRetriever() | |
| condense = ( | |
| RunnableLambda(lambda x: {"question": x["input"], "chat_history": get_buffer_string(x["chat_history"])} ) | |
| | condense_question_prompt | openai_llm | StrOutputParser() | |
| ) | |
| return ( | |
| RunnablePassthrough.assign(standalone_question=condense) | |
| .assign(context=(RunnableLambda(lambda x: x["standalone_question"]) | retriever)) | |
| | RunnableParallel( | |
| answer=(RunnableLambda(lambda x: { | |
| "context": format_docs(x["context"]), | |
| "question": x["input"], | |
| "subject": x.get("subject", "the topic"), | |
| "difficulty_level": x.get("difficulty_level", "beginner"), | |
| }) | rag_prompt | finetuned_llm | StrOutputParser()), | |
| sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), | |
| ) | |
| ) | |
| # ------------------------- | |
| # Generalized Content Chains | |
| # ------------------------- | |
| def QuizChain(): | |
| return ( | |
| RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever())) | |
| | RunnableParallel( | |
| answer=(RunnableLambda(lambda x: { | |
| "context": format_docs(x["context"]), | |
| "subject": x.get("subject", "topic"), | |
| "difficulty_level": x.get("difficulty_level", "intermediate"), | |
| }) | quiz_generator_prompt | finetuned_llm | StrOutputParser()), | |
| sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), | |
| ) | |
| ) | |
| def FlashcardChain(): | |
| return ( | |
| RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever())) | |
| | RunnableParallel( | |
| answer=(RunnableLambda(lambda x: { | |
| "context": format_docs(x["context"]), | |
| "difficulty_level": x.get("difficulty_level", "beginner"), | |
| }) | flashcard_generator_prompt | finetuned_llm | StrOutputParser()), | |
| sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), | |
| ) | |
| ) | |
| def ExplanationChain(): | |
| return ( | |
| RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever())) | |
| | RunnableParallel( | |
| answer=(RunnableLambda(lambda x: { | |
| "context": format_docs(x["context"]), | |
| "topic": x.get("subject", "topic"), | |
| "difficulty_level": x.get("difficulty_level", "beginner"), | |
| }) | explanation_prompt | finetuned_llm | StrOutputParser()), | |
| sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), | |
| ) | |
| ) | |
| def DefinitionChain(): | |
| return ( | |
| RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever())) | |
| | RunnableParallel( | |
| answer=(RunnableLambda(lambda x: { | |
| "context": format_docs(x["context"]), | |
| "term": x.get("subject", "term"), | |
| "difficulty_level": x.get("difficulty_level", "beginner"), | |
| }) | definition_prompt | finetuned_llm | StrOutputParser()), | |
| sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])), | |
| ) | |
| ) | |
| # ------------------------- | |
| # Dispatcher for content generation | |
| # ------------------------- | |
| def ContentGenerator(): | |
| return RunnableBranch( | |
| (lambda x: x.get("request_type") == "quiz_generation", QuizChain()), | |
| (lambda x: x.get("request_type") == "flashcard_creation", FlashcardChain()), | |
| (lambda x: x.get("request_type") == "explanation", ExplanationChain()), | |
| (lambda x: x.get("request_type") == "definition", DefinitionChain()), | |
| RunnableLambda(lambda _: {"answer": "Unknown request.", "sources": []}), | |
| ) | |
| # ------------------------- | |
| # Analyzer stub | |
| # ------------------------- | |
| def LearningAnalyzer(): | |
| return RunnableLambda(lambda x: (print("LOG: LearningAnalyzer", x.get("input")), x)[1]) | |
| # ------------------------- | |
| # Main assistant chain | |
| # ------------------------- | |
| def run_educational_assistant(): | |
| return RunnableBranch( | |
| (lambda x: x.get("request_type") == "tutoring", AdaptiveConversationChain()), | |
| ContentGenerator(), | |
| ) | |
| # ------------------------- | |
| # Exposed pipelines | |
| # ------------------------- | |
| educational_assistant_chain = run_educational_assistant() | LearningAnalyzer() | |
| chat_chain_with_history = RunnableWithMessageHistory( | |
| educational_assistant_chain, get_memory_for_session, | |
| input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer" | |
| ).with_types(input_type=ChatInput, output_type=ChatOutput) | |
| content_generation_chain = (ContentGenerator() | LearningAnalyzer()).with_types( | |
| input_type=ChatInput, output_type=ChatOutput | |
| ) | |