KOkeke94
Final version: multi-agent stats assistant using RAG + OpenAI router
d41aa32
import os
import gradio as gr
import torch
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
from transformers.pipelines import pipeline
# βœ… Load writer model and wrap it for LangChain
writer_model = pipeline("text-generation", model="BivinSadler/llama3-finetuned-Statistics", return_full_text=False)
writer_llm = HuggingFacePipeline(pipeline=writer_model)
# βœ… RAG Agent Builder
def build_rag_agent(pdf_path):
loader = PyPDFLoader(pdf_path)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = splitter.split_documents(docs)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(chunks, embeddings)
retriever = vectorstore.as_retriever()
return RetrievalQA.from_chain_type(llm=writer_llm, retriever=retriever, chain_type="stuff")
# βœ… Load RAG agents
stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
# βœ… Load OpenAI LLM for routing
openai_key = os.environ.get("OPENAI_API_KEY")
llm = ChatOpenAI(api_key=openai_key, model="gpt-3.5-turbo", temperature=0)
# βœ… Routing logic
def route_question(question):
routing_prompt = f"""
You are a classification agent that helps route questions to the appropriate expert.
There are three possible categories:
A. Stat 6371 (Theoretical statistics course)
B. DS 7333 (Decision Analytics Course)
C. General statistics (any other statistics question)
Classify the following question into one of those three categories by answering only with a single letter: A, B, or C.
Question: "{question}"
Answer:"""
response = llm.invoke(routing_prompt).content.strip().upper()
if response.startswith("A"):
return "stat6371"
elif response.startswith("B"):
return "ds7333"
else:
return "general"
# βœ… Explanation agent
def writer_agent(raw_answer, audience="high school students"):
prompt = f"""
You are a talented science communicator. Your job is to explain the following answer in a way that is clear, short, and engaging for {audience}.
Answer:
{raw_answer}
Write your response in 2–3 sentences. Avoid technical jargon.
"""
result = writer_model(prompt, max_new_tokens=200)
return result[0]['generated_text']
# βœ… Main logic
def multiagent_system(question):
print(f"🧭 Routing: {question}")
route = route_question(question)
if route == "stat6371":
print("πŸ”Ž Stat 6371 RAG")
raw_answer = stat6371_agent.run(question)
elif route == "ds7333":
print("πŸ”Ž DS 7333 RAG")
raw_answer = ds7333_agent.run(question)
else:
print("🧠 General Stats HF Agent")
result = writer_model(question, max_new_tokens=200, do_sample=False)
raw_answer = result[0]['generated_text']
print("✍️ Simplifying...")
return writer_agent(raw_answer)
# βœ… Gradio UI
iface = gr.Interface(
fn=multiagent_system,
inputs=gr.Textbox(lines=2, label="Ask a statistics question"),
outputs=gr.Textbox(label="Answer"),
title="πŸ“Š Multi-Agent Statistics Assistant",
description="Routes your stats question to the right syllabus (Stat 6371, DS 7333) or uses a general statistics model (Llama3)."
)
iface.launch()