KOkeke94 commited on
Commit
1b7cf63
Β·
1 Parent(s): 06a77e0

Fix: Wrap HF pipeline with LangChain, correct imports, remove OpenAI deps

Browse files
Files changed (2) hide show
  1. app.py +25 -40
  2. requirements.txt +1 -0
app.py CHANGED
@@ -3,57 +3,41 @@ import gradio as gr
3
  import torch
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.embeddings import OpenAIEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
  from langchain.chains import RetrievalQA
9
- from langchain_openai import ChatOpenAI
10
- from transformers.pipelines import pipeline
11
 
12
- # βœ… Load API key from Hugging Face secret
13
- openai_key = os.environ.get("OPENAI_API_KEY")
14
- llm = ChatOpenAI(api_key=openai_key, model="gpt-3.5-turbo", temperature=0)
 
15
 
16
- # βœ… Build RAG agent
17
  def build_rag_agent(pdf_path):
18
  loader = PyPDFLoader(pdf_path)
19
  docs = loader.load()
20
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
21
  chunks = splitter.split_documents(docs)
22
- embeddings = OpenAIEmbeddings(api_key=openai_key)
23
  vectorstore = FAISS.from_documents(chunks, embeddings)
24
  retriever = vectorstore.as_retriever()
25
- return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
26
 
27
- # βœ… Load RAG agents
28
  stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
29
  ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
30
 
31
- # βœ… Load Hugging Face fine-tuned model
32
- general_stat_agent = pipeline("text2text-generation", model="BivinSadler/llama3-finetuned-Statistics")
 
 
 
 
 
33
 
34
- # βœ… Routing logic
35
- def route_question_llm(question):
36
- prompt = f"""
37
- You are a classification agent that helps route questions to the appropriate expert.
38
-
39
- There are three possible categories:
40
- A. Stat 6371 (Theoretical statistics course)
41
- B. DS 7333 (Decision Analytics Course)
42
- C. General statistics (any other statistics question)
43
-
44
- Classify the following question into one of those three categories by answering only with a single letter: A, B, or C.
45
-
46
- Question: "{question}"
47
- Answer:"""
48
- response = llm.invoke(prompt).content.strip().upper()
49
- if response.startswith("A"):
50
- return "stat6371"
51
- elif response.startswith("B"):
52
- return "ds7333"
53
- else:
54
- return "general"
55
-
56
- # βœ… Writer agent
57
  def writer_agent(raw_answer, audience="high school students"):
58
  prompt = f"""
59
  You are a talented science communicator. Your job is to explain the following answer in a way that is clear, short, and engaging for {audience}.
@@ -63,12 +47,13 @@ Answer:
63
 
64
  Write your response in 2–3 sentences. Avoid technical jargon.
65
  """
66
- return llm.invoke(prompt).content
 
67
 
68
- # βœ… Main app logic
69
  def multiagent_system(question):
70
  print(f"🧭 Routing: {question}")
71
- route = route_question_llm(question)
72
 
73
  if route == "stat6371":
74
  print("πŸ”Ž Stat 6371 RAG")
@@ -78,13 +63,13 @@ def multiagent_system(question):
78
  raw_answer = ds7333_agent.run(question)
79
  else:
80
  print("🧠 General Stats HF Agent")
81
- result = general_stat_agent(question, max_new_tokens=200, do_sample=False)
82
  raw_answer = result[0]['generated_text']
83
 
84
  print("✍️ Simplifying...")
85
  return writer_agent(raw_answer)
86
 
87
- # βœ… Gradio UI
88
  iface = gr.Interface(
89
  fn=multiagent_system,
90
  inputs=gr.Textbox(lines=2, label="Ask a statistics question"),
 
3
  import torch
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
  from langchain.chains import RetrievalQA
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from transformers import pipeline
11
 
12
+ # βœ… Hugging Face pipelines
13
+ routing_agent = pipeline("text-classification", model="BivinSadler/statistics-routing-agent")
14
+ writer_model = pipeline("text2text-generation", model="BivinSadler/llama3-finetuned-Statistics")
15
+ writer_llm = HuggingFacePipeline(pipeline=writer_model)
16
 
17
+ # βœ… RAG Agent Builder
18
  def build_rag_agent(pdf_path):
19
  loader = PyPDFLoader(pdf_path)
20
  docs = loader.load()
21
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
22
  chunks = splitter.split_documents(docs)
23
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
24
  vectorstore = FAISS.from_documents(chunks, embeddings)
25
  retriever = vectorstore.as_retriever()
26
+ return RetrievalQA.from_chain_type(llm=writer_llm, retriever=retriever, chain_type="stuff")
27
 
28
+ # βœ… Load agents
29
  stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
30
  ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
31
 
32
+ # βœ… Routing
33
+ def route_question(question):
34
+ label = routing_agent(question)[0]["label"]
35
+ return {
36
+ "LABEL_0": "stat6371",
37
+ "LABEL_1": "ds7333"
38
+ }.get(label, "general")
39
 
40
+ # βœ… Writing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def writer_agent(raw_answer, audience="high school students"):
42
  prompt = f"""
43
  You are a talented science communicator. Your job is to explain the following answer in a way that is clear, short, and engaging for {audience}.
 
47
 
48
  Write your response in 2–3 sentences. Avoid technical jargon.
49
  """
50
+ result = writer_model(prompt, max_new_tokens=200, do_sample=False)
51
+ return result[0]['generated_text']
52
 
53
+ # βœ… Core Logic
54
  def multiagent_system(question):
55
  print(f"🧭 Routing: {question}")
56
+ route = route_question(question)
57
 
58
  if route == "stat6371":
59
  print("πŸ”Ž Stat 6371 RAG")
 
63
  raw_answer = ds7333_agent.run(question)
64
  else:
65
  print("🧠 General Stats HF Agent")
66
+ result = writer_model(question, max_new_tokens=200, do_sample=False)
67
  raw_answer = result[0]['generated_text']
68
 
69
  print("✍️ Simplifying...")
70
  return writer_agent(raw_answer)
71
 
72
+ # βœ… Gradio
73
  iface = gr.Interface(
74
  fn=multiagent_system,
75
  inputs=gr.Textbox(lines=2, label="Ask a statistics question"),
requirements.txt CHANGED
@@ -5,6 +5,7 @@ faiss-cpu
5
  PyPDF2
6
  pypdf
7
  transformers
 
8
  gradio
9
  torch
10
  tiktoken
 
5
  PyPDF2
6
  pypdf
7
  transformers
8
+ sentence-transformers
9
  gradio
10
  torch
11
  tiktoken