KOkeke94 commited on
Commit
4c9413b
Β·
1 Parent(s): 1b7cf63

Fix: Remove nonexistent routing model, restore OpenAI-based routing

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -7,10 +7,10 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
  from langchain.chains import RetrievalQA
9
  from langchain_community.llms import HuggingFacePipeline
 
10
  from transformers import pipeline
11
 
12
- # βœ… Hugging Face pipelines
13
- routing_agent = pipeline("text-classification", model="BivinSadler/statistics-routing-agent")
14
  writer_model = pipeline("text2text-generation", model="BivinSadler/llama3-finetuned-Statistics")
15
  writer_llm = HuggingFacePipeline(pipeline=writer_model)
16
 
@@ -25,19 +25,37 @@ def build_rag_agent(pdf_path):
25
  retriever = vectorstore.as_retriever()
26
  return RetrievalQA.from_chain_type(llm=writer_llm, retriever=retriever, chain_type="stuff")
27
 
28
- # βœ… Load agents
29
  stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
30
  ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
31
 
32
- # βœ… Routing
 
 
 
 
33
  def route_question(question):
34
- label = routing_agent(question)[0]["label"]
35
- return {
36
- "LABEL_0": "stat6371",
37
- "LABEL_1": "ds7333"
38
- }.get(label, "general")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # βœ… Writing
41
  def writer_agent(raw_answer, audience="high school students"):
42
  prompt = f"""
43
  You are a talented science communicator. Your job is to explain the following answer in a way that is clear, short, and engaging for {audience}.
@@ -50,7 +68,7 @@ Write your response in 2–3 sentences. Avoid technical jargon.
50
  result = writer_model(prompt, max_new_tokens=200, do_sample=False)
51
  return result[0]['generated_text']
52
 
53
- # βœ… Core Logic
54
  def multiagent_system(question):
55
  print(f"🧭 Routing: {question}")
56
  route = route_question(question)
@@ -69,7 +87,7 @@ def multiagent_system(question):
69
  print("✍️ Simplifying...")
70
  return writer_agent(raw_answer)
71
 
72
- # βœ… Gradio
73
  iface = gr.Interface(
74
  fn=multiagent_system,
75
  inputs=gr.Textbox(lines=2, label="Ask a statistics question"),
 
7
  from langchain_community.vectorstores import FAISS
8
  from langchain.chains import RetrievalQA
9
  from langchain_community.llms import HuggingFacePipeline
10
+ from langchain_openai import ChatOpenAI
11
  from transformers import pipeline
12
 
13
+ # βœ… Load writer model and wrap it for LangChain
 
14
  writer_model = pipeline("text2text-generation", model="BivinSadler/llama3-finetuned-Statistics")
15
  writer_llm = HuggingFacePipeline(pipeline=writer_model)
16
 
 
25
  retriever = vectorstore.as_retriever()
26
  return RetrievalQA.from_chain_type(llm=writer_llm, retriever=retriever, chain_type="stuff")
27
 
28
+ # βœ… Load RAG agents
29
  stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
30
  ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
31
 
32
+ # βœ… Load OpenAI LLM for routing
33
+ openai_key = os.environ.get("OPENAI_API_KEY")
34
+ llm = ChatOpenAI(api_key=openai_key, model="gpt-3.5-turbo", temperature=0)
35
+
36
+ # βœ… Routing logic
37
  def route_question(question):
38
+ routing_prompt = f"""
39
+ You are a classification agent that helps route questions to the appropriate expert.
40
+
41
+ There are three possible categories:
42
+ A. Stat 6371 (Theoretical statistics course)
43
+ B. DS 7333 (Decision Analytics Course)
44
+ C. General statistics (any other statistics question)
45
+
46
+ Classify the following question into one of those three categories by answering only with a single letter: A, B, or C.
47
+
48
+ Question: "{question}"
49
+ Answer:"""
50
+ response = llm.invoke(routing_prompt).content.strip().upper()
51
+ if response.startswith("A"):
52
+ return "stat6371"
53
+ elif response.startswith("B"):
54
+ return "ds7333"
55
+ else:
56
+ return "general"
57
 
58
+ # βœ… Explanation agent
59
  def writer_agent(raw_answer, audience="high school students"):
60
  prompt = f"""
61
  You are a talented science communicator. Your job is to explain the following answer in a way that is clear, short, and engaging for {audience}.
 
68
  result = writer_model(prompt, max_new_tokens=200, do_sample=False)
69
  return result[0]['generated_text']
70
 
71
+ # βœ… Main logic
72
  def multiagent_system(question):
73
  print(f"🧭 Routing: {question}")
74
  route = route_question(question)
 
87
  print("✍️ Simplifying...")
88
  return writer_agent(raw_answer)
89
 
90
+ # βœ… Gradio UI
91
  iface = gr.Interface(
92
  fn=multiagent_system,
93
  inputs=gr.Textbox(lines=2, label="Ask a statistics question"),