KOkeke94
Fix: Update deprecated imports, add tiktoken, migrate to langchain_community
3d87318
raw
history blame
3.52 kB
import os
import gradio as gr
import torch
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
from transformers.pipelines import pipeline
# βœ… Load API key from Hugging Face secret
openai_key = os.environ.get("OPENAI_API_KEY")
llm = ChatOpenAI(api_key=openai_key, model="gpt-3.5-turbo", temperature=0)
# βœ… Build RAG agent
def build_rag_agent(pdf_path):
loader = PyPDFLoader(pdf_path)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = splitter.split_documents(docs)
embeddings = OpenAIEmbeddings(api_key=openai_key)
vectorstore = FAISS.from_documents(chunks, embeddings)
retriever = vectorstore.as_retriever()
return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
# βœ… Load RAG agents
stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
# βœ… Load Hugging Face fine-tuned model
general_stat_agent = pipeline("text2text-generation", model="BivinSadler/llama3-finetuned-Statistics")
# βœ… Routing logic
def route_question_llm(question):
prompt = f"""
You are a classification agent that helps route questions to the appropriate expert.
There are three possible categories:
A. Stat 6371 (Theoretical statistics course)
B. DS 7333 (Decision Analytics Course)
C. General statistics (any other statistics question)
Classify the following question into one of those three categories by answering only with a single letter: A, B, or C.
Question: "{question}"
Answer:"""
response = llm.invoke(prompt).content.strip().upper()
if response.startswith("A"):
return "stat6371"
elif response.startswith("B"):
return "ds7333"
else:
return "general"
# βœ… Writer agent
def writer_agent(raw_answer, audience="high school students"):
prompt = f"""
You are a talented science communicator. Your job is to explain the following answer in a way that is clear, short, and engaging for {audience}.
Answer:
{raw_answer}
Write your response in 2–3 sentences. Avoid technical jargon.
"""
return llm.invoke(prompt).content
# βœ… Main app logic
def multiagent_system(question):
print(f"🧭 Routing: {question}")
route = route_question_llm(question)
if route == "stat6371":
print("πŸ”Ž Stat 6371 RAG")
raw_answer = stat6371_agent.run(question)
elif route == "ds7333":
print("πŸ”Ž DS 7333 RAG")
raw_answer = ds7333_agent.run(question)
else:
print("🧠 General Stats HF Agent")
result = general_stat_agent(question, max_new_tokens=200, do_sample=False)
raw_answer = result[0]['generated_text']
print("✍️ Simplifying...")
return writer_agent(raw_answer)
# βœ… Gradio UI
iface = gr.Interface(
fn=multiagent_system,
inputs=gr.Textbox(lines=2, label="Ask a statistics question"),
outputs=gr.Textbox(label="Answer"),
title="πŸ“Š Multi-Agent Statistics Assistant",
description="Routes your stats question to the right syllabus (Stat 6371, DS 7333) or uses a general statistics model (Llama3)."
)
iface.launch()