KOkeke94 commited on
Commit
297f3ae
Β·
1 Parent(s): a2af92f

Fix: add missing torch import for HF pipeline

Browse files
Files changed (1) hide show
  1. app.py +48 -45
app.py CHANGED
@@ -1,79 +1,83 @@
1
  import os
2
  import gradio as gr
3
- from transformers.pipelines import pipeline
4
- from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.embeddings import HuggingFaceEmbeddings
7
- from langchain_community.vectorstores import FAISS
8
  from langchain.chains import RetrievalQA
9
- from langchain_community.llms import HuggingFacePipeline
10
-
11
- # βœ… Load Hugging Face LLM (LLama 3 fine-tuned model)
12
- llm_pipeline = pipeline("text2text-generation", model="BivinSadler/llama3-finetuned-Statistics", max_length=512)
13
- llm = HuggingFacePipeline(pipeline=llm_pipeline)
14
 
15
- # βœ… Create embeddings for RAG
16
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
17
 
18
- # βœ… Build RAG Agent Function
19
  def build_rag_agent(pdf_path):
20
  loader = PyPDFLoader(pdf_path)
21
  docs = loader.load()
22
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
23
  chunks = splitter.split_documents(docs)
24
- vectorstore = FAISS.from_documents(chunks, embedding_model)
25
  retriever = vectorstore.as_retriever()
26
  return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
27
 
28
- # βœ… Create RAG agents for both syllabi
29
  stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
30
  ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
31
 
32
- # βœ… Writer Agent (makes answers easier)
33
- def writer_agent(raw_answer, audience="high school students"):
 
 
 
34
  prompt = f"""
35
- You are a skilled teacher explaining to {audience}. Simplify the following answer in 2–3 short, clear sentences:
36
 
37
- Answer:
38
- {raw_answer}
39
- """
40
- result = llm_pipeline(prompt, max_length=200, do_sample=False)
41
- return result[0]['generated_text']
42
 
43
- # βœ… Question Routing Agent (classifies the question)
44
- def route_question(question):
45
- routing_prompt = f"""
46
- You are a routing agent. Classify the question into one of:
47
- A. Stat 6371
48
- B. DS 7333
49
- C. General statistics
50
 
51
  Question: "{question}"
52
- Answer with only A, B, or C.
53
- """
54
- result = llm_pipeline(routing_prompt, max_new_tokens=30, do_sample=False)
55
- route = result[0]['generated_text'].strip().upper()
56
- if route.startswith("A"):
57
  return "stat6371"
58
- elif route.startswith("B"):
59
  return "ds7333"
60
  else:
61
  return "general"
62
 
63
- # βœ… Multi-Agent Pipeline
 
 
 
 
 
 
 
 
 
 
 
 
64
  def multiagent_system(question):
65
- print(f"\n🧭 Routing: {question}")
66
- route = route_question(question)
67
 
68
  if route == "stat6371":
69
- print("πŸ”Ž Stat 6371 Agent")
70
  raw_answer = stat6371_agent.run(question)
71
  elif route == "ds7333":
72
- print("πŸ”Ž DS 7333 Agent")
73
  raw_answer = ds7333_agent.run(question)
74
  else:
75
  print("🧠 General Stats HF Agent")
76
- result = llm_pipeline(question, max_length=200, do_sample=False)
77
  raw_answer = result[0]['generated_text']
78
 
79
  print("✍️ Simplifying...")
@@ -84,9 +88,8 @@ iface = gr.Interface(
84
  fn=multiagent_system,
85
  inputs=gr.Textbox(lines=2, label="Ask a statistics question"),
86
  outputs=gr.Textbox(label="Answer"),
87
- title="πŸ“Š Multi-Agent Statistics Assistant (HuggingFace)",
88
- description="Routes your stats question to the right syllabus (Stat 6371, DS 7333) or uses a general statistics model (LLama3)."
89
  )
90
 
91
- if __name__ == "__main__":
92
- iface.launch()
 
1
  import os
2
  import gradio as gr
3
+ import torch
4
+ from langchain.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import OpenAIEmbeddings
7
+ from langchain.vectorstores import FAISS
8
  from langchain.chains import RetrievalQA
9
+ from langchain.chat_models import ChatOpenAI
10
+ from transformers.pipelines import pipeline
 
 
 
11
 
12
+ # βœ… Load API key from environment variable (set in Hugging Face Secrets)
13
+ openai_key = os.environ.get("OPENAI_API_KEY")
14
+ llm = ChatOpenAI(openai_api_key=openai_key, model_name="gpt-3.5-turbo", temperature=0)
15
 
16
+ # βœ… Build RAG agent
17
  def build_rag_agent(pdf_path):
18
  loader = PyPDFLoader(pdf_path)
19
  docs = loader.load()
20
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
21
  chunks = splitter.split_documents(docs)
22
+ vectorstore = FAISS.from_documents(chunks, OpenAIEmbeddings(openai_api_key=openai_key))
23
  retriever = vectorstore.as_retriever()
24
  return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
25
 
26
+ # βœ… Load RAG agents
27
  stat6371_agent = build_rag_agent("PDFs/DS 6371 Syllabus Ver 6.pdf")
28
  ds7333_agent = build_rag_agent("PDFs/ds-7333_syllabus.pdf")
29
 
30
+ # βœ… Load HF fine-tuned model for general stats
31
+ general_stat_agent = pipeline("text2text-generation", model="BivinSadler/llama3-finetuned-Statistics")
32
+
33
+ # βœ… Routing agent
34
+ def route_question_llm(question):
35
  prompt = f"""
36
+ You are a classification agent that helps route questions to the appropriate expert.
37
 
38
+ There are three possible categories:
39
+ A. Stat 6371 (Theoretical statistics course)
40
+ B. DS 7333 (Decision Analytics Course)
41
+ C. General statistics (any other statistics question)
 
42
 
43
+ Classify the following question into one of those three categories by answering only with a single letter: A, B, or C.
 
 
 
 
 
 
44
 
45
  Question: "{question}"
46
+ Answer:"""
47
+ route_response = llm.invoke(prompt).content.strip().upper()
48
+ if route_response.startswith("A"):
 
 
49
  return "stat6371"
50
+ elif route_response.startswith("B"):
51
  return "ds7333"
52
  else:
53
  return "general"
54
 
55
+ # βœ… Writer agent
56
+ def writer_agent(raw_answer, audience="high school students"):
57
+ prompt = f"""
58
+ You are a talented science communicator. Your job is to explain the following answer in a way that is clear, short, and engaging for {audience}.
59
+
60
+ Answer:
61
+ {raw_answer}
62
+
63
+ Write your response in 2–3 sentences. Avoid technical jargon.
64
+ """
65
+ return llm.invoke(prompt).content
66
+
67
+ # βœ… Multi-agent logic
68
  def multiagent_system(question):
69
+ print(f"🧭 Routing: {question}")
70
+ route = route_question_llm(question)
71
 
72
  if route == "stat6371":
73
+ print("πŸ”Ž Stat 6371 RAG")
74
  raw_answer = stat6371_agent.run(question)
75
  elif route == "ds7333":
76
+ print("πŸ”Ž DS 7333 RAG")
77
  raw_answer = ds7333_agent.run(question)
78
  else:
79
  print("🧠 General Stats HF Agent")
80
+ result = general_stat_agent(question, max_new_tokens=200, do_sample=False)
81
  raw_answer = result[0]['generated_text']
82
 
83
  print("✍️ Simplifying...")
 
88
  fn=multiagent_system,
89
  inputs=gr.Textbox(lines=2, label="Ask a statistics question"),
90
  outputs=gr.Textbox(label="Answer"),
91
+ title="πŸ“Š Multi-Agent Statistics Assistant",
92
+ description="Routes your stats question to the right syllabus (Stat 6371, DS 7333) or uses a general statistics model (Llama3)."
93
  )
94
 
95
+ iface.launch()