KnockoutNed commited on
Commit
d767e18
Β·
verified Β·
1 Parent(s): 7245371

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +92 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from transformers.pipelines import pipeline
4
+ from langchain.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.chains import RetrievalQA
9
+ from langchain_community.llms import HuggingFacePipeline
10
+
11
+ # βœ… Load Hugging Face LLM (LLama 3 fine-tuned model)
12
+ llm_pipeline = pipeline("text2text-generation", model="BivinSadler/llama3-finetuned-Statistics", max_length=512)
13
+ llm = HuggingFacePipeline(pipeline=llm_pipeline)
14
+
15
+ # βœ… Create embeddings for RAG
16
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
17
+
18
+ # βœ… Build RAG Agent Function
19
+ def build_rag_agent(pdf_path):
20
+ loader = PyPDFLoader(pdf_path)
21
+ docs = loader.load()
22
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
23
+ chunks = splitter.split_documents(docs)
24
+ vectorstore = FAISS.from_documents(chunks, embedding_model)
25
+ retriever = vectorstore.as_retriever()
26
+ return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
27
+
28
+ # βœ… Create RAG agents for both syllabi
29
+ stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
30
+ ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
31
+
32
+ # βœ… Writer Agent (makes answers easier)
33
+ def writer_agent(raw_answer, audience="high school students"):
34
+ prompt = f"""
35
+ You are a skilled teacher explaining to {audience}. Simplify the following answer in 2–3 short, clear sentences:
36
+
37
+ Answer:
38
+ {raw_answer}
39
+ """
40
+ result = llm_pipeline(prompt, max_length=200, do_sample=False)
41
+ return result[0]['generated_text']
42
+
43
+ # βœ… Question Routing Agent (classifies the question)
44
+ def route_question(question):
45
+ routing_prompt = f"""
46
+ You are a routing agent. Classify the question into one of:
47
+ A. Stat 6371
48
+ B. DS 7333
49
+ C. General statistics
50
+
51
+ Question: "{question}"
52
+ Answer with only A, B, or C.
53
+ """
54
+ result = llm_pipeline(routing_prompt, max_length=10, do_sample=False)
55
+ route = result[0]['generated_text'].strip().upper()
56
+ if route.startswith("A"):
57
+ return "stat6371"
58
+ elif route.startswith("B"):
59
+ return "ds7333"
60
+ else:
61
+ return "general"
62
+
63
+ # βœ… Multi-Agent Pipeline
64
+ def multiagent_system(question):
65
+ print(f"\n🧭 Routing: {question}")
66
+ route = route_question(question)
67
+
68
+ if route == "stat6371":
69
+ print("πŸ”Ž Stat 6371 Agent")
70
+ raw_answer = stat6371_agent.run(question)
71
+ elif route == "ds7333":
72
+ print("πŸ”Ž DS 7333 Agent")
73
+ raw_answer = ds7333_agent.run(question)
74
+ else:
75
+ print("🧠 General Stats HF Agent")
76
+ result = llm_pipeline(question, max_length=200, do_sample=False)
77
+ raw_answer = result[0]['generated_text']
78
+
79
+ print("✍️ Simplifying...")
80
+ return writer_agent(raw_answer)
81
+
82
+ # βœ… Gradio UI
83
+ iface = gr.Interface(
84
+ fn=multiagent_system,
85
+ inputs=gr.Textbox(lines=2, label="Ask a statistics question"),
86
+ outputs=gr.Textbox(label="Answer"),
87
+ title="πŸ“Š Multi-Agent Statistics Assistant (HuggingFace)",
88
+ description="Routes your stats question to the right syllabus (Stat 6371, DS 7333) or uses a general statistics model (LLama3)."
89
+ )
90
+
91
+ if __name__ == "__main__":
92
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ faiss-cpu
3
+ PyPDF2
4
+ pypdf
5
+ transformers
6
+ sentence-transformers
7
+ huggingface-hub
8
+ langchain
9
+ langchain-community
10
+
11
+