rayymaxx's picture
Fixed fault
441b35f
# backend/chains/router.py
import os
import io
import shutil
import zipfile
import requests
from typing import Dict
from dotenv import load_dotenv
# -------------------------
# Boot-time configuration
# -------------------------
load_dotenv()
# Hugging Face cache dirs -> tmp (HF Spaces cannot write to /root)
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/sentence_transformers")
for _p in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["SENTENCE_TRANSFORMERS_HOME"]]:
os.makedirs(_p, exist_ok=True)
# -------------------------
# GitHub repo configuration
# -------------------------
REPO_URL = "https://github.com/rayymaxx/E-Learning-App"
BRANCH = os.environ.get("GIT_BRANCH", "main")
REPO_VECTOR_SUBFOLDER = os.environ.get("REPO_VECTOR_SUBFOLDER", "direct-ed-ai-assistant/app/vector_store")
VECTOR_STORE_PATH = os.environ.get("VECTOR_STORE_PATH", "/tmp/vector_store") # extracted Chroma folder
CHROMA_PERSIST_DIR = os.environ.get("CHROMA_PERSIST_DIR", "/tmp/chroma_db") # working Chroma instance dir
os.makedirs("/tmp", exist_ok=True)
# -------------------------
# Helpers
# -------------------------
def _repo_zip_prefix(repo_url: str, branch: str) -> str:
"""GitHub zips extract to <repo-name>-<branch>/..."""
repo_name = repo_url.rstrip("/").split("/")[-1]
return f"{repo_name}-{branch}"
def fetch_vector_store_if_needed(repo_url: str, branch: str, subfolder: str, dest_dir: str):
"""Download a subfolder from GitHub repo ZIP into dest_dir (if missing)."""
if os.path.exists(dest_dir) and os.listdir(dest_dir):
print(f"✅ Vector store already present at {dest_dir}")
return dest_dir
zip_url = f"{repo_url}/archive/refs/heads/{branch}.zip"
print(f"⬇️ Downloading vector store ZIP from {zip_url} ...")
r = requests.get(zip_url, timeout=60)
r.raise_for_status()
extract_root = "/tmp/repo_zip_extract"
shutil.rmtree(extract_root, ignore_errors=True)
os.makedirs(extract_root, exist_ok=True)
with zipfile.ZipFile(io.BytesIO(r.content)) as zf:
zf.extractall(extract_root)
root_prefix = _repo_zip_prefix(repo_url, branch)
source_folder = os.path.join(extract_root, root_prefix, subfolder)
if not os.path.exists(source_folder):
raise RuntimeError(f"❌ Could not find {subfolder} in {repo_url}@{branch}")
for name in os.listdir(source_folder):
src, dst = os.path.join(source_folder, name), os.path.join(dest_dir, name)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
shutil.copy2(src, dst)
print(f"✅ Vector store ready at {dest_dir}")
return dest_dir
# -------------------------
# Vector store initialization
# -------------------------
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
def init_vector_store():
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
try:
path = fetch_vector_store_if_needed(REPO_URL, BRANCH, REPO_VECTOR_SUBFOLDER, VECTOR_STORE_PATH)
print(f"📂 Using vector store from {path}")
return Chroma(persist_directory=path, embedding_function=embeddings)
except Exception as e:
print(f"⚠️ Could not load prebuilt Chroma DB: {e}")
print("➡️ Falling back to empty Chroma at", CHROMA_PERSIST_DIR)
os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True)
return Chroma(persist_directory=CHROMA_PERSIST_DIR, embedding_function=embeddings)
vector_store = init_vector_store()
# -------------------------
# LLM wiring
# -------------------------
from backend.llms.custom import CustomChatModel
from langchain_openai import ChatOpenAI
openai_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.1)
finetuned_llm = CustomChatModel(
api_url="https://Nutnell-E-Learning-platform.hf.space/generate"
).with_fallbacks([openai_llm])
# -------------------------
# Chains
# -------------------------
from langchain_core.runnables import (
RunnableBranch, RunnableLambda, RunnableParallel, RunnablePassthrough
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import get_buffer_string
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from backend.schemas.api_models import ChatInput, ChatOutput
from backend.prompts.templates import (
rag_prompt,
quiz_generator_prompt,
flashcard_generator_prompt,
explanation_prompt,
definition_prompt,
condense_question_prompt
)
def format_docs(docs): return "\n---\n".join(doc.page_content for doc in docs)
def get_sources_from_docs(docs): return [
{"source": doc.metadata.get("source_url", ""), "name": doc.metadata.get("source_name", "")}
for doc in docs
]
# Memory
memories: Dict[str, ChatMessageHistory] = {}
def get_memory_for_session(session_id: str):
if session_id not in memories:
memories[session_id] = ChatMessageHistory()
return memories[session_id]
# Retriever
def EducationalRetriever(): return vector_store.as_retriever(search_kwargs={"k": 2})
# -------------------------
# Conversation chain
# -------------------------
def AdaptiveConversationChain():
retriever = EducationalRetriever()
condense = (
RunnableLambda(lambda x: {"question": x["input"], "chat_history": get_buffer_string(x["chat_history"])} )
| condense_question_prompt | openai_llm | StrOutputParser()
)
return (
RunnablePassthrough.assign(standalone_question=condense)
.assign(context=(RunnableLambda(lambda x: x["standalone_question"]) | retriever))
| RunnableParallel(
answer=(RunnableLambda(lambda x: {
"context": format_docs(x["context"]),
"question": x["input"],
"subject": x.get("subject", "the topic"),
"difficulty_level": x.get("difficulty_level", "beginner"),
}) | rag_prompt | finetuned_llm | StrOutputParser()),
sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])),
)
)
# -------------------------
# Generalized Content Chains
# -------------------------
def QuizChain():
return (
RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever()))
| RunnableParallel(
answer=(RunnableLambda(lambda x: {
"context": format_docs(x["context"]),
"subject": x.get("subject", "topic"),
"difficulty_level": x.get("difficulty_level", "intermediate"),
}) | quiz_generator_prompt | finetuned_llm | StrOutputParser()),
sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])),
)
)
def FlashcardChain():
return (
RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever()))
| RunnableParallel(
answer=(RunnableLambda(lambda x: {
"context": format_docs(x["context"]),
"difficulty_level": x.get("difficulty_level", "beginner"),
}) | flashcard_generator_prompt | finetuned_llm | StrOutputParser()),
sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])),
)
)
def ExplanationChain():
return (
RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever()))
| RunnableParallel(
answer=(RunnableLambda(lambda x: {
"context": format_docs(x["context"]),
"topic": x.get("subject", "topic"),
"difficulty_level": x.get("difficulty_level", "beginner"),
}) | explanation_prompt | finetuned_llm | StrOutputParser()),
sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])),
)
)
def DefinitionChain():
return (
RunnablePassthrough.assign(context=(RunnableLambda(lambda x: x["input"]) | EducationalRetriever()))
| RunnableParallel(
answer=(RunnableLambda(lambda x: {
"context": format_docs(x["context"]),
"term": x.get("subject", "term"),
"difficulty_level": x.get("difficulty_level", "beginner"),
}) | definition_prompt | finetuned_llm | StrOutputParser()),
sources=RunnableLambda(lambda x: get_sources_from_docs(x["context"])),
)
)
# -------------------------
# Dispatcher for content generation
# -------------------------
def ContentGenerator():
return RunnableBranch(
(lambda x: x.get("request_type") == "quiz_generation", QuizChain()),
(lambda x: x.get("request_type") == "flashcard_creation", FlashcardChain()),
(lambda x: x.get("request_type") == "explanation", ExplanationChain()),
(lambda x: x.get("request_type") == "definition", DefinitionChain()),
RunnableLambda(lambda _: {"answer": "Unknown request.", "sources": []}),
)
# -------------------------
# Analyzer stub
# -------------------------
def LearningAnalyzer():
return RunnableLambda(lambda x: (print("LOG: LearningAnalyzer", x.get("input")), x)[1])
# -------------------------
# Main assistant chain
# -------------------------
def run_educational_assistant():
return RunnableBranch(
(lambda x: x.get("request_type") == "tutoring", AdaptiveConversationChain()),
ContentGenerator(),
)
# -------------------------
# Exposed pipelines
# -------------------------
educational_assistant_chain = run_educational_assistant() | LearningAnalyzer()
chat_chain_with_history = RunnableWithMessageHistory(
educational_assistant_chain, get_memory_for_session,
input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer"
).with_types(input_type=ChatInput, output_type=ChatOutput)
content_generation_chain = (ContentGenerator() | LearningAnalyzer()).with_types(
input_type=ChatInput, output_type=ChatOutput
)